Jax Basic

Sources:

  1. JAX Quickstart
  2. Learning JAX as a PyTorch developer

Introdction to Jax

-->Source: Tutorial 2 (JAX): Introduction to JAX+Flax

Why should you learn JAX, if there are already so many other deep learning frameworks like PyTorch and TensorFlow? The short answer: because it can be extremely fast. For instance, a small GoogleNet on CIFAR10, which we discuss in detail in Tutorial 5, can be trained in JAX 3x faster than in PyTorch with a similar setup

However, everything comes with a price. In order to efficiently compile programs just-in-time in JAX, the functions need to be written with certain constraints. Firstly, the functions are not allowed to have side-effects, meaning that they are not allowed to affect any variable outside of their namespaces. For instance, in-place operations affect a variable even outside of the function. Moreover, stochastic operations such as torch.rand(...) change the global state of pseudo random number generators, which is not allowed in functional JAX (we will see later how JAX handles random number generation). Secondly, JAX compiles the functions based on anticipated shapes of all arrays/tensors in the function. This becomes problematic if the shapes or the program flow within the function depends on the values of the tensor. For instance, in the operation y = x[x>3], the shape of y depends on how many values of x are greater than 3. We will discuss more of these constraints in this notebook. Still, in most common cases of training neural networks, it is straightforward to write functions within these constraints.

Jax arraies

  • JAX provides a NumPy-inspired interface for convenience.
  • Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.
  • Unlike NumPy arrays, JAX arrays are always immutable.

In numpy, arries are mutable:

1
2
3
4
5
6
7
8
9
size = 10
index = 0
value = 23

# In NumPy arrays are mutable
x = np.arange(size)
print(x)
x[index] = value
print(x)

output:

1
2
[0 1 2 3 4 5 6 7 8 9]
[23 1 2 3 4 5 6 7 8 9]

However, JAX arrays are immutable.

1
2
3
4
x = jnp.arange(size)
print(x)
x[index] = value
print(x)

Output:

1
2
3
[0 1 2 3 4 5 6 7 8 9]
...
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another

Solution is:

1
y = x.at[index].set(value)

This will create and return a new array with the modified values. You can test it via

1
print(x is y)

This will return False.

Random numbers

JAX uses an explicit approach to generating randomn numbers: you have to provide a PRNG key. So instead of torch.randn(shape), you have jax.random.normal(key, shape).

This is actually a JAX superpower! Explicitly threading the random state like this means you get trivially reproducible behaviour.

numpy's PRNG is stateful:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# NumPy - PRNG is stateful!

# Let's sample calling the same function twice
print(np.random.random())
print(np.random.random())

np.random.seed(seed)

rng_state = np.random.get_state()
print(rng_state[2:])

_ = np.random.uniform()
rng_state = np.random.get_state()
print(rng_state[2:])

_ = np.random.uniform()
rng_state = np.random.get_state()
print(rng_state[2:])

# Mersenne Twister PRNG is known to have a number of problems (NumPy's imp of PRNG)

Output:

1
2
3
4
5
0.6027633760716439
0.5448831829968969
(624, 0, 0.0)
(2, 0, 0.0)
(4, 0, 0.0)

The rng_state changes each time.

However, JAX's random functions can't modify PRNG's state!

1
2
3
4
5
6
7
8
9
key = random.PRNGKey(seed)
print(key) # key defines the state (2 unsigned int32s)

# Let's again sample calling the same function twice
print(random.normal(key, shape=(1,)))
print(key) # verify that the state hasn't changed

print(random.normal(key, shape=(1,))) # oops - same results?
print(key)

Output:

1
2
3
4
5
[0 0]
[-0.20584226]
[0 0]
[-0.20584226]
[0 0]

Instead of a mutable global state, JAX uses explicit PRNG keys (key = random.PRNGKey(seed)) to control randomness.

So we need to split every time you need a pseudorandom number.

1
2
3
4
5
6
7
8
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)

# Note1: you can also split into more subkeys and not just 1
# Note2: key, subkey no difference it's only a convention

Output:

1
2
3
old key [0 0]
\---SPLIT --> new key [4146024105 967050713]
\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]

JAX requires explicit PRNG keys for generating random numbers to ensure reproducibility and to fit with its functional programming model. A key concept in JAX's PRNG system is that every call to a random number generator consumes a PRNG key and produces a new key along with the random output. This system is designed to avoid hidden state and side effects, adhering to functional programming principles.

Give code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# NumPy
np.random.seed(seed)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(seed)
print("all at once: ", np.random.uniform(size=3))

# JAX

# When generating random numbers "individually" by splitting the original key into subkeys and using each subkey to generate one random number,
# you're ensuring that each random number is generated with a unique, independent piece of the PRNG state.
key = random.PRNGKey(seed)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

# Generating random numbers "all at once" with a single call using the original key does not split the key into subkeys for each number.
# Instead, it generates a sequence of numbers in one go, based on the state encapsulated by that single key:
key = random.PRNGKey(seed)
print("all at once: ", random.normal(key, shape=(3,)))

# NumPy violates 3)

Output:

1
2
3
4
individually: [0.5488135  0.71518937 0.60276338]
all at once: [0.5488135 0.71518937 0.60276338]
individually: [1.1188384 0.5781488 0.8535516]
all at once: [ 1.8160863 -0.48262316 0.33988908]

As you can see, there is difference in the JAX output for generating random numbers "individually" vs. "all at once".

Device agnostic

JAX is AI accelerator agnostic. Same code runs everywhere!

In JAX, many operations are executed asynchronously.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import jax.numpy as jnp
import numpy as np
from jax import random

size = 3000

# Data is automagically pushed to the AI accelerator! (DeviceArray structure)
# No more need for ".to(device)" (PyTorch syntax)
x_jnp = random.normal(key, (size, size), dtype=jnp.float32)
x_np = np.random.normal(size=(size, size)).astype(np.float32) # some diff in API exists!

%timeit jnp.dot(x_jnp, x_jnp.T).block_until_ready() # 1) on GPU - fast
%timeit np.dot(x_np, x_np.T) # 2) on CPU - slow (NumPy only works with CPUs)
%timeit jnp.dot(x_np, x_np.T).block_until_ready() # 3) on GPU with transfer overhead

x_np_device = device_put(x_np) # push NumPy explicitly to GPU
%timeit jnp.dot(x_np_device, x_np_device.T).block_until_ready() # same as 1)

# Note1: I'm using GPU as a synonym for AI accelerator.
# In reality, especially in Colab, this can also be a TPU, etc.

# Note2: block_until_ready() -> block the calling thread (i.e., pause the execution of your Python code) until the computation that produced the DeviceArray is complete

Variables initiazlied via numpy are stored in CPU, initiazlied via jax are stored in GPU (or other accelerators).

1
2
3
4
1.42 ms ± 255 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
207 ms ± 20.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
81.5 ms ± 2.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.44 ms ± 5.29 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

jit()

  • By default JAX executes operations one at a time, in sequence.
  • Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.
    • How to activate jit: Use @jit or within functions like lax.fori_loop, lax.scan, lax.cond, etc.)
  • Not all JAX code can be JIT compiled, as it requires array shapes abd types to be static & known at compile time.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Define a function
def selu(x, alpha=1.67, lmbda=1.05): # note: SELU is an activation function
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jit(selu) # let's jit it


# Benchmark non-jit vs jit version
data = random.normal(key, (1000000,))

print('non-jit version:')
%timeit selu(data).block_until_ready()
print('jit version:')
%timeit selu_jit(data).block_until_ready()

Output:

1
2
3
4
non-jit version:
1.21 ms ± 307 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit version:
243 µs ± 12.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Another usage:

1
2
3
4
5
6
7
8
9
10
11
12
def norm(X):
X = X - X.mean(0)
return X / X.std(0)

norm_compiled = jit(norm) # jit

X = random.normal(key, (10000, 100), dtype=jnp.float32)

assert np.allclose(norm(X), norm_compiled(X), atol=1E-6)

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()

Output:

1
2
584 µs ± 48 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
267 µs ± 24.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Arries to be jitted must be static:

1
2
3
4
5
6
7
8
9
10
# Example of a failure: array shapes must be static

def get_negatives(x):
# x < 0 produces a boolean array where each element is True if the corresponding element in x is less than 0, and False otherwise.
# x[x < 0] uses boolean array indexing to select and return only the elements of x where the condition x < 0 is True. In other words, it filters out all the positive elements and zeroes, returning only the negatives.
# This will change the shape of the array.
return x[x < 0]

x = random.normal(key, (10,), dtype=jnp.float32)
print(get_negatives(x))

Output:

1
2
[-0.3721109  -0.18252768 -0.7368197  -0.44030377 -0.1521442  -0.67135346
-0.5908641 ]

Now try to jit it

1
print(jit(get_negatives)(x))

you'll get a failure:

1
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])

grad()

vmap()

vmap() makes the batch space transparent.

Given:

1
2
3
4
5
W = random.normal(key, (150, 100))  # e.g. weights of a linear NN layer
batched_x = random.normal(key, (10, 100)) # e.g. a batch of 10 flattened images

def apply_matrix(x):
return jnp.dot(W, x) # (150, 100) * (100, 1) -> (150, 1)

We want to do dot product in parallell.

Naively batched

The easiest way is using a for loop to iterate over the batched input and stack the results into one batch:

1
2
3
4
5
def naively_batched_apply_matrix(batched_x):
return jnp.stack([apply_matrix(x) for x in batched_x])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Output:

1
2
Naively batched
2.71 ms ± 141 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Manually batched

Moreover, we can do the computation in matrix form without the need of iterations.

1
2
3
4
5
6
@jit
def batched_apply_matrix(batched_x):
return jnp.dot(batched_x, W.T) # (10, 100) * (100, 150) -> (10, 150)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Output:

1
2
Manually batched
256 µs ± 12 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Auto-vectorized with vmap

Jax provided vmap() to make batch space become transparent.

1
2
3
4
5
6
@jit  # Note: we can arbitrarily compose JAX transforms! Here jit + vmap.
def vmap_batched_apply_matrix(batched_x):
return vmap(apply_matrix)(batched_x) # The batch space is transparent!

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

lax()

  • jax.numpy is a high-level wrapper that provides a familiar interface.
  • jax.lax is a lower-level API that is stricter and often more powerful.
  • All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler.

lax is stricter than jax.numpy:

1
2
3
4
print(jnp.add(1, 1.0))  # jax.numpy API implicitly promotes mixed types

# This will throw an error.
print(lax.add(1, 1.0)) # jax.lax API requires explicit type promotion.

JIT mechanics: tracing and static variables

  • JIT and other JAX transforms work by tracing a function to determine its effect on inputs of a specific shape and type.
  • Variables that you don’t want to be traced can be marked as static
  • Tracing is recording the computation with zero FLOP.

Traced values

Jax operations, such as jax.jit , will convert values to traced values. These tracer objects are what jax.jit uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the shape and dtype of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.

When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:

To use jax.jit effectively, it is useful to understand how it works. Let’s put a few print() statements within a JIT-compiled function and then call the function:

1
2
3
4
5
6
7
8
9
10
11
12
@jit
def f(x, y):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
result = jnp.dot(x + 1, y + 1)
print(f" result = {result}")
return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Output:

1
2
3
4
5
Running f():
x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
Array([0.25773212, 5.3623195 , 5.403243 ], dtype=float32)

Notice that the print statements execute, but rather than printing the data we passed to the function, though, it prints tracer objects that stand-in for them.

1
2
3
4
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)

For this reason, the side-effect will only happen once!

JAX expressions

The extracted sequence of operations is encoded in a JAX expression, or jaxpr for short. You can view the jaxpr using the jax.make_jaxpr transformation:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from jax import make_jaxpr

def f(x, y):
return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
c:f32[3,4] = add a 1.0
d:f32[4] = add b 1.0
e:f32[3] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] c d
in (e,) }

Note one consequence of this: because JIT compilation is done without information on the content of the array, control flow statements in the function cannot depend on traced values. For example, this fails:

1
2
3
4
5
@jit
def f(x, neg):
return -x if neg else x

f(1, True)

Output:

1
2
3
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_5466/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Static values

If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation:

1
2
3
4
5
6
7
8
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
return -x if neg else x

f(1, True)
Array(-1, dtype=int32, weak_type=True)

Note that calling a JIT-compiled function with a different static argument results in re-compilation, so the function still works as expected:

1
2
f(1, False)
Array(1, dtype=int32, weak_type=True)

Understanding which values and operations will be static and which will be traced is a key part of using jax.jit effectively.

Static vs Traced Operations

  • Just as values can be either static or traced, operations can be static or traced.
  • Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.
  • Use numpy for operations that you want to be static; use jax.numpy for operations that you want to be traced.

This distinction between static and traced values makes it important to think about how to keep a static value static. Consider this function:

1
2
3
4
5
6
7
8
9
import jax.numpy as jnp
from jax import jit

@jit
def f(x):
return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)

Output:

1
2
3
4
5
6
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_5466/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /tmp/ipykernel_5466/1983583872.py:6 (f)

This fails with an error specifying that a tracer was found instead of a 1D sequence of concrete values of integer type. Let’s add some print statements to the function to understand why this is happening:

1
2
3
4
5
6
7
8
9
@jit
def f(x):
print(f"x = {x}")
print(f"x.shape = {x.shape}")
print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
# comment this out to avoid the error:
# return x.reshape(jnp.array(x.shape).prod())

f(x)

Output:

1
2
3
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>

Notice that although x is traced, x.shape is a static value. However, when we use jnp.array and jnp.prod on this static value, it becomes a traced value, at which point it cannot be used in a function like reshape() that requires a static input (recall: array shapes must be static).

A useful pattern is to use numpy for operations that should be static (i.e. done at compile-time), and use jax.numpy for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this:

1
2
3
4
5
6
7
8
9
10
from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
return x.reshape((np.prod(x.shape),))

f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)

For this reason, a standard convention in JAX programs is to import numpy as np and import jax.numpy as jnp so that both interfaces are available for finer control over whether operations are performed in a static matter (with numpy, once at compile-time) or a traced manner (with jax.numpy, optimized at run-time).

Pure function

Side effects of jitted functions will only happen once.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# So how does it work in the background? -> tracing on different levels of abstraction

@jit
def f(x, y):
print("Running f():")
print(f" x = {x}") # (3, 4)
print(f" y = {y}") # (4) --> (4, 1)
result = jnp.dot(x + 1, y + 1)
print(f" result = {result}") # (3, 1) --> (3)
return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
print(f(x, y))

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
print('Second call:')
print(f(x2, y2)) # Oops! Side effects (like print) are not compiled...

# Note: any time we get the same shapes and types we just call the compiled fn!

Output:

1
2
3
4
5
6
7
Running f():
x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
[3.1814232 4.1241536 7.560952 ]
Second call:
[-0.05702817 4.15776 8.624691 ]

JAX's JIT compilation process traces the function to generate a compiled version. During this tracing, the function is executed on abstract values that represent the shapes and types of the inputs, not their actual values. This means that conditionals based on the values of inputs (like neg in this case) cannot be resolved during compilation because the tracer doesn't know the value of neg ahead of time.

In JAX, inputs to JIT-compiled functions are considered "dynamic" by default, meaning their values can change between calls, and JAX expects their shapes and types to determine how to compile the function. However, when a function's behavior depends on a value (like a flag that changes its logic path), you need to tell JAX to treat this input as "static" so that it can compile different versions of the function for different values of these static inputs.

This can't be jitted.

1
2
3
4
5
@jit
def f(x, neg): # depends on the value - remember tracer cares about shapes and types!
return -x if neg else x

f(1, True)

JAX re-runs the Python function when the type or shape of the argument changes. This will end up reading the latest value of the global.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Example 2

g = 0.

def impure_uses_globals(x):
return x + g # Violating both #1 and #2

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))

# Let's update the global!
g = 10.

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

For loop

However, iterators in Python are stateful objects. The progression of an iterator (e.g., via next(iterator)) cannot be statically determined by JAX's tracing mechanism. When JAX tries to compile the function, it cannot predict or analyze the behavior of next(iterator) because it involves mutable state that changes as the iterator progresses.

As a result, JAX's static analysis tools fail to incorporate the iterator's behavior correctly into the compiled code. Depending on the context and how JAX's tracer interacts with the iterator, this can lead to unexpected results, such as the operation seemingly not being executed at all, hence the "unexpected result 0".

1
2
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

Output:

1
0

Out-of-Bounds Indexing

Due to JAX's accelerator agnostic approach JAX had to make a non-error behaviour for out of bounds indexing (similarly to how invalid fp arithmetic results in NaNs and not an exception).

1
2
3
4
5
6
# NumPy behavior

try:
np.arange(10)[11]
except Exception as e:
print("Exception {}".format(e))

This will throw en error.

But jax arraies won't throw exceptions:

1
2
3
4
5
6
7
# JAX behavior
# 1) updates at out-of-bounds indices are skipped
# 2) retrievals result in index being clamped
# in general there are currently some bugs so just consider the behavior undefined!

print(jnp.arange(10).at[11].add(23)) # example of 1)
print(jnp.arange(10)[11]) # example of 2)

Non-array inputs

This will throw an error:

1
2
3
4
try:
jnp.sum([1, 2, 3])
except TypeError as e:
print(f"TypeError: {e}")

Output:

1
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

Use jnp.array() to convert list to convert.

1
2
3
4
5
def permissive_sum(x):
return jnp.sum(jnp.array(x))

x = list(range(10))
print(make_jaxpr(permissive_sum)(x))

PyTree

When you do something like this:

1
2
3
4
5
6
7
x = [array1, array2]

@jax.jit
def f(x):
...

f(x)

then JAX knows to unpack the x argument (which is a list) in order to find the arrays array1 and array2. This is how it can replace arrays with tracers when JIT compiling – and this unpacking is also how JAX can find the arrays to create gradients for when using jax.grad, or the arrays to vectorise when using jax.vmap, and so on.

The objects that JAX knows how to unpack are lists, tuples, dictionaries, user-registered custom nodes (we’ll come back to these), and any arbitrarily-nested collection of these. Overall, these are called “PyTrees”. It’s typical to represent a model as some PyTree of its parameters.

It’s almost always a mistake to use raw Python classes with JAX. These aren’t PyTrees, so JAX can’t know how you want to handle these unless you tell it how to. As such, we generally register classes as custom pytree nodes. (One way to do this is by subclassing equinox.Module, see below.)

This is directly analogous to torch.nn.Module, and how you must use self.foo = torch.nn.ModuleList(...) rather than self.foo = [...]. You can’t use raw Python classes with PyTorch either. The only important difference is that PyTorch Modules are treated as directed acyclic graphs (=the same Module can appear in multiple places), whilst JAX PyTrees are treated as, well, trees (=multiple appearances of the same object are treated as independent copies).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# A contrived example for pedagogical purposes
# (if your mind needs to attach some semantics to parse this - treat it as model params)
pytree_example = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]

# Let's see how many leaves they have:
for pytree in pytree_example:
leaves = jax.tree.leaves(pytree) # handy little function
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

Output:

1
2
3
4
5
[1, 'a', <object object at 0x7f9cb0f032e0>]   has 3 leaves: [1, 'a', <object object at 0x7f9cb0f032e0>]
(1, (2, 3), ()) has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32) has 1 leaves: [Array([1, 2, 3], dtype=int32)]

How do we manipulate PyTrees?

Use tree_map(). It iterates through leaves and applies the lambda function.

1
print(jax.tree_map(lambda x: x*2, list_of_lists)) 

Output:

1
[{'a': 6}, [2, 4, 6], [2, 4], [2, 4, 6, 8]]

Another example:

1
2
another_list_of_lists = list_of_lists
print(jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists))

Output:

1
2
another_list_of_lists = list_of_lists
print(jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists))

PyTrees need to have the same structure if we are to apply tree_multimap. So following code will fail:

1
2
3
another_list_of_lists = deepcopy(list_of_lists)
another_list_of_lists.append([23])
print(jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists))

Output:

1
ValueError: List arity mismatch: 5 != 4; list: [{'a': 3}, [1, 2, 3], [1, 2], [1, 2, 3, 4], [23]].

Weired things

NaN behavior

The default non-error behavior will simply return a NaN (as usually).

1
2
3
4
5
6
jnp.divide(0., 0.)  # the default non-error behavior will simply return a NaN (as usually)

# If you want to debug where the NaNs are coming from, there are multiple ways
# to do that, here is one:
from jax import config
config.update("jax_debug_nans", True)

State

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 1) We've seen in the last notebook/video that impure functions are problematic.

g = 0. # state

# We're accessing some external state in this function which causes problems
def impure_uses_globals(x):
return x + g

# JAX captures the value of the global/state during the first run
print ("First call: ", jit(impure_uses_globals)(4.))

# Let's update the global/state!
g = 10.

# Subsequent runs may silently use the cached value of the globals/state
print ("Second call: ", jit(impure_uses_globals)(5.))

Output:

1
2
First call:  4.0
Second call: 5.0

Stateful --> Stateless

In summary, we use the following rule to convert a stateful class:

1
2
3
4
5
class StatefulClass

state: State

def stateful_method(*args, **kwargs) -> Output:

into a class of the form:

1
2
3
class StatelessClass

def stateless_method(state: State, *args, **kwargs) -> (Output, State):

The member functions of Clock are stateful. For instance, Function count(self) leverages self.n, which is external to count(self) and is stateful to it (Well, I think self.n is not external since its a parameter. But IDK)

Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Let's now explictly address and understand the problem of state!
# Why?
# Well, NNs love statefulness: model params, optimizer params, BatchNorm, etc.
# and we've seen that JAX seems to have a problem with it.

class Counter:
"""A simple counter."""

def __init__(self):
self.n = 0

def count(self) -> int:
"""Increments the counter and returns the new value."""
self.n += 1
return self.n

def reset(self):
"""Resets the counter to zero."""
self.n = 0


counter = Counter()

for _ in range(3): # works like a charm
print(counter.count())

Output:

1
2
3
1
2
3

The jitted function is stateful!

1
2
3
4
5
counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3): # oops, it's not working as it's supposed to be
print(fast_count())

Output:

1
2
3
1
1
1

The solution is:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Solution:

CounterState = int # our counter state is implemented as a simple integer

class CounterV2:

def count(self, n: CounterState) -> Tuple[int, CounterState]:
# You could just return n+1, but here we separate its role as
# the output and as the counter state for didactic purposes.
# (as the output may be some arbitrary function of state in general case)
return n+1, n+1

def reset(self) -> CounterState:
return 0

counter = CounterV2()
state = counter.reset() # notice how reset() now returns state (external vs internal imp)

for _ in range(3): # works like a charm pre-jit, let's see whether the jit version works
value, state = counter.count(state) # looks familiar?
print(value)

Output:

1
2
3
1
2
3

Gotcha: vmap(jax.lax.cond) evaluates both branches

JAX has jax.numpy.where(pred, a, b) to do an if statement between two arrays a and b. This is like NumPy, and both arrays a and b need to have been evaluated.

What if you don’t want to evaluate both branches? (Maybe they’re expensive to compute.) For this there is jax.lax.cond(pred, if_fn, else_fn), which is the runtime equivalent of a Python if statement.

There is one “gotcha” here. If your computation is batched due to a jax.vmap, then both if_fn and else_fn will be evaluated. Under-the-hood it is rewritten into a jax.numpy.where(batch_pred, if_fn(...), else_fn(...)). After all, some batch elements might need one branch, and some batch elements might need the other.

So this makes sense for JAX’s programming model, but it can also be a bit surprising. E.g. if if_fn sometimes produces an infinite loop for some inputs, and the jax.lax.cond is to guard against that, then that infinite loop will still be caught when vmap’d! (In this case you could fix this by using another jax.lax.cond to replace the bad input with a dummy-but-safe input before the loop.)