Jax Basic
Sources:
- JAX Quickstart
- Learning JAX as a PyTorch developer
Introdction to Jax
-->Source: Tutorial 2 (JAX): Introduction to JAX+Flax
Why should you learn JAX, if there are already so many other deep learning frameworks like PyTorch and TensorFlow? The short answer: because it can be extremely fast. For instance, a small GoogleNet on CIFAR10, which we discuss in detail in Tutorial 5, can be trained in JAX 3x faster than in PyTorch with a similar setup
However, everything comes with a price. In order to efficiently compile programs just-in-time in JAX, the functions need to be written with certain constraints. Firstly, the functions are not allowed to have side-effects, meaning that they are not allowed to affect any variable outside of their namespaces. For instance, in-place operations affect a variable even outside of the function. Moreover, stochastic operations such as
torch.rand(...)
change the global state of pseudo random number generators, which is not allowed in functional JAX (we will see later how JAX handles random number generation). Secondly, JAX compiles the functions based on anticipated shapes of all arrays/tensors in the function. This becomes problematic if the shapes or the program flow within the function depends on the values of the tensor. For instance, in the operationy = x[x>3]
, the shape ofy
depends on how many values ofx
are greater than 3. We will discuss more of these constraints in this notebook. Still, in most common cases of training neural networks, it is straightforward to write functions within these constraints.
Jax arraies
- JAX provides a NumPy-inspired interface for convenience.
- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.
- Unlike NumPy arrays, JAX arrays are always immutable.
In numpy, arries are mutable:
1 | size = 10 |
output:
1 | [0 1 2 3 4 5 6 7 8 9] |
However, JAX arrays are immutable.
1 | x = jnp.arange(size) |
Output:
1 | [0 1 2 3 4 5 6 7 8 9] |
Solution is:
1 | y = x.at[index].set(value) |
This will create and return a new array with the modified values. You can test it via
1 | print(x is y) |
This will return False.
Random numbers
JAX uses an explicit approach to generating randomn numbers: you have to provide a PRNG key. So instead of torch.randn(shape)
, you have jax.random.normal(key, shape)
.
This is actually a JAX superpower! Explicitly threading the random state like this means you get trivially reproducible behaviour.
numpy
's PRNG is stateful:
1 | # NumPy - PRNG is stateful! |
Output:
1 | 0.6027633760716439 |
The rng_state
changes each time.
However, JAX's random functions can't modify PRNG's state!
1 | key = random.PRNGKey(seed) |
Output:
1 | [0 0] |
Instead of a mutable global state, JAX uses explicit PRNG keys (key = random.PRNGKey(seed)
) to control randomness.
So we need to split every time you need a pseudorandom number.
1 | print("old key", key) |
Output:
1 | old key [0 0] |
JAX requires explicit PRNG keys for generating random numbers to ensure reproducibility and to fit with its functional programming model. A key concept in JAX's PRNG system is that every call to a random number generator consumes a PRNG key and produces a new key along with the random output. This system is designed to avoid hidden state and side effects, adhering to functional programming principles.
Give code
1 | # NumPy |
Output:
1 | individually: [0.5488135 0.71518937 0.60276338] |
As you can see, there is difference in the JAX output for generating random numbers "individually" vs. "all at once".
Device agnostic
JAX is AI accelerator agnostic. Same code runs everywhere!
In JAX, many operations are executed asynchronously.
1 | import jax.numpy as jnp |
Variables initiazlied via numpy
are stored in CPU, initiazlied via jax
are stored in GPU (or other accelerators).
1 | 1.42 ms ± 255 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) |
jit()
- By default JAX executes operations one at a time, in sequence.
- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.
- How to activate jit: Use
@jit
or within functions likelax.fori_loop
,lax.scan
,lax.cond
, etc.)
- How to activate jit: Use
- Not all JAX code can be JIT compiled, as it requires array shapes abd types to be static & known at compile time.
1 | # Define a function |
Output:
1 | non-jit version: |
Another usage:
1 | def norm(X): |
Output:
1 | 584 µs ± 48 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) |
Arries to be jitted must be static:
1 | # Example of a failure: array shapes must be static |
Output:
1 | [-0.3721109 -0.18252768 -0.7368197 -0.44030377 -0.1521442 -0.67135346 |
Now try to jit it
1 | print(jit(get_negatives)(x)) |
you'll get a failure:
1 | NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10]) |
grad()
vmap()
vmap()
makes the batch space transparent.
Given:
1 | W = random.normal(key, (150, 100)) # e.g. weights of a linear NN layer |
We want to do dot product in parallell.
Naively batched
The easiest way is using a for loop to iterate over the batched input and stack the results into one batch:
1 | def naively_batched_apply_matrix(batched_x): |
Output:
1 | Naively batched |
Manually batched
Moreover, we can do the computation in matrix form without the need of iterations.
1 |
|
Output:
1 | Manually batched |
Auto-vectorized with vmap
Jax provided vmap()
to make batch space become transparent.
1 | # Note: we can arbitrarily compose JAX transforms! Here jit + vmap. |
lax()
jax.numpy
is a high-level wrapper that provides a familiar interface.jax.lax
is a lower-level API that is stricter and often more powerful.- All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler.
lax
is stricter than jax.numpy
:
1 | print(jnp.add(1, 1.0)) # jax.numpy API implicitly promotes mixed types |
JIT mechanics: tracing and static variables
- JIT and other JAX transforms work by tracing a function to determine its effect on inputs of a specific shape and type.
- Variables that you don’t want to be traced can be marked as static
- Tracing is recording the computation with zero FLOP.
Traced values
Jax operations, such as jax.jit
, will convert values to traced values. These tracer objects are what jax.jit
uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the shape and dtype of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.
When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:
To use jax.jit
effectively, it is useful to understand how it works. Let’s put a few print()
statements within a JIT-compiled function and then call the function:
1 |
|
Output:
1 | Running f(): |
Notice that the print statements execute, but rather than printing the data we passed to the function, though, it prints tracer objects that stand-in for them.
1 | x2 = np.random.randn(3, 4) |
For this reason, the side-effect will only happen once!
JAX expressions
The extracted sequence of operations is encoded in a JAX expression, or jaxpr for short. You can view the jaxpr using the jax.make_jaxpr
transformation:
1 | from jax import make_jaxpr |
Note one consequence of this: because JIT compilation is done without information on the content of the array, control flow statements in the function cannot depend on traced values. For example, this fails:
1 |
|
Output:
1 | TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].. |
Static values
If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation:
1 | from functools import partial |
Note that calling a JIT-compiled function with a different static argument results in re-compilation, so the function still works as expected:
1 | f(1, False) |
Understanding which values and operations will be static and which will be traced is a key part of using jax.jit
effectively.
Static vs Traced Operations
- Just as values can be either static or traced, operations can be static or traced.
- Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.
- Use
numpy
for operations that you want to be static; usejax.numpy
for operations that you want to be traced.
This distinction between static and traced values makes it important to think about how to keep a static value static. Consider this function:
1 | import jax.numpy as jnp |
Output:
1 | TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>]. |
This fails with an error specifying that a tracer was found instead of a 1D sequence of concrete values of integer type. Let’s add some print statements to the function to understand why this is happening:
1 |
|
Output:
1 | x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)> |
Notice that although x
is traced, x.shape
is a static value. However, when we use jnp.array
and jnp.prod
on this static value, it becomes a traced value, at which point it cannot be used in a function like reshape()
that requires a static input (recall: array shapes must be static).
A useful pattern is to use numpy
for operations that should be static (i.e. done at compile-time), and use jax.numpy
for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this:
1 | from jax import jit |
For this reason, a standard convention in JAX programs is to import numpy as np
and import jax.numpy as jnp
so that both interfaces are available for finer control over whether operations are performed in a static matter (with numpy
, once at compile-time) or a traced manner (with jax.numpy
, optimized at run-time).
Pure function
Side effects of jitted functions will only happen once.
1 | # So how does it work in the background? -> tracing on different levels of abstraction |
Output:
1 | Running f(): |
JAX's JIT compilation process traces the function to generate a compiled version. During this tracing, the function is executed on abstract values that represent the shapes and types of the inputs, not their actual values. This means that conditionals based on the values of inputs (like neg
in this case) cannot be resolved during compilation because the tracer doesn't know the value of neg
ahead of time.
In JAX, inputs to JIT-compiled functions are considered "dynamic" by default, meaning their values can change between calls, and JAX expects their shapes and types to determine how to compile the function. However, when a function's behavior depends on a value (like a flag that changes its logic path), you need to tell JAX to treat this input as "static" so that it can compile different versions of the function for different values of these static inputs.
This can't be jitted.
1 |
|
JAX re-runs the Python function when the type or shape of the argument changes. This will end up reading the latest value of the global.
1 | # Example 2 |
For loop
However, iterators in Python are stateful objects. The progression of an iterator (e.g., via next(iterator)
) cannot be statically determined by JAX's tracing mechanism. When JAX tries to compile the function, it cannot predict or analyze the behavior of next(iterator)
because it involves mutable state that changes as the iterator progresses.
As a result, JAX's static analysis tools fail to incorporate the iterator's behavior correctly into the compiled code. Depending on the context and how JAX's tracer interacts with the iterator, this can lead to unexpected results, such as the operation seemingly not being executed at all, hence the "unexpected result 0".
1 | iterator = iter(range(10)) |
Output:
1 | 0 |
Out-of-Bounds Indexing
Due to JAX's accelerator agnostic approach JAX had to make a non-error behaviour for out of bounds indexing (similarly to how invalid fp arithmetic results in NaNs and not an exception).
1 | # NumPy behavior |
This will throw en error.
But jax arraies won't throw exceptions:
1 | # JAX behavior |
Non-array inputs
This will throw an error:
1 | try: |
Output:
1 | TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0. |
Use jnp.array()
to convert list to convert.
1 | def permissive_sum(x): |
PyTree
When you do something like this:
1 | x = [array1, array2] |
then JAX knows to unpack the x
argument (which is a list
) in order to find the arrays array1
and array2
. This is how it can replace arrays with tracers when JIT compiling – and this unpacking is also how JAX can find the arrays to create gradients for when using jax.grad
, or the arrays to vectorise when using jax.vmap
, and so on.
The objects that JAX knows how to unpack are lists, tuples, dictionaries, user-registered custom nodes (we’ll come back to these), and any arbitrarily-nested collection of these. Overall, these are called “PyTrees”. It’s typical to represent a model as some PyTree of its parameters.
It’s almost always a mistake to use raw Python classes with JAX. These aren’t PyTrees, so JAX can’t know how you want to handle these unless you tell it how to. As such, we generally register classes as custom pytree nodes. (One way to do this is by subclassing equinox.Module
, see below.)
This is directly analogous to torch.nn.Module
, and how you must use self.foo = torch.nn.ModuleList(...)
rather than self.foo = [...]
. You can’t use raw Python classes with PyTorch either. The only important difference is that PyTorch Modules are treated as directed acyclic graphs (=the same Module can appear in multiple places), whilst JAX PyTrees are treated as, well, trees (=multiple appearances of the same object are treated as independent copies).
1 | # A contrived example for pedagogical purposes |
Output:
1 | [1, 'a', <object object at 0x7f9cb0f032e0>] has 3 leaves: [1, 'a', <object object at 0x7f9cb0f032e0>] |
How do we manipulate PyTrees?
Use tree_map()
. It iterates through leaves and applies the lambda function.
1 | print(jax.tree_map(lambda x: x*2, list_of_lists)) |
Output:
1 | [{'a': 6}, [2, 4, 6], [2, 4], [2, 4, 6, 8]] |
Another example:
1 | another_list_of_lists = list_of_lists |
Output:
1 | another_list_of_lists = list_of_lists |
PyTrees need to have the same structure if we are to apply tree_multimap. So following code will fail:
1 | another_list_of_lists = deepcopy(list_of_lists) |
Output:
1 | ValueError: List arity mismatch: 5 != 4; list: [{'a': 3}, [1, 2, 3], [1, 2], [1, 2, 3, 4], [23]]. |
Weired things
NaN behavior
The default non-error behavior will simply return a NaN (as usually).
1 | jnp.divide(0., 0.) # the default non-error behavior will simply return a NaN (as usually) |
State
1 | # 1) We've seen in the last notebook/video that impure functions are problematic. |
Output:
1 | First call: 4.0 |
Stateful --> Stateless
In summary, we use the following rule to convert a stateful class:
1 | class StatefulClass |
into a class of the form:
1 | class StatelessClass |
The member functions of Clock
are stateful. For instance, Function count(self)
leverages self.n
, which is external to count(self)
and is stateful to it (Well, I think self.n
is not external since its a parameter. But IDK)
Example
1 | # Let's now explictly address and understand the problem of state! |
Output:
1 | 1 |
The jitted function is stateful!
1 | counter.reset() |
Output:
1 | 1 |
The solution is:
1 | # Solution: |
Output:
1 | 1 |
Gotcha: vmap(jax.lax.cond)
evaluates both branches
JAX has jax.numpy.where(pred, a, b)
to do an if
statement between two arrays a
and b
. This is like NumPy, and both arrays a
and b
need to have been evaluated.
What if you don’t want to evaluate both branches? (Maybe they’re expensive to compute.) For this there is jax.lax.cond(pred, if_fn, else_fn)
, which is the runtime equivalent of a Python if
statement.
There is one “gotcha” here. If your computation is batched due to a jax.vmap
, then both if_fn
and else_fn
will be evaluated. Under-the-hood it is rewritten into a jax.numpy.where(batch_pred, if_fn(...), else_fn(...))
. After all, some batch elements might need one branch, and some batch elements might need the other.
So this makes sense for JAX’s programming model, but it can also be a bit surprising. E.g. if if_fn
sometimes produces an infinite loop for some inputs, and the jax.lax.cond
is to guard against that, then that infinite loop will still be caught when vmap’d! (In this case you could fix this by using another jax.lax.cond
to replace the bad input with a dummy-but-safe input before the loop.)