Skip to content

Haliax Cheatsheet¤

This is a cheatsheet for converting common functions and tasks from JAX/NumPy to Haliax. Please open an issue on Github if you find any errors or omissions. We're happy to add more examples if you have any that you think would be useful.

Preamble¤

Throughout we assume the following:

import jax.numpy as jnp

import haliax as hax

Batch = hax.Axis("batch", 32)
Embed = hax.Axis("embed", 64)
H = hax.Axis("h", 16)
W = hax.Axis("w", 16)
C = hax.Axis("c", 3)


Step = hax.Axis("step", 2)
Mini = hax.Axis("mini", 16)

# for jax
x = jnp.zeros((32, 64))
y = jnp.zeros((32, 64))
z = jnp.zeros((32,))
w = jnp.zeros((64,))
ind = jnp.arange((8,), dtype=jnp.int32)
im = jnp.zeros((32, 16, 16, 3))
w2 = jnp.zeros((3, 64))

# for haliax
x = hax.zeros((Batch, Embed))
y = hax.zeros((Batch, Embed))
z = hax.zeros((Batch,))
w = hax.zeros((Embed,))
ind = hax.arange(hax.Axis("Index", 8), dtype=jnp.int32)
im = hax.zeros((Batch, H, W, C))
w2 = hax.zeros((C, Embed))

Array Creation¤

JAX Haliax
jnp.zeros((32, 64)) hax.zeros((Batch, Embed))
jnp.ones((32, 64)) hax.ones((Batch, Embed))
jnp.zeros_like(x) hax.zeros_like(x)
jnp.ones_like(x) hax.ones_like(x)
jnp.eye(32)
jnp.arange(32) hax.arange(Batch)
jnp.linspace(0, 1, 32) hax.linspace(Batch, 0, 1)
jnp.logspace(0, 1, 32) hax.logspace(Batch, 0, 1)
jnp.geomspace(0, 1, 32) hax.geomspace(Batch, 0, 1)

Combining Arrays¤

JAX Haliax
jnp.concatenate([x, y]) hax.concatenate("batch", [x, y])
jnp.stack([x, y]) hax.stack("foo", [x, y])
jnp.hstack([x, y]) hax.concatenate("embed", [x, y])
jnp.vstack([x, y]) hax.concatenate("batch", [x, y])

Array Manipulation¤

JAX Haliax
jnp.reshape(x, (2, 16, 64)) hax.unflatten_axis(x, "batch", (Step, Mini)
jnp.reshape(x, (-1,)) hax.flatten_axes(x, ("batch", "embed"), "foo")
jnp.transpose(x, (1, 0)) hax.rearrange(x, ("embed", "batch"))

Shape Manipulation¤

JAX Haliax
x.transpose((1, 0)) x.rearrange("embed", "batch")
x.reshape((2, 16, 64)) x.unflatten_axis("batch", (Axis("a", 2), Axis("b", 16)))
x.reshape((-1,)) x.flatten_axes(("batch", "embed"), "foo")
jnp.ravel(x) hax.ravel(x, "Embed")
jnp.ravel(x) hax.flatten(x, "Embed")

Einops-style Rearrange¤

See also the section on Rearrange.

JAX (with einops) Haliax
einops.rearrange(x, "batch embed -> embed batch") hax.rearrange(x, ("embed", "batch"))
einops.rearrange(x, "batch embed -> embed batch") hax.rearrange(x, "b e -> e b")
einops.rearrange(im, "... h w -> ... (h w)") hax.flatten_axes(im, ("h", "w"), "hw")
einops.rearrange(im, "... h w c -> ... (h w c)") hax.rearrange(im, "{h w c} -> ... (embed: h w c)")
einops.rearrange(x, "b (h w) -> b h w", h=8) hax.rearrange(x, "b (h w) -> b h w", h=8)
einops.rearrange(x, "b (h w) -> b h w", h=8) hax.rearrange(x, "{(embed: h w)} -> ... h w", h=8)

Broadcasting¤

See also the section on Broadcasting.

JAX Haliax
jnp.broadcast_to(z.reshape(-1, 1), (32, 64)) hax.broadcast_axis(z, Embed)
Outer product: z.reshape(-1, 1) * w.reshape(1, -1) z * w.broadcast_axis(Batch)

Indexing and Slicing¤

See also the section on Indexing and Slicing.

JAX Haliax
x[0] x["batch", 0]
x[:, 0] x["embed", 0]
x[0, 1] x["batch", 0, "embed", 1]
x[0:10] x["batch", 0:10]
x[0:10:2] x["batch", 0:10:2]
x[0, 1:10:2] x["batch", 0, "embed", 1:10:2]
x[0, [1, 2, 3]] x["batch", 0, "embed", [1, 2, 3]]
x[0, ind] x["batch", 0, "embed", ind]
jnp.take_along_axis(x, ind, axis=1)][jax.numpy.take_along_axis] hax.take(x, "embed", ind)
jax.lax.dynamic_slice_in_dim(x, 4, 10) hax.slice(x, "batch", start=4, length=10)
jax.lax.dynamic_slice_in_dim(x, 4, 10) x["batch", hax.ds(4, 10)]

Operations¤

Elementwise Operations¤

Almost all elementwise operations are the same as JAX, except that they work on either haliax.NamedArray or jax.numpy.ndarray objects.

Any elementwise operation in jax.nn should be in haliax.nn.

Binary Operations¤

Similarly, binary operations are the same as JAX, except that they work on haliax.NamedArray objects.

Reductions¤

Reductions are similar to JAX, except that they use an axis name instead of an axis index.

JAX Haliax
jnp.sum(x, axis=0) hax.sum(x, axis=Batch)
jnp.mean(x, axis=(0, 1)) hax.mean(x, axis=(Batch, Embed))
jnp.max(x) hax.max(x)
jnp.min(x, where=x > 0) hax.min(x, where=x > 0)
jnp.argmax(x, axis=0) hax.argmax(x, axis=Batch)

Matrix Multiplication¤

JAX Haliax
jnp.dot(z, x) hax.dot(z, x, axis="batch")
jnp.matmul(z, x) hax.dot(z, x, axis="batch")
jnp.dot(w, x.t) hax.dot(w, x, axis="embed")
jnp.einsum("ij,j -> i", x, w) hax.dot(x, w, axis="embed")
jnp.einsum("i,ij,ij,j -> i", z, x, y, w) hax.dot(z, x, y, w, axis="embed")
jnp.einsum("ij,j -> ji", x, w) hax.dot(x, w, axis=(), out_axes=("embed", "batch")
jnp.einsum("bhwc,ce -> bhwe", im, w2) hax.einsum("b h w c, c e -> b h w e", im, w2)
jnp.einsum("...c,ce -> ...e", im, w2) hax.einsum("... c, c e -> ... e", im, w2)
jnp.einsum("bhwc,ce -> bhwe", im, w2) hax.einsum("{c embed} -> embed", im, w2)
jnp.einsum("bhwc,ce -> bhwe", im, w2) hax.einsum("-> batch h w embed", im, w2)
jnp.einsum("bhwc,ce -> bhwce", im, w2) hax.einsum("{...} -> ...", im, w2)
jnp.einsum("bhwc,ce -> ", im, w2) hax.einsum("{...} -> ", im, w2)
jnp.einsum("bhwc,ce -> bhwce", im, w2) hax.dot(im, w2, axis=())
jnp.einsum("bhwc,ce -> ", im, w2) hax.dot(im, w2, axis=None)