By Oleksiy Grechnyev, CV/ML engineer @It-Jim

Historically, there have been many Deep Learning (DL) frameworks, like Theano, CNTK, Caffe2, and MXNet. Nowadays, they appear to be dead or dying, as just two frameworks heavily dominate the DL scene: Google TensorFlow (TF), which includes Keras; and PyTorch from Meta aka FaceBook. However, there is no reason to believe such a duopoly will persist forever. All the time, new DL frameworks are proposed. We have no idea which DL framework will be popular in, say, ten years.

One of the more serious contenders in the ”DL framework Junior League” is Google JAX. In this article, we examine JAX and look at its positive and negative sides. We will address questions like “When to use JAX?” and “Does JAX have any chance of success?”. But first, why was JAX created? We don’t know exactly, but apparently, AI folk in Google got fed up with TensorFlow and wanted a new toy to fool around with.

To understand JAX, note that Google has not one, but at least two (perhaps more) competing AI teams: Google Brain and DeepMind. They seem to never agree on anything. Even in the TensorFlow era, DeepMind used their own layer API called Sonnet (instead of the usual Keras). Probably nobody outside of DeepMind has ever heard of it. Now history repeats itself with JAX. 

JAX ecosystem consists of the following packages (which are separate PIP packages):

  • JAX: Low-level API (like torch without torch.nn or TF without tf.keras)
  • FLAX (FLexible JAX): Layer API from Google (excluding DeepMind)
  • Haiku: Another layer API, from DeepMind, inspired by Sonnet (TF)
  • OPTAX: Optimizers and loss function for JAX
  • Numerous more specialized packages: Trax, Objax, Stax, Elegy, RLax, Coax, Chex, Jraph, Oryx … . See the JAX ecosystem article.

Note that currently, JAX has no dataset/dataloader API, nor standard datasets like MNIST. Thus you will have to use either TF or PyTorch for these tasks or implement everything yourself.

JAX is open-source, it has pretty good documentation and tutorials. We also recommend AI Epiphany lectures on JAX and Flax. We assume that the reader has basic DL and python knowledge and some experience with either TF or PyTorch.

JAX: Basics, Pytrees, Random Numbers & Neural Networks

JAX Basics and Functional Programming

JAX (the low-level API) has two predecessors:

  • autograd: Numpy-like library with gradients (backprop)
  • Google XLA (accelerated linear algebra): fast matrix operations for CPU, Nvidia GPU and TPU. It compiles stuff into an efficient machine code. It is optional in TensorFlow, but required by JAX.

You can view JAX as “numpy with backprop, XLA JIT, and GPU+TPU support”. You write code like in numpy, but use the prefix jnp. (jax.numpy.) instead of np. . Then your code can run on CPU, GPU, or TPU with no changes. At least, this is the theory. Practice can be a bit harder. GPU installation requires precise versions of CUDA and CUDNN, just like for TensorFlow. It is only practical in Docker. However, unlike TF, JAX has no official docker images yet. And unless you work for Google, you will probably never see a TPU anywhere outside Google Colab.

Apart from the numpy-like API, JAX includes the following main operations:

  • Calculate gradients with jax.grad()
  • Compile python code to XLA with jax.jit()
  • Add batch dimension to a function using jax.vmap() or jax.pmap()

The biggest difference between numpy and JAX is that JAX is heavily into functional programming; thus, JAX arrays (aka tensors) are always immutable. What does “functional programming” mean? It means that the python functions must be “pure”, e.g., behave like mathematical functions. In particular, a function f(x, y, z) is PURE if:

  • It receives input data ONLY through the arguments x, y, z
  • It outputs results ONLY through the return value(s)
  • It does NOT modify objects x, y, z
  • It does NOT access any global variables
  • It does not print anything, does not access the screen, keyboard, any files or devices, or OS API

The function which breaks these rules is not pure, and we say that it has “side effects”. Such functions are not allowed in functional programming.

But what about classes and objects? The vanilla functional programming does not allow any classes. However, it would be highly impractical in Python, where we need classes for objects like multidimensional arrays. Thus JAX makes a compromise: classes and objects are allowed as long as they are strictly immutable: created once and never changed. What does it mean for DeepLearning? It means that any DL object, such as a model (neural net) or optimizer, must be separated into the immutable object (containing the code) and mutable parameters and state. In particular, the following data objects (usually python dictionaries) are separated from the main immutable objects (containing the code):

  • Neural network parameters (which are trained)
  • Neural network state (which is not trained, e.g., BatchNorm state)
  • Optimizer state
  • Random number generator (RNG) state.

Note that all this is very different from e.g. PyTorch, where a model, optimizer and RNG are all mutable objects under the hood, containing their own states and parameters. Moreover, the RNG is global. Functional programming in JAX makes things clearer for an experienced DL engineer, as you don’t have to worry about many ways the objects can be modified. All modifications are always explicit! On the other hand, it can make JAX harder to understand for beginners compared to PyTorch or TF.

How do you work with the immutable JAX arrays? A typical numpy code

a = np.arange(5.)

a[1:3] = [-1., -2.]

will not work in JAX, as the array a is modified in-place. Instead, you will have to write the following:

a = jnp.arange(5.)

a =[1:3].set([-1., -2.])

Here, the object a is not modified but replaced by a new python object.

jax.jit(): Make a Python Function Run Much Faster

Suppose you have a python function my_function. According to JAX tutorials, you can make a python function faster by JIT-compilig it with jax.jit(). Actually, it’s compiling with XLA. Sounds too good to be true? It is.

Of course, magically accelerating any arbitrary python function will be impossible (unless you port it to C++). What’s the catch? Let’s see how exactly jax.jit() works. What happens when you type

fast_function = jax.jit(my_function)


  • my_function is compiled from python to XLA. It’s achieved by tracing, similar to torchscript tracing in PyTorch. Actually, to be precise, the tracing happens when fast_function is called for the first time.
  • Tracing takes significant time, so such “optimization” only makes sense if we are going to call fast_function repeatedly many times without recompiling.
  • XLA is optimized to particular types of input arguments and particular shapes and dtypes of input jnp arrays. If input shape or type changes, fast_function is automatically recompiled, which takes time.
  • Python statements such as if and for are not allowed unless they involve only arguments declared static. If the value of a static argument changes, the function is recompiled.
  • Function my_function is supposed to be pure. In reality, if a side effect like print() is present, it works at the tracing stage, but NOT when running fast_function without recompiling.

Despite all these limitations, JIT can accelerate JAX code significantly when used correctly and is routinely used in most JAX codes.

Note that jax.jit() is often used as a decorator:

def my_function(x):


jax.grad() : Gradients of a Scalar Function

Probably the most important JAX function is jax.grad(), which implements the gradients (backprop), which are a must for neural network training. Minimal example:

def f(x):

         return jnp.sum(x ** 2)

gf = jax.grad(f)

x = jnp.array([1., 2., 3.])

print(f(x), gf(x)) # Prints 14.0 [2. 4. 6.]

Function f() must return a scalar. If it has multiple arguments, jax.grad() differentiates with respect to the first one. jax.grad() also uses tracing, but now if and for statements are allowed. A useful variation called jax.value_and_grad(f) creates a function which returns a tuple (f(x), grad(f)(x)).

jax.vmap() and jax.pmap(): Vectorize a Function along the Batch Dimension

Sometimes you have a function that works on, say, a vector, but you want to make it accept batches (one extra dimension). For example, let’s define a function:

def f(x):                     

         w = jnp.array([[0., 0., 1.], [0., 1., 0.], [1., 0., 0.]])

         return, x)

It works on the shape (3,) only, but not (B, 3), where B is the batch size. You can then transform function f with jax.vmap() to make it batch-compatible:

vf = jax.vmap(f)         # Add a batch dimension to function

x = jnp.array([[1., 2., 3.], [4., 5., 6.]])   # shape (2, 3)

print(f(x))          # Error  ! Wrong shape of x !

print(vf(x))        #  Success :  [[3., 2., 1.], [6., 5,. 4.]]

Note: this function is similar to the numpy-derived function jnp.vectorize(), but the two differ in details.

There is a parallel version called jax.pmap(), which distributes the computation across multiple XLA devices (GPUs, TPUs or CPU cores). Note that while CPU cores are separate devices, GPU cores are not. There is also a rudimentary API for inter-thread communication: psum(), pmean(), pmax(). Unfortunately, jax.pmap() strictly requires that the batch size B must be smaller or equal to the number of XLA devices. It is too stupid to distribute the threads otherwise! Note that if you are running on a CPU, by default JAX uses only one CPU core. To force JAX to use 8 CPU cores, write:
os.environ[‘XLA_FLAGS’] = ‘–xla_force_host_platform_device_count=8’

To check the available devices, write:
print(‘n_devices=’, jax.local_device_count())

print(‘devices=’, jax.devices())

JAX Pytrees

As we mentioned, various parameters and states must be kept separate from the immutable model objects. They are typically kept in a nested structure of python dict() and list() or similar classes. Such objects are called pytrees in JAX. Their leaves (lowest-level nodes) are typically JAX arrays. 

Functions like jax.grad() support pytrees. For example, if the first argument p of a function f(p, x) is a pytree, the gradient with respect to p means a pytree of the same structure as p, consisting of gradients with respect to each leaf in p. This is a routine for neural nets; if f(p, x) is a JAX model, then the first argument p is typically a pytree of the network parameters (which we train).

There are a couple of useful functions to work with pytrees. jax.tree_map() applies a unary function to each node in a pytree, generating a new pytree of the same structure. For example, to print the pytree of the shapes of all parameters in the pytree t, type:

print(jax.tree_map(lambda x: x.shape, t))

A similar function jax.tree_multimap() applies a binary operation to two pytrees of the same structure. For example, the “sum of two trees t1 and t2” is given by:

jax.tree_multimap(lambda x, y: x+y, t1, t2)

JAX Random Numbers

Random numbers in JAX can confuse people who are used to numpy or PyTorch. Because of the functional programming paradigm, stateful or global random number generators (RNGs) are not allowed in JAX. How do you implement an RNG without a mutable state? Here we see the first example of how code and state are separated in JAX. Essentially the same logic applies to other objects like models and optimizers.

The function jax.random.normal() requires a “key”, which is the RNG state. You can create the key from a random seed like this (RNG initialization):

key = jax.random.PRNGKey(2022)

Then you can create a random array like this:

print(jax.random.normal(key, (2,)))

Everything works, right? Not really! If you put the same statement again, you will get exactly the same result:

print(jax.random.normal(key, (2,)))

This is the blessing and the curse of functional programming. Everything is explicit, predictable, clear, and immutable. In this case, the same state key results in the same random array. 

How to generate random numbers properly? For that you need a function jax.random.split(), which generates two (or more) keys from the input key. Each key must be used only once (strictly !). Every time you need a random number, you write the following code:

key1, key = jax.random.split(key)            # Split key1, update key
print(jax.random.normal(key1, (2,)))          # Use key1 only once !

Don’t forget to split a new key (key1) every time you generate a new random number, because you can use each key only once! Alternatively, you can split multiple keys at once:

key1, key2, key3, key = jax.random.split(key, 4)   # Generate 3 keys, update key

or even a list of keys:

*keys, key = jax.random.split(key, 10 + 1)   # Generate a list of 10 keys, update key

One more thing: if a neural net requires an RNG for inference (e.g., it has dropout), the RNG key must be supplied explicitly at the inference time. You will see examples of this below.

Side note: of course, nothing stops you from using the numpy RGN and converting the result to JAX, but it is considered a bad style among JAX developers.

A Minimal Neural Net in Pure JAX

Let’s use our knowledge to code a trivial neural net in pure JAX. We want to implement a linear regression. Let’s create our data, a linear function plus some random noise:

n = 101

xx = jnp.linspace(-1, 1, n)

key_noise, key = jax.random.split(key)

yy = 3 * xx – 1 + 0.2 * jax.random.normal(key_noise, (n,))

Now we create a linear model with two parameters:

def model(params, x):

         return params[0] * x + params[1]

Note how the first argument (which we usually take a gradient over) is a pytree of the optimizable network parameters (just a size-2 list in this case).

Next, we define a loss function, compile it with jax.jit(), and calculate its gradient (with respect to params):


def loss_fun(params, x, y):

         pred = model(params, x)

         return jnp.mean((y – pred) ** 2)

vgl = jax.value_and_grad(loss_fun)

Finally, we initialize the parameters and perform the training loop:

params = [1., 1.]

lr = 0.1

for i in range(100):

         loss, grad = vgl(params, xx, yy)

         params = jax.tree_multimap(lambda p, g: p – lr*g, params, grad)

         print(i, loss)

Note how we use jax.tree_multimap() to update our parameters (the vanilla SGD optimizer). The result looks like this:


This is going to be the shortest chapter of this article. In the previous example, we implemented vanilla SGD using jax.tree_multimap(). But we know it is better to use more advanced optimizers like Adam or SGD+momentum. Here, OPTAX comes to the rescue. Let’s see how we can modify the previous example using OPTAX. Since this is JAX, we must create an optimizer object (immutable, code-only) and then a state for it: 

params = [1., 1.]

lr = 0.1

optimizer = optax.adam(learning_rate=lr)   # Create the optimizer

opt_state = optimizer.init(params)        # Init optimizer state

Next, we rewrite the training loop by using the optimizer:

for i in range(100):

         loss, grad = vgl(params, xx, yy)

         upd, opt_state = optimizer.update(grad, opt_state)  # Optimizer step

         params = optax.apply_updates(params, upd)       # Basically params + upd

         print(i, loss)

First, the method optimizer.update() calculates the updates plus the new optimizer state. Second, we add the updates to the params using optax.apply_updates(), which is basically just a sum of two pytrees using jax.tree_multimap() under the hood.

Note that we use the same OPTAX optimizers regardless of whether our model is written in pure JAX, FLAX, or Haiku. Apart from optimizers, OPTAX also contains several loss functions and schedulers.


FLAX: Basics

FLAX (FLexible JAX) is a layer API for JAX created by Google (DeepMind excluded). It plays a role similar to Keras in TF or torch.nn in PyTorch. We are going to use the modern flax.linen API, typically imported as nn; the old API flax.nn is deprecated and removed!

Let’s create a FLAX model of a single linear (aka FC aka Dense) layer:

model = nn.Dense(features=3)

This creates an (immutable) model object, but we also have to create and initialize the model parameters. For that, you need two things: a single-use random key key_init, and a sample input x (to specify input shape):

x = jnp.ones((4, 2))                            #  Sample input: batch_size=4, dim=2

key_init, key = jax.random.split(key)

params = model.init(key_init, x)     

Now, let’s print the parameter shapes:

print(‘params:’, jax.tree_map(lambda p: p.shape, params))

The output looks like this:

params: FrozenDict({

         params: {

                  bias: (3,),

                  kernel: (2, 3),

         }, })

What on earth is a “FrozenDict”? Basically, it’s an immutable python dict, defined in FLAX and registered with the JAX pytree ecosystem (which allows registering custom collection types). FLAX models prefer FrozenDict, but they can take python dict as well.

If the dictionary params consist of model parameters, why does it have a subdictionary named “params”? We’ll see it in a moment.

To run a model inference (or training) on a single input x, you type:

y = model.apply(params, x)

Note that you cannot use the parentheses operator! If the model requires a random key (e.g., for dropout), you’ll have to supply it as well:
y = model.apply(params, x, rngs={‘dropout’: key_do})

FLAX Models

But how can we create a FLAX model of more than one layer? There are several options. First, we can use sequential models (since FLAX 0.4.1):
model = nn.Sequential([





Note that nn.Dense() is a FLAX model object, while nn.relu is a function (with no parameters), nn.Sequential() supports both. 

For more serious models we’ll have to inherit the nn.Module class. Note that this class is a python dataclass (read about them if you didn’t already, they are fun):

class GoblinModel(nn.Module):

         feat1: int

         feat2: int


         def setup(self):                                  # This is called when init() is called

                  self.d1 = nn.Dense(self.feat1)       # A Submodule is registered

                  self.d2 = nn.Dense(self.feat2)


         def __call__(self, x):

                  x = jax.nn.relu(self.d1(x))             # Note: no apply(), no params !

                 return self.d2(x)

A dataclass defines several strictly-typed features (feat1, feat2 in our case) and automatically creates a constructor for them. That is why we should not define an explicit constructor in a dataclass, and the method setup() is used instead. It is called when you call the init() method of the model. Now we can create a model instance as usual followed by the initialization:
model = GoblinModel(5, 3)  # (feat1, feat2)

x = jnp.ones((4, 7))               #  Sample input: batch_size=4, dim=7

key_init, key = jax.random.split(key)

params = model.init(key_init, x)

A lot of magic happens under the hood when we call init(). It calls setup() and all submodules defined in setup() are registered, i.e. they are added to the parameter dictionary and initialized. Also init() handles random number keys under the hood and automatically generates a new single-use key for initializing each layer.

Note how in the __call__() method the submodules are called directly without giving them any parameters. When we run the inference, we actually call apply() and not __call__(), and the former method handles all parameters and passes all submodule parameters to the respective submodules.

However, the syntax with setup() is  somewhat cumbersome, thus people often use the decorator nn.compact() (a lot of black magic happens in this function) to define the submodules directly in __call__():

class OrcModel(nn.Module):

         feat1: int

         feat2: int



          def __call__(self, x):

                  x = nn.Dense(self.feat1)(x) # Layer    

                  x = jax.nn.relu(x)                # Function

                  return nn.Dense(self.feat2)(x)


A CNN example (CNNs are popular in computer vision) is not much harder:

class DwarfCNN(nn.Module):


         def __call__(self, x):

                  x = nn.Conv(features=32, kernel_size=(3, 3))(x)    # Layer

                  x = nn.relu(x)                                                         # Function

                  x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

                  x = nn.Conv(features=64, kernel_size=(3, 3))(x)

                  x = nn.relu(x)

                  x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

                  return x

FLAX models (custom layers)

So far, we created models out of standard FLAX layers. But how do we create a custom one? Here is an example. We use the method param() to define parameters:
class ElfLinear(nn.Module):

         feat: int

         w_init: typing.Callable = nn.initializers.lecun_normal()

         b_init: typing.Callable = nn.initializers.zeros



         def __call__(self, x):

                  w = self.param(‘w’, self.w_init, (x.shape[-1], self.feat))

                  b = self.param(‘b’, self.b_init, (self.feat,))

                  return x @ w + b

Due to the nn.compact() magic, we can declare the parameters directly in __call__() using the param() method. We need to supply an initializer and the shape. The latter is (in our example) derived from the input x, the sample input provided at the initialization. 

The Problem of State

But what if our model has a state (variables that are NOT trainable parameters)? For example, BatchNorm and related ..Norm layers have a state, and so do all models containing a BatchNorm as a submodule. This is where FLAX gets a bit awkward, in our opinion. We initialize the model as usual:

vars0 = model.init(key_init, x)

However, now the FrozenDict vars0 contain not only parameters (in the section params), but also state variables in other sections. Now we have to separate the two by hand:

state, params = vars0.pop(‘params’)

It is very important that the optimizer optimizes only parameters (params) and not vars0! In other words, we use params when initializing the optimizer and performing the optimizer updates. 

Before running apply(), we recombine parameters and state back into the dictionary vars, and use the form of apply() which updates the state. If we don’t want to update the state (a frozen BatchNorm at testing stage), we use the regular apply() instead:

vars = {‘params’: params, **state}

pred, state = model.apply(vars, x, mutable=state.keys())

Once again:

  • We must separate all model variables into state and param.
  • param (Network parameters) are optimized by the optimizer (via backprop).
  • state (Network state, e.g. batchnorm state) is updated in apply() during training, but is frozen during testing.

Do you think handling all the different parameters and states is awkward? We agree. However, for common training scenarios, FLAX provides a higher-level API FLAX TrainState, which combines model, parameters, and optimizer together. You can try it if you want. This is the highest possible level DL API, like fit() in Keras, or PyTorch Lightning.


Haiku basics

Haiku is another layer API from DeepMind. Compared to FLAX, it is even purer functional programming. In FLAX, a lot of magic was buried in the nn.Module class and nn.compact() function. In contrast, in Haiku a model class does not matter. There is a class hk.Module, but it’s a thin submodule container that does almost nothing, and you don’t have to use it at all if you don’t want to. All elven magic happens in the function hk.transform() and its variations. Let’s see how it all works.

A neural net is always defined as a function (similar to the functional definition in Keras), for example:

def forward(x):

         return hk.Linear(3)(x)

Note that you define a Haiku module hk.Linear inside the function and do not provide any initialization data like input shape or a RNG key. Such a function will not work directly! Instead, you must transform it like this.

model = hk.transform(forward)

This step is similar to creating a functional Keras model (tf.keras.Model) in TensorFlow.

Now we get a transformed model object model. It works sort of like the FLAX module, but we never create such objects explicitly, only via hk.transform(). We still have to initialize it:

params = model.init(key_init, jnp.zeros((5, 2)))

Initialization is pretty much identical to the one in FLAX.

To run the model, we call apply():

y = model.apply(params, key_apply, x)

Note that in Haiku you must supply a RNG key in apply(), whether or not your model actually needs it (e.g., has dropout layers). If you want to get rid of this key, use an extra transformation:

model = hk.without_apply_rng(hk.transform(forward))

params = model.init(key_init, jnp.zeros((5, 2)))

y = model.apply(params, x)

A special version of hk.transform() is used when your model has a state (e.g. BatchNorm state):
model = hk.transform_with_state(forward)

params, state = model.init(key_init, x)                    # Init params + state

y, state = model.apply(params, state, key_apply, x)  # Apply model, update state

Only params are optimized. Note how Haiku separates params and state automatically, while in FLAX you had to do it by hand.

Haiku models

How do you define a model in Haiku? First, you can inherit hk.Module:

class GoblinMLP(hk.Module):

         def __init__(self, name=’goblin_mlp’):

                  super(GoblinMLP, self).__init__(name=name)

                  self.l1 = hk.Linear(5)

                  self.l2 = hk.Linear(3)


         def __call__(self, x):

                  x = jax.nn.relu(self.l1(x))

                 return self.l2(x)

You’ll still have to define a forward function (or lambda):

def forward(x):

         return GoblinMLP()(x)

But it’s Haiku, so you don’t have to use the model class if you don’t want to. How about this:

def forward_goblin(x):

         x = hk.Linear(5)(x)

         x = jax.nn.relu(x)

         return hk.Linear(3)(x)

It works, you can actually register layers in the forward function, without using the hk.Module class at all!

If you are writing a custom module, use hk.get_parameter() to register network parameters:

class GoblinLinear(hk.Module):

         def __init__(self, osize, name=’goblin_linear’):

                  super(GoblinLinear, self).__init__(name=name)

                  self.osize = osize


         def __call__(self, x):

                  n_in, n_out = x.shape[-1], self.osize

                  w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(n_in))

                  w = hk.get_parameter(‘w’, shape=[n_in, n_out], dtype=x.dtype, init=w_init)

                  b = hk.get_parameter(‘b’, shape=[n_out], dtype=x.dtype, init=jnp.ones)

                  return, w) + b

You can even use parameters directly in forward(), without any hk.Module.

Haiku has a couple of further nice things. You can define an MLP (multi-layer dense network) concisely (it uses ReLU by default, this can be changed):
hk.nets.MLP([20, 20, 1])

There is also a number of standard architectures (but alas no pre-trained weights for them):


hk.nets.ResNet18, 34, 50, 101, 152, 200

So, Can JAX Succeed?

We don’t know it. But here are the good and bad sides of JAX, in our opinion (as compared to PyTorch and TensorFlow):

The good:

  • JAX is very TPU friendly and has built-in support for multiple devices.
  • Functional programming makes things a bit cleaner (but only for pros).
  • The weight of Google behind it should matter.

The bad:

  • It is still in the 0.x versions, the API might change.
  • Functional programming can be annoying for beginners.
  • Apart from the TPU, there are few real advantages over PyTorch (or TF).
  • There are very few deploy options. Currently, there is only the experimental jax2tf converter. No ONNX, tflite or TensorRT.
  • There is no dataset/dataloader API.
  • There are still very few existing or pre-trained models (but see flaxmodels).

Have fun and enjoy JAX! And if we you are more into video content, we have a lecture on JAX on our YouTube channel:

P.S. If the author were Google, he would create some nice deployment system for JAX, like tflite, but based on XLA, with versions for C++, Android, iOS, embedded and Web browser.

JAX: Can It Beat PyTorch and TensorFlow?
Tagged on: