Haliax Cheatsheet¤
This is a cheatsheet for converting common functions and tasks from JAX/NumPy to Haliax. Please open an issue on Github if you find any errors or omissions. We're happy to add more examples if you have any that you think would be useful.
Preamble¤
Throughout we assume the following:
import jax.numpy as jnp
import haliax as hax
Batch = hax.Axis("batch", 32)
Embed = hax.Axis("embed", 64)
H = hax.Axis("h", 16)
W = hax.Axis("w", 16)
C = hax.Axis("c", 3)
Step = hax.Axis("step", 2)
Mini = hax.Axis("mini", 16)
# for jax
x = jnp.zeros((32, 64))
y = jnp.zeros((32, 64))
z = jnp.zeros((32,))
w = jnp.zeros((64,))
ind = jnp.arange((8,), dtype=jnp.int32)
im = jnp.zeros((32, 16, 16, 3))
w2 = jnp.zeros((3, 64))
# for haliax
x = hax.zeros((Batch, Embed))
y = hax.zeros((Batch, Embed))
z = hax.zeros((Batch,))
w = hax.zeros((Embed,))
ind = hax.arange(hax.Axis("Index", 8), dtype=jnp.int32)
im = hax.zeros((Batch, H, W, C))
w2 = hax.zeros((C, Embed))
Array Creation¤
Combining Arrays¤
Array Manipulation¤
Shape Manipulation¤
Einops-style Rearrange¤
See also the section on Rearrange.
Broadcasting¤
See also the section on Broadcasting.
| JAX | Haliax |
|---|---|
jnp.broadcast_to(z.reshape(-1, 1), (32, 64)) |
hax.broadcast_axis(z, Embed) |
Outer product: z.reshape(-1, 1) * w.reshape(1, -1) |
z * w.broadcast_axis(Batch) |
Indexing and Slicing¤
See also the section on Indexing and Slicing.
| JAX | Haliax |
|---|---|
x[0] |
x["batch", 0] |
x[:, 0] |
x["embed", 0] |
x[0, 1] |
x["batch", 0, "embed", 1] |
x[0:10] |
x["batch", 0:10] |
x[0:10:2] |
x["batch", 0:10:2] |
x[0, 1:10:2] |
x["batch", 0, "embed", 1:10:2] |
x[0, [1, 2, 3]] |
x["batch", 0, "embed", [1, 2, 3]] |
x[0, ind] |
x["batch", 0, "embed", ind] |
jnp.take_along_axis(x, ind, axis=1)][jax.numpy.take_along_axis] |
hax.take(x, "embed", ind) |
jax.lax.dynamic_slice_in_dim(x, 4, 10) |
hax.slice(x, "batch", start=4, length=10) |
jax.lax.dynamic_slice_in_dim(x, 4, 10) |
x["batch", hax.ds(4, 10)] |
Operations¤
Elementwise Operations¤
Almost all elementwise operations are the same as JAX, except that they work on either haliax.NamedArray or jax.numpy.ndarray objects.
Any elementwise operation in jax.nn should be in haliax.nn.
Binary Operations¤
Similarly, binary operations are the same as JAX, except that they work on haliax.NamedArray objects.
Reductions¤
Reductions are similar to JAX, except that they use an axis name instead of an axis index.