JAX

JAX is a Python package that enables running Numpy like code on GPU like processors, with API's for Automatic Differentiation, Just-In-Time (JIT) Compilation, and straight-forward vectorization of code.

Transformations

grad

The grad function in JAX takes in a function \(f\) that returns a scalar, and returns a function \(df\) that takes in the same arguments as \(f\) and evaluates the gradient of \(f\) at its input values.

from jax import grad
import jax.numpy as jnp

sq = lambda x: jnp.sum(x**2)
dx_sq = grad(sq)
print(dx_sq(jnp.arange(3.)))

We can specify which of the arguments to differentiate with respect to using the argnums argument.

from jax import grad
import jax.numpy as jnp
import jax.random as jrandom

poly_3d = lambda x,y,z: jnp.sum(x**2 + y**3 + z**4)
# output derivative with respect to x only
# note arnums=0 by default
dx_poly_3d = grad(poly_3d, argnums=0)
print("dpoly_3d/dx is ", dx_poly_3d(jnp.arange(3.), jnp.arange(3.), jnp.arange(3.)))

# output derivative with respect to x,y,z
dxdydz_poly_3d = grad(poly_3d, argnums=(0,1,2))
print("dpoly_3d/dxdydz is ", dxdydz_poly_3d(jnp.arange(3.),
                                          jnp.arange(3.),
                                          jnp.arange(3.)))

# or more clearly
dx_poly_3d, dy_poly_3d, dz_poly_3d = dxdydz_poly_3d(jnp.arange(3.),
                                                    jnp.arange(3.),
                                                    jnp.arange(3.))
print("dpoly_3d/dx is ", dx_poly_3d,
      "dpoly_3d_dy is ", dy_poly_3d,
      "dpoly_3d_dz is ", dz_poly_3d)

Below is an example where we use grad to estimate the parameters for a simple linear regression model. It is common practice to use a single variable (in the code below params) for the trainable parameters.

TODO Understand how the "a =- lr*da" style code ran.

from jax import vmap, grad, tree_map
import jax.numpy as jnp
import jax.random as jrandom

# Batch size for SGD
batch_size=10

# Total number of samples in dataset
samples=batch_size*10

# Generate dummy dataset
key=jrandom.PRNGKey(1)
x = jnp.linspace(-2,2,samples)
y = 5*x + 2 + 0.001*jrandom.normal(key, shape=x.shape)

# Linear model
f = lambda params, x: params[0]*x+params[1]

# Loss function
loss = lambda params, x, y: jnp.mean((y-f(params,x))**2)

# Gradients of loss function with respect to first argument
# which is what we want
grad_loss = grad(loss)

lr=0.1

# Apply one SGD iteration and return the updated parameters
# Another option is to decompose the tree_map and lambda function
# inside it below
update = lambda params, xs, ys: tree_map(lambda p, g: p-lr*g,
                                         params,
                                         grad_loss(params, xs, ys))

# Initial guess
params = jnp.array([0.,1.])

# Run for 20 iterations, each time choosing a batch of
# 10 samples randomly from the dataset
for _ in range(20):
    key, _ = jrandom.split(key)
    batch_indices = jrandom.choice(key, jnp.arange(samples, dtype=int), (batch_size,))
    xbatch, ybatch = x[batch_indices], y[batch_indices]
    params = update(params, xbatch, ybatch)

print(params)

In a more involved training loop, one might want the loss value and the gradient at once, for this we can use jax.value_and_grad(loss_func) which returns a function which takes the same type of arguments as loss_func and returns a tuple of the loss value and gradients.

vmap

Often we want to compute loss function on a batch of multiple data points. (Historically, the term batch refers to the whole training dataset and a mini-batch refers to a subset of it (, ). In JAX, we can write the loss function for 1 data point and use vmap to get a function which does the batching for us.

from jax import vmap
import jax.numpy as jnp

# Dummy prediction function
fhat = lambda params, x: params[0]*x+params[1]
# Squared error loss function
loss = lambda params, x, y: (fhat(params, x)-y)**2

vmap_loss = vmap(loss, in_axes=(None, 0, 1))

samples=5
xs = jnp.linspace(0,2,samples)
ys = 5*xs + 2
loss_vals = vmap_loss(jnp.array([5.,2.]),
                      xs.reshape((samples,1)),
                      ys.reshape((1, samples)))
print(loss_vals) # should be all 0s since params are true values

In the code above, in_axes=(None, 0, 1) means the first argument for loss is not batched, while the second and third argument's batch dimensions are 0 and 1 respectively. Note however batch dimensions are often 0, the example above is for illustrative purposes. The out_axes argument (0 by default) can be used in vmap to denote where to put the axis which is mapped over in the output.

Another example of using vmap is when plotting the gradients of a 1D function. If we have a function f that returns scalars, then as mentioned above grad(f) is a function that takes in the same type of argument(s) and returns the change in f with respect to the change in those argument(s).

import sys
from jax import grad, vmap
import jax.numpy as jnp
import matplotlib.pyplot as plt

f = lambda x: jnp.sin(x)
dx_f = grad(f)

x = jnp.linspace(0,2*jnp.pi,100)
y = f(x)
dy = vmap(dx_f)(x)

plt.plot(x, y, x, dy)
plt.legend(["sine(x)", "d/dx(sine)(x) = cosine(x)"])
plt.show()

Note that we cannot use vmap for functions where the variables created within it have different shapes (see here). For example, if we create an array whose shape changed depending on the function input, then vmap will not work. lambda Another useful application of vmap is to replace double for loops for computations such as making matrices of pairwise distances or Gramm matrices. Below are two cases of this, one from the Gaussian process regression example in the main JAX repository and one from Matthew Johnson's talk at the Fields Institute.

from jax import vmap
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# From JAX repo 
def apply_double_arg_func(func, xs, xs2=None):
    if xs2 is None:
        return vmap(lambda x: vmap(lambda y: func(x, y))(xs))(xs)
    else:
      return vmap(lambda x: vmap(lambda y: func(x, y))(xs))(xs2).T

# From Matthew Johnson
def apply_double_arg_func2(func, xs, xs2=None):
    if xs2 is None:
        return vmap(vmap(func, (0, None)), (None, 0))(xs, xs)
    else:
       return vmap(vmap(func, (0, None)), (None, 0))(xs,xs2).T

def l1dist(x,y):
    return jnp.sum(jnp.abs(x-y))

x1=jnp.linspace(1,10,100)
y1=jnp.linspace(1,5,20)

print(jnp.allclose(apply_double_arg_func(l1dist, x1, y1),
                   apply_double_arg_func2(l1dist, x1, y1)))

jit

Sometimes we may have a function that we may run for many iterations, e.g. the training loop itself. jit allows to compile such functions, by performing a Just-In-Time (JIT) Compilation of the function at run time. Applying jit to a function f returns the JIT version jit(f) of the function, which is essentially the same as f except that the function is compiled the first time it is run. This works subject to certain constraints on f, namely, it should be a pure function,

Low-level

lax

The lax library in JAX provides more primitive operations that underly jax.numpy.

The lax API has useful functions for managing cotrol flow. One such function is scan, which is useful for sequential computations where the previous state is required in some iterative computation.

Emacs 29.4 (Org mode 9.6.15)