JAX
JAX is a Python package that enables running Numpy like code on GPU like processors, with API's for Automatic Differentiation, Just-In-Time (JIT) Compilation, and straight-forward vectorization of code.
Transformations
grad
The grad
function in JAX takes in a function \(f\) that returns a scalar, and
returns a function \(df\) that takes in the same arguments as \(f\) and
evaluates the gradient of \(f\) at its input values.
from jax import grad import jax.numpy as jnp sq = lambda x: jnp.sum(x**2) dx_sq = grad(sq) print(dx_sq(jnp.arange(3.)))
We can specify which of the arguments to differentiate with respect to
using the argnums
argument.
from jax import grad import jax.numpy as jnp import jax.random as jrandom poly_3d = lambda x,y,z: jnp.sum(x**2 + y**3 + z**4) # output derivative with respect to x only # note arnums=0 by default dx_poly_3d = grad(poly_3d, argnums=0) print("dpoly_3d/dx is ", dx_poly_3d(jnp.arange(3.), jnp.arange(3.), jnp.arange(3.))) # output derivative with respect to x,y,z dxdydz_poly_3d = grad(poly_3d, argnums=(0,1,2)) print("dpoly_3d/dxdydz is ", dxdydz_poly_3d(jnp.arange(3.), jnp.arange(3.), jnp.arange(3.))) # or more clearly dx_poly_3d, dy_poly_3d, dz_poly_3d = dxdydz_poly_3d(jnp.arange(3.), jnp.arange(3.), jnp.arange(3.)) print("dpoly_3d/dx is ", dx_poly_3d, "dpoly_3d_dy is ", dy_poly_3d, "dpoly_3d_dz is ", dz_poly_3d)
Below is an example where we use grad
to estimate the parameters for a
simple linear regression model. It is common practice to use a single
variable (in the code below params
) for the trainable parameters.
TODO Understand how the "a =- lr*da" style code ran.
from jax import vmap, grad, tree_map import jax.numpy as jnp import jax.random as jrandom # Batch size for SGD batch_size=10 # Total number of samples in dataset samples=batch_size*10 # Generate dummy dataset key=jrandom.PRNGKey(1) x = jnp.linspace(-2,2,samples) y = 5*x + 2 + 0.001*jrandom.normal(key, shape=x.shape) # Linear model f = lambda params, x: params[0]*x+params[1] # Loss function loss = lambda params, x, y: jnp.mean((y-f(params,x))**2) # Gradients of loss function with respect to first argument # which is what we want grad_loss = grad(loss) lr=0.1 # Apply one SGD iteration and return the updated parameters # Another option is to decompose the tree_map and lambda function # inside it below update = lambda params, xs, ys: tree_map(lambda p, g: p-lr*g, params, grad_loss(params, xs, ys)) # Initial guess params = jnp.array([0.,1.]) # Run for 20 iterations, each time choosing a batch of # 10 samples randomly from the dataset for _ in range(20): key, _ = jrandom.split(key) batch_indices = jrandom.choice(key, jnp.arange(samples, dtype=int), (batch_size,)) xbatch, ybatch = x[batch_indices], y[batch_indices] params = update(params, xbatch, ybatch) print(params)
In a more involved training loop, one might want the loss value and
the gradient at once, for this we can use
jax.value_and_grad(loss_func)
which returns a function which takes the
same type of arguments as loss_func
and returns a tuple of the loss
value and gradients.
vmap
Often we want to compute loss function on a batch of multiple data points. (Historically, the term batch refers to the whole training dataset and a mini-batch refers to a subset of it (, ). In JAX, we can write the loss function for 1 data point and use vmap
to get a function which does the batching for us.
from jax import vmap import jax.numpy as jnp # Dummy prediction function fhat = lambda params, x: params[0]*x+params[1] # Squared error loss function loss = lambda params, x, y: (fhat(params, x)-y)**2 vmap_loss = vmap(loss, in_axes=(None, 0, 1)) samples=5 xs = jnp.linspace(0,2,samples) ys = 5*xs + 2 loss_vals = vmap_loss(jnp.array([5.,2.]), xs.reshape((samples,1)), ys.reshape((1, samples))) print(loss_vals) # should be all 0s since params are true values
In the code above, in_axes=(None, 0, 1)
means the first argument for
loss
is not batched, while the second and third argument's batch
dimensions are 0 and 1 respectively. Note however batch dimensions are
often 0, the example above is for illustrative purposes. The out_axes
argument (0 by default) can be used in vmap
to denote where to put the
axis which is mapped over in the output.
Another example of using vmap
is when plotting the gradients of a 1D
function. If we have a function f
that returns scalars, then as
mentioned above grad(f)
is a function that takes in the same type of
argument(s) and returns the change in f
with respect to the change in
those argument(s).
import sys from jax import grad, vmap import jax.numpy as jnp import matplotlib.pyplot as plt f = lambda x: jnp.sin(x) dx_f = grad(f) x = jnp.linspace(0,2*jnp.pi,100) y = f(x) dy = vmap(dx_f)(x) plt.plot(x, y, x, dy) plt.legend(["sine(x)", "d/dx(sine)(x) = cosine(x)"]) plt.show()
Note that we cannot use vmap
for functions where the variables created
within it have different shapes (see here). For example, if we create
an array whose shape changed depending on the function input, then
vmap
will not work.
lambda
Another useful application of vmap
is to replace double for loops for
computations such as making matrices of pairwise distances or Gramm
matrices. Below are two cases of this, one from the Gaussian process
regression example in the main JAX repository and one from Matthew
Johnson's talk at the Fields Institute.
from jax import vmap import jax import jax.numpy as jnp import matplotlib.pyplot as plt # From JAX repo def apply_double_arg_func(func, xs, xs2=None): if xs2 is None: return vmap(lambda x: vmap(lambda y: func(x, y))(xs))(xs) else: return vmap(lambda x: vmap(lambda y: func(x, y))(xs))(xs2).T # From Matthew Johnson def apply_double_arg_func2(func, xs, xs2=None): if xs2 is None: return vmap(vmap(func, (0, None)), (None, 0))(xs, xs) else: return vmap(vmap(func, (0, None)), (None, 0))(xs,xs2).T def l1dist(x,y): return jnp.sum(jnp.abs(x-y)) x1=jnp.linspace(1,10,100) y1=jnp.linspace(1,5,20) print(jnp.allclose(apply_double_arg_func(l1dist, x1, y1), apply_double_arg_func2(l1dist, x1, y1)))
jit
Sometimes we may have a function that we may run for many iterations,
e.g. the training loop itself. jit
allows to compile such functions,
by performing a Just-In-Time (JIT) Compilation of the function at run
time. Applying jit
to a function f
returns the JIT version jit(f)
of the
function, which is essentially the same as f
except that the function
is compiled the first time it is run. This works subject to certain
constraints on f
, namely, it should be a pure function,
Low-level
lax
The lax
library in JAX provides more primitive operations that underly
jax.numpy
.
The lax
API has useful functions for managing cotrol flow. One such
function is scan
, which is useful for sequential computations where
the previous state is required in some iterative computation.