Vectorization
Vectorization with haliax.vmap¤
haliax.vmap is a NamedArray aware wrapper around
jax.vmap. Instead of supplying positional axis numbers you pass
the Axis (or axis name) you want to map over. Any
NamedArray containing that axis is mapped in parallel and the axis is
reinserted in the output. Regular JAX arrays can be mapped as well by
providing a default spec or per‑argument overrides.
Unlike vanilla jax.vmap, you may supply one or more axes. When multiple
axes are given, the function is vmapped over each axis in turn (innermost first).
If an axis isn't already present in the array you must also specify its size,
either by passing an Axis object (Axis("batch", 4)) or a mapping such as
{"batch": 4} so the new dimension can be inserted.
Basic Example¤
import haliax as hax
Batch = hax.Axis("batch", 4)
def double(x):
return x * 2
x = hax.arange(Batch)
y = hax.vmap(double, Batch)(x)
The result y has the same Batch axis as x, and each element was processed
in parallel. With JAX you would write jax.vmap(double)(x.array) and manually
specify in_axes, but Haliax handles the axis automatically.
For applying many modules in parallel see
Stacked.vmap which builds on this
primitive.
vmap(fn, axis: AxisSelection, *, default: PyTree[UnnamedAxisSpec] = _zero_if_array_else_none, args: PyTree[UnnamedAxisSpec] = (), kwargs: PyTree[UnnamedAxisSpec] = None)
¤
haliax.NamedArray-aware version of jax.vmap. Normal arrays are mapped according to the specs as in equinox.filter_vmap
Because of NamedArrays, vmap is typically less useful than in vanilla JAX, but it is sometimes useful for initializing modules that will be scanned over. See haliax.nn.Stacked for an example.
Parameters:
-
(fn¤Callable) –function to vmap over
-
(axis¤Axis or Sequence[Axis]) –axis or axes to vmap over. If a sequence is provided, the function will be vmapped over each axis in turn, from innermost to outermost.
-
(default¤PyTree[UnnamedAxisSpec], default:_zero_if_array_else_none) –how to handle (unnamed) arrays by default. Should be either an integer or None, or a callable that takes a PyTree leaf and returns an integer or None, or a PyTree prefix of the same. If an integer, the array will be mapped over that axis. If None, the array will not be mapped over.
-
(args¤PyTree[UnnamedAxisSpec], default:()) –optional per-argument overrides for how to handle arrays. Should be a PyTree prefix of the same type as default.
-
(kwargs¤PyTree[UnnamedAxisSpec], default:None) –optional per-keyword-argument overrides for how to handle arrays. Should be a PyTree prefix of the same type as default.