# JAX

JAX is a Python package that enables running Numpy like code on GPU like processors, with API's for Automatic Differentiation, Just-In-Time (JIT) Compilation, and straight-forward vectorization of code.

## Transformations

### grad

The `grad`

function in JAX takes in a function \(f\) that returns a scalar, and
returns a function \(df\) that takes in the same arguments as \(f\) and
evaluates the gradient of \(f\) at its input values.

from jax import grad import jax.numpy as jnp sq = lambda x: jnp.sum(x**2) dx_sq = grad(sq) print(dx_sq(jnp.arange(3.)))

We can specify which of the arguments to differentiate with respect to
using the `argnums`

argument.

from jax import grad import jax.numpy as jnp import jax.random as jrandom poly_3d = lambda x,y,z: jnp.sum(x**2 + y**3 + z**4) # output derivative with respect to x only # note arnums=0 by default dx_poly_3d = grad(poly_3d, argnums=0) print("dpoly_3d/dx is ", dx_poly_3d(jnp.arange(3.), jnp.arange(3.), jnp.arange(3.))) # output derivative with respect to x,y,z dxdydz_poly_3d = grad(poly_3d, argnums=(0,1,2)) print("dpoly_3d/dxdydz is ", dxdydz_poly_3d(jnp.arange(3.), jnp.arange(3.), jnp.arange(3.))) # or more clearly dx_poly_3d, dy_poly_3d, dz_poly_3d = dxdydz_poly_3d(jnp.arange(3.), jnp.arange(3.), jnp.arange(3.)) print("dpoly_3d/dx is ", dx_poly_3d, "dpoly_3d_dy is ", dy_poly_3d, "dpoly_3d_dz is ", dz_poly_3d)

Below is an example where we use `grad`

to estimate the parameters for a
simple linear regression model. It is common practice to use a single
variable (in the code below `params`

) for the trainable parameters.

TODO Understand how the "a =- lr*da" style code ran.

from jax import vmap, grad, tree_map import jax.numpy as jnp import jax.random as jrandom # Batch size for SGD batch_size=10 # Total number of samples in dataset samples=batch_size*10 # Generate dummy dataset key=jrandom.PRNGKey(1) x = jnp.linspace(-2,2,samples) y = 5*x + 2 + 0.001*jrandom.normal(key, shape=x.shape) # Linear model f = lambda params, x: params[0]*x+params[1] # Loss function loss = lambda params, x, y: jnp.mean((y-f(params,x))**2) # Gradients of loss function with respect to first argument # which is what we want grad_loss = grad(loss) lr=0.1 # Apply one SGD iteration and return the updated parameters # Another option is to decompose the tree_map and lambda function # inside it below update = lambda params, xs, ys: tree_map(lambda p, g: p-lr*g, params, grad_loss(params, xs, ys)) # Initial guess params = jnp.array([0.,1.]) # Run for 20 iterations, each time choosing a batch of # 10 samples randomly from the dataset for _ in range(20): key, _ = jrandom.split(key) batch_indices = jrandom.choice(key, jnp.arange(samples, dtype=int), (batch_size,)) xbatch, ybatch = x[batch_indices], y[batch_indices] params = update(params, xbatch, ybatch) print(params)

In a more involved training loop, one might want the loss value and
the gradient at once, for this we can use
`jax.value_and_grad(loss_func)`

which returns a function which takes the
same type of arguments as `loss_func`

and returns a tuple of the loss
value and gradients.

### vmap

Often we want to compute loss function on a batch of multiple data points. (Historically, the term batch refers to the whole training dataset and a mini-batch refers to a subset of it (, ). In JAX, we can write the loss function for 1 data point and use `vmap`

to get a function which does the batching for us.

from jax import vmap import jax.numpy as jnp # Dummy prediction function fhat = lambda params, x: params[0]*x+params[1] # Squared error loss function loss = lambda params, x, y: (fhat(params, x)-y)**2 vmap_loss = vmap(loss, in_axes=(None, 0, 1)) samples=5 xs = jnp.linspace(0,2,samples) ys = 5*xs + 2 loss_vals = vmap_loss(jnp.array([5.,2.]), xs.reshape((samples,1)), ys.reshape((1, samples))) print(loss_vals) # should be all 0s since params are true values

In the code above, `in_axes=(None, 0, 1)`

means the first argument for
`loss`

is not batched, while the second and third argument's batch
dimensions are 0 and 1 respectively. Note however batch dimensions are
often 0, the example above is for illustrative purposes. The `out_axes`

argument (0 by default) can be used in `vmap`

to denote where to put the
axis which is mapped over in the output.

Another example of using `vmap`

is when plotting the gradients of a 1D
function. If we have a function `f`

that returns scalars, then as
mentioned above `grad(f)`

is a function that takes in the same type of
argument(s) and returns the change in `f`

with respect to the change in
those argument(s).

import sys from jax import grad, vmap import jax.numpy as jnp import matplotlib.pyplot as plt f = lambda x: jnp.sin(x) dx_f = grad(f) x = jnp.linspace(0,2*jnp.pi,100) y = f(x) dy = vmap(dx_f)(x) plt.plot(x, y, x, dy) plt.legend(["sine(x)", "d/dx(sine)(x) = cosine(x)"]) plt.show()

Note that we cannot use `vmap`

for functions where the variables created
within it have different shapes (see here). For example, if we create
an array whose shape changed depending on the function input, then
`vmap`

will not work.
lambda
Another useful application of `vmap`

is to replace double for loops for
computations such as making matrices of pairwise distances or Gramm
matrices. Below are two cases of this, one from the Gaussian process
regression example in the main JAX repository and one from Matthew
Johnson's talk at the Fields Institute.

from jax import vmap import jax import jax.numpy as jnp import matplotlib.pyplot as plt # From JAX repo def apply_double_arg_func(func, xs, xs2=None): if xs2 is None: return vmap(lambda x: vmap(lambda y: func(x, y))(xs))(xs) else: return vmap(lambda x: vmap(lambda y: func(x, y))(xs))(xs2).T # From Matthew Johnson def apply_double_arg_func2(func, xs, xs2=None): if xs2 is None: return vmap(vmap(func, (0, None)), (None, 0))(xs, xs) else: return vmap(vmap(func, (0, None)), (None, 0))(xs,xs2).T def l1dist(x,y): return jnp.sum(jnp.abs(x-y)) x1=jnp.linspace(1,10,100) y1=jnp.linspace(1,5,20) print(jnp.allclose(apply_double_arg_func(l1dist, x1, y1), apply_double_arg_func2(l1dist, x1, y1)))

### jit

Sometimes we may have a function that we may run for many iterations,
e.g. the training loop itself. `jit`

allows to compile such functions,
by performing a Just-In-Time (JIT) Compilation of the function at run
time. Applying `jit`

to a function `f`

returns the JIT version `jit(f)`

of the
function, which is essentially the same as `f`

except that the function
is compiled the first time it is run. This works subject to certain
constraints on `f`

, namely, it should be a pure function,

## Low-level

### lax

The `lax`

library in JAX provides more primitive operations that underly
`jax.numpy`

.

The `lax`

API has useful functions for managing cotrol flow. One such
function is `scan`

, which is useful for sequential computations where
the previous state is required in some iterative computation.