Skip to content

Partitioning¤

Partitioning refers to the process of splitting arrays and computation across multiple devices. Haliax provides a number of functions for partitioning arrays and computation across multiple devices.

Tutorial¤

An introduction to using Haliax's partitioning functions to scale a transformer can be found here: Distributed Training in Haliax.

This page is designed to be more of a reference than a tutorial, and we assume you've read the tutorial before reading this page.

Device Meshes in JAX¤

See also JAX's tutorial Distributed Arrays and Automatic Parallelization for more details.

One of the main ways JAX provides distributed parallelism is via the jax.sharding.Mesh. A mesh is a logical n-dimensional array of devices. Meshes in JAX are represented as a Numpy ndarray (note: not jax.numpy) of devices and a tuple of axis names. For example, a 2D mesh of 16 devices might look like this:

import jax
import jax.numpy as jnp

from jax.sharding import Mesh

devices = jax.devices()
mesh = Mesh(jnp.array(devices).reshape((-1, 2)), ("data", "model"))

2d Device Mesh showing 16 devices

The mesh above has two axes, data and model. In JAX's mesh parallelism, arrays are distributed by overlaying axes of the array on top of the axes of the mesh. For example, if we have a batch of 32 sequences we might do something like this:

from jax.sharding import NamedSharding, PartitionSpec

batch_size = 32
seqlen = 512

batch = jnp.zeros((batch_size, seqlen), dtype=jnp.float32)
batch = jax.device_put(batch, NamedSharding(mesh, PartitionSpec("data", None)))

This specifies that the first axis of batch should be distributed across the data axis of the mesh. The None in the PartitionSpec indicates that the second axis of batch is not distributed, which means that the data is replicated so that one copy of the data is partitioned across each row of the mesh.

Device Mesh showing 16 devices with data partitioned across data axis

What's nice about this approach is that jax will automatically schedule computations so that operations are distributed in the way you would expect: you don't have to explicitly manage communication between devices.

However, JAX sometimes gets confused, and it's not sure how you want your arrays partitioned. In Jax, there's a function called jax.lax.with_sharding_constraint that lets you explicitly specify the sharding for the outputs of arrays. You use this function only inside jit.

Haliax Partitioning in a nutshell¤

As you might imagine, it gets tedious and error-prone to have to specify the partitioning of every array you create. Haliax provides routines to handle mapping of haliax.NamedArrays automatically.

import haliax as hax

Batch = hax.Axis("batch", 32)
SeqLen = hax.Axis("seqlen", 512)

axis_mapping = {"batch": "data", }

batch = hax.zeros((Batch, SeqLen), dtype=jnp.float32)
batch = hax.shard(batch, axis_mapping)

# we also have "auto_sharded" and support context mappings for axis mappings:
with hax.axis_mapping({"batch": "data"}):
    batch = hax.zeros((Batch, SeqLen), dtype=jnp.float32)
    batch = hax.shard(batch)

Unlike in JAX, which has separate APIs for partitioning arrays inside and outside of jit, Haliax has a single API: haliax.shard works inside and outside of jit. Haliax automatically chooses which JAX function to use based on context.

Axis Mappings¤

The core data structure we use to represent partitioning is the haliax.partitioning.ResourceMapping which is just an alias for a Dict[str, str|Sequence[str]]. The keys in this dictionary are the names of "logical" Axes in NamedArrays and the values are the names of axes in the mesh. (In theory you can partition a single Axis across multiple axes in the mesh, but we don't use this functionality.)

ResourceMapping: TypeAlias = Mapping[str, PhysicalAxisSpec] module-attribute ¤

A context manager can be used to specify an axis mapping for the current thread for the duration of the context:

with hax.axis_mapping({"batch": "data"}):
    batch = hax.zeros((Batch, SeqLen), dtype=jnp.float32)
    batch = hax.auto_sharded(batch)

axis_mapping(mapping: ResourceMapping, *, merge: bool = False, **kwargs) ¤

Context manager for setting the global resource mapping

Partitioning Functions¤

Sharding Arrays and PyTrees¤

These functions are used to shard arrays and PyTrees of arrays, e.g. Modules. This is the main function you will use to shard arrays:

shard(x: T, mapping: ResourceMapping | None = None, mesh: Mesh | None = None) -> T ¤

Shard a PyTree using the provided axis mapping. NamedArrays in the PyTree are sharded using the axis mapping. Other arrays (i.e. plain JAX arrays) are left alone.

This is basically a fancy wrapper around with_sharding_constraint that uses the axis mapping to determine the sharding.

This function is like shard but does not issue a warning if there is no context axis mapping. It's useful for library code where there may or may not be a context mapping:

auto_sharded(x: T, mesh: Optional[Mesh] = None) -> T ¤

Shard a PyTree using the global axis mapping. NamedArrays in the PyTree are sharded using the axis mapping and the names in the tree.

If there is no axis mapping, the global axis mapping, this function is a no-op.

This is an older function that is being deprecated in favor of shard. It is functionally equivalent to shard:

shard_with_axis_mapping(x: T, mapping: ResourceMapping, mesh: Mesh | None = None) -> T ¤

named_jit and friends¤

named_jit(fn: Callable[Args, R] | None = None, axis_resources: ResourceMapping | None = None, *, in_axis_resources: ResourceMapping | None = None, out_axis_resources: ResourceMapping | None = None, donate_args: PyTree | None = None, donate_kwargs: PyTree | None = None, **pjit_args) -> WrappedCallable[Args, R] | typing.Callable[[Callable[Args, R]], WrappedCallable[Args, R]] ¤

named_jit(fn: Callable[Args, R], axis_resources: ResourceMapping | None = None, *, in_axis_resources: ResourceMapping | None = None, out_axis_resources: ResourceMapping | None = None, donate_args: PyTree | None = None, donate_kwargs: PyTree | None = None, keep_unused: bool = False, backend: str | None = None, inline: bool | None = None) -> WrappedCallable[Args, R]
named_jit(*, axis_resources: ResourceMapping | None = None, in_axis_resources: ResourceMapping | None = None, out_axis_resources: ResourceMapping | None = None, donate_args: PyTree | None = None, donate_kwargs: PyTree | None = None, keep_unused: bool = False, backend: str | None = None, inline: bool | None = None) -> typing.Callable[[Callable[Args, R]], WrappedCallable[Args, R]]

A version of pjit that uses NamedArrays and the provided resource mapping to infer resource partitions for sharded computation for.

axis_resources will be used for a context-specific resource mapping when the function is invoked. In addition, if in_axis_resources is not provided, the arguments' own (pre-existing) shardings will be used as the in_axis_resources. If out_axis_resources is not provided, axis_resources will be used as the out_axis_resources.

If no resource mapping is provided, this function attempts to use the context resource mapping.

Functionally this is very similar to something like:

This function can be used as a decorator or as a function.

 def wrapped_fn(arg):
    result = fn(arg)
    return hax.shard(result, out_axis_resources)

 arg = hax.shard(arg, in_axis_resources)
 with hax.axis_mapping(axis_resources):
    result = jax.jit(wrapped_fn, **pjit_args)(arg)
return result

Parameters:

  • fn ¤
    (Callable, default: None ) –

    The function to be jit'd.

  • axis_resources ¤
    (ResourceMapping, default: None ) –

    A mapping from logical axis names to physical axis names use for the context-specific resource mapping.

  • in_axis_resources ¤
    (ResourceMapping, default: None ) –

    A mapping from logical axis names to physical axis names for arguments. If not passed, it uses the argument's own shardings.

  • out_axis_resources ¤
    (ResourceMapping, default: None ) –

    A mapping from logical axis names to physical axis names for the result, defaults to axis_resources.

  • donate_args ¤
    (PyTree, default: None ) –

    A PyTree of booleans or function leaf->bool, indicating if the arguments should be donated to the computation.

  • donate_kwargs ¤
    (PyTree, default: None ) –

    A PyTree of booleans or function leaf->bool, indication if the keyword arguments should be donated to the computation.

Returns:

  • WrappedCallable[Args, R] | Callable[[Callable[Args, R]], WrappedCallable[Args, R]]

    A jit'd version of the function.

fsdp(*args, **kwargs) ¤

fsdp(fn: F, parameter_mapping: ResourceMapping, compute_mapping: ResourceMapping) -> F
fsdp(parameter_mapping: ResourceMapping, compute_mapping: ResourceMapping) -> typing.Callable[[F], F]

A convenience wrapper around named_jit / pjit to encode the FSDP pattern. It's basically equivalent to this:

@named_jit(in_axis_resources=parameter_mapping, out_axis_resources=parameter_mapping, axis_resources=compute_mapping)
def f(*args, **kwargs):
    return fn(*args, **kwargs)

This function can be used as a decorator or as a function.

Querying the Mesh and Axis Mapping¤

round_axis_for_partitioning(axis: Axis, mapping: ResourceMapping | None = None) -> Axis ¤

Round an axis so that it's divisible by the size of the partition it's on

physical_axis_name(axis: AxisSelector, mapping: ResourceMapping | None = None) -> PhysicalAxisSpec | None ¤

Get the physical axis name for a logical axis from the mapping. Returns none if the axis is not mapped.

physical_axis_size(axis: AxisSelector, mapping: ResourceMapping | None = None) -> int | None ¤

Get the physical axis size for a logical axis. This is the product of the size of all physical axes that this logical axis is mapped to.

sharding_for_axis(axis: AxisSelection, mapping: ResourceMapping | None = None, mesh: MeshLike | None = None) -> NamedSharding ¤

Get the sharding for a single axis