Skip to content

Main API Reference¤

In general, the API is designed to be similar to JAX's version of NumPy's API, with the main difference being that we use names (either strings or haliax.Axis objects) to specify axes instead of integers. This shows up for creating arrays (see haliax.zeros and haliax.ones) as well as things like reductions (see haliax.sum and haliax.mean).

PyTree Helpers¤

PyTrees are the lingua franca for composing state in JAX ecosystems. Haliax provides drop-in replacements for the jax.tree helpers that are aware of NamedArray semantics. They preserve axis metadata across transformations while interoperating with standard JAX containers, so you can use them anywhere you would have reached for JAX's versions.

Use these helpers whenever you need to map, flatten, or rebuild PyTrees that might include NamedArray instances:

All of these helpers accept the same is_leaf hook you might already use with JAX's utilities. They should be the first tools you reach for when you need deterministic tree transforms that understand named axes.

map(fn: Callable[..., T], tree: Any, *rest: Any, is_leaf: Callable[[Any], bool] | None = None) -> Any ¤

Alias for :func:haliax.tree_util.tree_map matching :func:jax.tree.map.

scan_aware_map(fn: Callable[..., T], tree: Any, *rest: Any, is_leaf: Callable[[Any], bool] | None = None) -> Any ¤

Alias for :func:haliax.tree_util.scan_aware_tree_map with :mod:jax.tree style naming.

flatten(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> tuple[Sequence[Any], Any] ¤

Alias for :func:haliax.tree_util.tree_flatten matching :func:jax.tree.flatten.

unflatten(treedef: Any, leaves: Iterable[Any]) -> Any ¤

Alias for :func:haliax.tree_util.tree_unflatten matching :func:jax.tree.unflatten.

leaves(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> Sequence[Any] ¤

Alias for :func:haliax.tree_util.tree_leaves matching :func:jax.tree.leaves.

structure(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> Any ¤

Alias for :func:haliax.tree_util.tree_structure matching :func:jax.tree.structure.

Axis Types¤

If you already speak NumPy or jax.numpy, think of Haliax as swapping positional axes (axis=0) for named axes (axis="batch"). The type hints in this section describe the different ways those named axes can be provided to the API. They appear throughout the documentation and in signatures so that you can quickly tell which forms are accepted.

Name Accepts When to use it Example
Axis Axis(name: str, size: int) Define a named dimension with an explicit size Batch = Axis("batch", 32)
AxisSelector Axis or str Refer to an existing axis whose size can be inferred from the arrays you pass in x.sum(axis="batch")
AxisSpec dict[str, int], Axis, or a sequence of Axis objects Create or reshape arrays when the axis sizes must be provided hax.zeros((Batch, Feature))
AxisSelection dict[str, int | None], AxisSpec, or a sequence of AxisSelector values Work with one or more existing axes (reductions, indexing helpers, flattening, …) x.sum(axis=("batch", Feature))

Axis¤

An Axis is the fundamental building block: it is a tiny dataclass that stores a name and a size. You can construct one directly or use haliax.make_axes to generate several at a time.

import haliax as hax
from haliax import Axis

Batch = Axis("batch", 32)
Feature = Axis("feature", 128)
x = hax.ones((Batch, Feature))
print(Batch.name, Batch.size)

Using Axis objects keeps array creation explicit and gives reusable handles you can share between different tensors. Equality compares both the name and size so you get guardrails when wiring pieces together.

AxisSelector¤

An AxisSelector accepts either an Axis object or just the axis name as a string. It is used whenever a function can read the axis size from one of its arguments. This mirrors how NumPy lets you pass axis=0 when reducing an array:

total = x.sum(axis=Batch)       # using the Axis handle
same_total = x.sum(axis="batch")  # using only the name

Strings are convenient when you only care about the name, but Axis objects still work so you can keep using the handles you created earlier. If an axis with that name is missing, Haliax raises a ValueError.

AxisSpec¤

An AxisSpec is used when Haliax needs full size information to create or reshape an array. You can provide a shape dictionary (sometimes called a "shape dict") that maps names to sizes, or a sequence of Axis objects:

shape = {"batch": 32, "feature": 128}
y = hax.zeros(shape)                 # using a shape dict
z = hax.zeros((Batch, Feature))      # using the Axis objects directly

Both forms describe the same layout. Python dictionaries preserve insertion order, so the ordering in a shape dict matches the order that axes appear in the array. Sequences must contain Axis objects (not plain strings) because Haliax cannot otherwise know the axis sizes.

AxisSelection¤

AxisSelection generalizes the previous aliases so you can talk about several axes at once. It shows up in reductions, indexing helpers, axis-mapping utilities, and anywhere you might have written axis=(0, 1) in NumPy. You may supply:

  • a sequence mixing Axis objects and strings, e.g. ("batch", Feature) when reducing two axes,
  • an AxisSpec, which is handy when you already have a tuple of Axis objects, or
  • a "partial shape dict" where the values are either sizes or None to indicate "any size". Dictionaries are useful when you only care about a subset of axes or want to assert a particular size.
# Reduce over two axes using a tuple of selectors.
scalar = x.sum(axis=("batch", Feature))

# Ask for the axes by name and optionally pin sizes.
x.resolve_axis({"batch": None, "feature": None})  # returns {"batch": 32, "feature": 128}

from haliax.axis import selects_axis
assert selects_axis((Batch, "feature"), {"batch": None, "feature": 128})

Occasionally, an axis size can be inferred in some circumstances but not others. When this happens we still use AxisSelector (or AxisSelection for multiple axes) but document the behavior in the docstring. A RuntimeError will be raised if the size cannot be inferred.

Axis(name: str, size: int) dataclass ¤

Axis is a dataclass that represents an axis of an NamedArray. It has a name and a size.

Methods:

Attributes:

name: str instance-attribute ¤
size: int instance-attribute ¤
alias(new_name: str) ¤
resize(size) -> Axis ¤

AxisSelector = Axis | str module-attribute ¤

AxisSelector is a type that can be used to select a single axis from an array. str or Axis

AxisSpec = Axis | Sequence[Axis] | ShapeDict module-attribute ¤

AxisSpec is a type that can be used to specify the axes of an array, usually for creation or adding a new axis whose size can't be determined another way. Axis or sequence of Axis

AxisSelection = AxisSelector | Sequence[AxisSelector] | PartialShapeDict module-attribute ¤

AxisSelection is a type that can be used to select multiple axes from an array. str, Axis, or sequence of mixed str and Axis

Axis Manipulation¤

make_axes(**kwargs: int) -> tuple[Axis, ...] ¤

Convenience function for creating a tuple of Axis objects.

Example:

X, Y = axes(X=10, Y=20)

axis_name(ax: AxisSelection) -> str | tuple[str, ...] ¤

axis_name(ax: AxisSelector) -> str
axis_name(ax: Sequence[AxisSelector]) -> tuple[str, ...]

Returns the name of the axis. If ax is a string, returns ax. If ax is an Axis, returns ax.name

concat_axes(a1, a2) ¤

concat_axes(a1: ShapeDict, a2: AxisSpec) -> ShapeDict
concat_axes(a1: Sequence[Axis], a2: AxisSpec) -> tuple[Axis, ...]
concat_axes(a1: AxisSpec, a2: AxisSpec) -> AxisSpec
concat_axes(a1: AxisSelection, a2: AxisSelection) -> AxisSelection

Concatenates two AxisSpecs. Raises ValueError if any axis is present in both specs

union_axes(a1: AxisSelection, a2: AxisSelection) -> AxisSelection ¤

union_axes(a1: ShapeDict, a2: AxisSpec) -> ShapeDict
union_axes(a1: AxisSpec, a2: ShapeDict) -> ShapeDict
union_axes(a1: AxisSpec, a2: AxisSpec) -> AxisSpec
union_axes(a1: AxisSelection, a2: AxisSelection) -> AxisSelection

Similar to concat_axes, but allows axes to be specified multiple times. The resulting AxisSpec will have the order of just concatenating each axis spec, but with any duplicate axes removed.

Raises if any axis is present in both specs with different sizes

intersect_axes(ax1: AxisSelection, ax2: AxisSelection) -> AxisSelection ¤

intersect_axes(ax1: ShapeDict, ax2: AxisSelection) -> ShapeDict
intersect_axes(ax1: tuple[AxisSelector, ...], ax2: AxisSpec) -> tuple[Axis, ...]
intersect_axes(ax1: tuple[AxisSelector, ...], ax2: AxisSelection) -> tuple[AxisSelector, ...]
intersect_axes(ax1: AxisSpec, ax2: AxisSelection) -> AxisSpec

Returns a tuple of axes that are present in both ax1 and ax2. The returned order is the same as ax1.

eliminate_axes(axis_spec: AxisSelection, to_remove: AxisSelection) -> AxisSelection ¤

eliminate_axes(axis_spec: Axis | Sequence[Axis], axes: AxisSelection) -> tuple[Axis, ...]
eliminate_axes(axis_spec: ShapeDict, axes: AxisSelection) -> ShapeDict
eliminate_axes(axis_spec: AxisSelection, axes: AxisSelection) -> AxisSelection
eliminate_axes(axis_spec: PartialShapeDict, axes: AxisSelection) -> PartialShapeDict

Returns a new axis spec that is the same as the original, but without any axes in axes. Raises if any axis in to_remove is not present in axis_spec

without_axes(axis_spec: AxisSelection, to_remove: AxisSelection, allow_mismatched_sizes=False) -> AxisSelection ¤

without_axes(axis_spec: ShapeDict, to_remove: AxisSelection, allow_mismatched_sizes=False) -> ShapeDict
without_axes(axis_spec: Sequence[Axis], to_remove: AxisSelection, allow_mismatched_sizes=False) -> tuple[Axis, ...]
without_axes(axis_spec: AxisSpec, to_remove: AxisSelection, allow_mismatched_sizes=False) -> AxisSpec
without_axes(axis_spec: AxisSelection, to_remove: AxisSelection, allow_mismatched_sizes=False) -> AxisSelection
without_axes(axis_spec: Sequence[AxisSelector], to_remove: AxisSelection, allow_mismatched_sizes=False) -> tuple[AxisSpec, ...]
without_axes(axis_spec: PartialShapeDict, to_remove: AxisSelection, allow_mismatched_sizes=False) -> PartialShapeDict

As eliminate_axes, but does not raise if any axis in to_remove is not present in axis_spec.

However, this does raise if any axis in to_remove is present in axis_spec with a different size.

selects_axis(selector: AxisSelection, selected: AxisSelection) -> bool ¤

Returns true if the selector has every axis in selected and, if dims are given, that they match

is_axis_compatible(ax1: AxisSelector, ax2: AxisSelector) ¤

Returns true if the two axes are compatible, meaning they have the same name and, if both are Axis, the same size

Array Creation¤

named(a, axis: AxisSelection) -> NamedArray ¤

Creates a NamedArray from a numpy array and a list of axes.

zeros(shape: AxisSpec, dtype: DTypeLike | None = None) -> NamedArray ¤

Creates a NamedArray with all elements set to 0

ones(shape: AxisSpec, dtype: DTypeLike | None = None) -> NamedArray ¤

Creates a NamedArray with all elements set to 1

full(shape: AxisSpec, fill_value: T, dtype: DTypeLike | None = None) -> NamedArray ¤

Creates a NamedArray with all elements set to fill_value

zeros_like(a: NamedArray, dtype=None) -> NamedArray ¤

Creates a NamedArray with all elements set to 0

ones_like(a: NamedArray, dtype=None) -> NamedArray ¤

Creates a NamedArray with all elements set to 1

full_like(a: NamedArray, fill_value: T, dtype: DTypeLike | None = None) -> NamedArray ¤

Creates a NamedArray with all elements set to fill_value

arange(axis: AxisSpec, *, start=0, step=1, dtype: DTypeLike | None = None) -> NamedArray ¤

Version of jnp.arange that returns a NamedArray.

This version differs from jnp.arange (beyond the obvious NamedArray) in two ways:

1) It can work with a start that is a tracer (i.e. a JAX expression), whereas jax arange is not able to handle tracers. 2) Axis can be more than one axis, in which case it's equivalent to arange of the product of sizes, followed by reshape.

Examples

X, Y = hax.make_axes(X=3, Y=4)
# Create a NamedArray along a single axis
arr = hax.arange(X)  # equivalent to jnp.arange(0, 3, 1)
# 2D
arr = hax.arange((X, Y))  # equivalent to jnp.arange(0, 12, 1).reshape(3, 4)

linspace(axis: AxisSelector, *, start: float, stop: float, endpoint: bool = True, dtype: DTypeLike | None = None) -> NamedArray ¤

Version of jnp.linspace that returns a NamedArray. If axis is a string, the default number of samples (50, per numpy) will be used.

logspace(axis: AxisSelector, *, start: float, stop: float, endpoint: bool = True, base: float = 10.0, dtype: DTypeLike | None = None) -> NamedArray ¤

Version of jnp.logspace that returns a NamedArray. If axis is a string, the default number of samples (50, per numpy) will be used.

geomspace(axis: AxisSelector, *, start: float, stop: float, endpoint: bool = True, dtype: DTypeLike | None = None) -> NamedArray ¤

Version of jnp.geomspace that returns a NamedArray. If axis is a string, the default number of samples (50, per numpy) will be used.

Combining Arrays¤

concatenate(axis: AxisSelector, arrays: Sequence[NamedArray]) -> NamedArray ¤

Version of jax.numpy.concatenate that returns a NamedArray. The returns array will have the same axis names in the same order as the first, with the selected axis extended by the sum of the sizes of the selected axes in the concatenated arrays.

stack(axis: AxisSelector, arrays: Sequence[NamedArray]) -> NamedArray ¤

Version of jax.numpy.stack that returns a NamedArray

(We don't include hstack or vstack because they are subsumed by stack.)

Array Manipulation¤

Broadcasting¤

See also the section on Broadcasting.

broadcast_axis(a: NamedArray, axis: AxisSpec) -> NamedArray ¤

Broadcasts a, ensuring that it has all the axes in axis. broadcast_axis is an alias for broadcast_to(a, axis, enforce_no_extra_axes=False, ensure_order=True)

You typically use this function when you want to broadcast an array to a common set of axes.

broadcast_to(a: NamedOrNumeric, axes: AxisSpec, ensure_order: bool = True, enforce_no_extra_axes: bool = True) -> NamedArray ¤

Broadcast a so that it has the given axes.

If ensure_order is True (default) then the returned array's axes are arranged in the same order as axes. Otherwise existing axes may remain in their current order, though they may still be moved to the front if new axes are added.

If enforce_no_extra_axes is True and a has axes that are not in axes then a ValueError is raised.

Slicing¤

See also the section on Indexing and Slicing.

index(array: NamedArray, slices: Mapping[AxisSelector, NamedIndex]) -> NamedArray ¤

Selects elements from an array along an axis via index or another named array.

This function is typically invoked using array[...] syntax. For instance, you might use array[{"batch": slice(0, 10)}] or array["batch", 0:10] to select the first 10 elements of the 'batch' axis.

When indexing with a dslice the slice is gathered starting at the given start for size elements. Values read past the end of the array are filled with the fill_value (defaults to 0), and writes outside the bounds are dropped.

See Also

Returns:

  • NamedArray

    NamedArray or jnp.ndarray: A NamedArray is returned if there are any axes remaining after selection,

  • NamedArray

    otherwise a scalar (0-dimensional) jnp.ndarray is returned if all axes are indexed out.

slice(array: NamedArray, *args, **kwargs) -> NamedArray ¤

slice(array: NamedArray, axis: AxisSelector, new_axis: AxisSelector | None = None, start: int = 0, length: int | None = None) -> NamedArray
slice(array: NamedArray, start: Mapping[AxisSelector, IntScalar], length: Mapping[AxisSelector, int] | None = None) -> NamedArray

Slices the array along the specified axis or axes, replacing them with new axes (or a shortened version of the old one)

This method has two signatures:

  • slice(array, axis, new_axis=None, start=0, length=None)
  • slice(array, start: Mapping[AxisSelector, IntScalar], length: Mapping[AxisSelector, int])

They both do similar things. The former slices an array along a single axis, replacing it with a new axis. The latter slices an array along multiple axes, replacing them with new axes.

take(array: NamedArray, axis: AxisSelector, index: int | NamedArray) -> NamedArray ¤

Selects elements from an array along an axis, by an index or by another named array

if index is a NamedArray, then those axes are added to the output array

updated_slice(array: NamedArray, start: Mapping[AxisSelector, int | jnp.ndarray | NamedArray], update: NamedArray) -> NamedArray ¤

Updates a slice of an array with another array.

Parameters:

Returns:

  • NamedArray ( NamedArray ) –

    The updated array.

Mutable References¤

See also the section on Mutable References.

JAX provides [jax.Ref][], a mutable array reference that can be read or written in place while remaining compatible with transformations such as jax.jit or jax.grad. Haliax mirrors that API with haliax.NamedRef, which carries axis metadata so you can keep using named indexing when plumbing state through your programs.

You introduce a new reference with haliax.new_ref. The returned object behaves much like a NamedArray for indexing purposes: ref[{"batch": 0}] reads a slice, and assignments like ref[{"token": slice(1, 3)}] = update perform in-place updates on the underlying buffer. If you need to stage part of a reference for repeated use, call NamedRef.slice to create a slice ref. Slice refs remember a partial indexing expression so you only supply the remaining axes during reads or writes:

Cache = hax.Axis("layers", 24)
Head = hax.Axis("head", 8)
cache = hax.zeros((Cache, Head))
cache_ref = hax.new_ref(cache)

# Pin the layer axis once so subsequent lookups only specify the head coordinate.
layer_ref = cache_ref.slice({"layers": slice(4, 8)})
layer_ref[{"layers": 0, "head": 3}] = 1.0  # updates layer 4, head 3 in the original cache

When you are done mutating a reference, call haliax.freeze to invalidate it and recover a final NamedArray snapshot. You can also perform atomic-style updates with haliax.swap (or the functional helpers under haliax.ref).

NamedRef(_ref: jax.Ref, _axes: tuple[Axis, ...], _prefix: tuple[Any, ...]) dataclass ¤

Named wrapper around :class:jax.Ref that preserves axis metadata.

Methods:

  • value

    Materialize this reference view as a NamedArray.

  • unsliced

    Return a view of the original reference without staged selectors.

  • resolve_axis

    Resolve an axis selector to the corresponding axis in the current view.

  • slice

    Return a new view with the provided selector staged for future operations.

  • unsafe_buffer_pointer

Attributes:

dtype property ¤

Return the dtype of the underlying reference.

axes: tuple[Axis, ...] property ¤

Axes visible from this view after applying staged selectors.

shape: Mapping[str, int] property ¤

Mapping from axis name to size for the current view.

named_shape: Mapping[str, int] property ¤
ndim: int property ¤

Number of axes in the current view.

value() -> NamedArray ¤

Materialize this reference view as a NamedArray.

unsliced() -> 'NamedRef' ¤

Return a view of the original reference without staged selectors.

resolve_axis(axis: AxisSelector) -> Axis ¤

Resolve an axis selector to the corresponding axis in the current view.

slice(selector: Mapping[AxisSelector, Any]) -> 'NamedRef' ¤

Return a new view with the provided selector staged for future operations.

unsafe_buffer_pointer() ¤

new_ref(value: NamedArray) -> NamedRef ¤

Construct a NamedRef from a NamedArray.

freeze(ref: NamedRef) -> NamedArray ¤

Freeze the reference and return its current contents.

swap(ref: NamedRef, idx: SliceSpec | EllipsisType, value: NamedOrNumeric) -> NamedArray ¤

Swap the value at idx, returning the previous contents as a NamedArray.

Dynamic Slicing¤

dslice(start: int, length: int | Axis) ¤

Bases: Module

Dynamic slice, comprising a (start, length) pair. Also aliased as ds.

NumPy-style slices like a[i:i+16] don't work inside :func:jax.jit, because JAX requires slice bounds to be static. dslice works around this by separating the dynamic start from the static size so that you can write a[dslice(i, 16)] or simply a[ds(i, 16)].

When used in indexing or at updates, dslice behaves like a gather of size elements starting at start. Reads beyond the end of the array are filled with a value (0 by default) and writes outside the array bounds are dropped, matching JAX's default scatter/gather semantics.

This class's name is taken from :mod:jax.experimental.pallas.

start:
length:

Methods:

  • to_slice
  • block

    Returns a dslice that selects a single block of size size starting at idx

Attributes:

start: int = start instance-attribute ¤
size: int instance-attribute ¤
to_slice() -> slice ¤
block(idx: int, size: int) -> dslice staticmethod ¤

Returns a dslice that selects a single block of size size starting at idx

dblock(idx: int, size: int) -> dslice ¤

Returns a dslice that selects a single block of size size starting at idx

ds: typing.TypeAlias = dslice module-attribute ¤

Shape Manipulation¤

flatten(array: NamedArray, new_axis_name: AxisSelector) -> NamedArray ¤

Returns a flattened view of the array, with all axes merged into one. Aliax for haliax.ravel

flatten_axes(*args, **kwargs) -> NamedArray | AxisSpec ¤

flatten_axes(axis: Axis, old_axes: Axis, new_axis: AxisSelector) -> Axis
flatten_axes(axis: AxisSpec, new_axis: AxisSelector) -> Axis
flatten_axes(axis: AxisSpec, old_axes: AxisSelection, new_axis: AxisSelector) -> AxisSpec
flatten_axes(array: NamedArray, old_axes: AxisSelection, new_axis: AxisSelector) -> NamedArray

Merge a sequence of axes into a single axis. The new axis must have the same size as the product of the old axes.

The new axis is always inserted starting at the index of the first old axis in the underlying array.

This function can be used in two ways:

  • flatten_axes(array, old_axes, new_axis): merge the old axes of the array into a new axis
  • flatten_axes(axes, old_axes, new_axis): merge the old axes into a new axis

rearrange(array: NamedArray, *args, **kwargs) -> NamedArray ¤

rearrange(array: NamedArray, axes: Sequence[AxisSelector | EllipsisType]) -> NamedArray
rearrange(array: NamedArray, expression: str, **bindings: AxisSelector | int) -> NamedArray

Rearrange a tensor according to an einops-style haliax rearrangement string or a sequence of axes. See full documentation here: Rearrange

The sequence form of rearrange rearranges an array so that its underlying storage conforms to axes. axes may include up to 1 ellipsis, indicating that the remaining axes should be permuted in the same order as the array's axes.

For example, if array has axes (a, b, c, d, e, f) and axes is (e, f, a, ..., c), then the output array will have axes (e, f, a, b, c, d).

The string form of rearrange works similarly to einops.rearrange, but also supports named axes and unordered matching. The string form of rearrange comes in two forms:

  • Ordered strings are like einops strings, with the only significant difference that flattened axes are named with a colon, e.g. B H W C -> B (E: H W C).
  • Unordered strings match axes by name and are marked with the addition of curly braces, e.g. {H W C} -> ... C H W or {H W C} -> ... (E: H W C)

As with einops, you can provide axis sizes to unflatten axes. For instance, to turn an image patches, `hax.rearrange(x, '{ B (H: w1 H) (W: w1 W)} -> (B: B h1 w1) H W ...', H=32, W=32) will turn a batch of images into a batch of image patches. Bindings can also be haliax.Axis objects, or strings that will be used as the actual name of the resulting axis.

Examples:

>>> import haliax as hax
>>> import jax.random as jrandom
>>> B, H, W, C = hax.Axis("B", 8), hax.Axis("H", 32), hax.Axis("W", 32), hax.Axis("C", 3)
>>> x = hax.random.normal( (B, H, W, C))
>>> # Sequence-based rearrange
>>> hax.rearrange(x, (C, B, H, W))
>>> hax.rearrange(x, (C, ...)) # ellipsis means "keep the rest of the axes in the same order"
>>> # String-based rearrange
>>> # permute the axes
>>> hax.rearrange(x, "B H W C -> C B H W")
>>> # flatten the image (note the assignment of a new name to the flattened axis)
>>> hax.rearrange(x, "B H W C -> B (E: H W C)")
>>> # turn the image into patches
>>> hax.rearrange(x, "{ B (H: h1 H) (W: w1 W) C } -> (B: B h1 w1) (E: H W C) ...", H=2, W=2)
>>> # names can be longer than one character
>>> hax.rearrange(x, "{ B (H: h1 H) (W: w1 W) C } -> (B: B h1 w1) (embed: H W C) ...", H=2, W=2)

unbind(array: NamedArray, axis: AxisSelector) -> list[NamedArray] ¤

Unbind an array along an axis, returning a list of NamedArrays, one for each position on that axis. Analogous to torch.unbind or np.rollaxis

unflatten_axis(array: NamedArray, axis: AxisSelector, new_axes: AxisSpec) -> NamedArray ¤

Split an axis into a sequence of axes. The old axis must have the same size as the product of the new axes.

split(a: NamedArray, axis: AxisSelector, new_axes: Sequence[Axis]) -> Sequence[NamedArray] ¤

Splits an array along an axis into multiple arrays, one for each element of new_axes.

Parameters:

  • a ¤
    (NamedArray) –

    the array to split

  • axis ¤
    (AxisSelector) –

    the axis to split along

  • new_axes ¤
    (Sequence[Axis]) –

    the axes to split into. Must have the same total length as the axis being split.

ravel(array: NamedArray, new_axis_name: AxisSelector) -> NamedArray ¤

Returns a flattened view of the array, with all axes merged into one

Operations¤

Binary and unary operations are all more or less directly from JAX's NumPy API. The only difference is they operate on named arrays instead.

Matrix Multiplication¤

See also the page on Matrix Multiplication as well as the cheat sheet section.

dot(*arrays, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, out_axes: PartialAxisSpec | None = None, dot_general=jax.lax.dot_general, **kwargs) -> NamedArray ¤

dot(axis: AxisSelection | None, *arrays: NamedArray, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, out_axes: PartialAxisSpec | None = ..., dot_general=jax.lax.dot_general) -> NamedArray
dot(*arrays: NamedArray, axis: AxisSelection | None, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, out_axes: PartialAxisSpec | None = ..., dot_general=jax.lax.dot_general) -> NamedArray

Returns the tensor product of two NamedArrays. The axes axis are contracted over, and any other axes that are shared between the arrays are batched over. Non-contracted Axes in one that are not in the other are preserved.

Note that if axis is None, the result will be a scalar, not a NamedArray. The semantics of axis=None are similar to e.g. how sum and other reduction functions work in numpy. If axis=(), then the result will be an "outer product" of the arrays, i.e. a tensor with shape equal to the concatenation of the shapes of the arrays.

By default, the order of output axes is determined by the order of the input axes, such that each output axis occurs in the same order as it first occurs in the concatenation of the input axes.

If out_axes is provided, the output will be transposed to match the provided axes. out_axes may be a partial specification of the output axes (using ellipses), in which case the output will be rearranged to be consistent with the partial specification. For example, if out_axes=(..., Height, Width) and the output axes are (Width, Height, Depth), the output will be transposed to (Depth, Height, Width). Multiple ellipses are supported, in which case axes will be inserted according to a greedy heuristic that prefers to place unconstrained axes as soon as all prior axes in the "natural" order are covered.

Parameters:

  • *arrays ¤
    (NamedArray, default: () ) –

    The arrays to contract.

  • axis ¤
    (AxisSelection) –

    The axes to contract over.

  • precision ¤
    (PrecisionLike, default: None ) –

    The precision to use. Defaults to None. This argument is passed to jax.numpy.einsum, which in turn passes it to jax.lax.dot_general.

  • preferred_element_type ¤
    (DTypeLike, default: None ) –

    The preferred element type of the result. Defaults to None. This argument is passed to jax.numpy.einsum.

  • out_axes ¤
    (PartialAxisSpec | None, default: None ) –

    a potentially partial specification of the output axes. If provided, the output will be transposed to match the provided axes. Defaults to None.

Returns:

  • NamedArray ( NamedArray ) –

    The result of the contraction.

einsum(equation: str, *arrays: NamedArray, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: DotGeneralOp = jax.lax.dot_general, **axis_aliases: AxisSelector) -> NamedArray ¤

Compute the tensor contraction of the input arrays according to Haliax's named variant of the Einstein summation convention.

Examples:

>>> # normal einsum
>>> import haliax as hax
>>> H = hax.Axis("H", 32)
>>> W = hax.Axis("W", 32)
>>> D = hax.Axis("D", 64)
>>> a = hax.zeros((H, W, D))
>>> b = hax.zeros((D, W, H))
>>> hax.einsum("h w d, d w h -> h w", a, b)
>>> # named einsum
>>> hax.einsum("{H W D} -> H W", a, b)
>>> hax.einsum("{D} -> ", a, b)  # same as the previous example
>>> hax.einsum("-> H W", a, b)  # same as the first example
>>> # axis aliases, useful for generic code
>>> hax.einsum("{x y} -> y", a, b, x=H, y=W)

Parameters:

  • equation ¤
    (str) –

    The einsum equation.

  • arrays ¤
    (NamedArray, default: () ) –

    The input arrays.

  • precision ¤
    (PrecisionLike, default: None ) –

    The precision of the computation.

  • preferred_element_type ¤
    (DTypeLike | None, default: None ) –

    The preferred element type of the computation.

  • _dot_general ¤
    (DotGeneralOp, default: dot_general ) –

    The dot_general function to use.

  • axis_aliases ¤
    (AxisSelector, default: {} ) –

    The axis aliases to use.

Returns:

Reductions¤

Reduction operations are things like haliax.sum and haliax.mean that reduce an array along one or more axes. Except for haliax.argmin and haliax.argmax, they all have the form:

def sum(x, axis: Optional[AxisSelection] = None, where: Optional[NamedArray] = None) -> haliax.NamedArray:
    ...

with the behavior closely following that of JAX's NumPy API. The axis argument can be a single axis (or axis name), a tuple of axes, or None to reduce all axes. The where argument is a boolean array that specifies which elements to include in the reduction. It must be broadcastable to the input array, using Haliax's broadcasting rules.

The result of a reduction operation is always haliax.NamedArray with the reduced axes removed. If you reduce all axes, the result is a NamedArray with 0 axes, i.e. a scalar. You can convert it to a jax.numpy.ndarray with haliax.NamedArray.scalar, or just haliax.NamedArray.array.

all(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray ¤

Named version of jax.numpy.all.

amax(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray ¤

Aliax for max. See max for details.

amin(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray ¤

Aliax for min. See min for details.

any(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray ¤

True if any elements along a given axis or axes are True. If axis is None, any elements are True.

argmax(array: NamedArray, axis: AxisSelector | None) -> NamedArray ¤

argmin(array: NamedArray, axis: AxisSelector | None) -> NamedArray ¤

max(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray ¤

mean(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray ¤

min(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray ¤

nanargmax(array: NamedArray, axis: AxisSelector | None = None) -> NamedArray ¤

nanargmin(array: NamedArray, axis: AxisSelector | None = None) -> NamedArray ¤

nanmax(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray ¤

nanmean(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray ¤

nanmin(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray ¤

nanprod(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray ¤

nanstd(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, ddof: int = 0, dtype: DTypeLike | None = None) -> NamedArray ¤

nansum(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray ¤

nanvar(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, ddof: int = 0, dtype: DTypeLike | None = None) -> NamedArray ¤

prod(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray ¤

ptp(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray ¤

std(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, ddof: int = 0, dtype: DTypeLike | None = None) -> NamedArray ¤

sum(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray ¤

var(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, ddof: int = 0, dtype: DTypeLike | None = None) -> NamedArray ¤

Axis-wise Operations¤

Axis-wise operations are things like haliax.cumsum and haliax.sort that operate on a single axis of an array but don't reduce it.

cumsum(a: NamedArray, axis: AxisSelector, *, dtype: DTypeLike | None = None) -> NamedArray ¤

Named version of jax.numpy.cumsum

cumprod(a: NamedArray, axis: AxisSelector, dtype: DTypeLike | None = None) -> NamedArray ¤

Named version of jax.numpy.cumprod

nancumprod(a: NamedArray, axis: AxisSelector, dtype: DTypeLike | None = None) -> NamedArray ¤

Named version of jax.numpy.nancumprod

nancumsum(a: NamedArray, axis: AxisSelector, *, dtype: DTypeLike | None = None) -> NamedArray ¤

Named version of jax.numpy.nancumsum

sort(a: NamedArray, axis: AxisSelector) -> NamedArray ¤

Named version of jax.numpy.sort

argsort(a: NamedArray, axis: AxisSelector | None, *, stable: bool = False) -> NamedArray ¤

Named version of jax.numpy.argsort.

If axis is None, the returned array will be a 1D array of indices that would sort the flattened array, identical to jax.numpy.argsort(a.array).

Parameters:

  • stable ¤
    (bool, default: False ) –

    If True, ensures that the indices of equal elements preserve their relative order.

Unary Operations¤

The A in these operations means haliax.NamedArray, a Scalar, or jax.numpy.ndarray. These are all more or less directly from JAX's NumPy API.

abs(a: A) -> A ¤

absolute(a: A) -> A ¤

angle(a: A) -> A ¤

arccos(a: A) -> A ¤

arccosh(a: A) -> A ¤

arcsin(a: A) -> A ¤

arcsinh(a: A) -> A ¤

arctan(a: A) -> A ¤

arctanh(a: A) -> A ¤

around(a: A) -> A ¤

bitwise_count(a: A) -> A ¤

bitwise_invert(a: A) -> A ¤

bitwise_not(a: A) -> A ¤

cbrt(a: A) -> A ¤

ceil(a: A) -> A ¤

conj(a: A) -> A ¤

conjugate(a: A) -> A ¤

copy(a: A) -> A ¤

cos(a: A) -> A ¤

cosh(a: A) -> A ¤

deg2rad(a: A) -> A ¤

degrees(a: A) -> A ¤

exp(a: A) -> A ¤

exp2(a: A) -> A ¤

expm1(a: A) -> A ¤

fabs(a: A) -> A ¤

fix(a: A) -> A ¤

floor(a: A) -> A ¤

frexp(a: A) -> A ¤

i0(a: A) -> A ¤

imag(a: A) -> A ¤

invert(a: A) -> A ¤

iscomplex(a: A) -> A ¤

isfinite(a: A) -> A ¤

isinf(a: A) -> A ¤

isnan(a: A) -> A ¤

isneginf(a: A) -> A ¤

isposinf(a: A) -> A ¤

isreal(a: A) -> A ¤

log(a: A) -> A ¤

log10(a: A) -> A ¤

log1p(a: A) -> A ¤

log2(a: A) -> A ¤

logical_not(a: A) -> A ¤

ndim(a: A) -> A ¤

negative(a: A) -> A ¤

positive(a: A) -> A ¤

rad2deg(a: A) -> A ¤

radians(a: A) -> A ¤

real(a: A) -> A ¤

reciprocal(a: A) -> A ¤

rint(a: A) -> A ¤

round(a: A, decimals: int = 0) -> A ¤

rsqrt(a: A) -> A ¤

sign(a: A) -> A ¤

signbit(a: A) -> A ¤

sin(a: A) -> A ¤

sinc(a: A) -> A ¤

sinh(a: A) -> A ¤

square(a: A) -> A ¤

sqrt(a: A) -> A ¤

tan(a: A) -> A ¤

tanh(a: A) -> A ¤

trunc(a: A) -> A ¤

Binary Operations¤

add(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.add

arctan2(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.arctan2

bitwise_and(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.bitwise_and

bitwise_left_shift(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.bitwise_left_shift

bitwise_or(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.bitwise_or

bitwise_right_shift(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

bitwise_xor(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.bitwise_xor

divide(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.divide

divmod(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.divmod

equal(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.equal

float_power(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.float_power

floor_divide(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.floor_divide

fmax(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.fmax

fmin(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.fmin

fmod(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.fmod

greater(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.greater

greater_equal(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.greater_equal

hypot(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.hypot

left_shift(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.left_shift

less(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.less

less_equal(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.less_equal

logaddexp(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.logaddexp

logaddexp2(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

logical_and(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

logical_or(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

logical_xor(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

maximum(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

minimum(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

mod(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

multiply(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

nextafter(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

not_equal(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

power(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

remainder(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

right_shift(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

subtract(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

true_divide(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric ¤

Polynomial Operations¤

poly ¤

Polynomial helpers for :mod:haliax.

This module provides NamedArray-aware wrappers around :mod:jax.numpy's polynomial utilities.

Functions:

Attributes:

DEFAULT_POLY_AXIS_NAME = 'degree' module-attribute ¤

Default name used for polynomial coefficient axes when none is provided.

poly(seq_of_zeros: NamedArray | ArrayLike) -> NamedArray ¤

Named version of jax.numpy.poly.

If seq_of_zeros is not a haliax.NamedArray, the returned coefficient axis is named degree.

polyadd(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray ¤

Named version of jax.numpy.polyadd.

If neither input is a haliax.NamedArray, the coefficient axis is named degree; otherwise the axis from the NamedArray input is reused (resized as needed).

polysub(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray ¤

Named version of jax.numpy.polysub.

If neither input is a haliax.NamedArray, the coefficient axis is named degree; otherwise the axis from the NamedArray input is reused (resized as needed).

polymul(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray ¤

Named version of jax.numpy.polymul.

If neither input is a haliax.NamedArray, the coefficient axis is named degree; otherwise the axis from the NamedArray input is reused (resized as needed).

polydiv(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> tuple[NamedArray, NamedArray] ¤

Named version of jax.numpy.polydiv.

The quotient and remainder reuse the coefficient axis from the NamedArray input when available; otherwise their coefficient axes are named degree.

polyint(p: NamedArray | ArrayLike, m: int = 1, k: ArrayLike | NamedArray | None = None) -> NamedArray ¤

Named version of jax.numpy.polyint.

If p is not a haliax.NamedArray, the integrated polynomial uses a coefficient axis named degree.

polyder(p: NamedArray | ArrayLike, m: int = 1) -> NamedArray ¤

Named version of jax.numpy.polyder.

If p is not a haliax.NamedArray, the differentiated polynomial uses a coefficient axis named degree.

polyval(p: NamedArray | ArrayLike, x: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.polyval.

When x is a haliax.NamedArray, the returned array reuses x's axes. Otherwise a regular :mod:jax.numpy array is returned.

polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = None, full: bool = False, w: NamedArray | ArrayLike | None = None, cov: bool = False) -> NamedArray | tuple ¤
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[False] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[False] = ...) -> NamedArray
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[True] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[False] = ...) -> tuple[NamedArray, Array, Array, Array, Array]
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[False] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[True] = ...) -> tuple[NamedArray, NamedArray]
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[True] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[True] = ...) -> tuple[NamedArray, Array, Array, Array, Array]

Named version of jax.numpy.polyfit.

If neither x nor y is a haliax.NamedArray, the fitted coefficients use a coefficient axis named degree; otherwise the axis from the NamedArray input is reused. When cov is True, the returned covariance matrix is wrapped in a haliax.NamedArray whose row axis matches the coefficient axis and whose column axis uses the same name with a "_cov" suffix.

roots(p: NamedArray | ArrayLike) -> NamedArray ¤

Named version of jax.numpy.roots.

If p is not a haliax.NamedArray, the root axis is named degree.

trim_zeros(f: NamedArray | ArrayLike, trim: str = 'fb') -> NamedArray ¤

Named version of jax.numpy.trim_zeros.

If f is not a haliax.NamedArray, the trimmed coefficient axis is named degree.

vander(x: NamedArray, degree: AxisSelector) -> NamedArray ¤

Named version of jax.numpy.vander.

Parameters:

  • x ¤
    (NamedArray) –

    Input array of shape (n,).

  • degree ¤
    (AxisSelector) –

    Axis for the polynomial degree in the output. If a string is provided, an axis with that name and size n is created.

Returns:

  • NamedArray

    Vandermonde matrix with row axis from x and the provided degree axis.

polyadd(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray ¤

Named version of jax.numpy.polyadd.

If neither input is a haliax.NamedArray, the coefficient axis is named degree; otherwise the axis from the NamedArray input is reused (resized as needed).

polysub(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray ¤

Named version of jax.numpy.polysub.

If neither input is a haliax.NamedArray, the coefficient axis is named degree; otherwise the axis from the NamedArray input is reused (resized as needed).

polymul(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray ¤

Named version of jax.numpy.polymul.

If neither input is a haliax.NamedArray, the coefficient axis is named degree; otherwise the axis from the NamedArray input is reused (resized as needed).

polydiv(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> tuple[NamedArray, NamedArray] ¤

Named version of jax.numpy.polydiv.

The quotient and remainder reuse the coefficient axis from the NamedArray input when available; otherwise their coefficient axes are named degree.

polyint(p: NamedArray | ArrayLike, m: int = 1, k: ArrayLike | NamedArray | None = None) -> NamedArray ¤

Named version of jax.numpy.polyint.

If p is not a haliax.NamedArray, the integrated polynomial uses a coefficient axis named degree.

polyder(p: NamedArray | ArrayLike, m: int = 1) -> NamedArray ¤

Named version of jax.numpy.polyder.

If p is not a haliax.NamedArray, the differentiated polynomial uses a coefficient axis named degree.

polyval(p: NamedArray | ArrayLike, x: NamedOrNumeric) -> NamedOrNumeric ¤

Named version of jax.numpy.polyval.

When x is a haliax.NamedArray, the returned array reuses x's axes. Otherwise a regular :mod:jax.numpy array is returned.

polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = None, full: bool = False, w: NamedArray | ArrayLike | None = None, cov: bool = False) -> NamedArray | tuple ¤

polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[False] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[False] = ...) -> NamedArray
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[True] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[False] = ...) -> tuple[NamedArray, Array, Array, Array, Array]
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[False] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[True] = ...) -> tuple[NamedArray, NamedArray]
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[True] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[True] = ...) -> tuple[NamedArray, Array, Array, Array, Array]

Named version of jax.numpy.polyfit.

If neither x nor y is a haliax.NamedArray, the fitted coefficients use a coefficient axis named degree; otherwise the axis from the NamedArray input is reused. When cov is True, the returned covariance matrix is wrapped in a haliax.NamedArray whose row axis matches the coefficient axis and whose column axis uses the same name with a "_cov" suffix.

roots(p: NamedArray | ArrayLike) -> NamedArray ¤

Named version of jax.numpy.roots.

If p is not a haliax.NamedArray, the root axis is named degree.

trim_zeros(f: NamedArray | ArrayLike, trim: str = 'fb') -> NamedArray ¤

Named version of jax.numpy.trim_zeros.

If f is not a haliax.NamedArray, the trimmed coefficient axis is named degree.

vander(x: NamedArray, degree: AxisSelector) -> NamedArray ¤

Named version of jax.numpy.vander.

Parameters:

  • x ¤
    (NamedArray) –

    Input array of shape (n,).

  • degree ¤
    (AxisSelector) –

    Axis for the polynomial degree in the output. If a string is provided, an axis with that name and size n is created.

Returns:

  • NamedArray

    Vandermonde matrix with row axis from x and the provided degree axis.

Other Operations¤

bincount(x: NamedArray, Counts: Axis, *, weights: NamedArray | ArrayLike | None = None, minlength: int = 0) -> NamedArray ¤

Named version of jax.numpy.bincount.

The output axis is specified by Counts.

clip(array: NamedOrNumeric, a_min: NamedOrNumeric, a_max: NamedOrNumeric) -> NamedArray ¤

Like jnp.clip, but with named axes. This version currently only accepts the three argument form.

packbits(a: NamedArray, axis: AxisSelector, *, bitorder: str = 'big') -> NamedArray ¤

Named version of jax.numpy.packbits.

unpackbits(a: NamedArray, axis: AxisSelector, *, count: int | None = None, bitorder: str = 'big') -> NamedArray ¤

Named version of jax.numpy.unpackbits.

isclose(a: NamedArray, b: NamedArray, rtol=1e-05, atol=1e-08, equal_nan=False) -> NamedArray ¤

Returns a boolean array where two arrays are element-wise equal within a tolerance.

allclose(a: NamedArray, b: NamedArray, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool ¤

Returns True if two arrays are element-wise equal within a tolerance.

array_equal(a: NamedArray, b: NamedArray) -> bool ¤

Returns True if two arrays have the same shape and elements.

array_equiv(a: NamedArray, b: NamedArray) -> bool ¤

Returns True if two arrays are shape-consistent and equal.

pad(array: NamedArray, pad_width: Mapping[AxisSelector, tuple[int, int]], *, mode: str = 'constant', constant_values: NamedOrNumeric = 0, **kwargs) -> NamedArray ¤

Version of jax.numpy.pad that works with NamedArray.

pad_width should be a mapping from axis (or axis name) to a (before, after) tuple specifying how much padding to add on each side of that axis. Any axis not present in pad_width will not be padded.

searchsorted(a: NamedArray, v: NamedArray | ArrayLike, *, side: str = 'left', sorter: NamedArray | ArrayLike | None = None, method: str = 'scan') -> NamedArray ¤

Named version of jax.numpy.searchsorted.

a and sorter (if provided) must be one-dimensional. The returned array has the same axes as v.

top_k(arr: NamedArray, axis: AxisSelector, k: int, new_axis: AxisSelector | None = None) -> tuple[NamedArray, NamedArray] ¤

Select the top k elements along the given axis. Args: arr (NamedArray): array to select from axis (AxisSelector): axis to select from k (int): number of elements to select new_axis (AxisSelector | None): new axis name, if none, the original axis will be resized to k

Returns:

  • NamedArray ( NamedArray ) –

    array with the top k elements along the given axis

  • NamedArray ( NamedArray ) –

    array with the top k elements' indices along the given axis

nonzero(array: NamedArray, *, size: Axis, fill_value: int = 0) -> tuple[NamedArray, ...] ¤

Like :func:jax.numpy.nonzero, but with named axes.

Parameters:

  • array ¤
    (NamedArray) –

    The input array to test for nonzero values. Must be a :class:NamedArray.

  • size ¤
    (Axis) –

    Axis specifying the size of the output axis. This is required because JAX requires the size of the result at tracing time.

  • fill_value ¤
    (int, default: 0 ) –

    Value used to fill the output when fewer than size elements are nonzero. Defaults to 0.

Returns:

  • NamedArray

    A tuple of :class:NamedArray objects, one for each axis of array. Each

  • ...

    returned array has size as its only axis and contains the indices of the

  • tuple[NamedArray, ...]

    nonzero elements along the corresponding input axis.

trace(array: NamedArray, axis1: AxisSelector, axis2: AxisSelector, offset=0, dtype=None) -> NamedArray ¤

Compute the trace of an array along two named axes.

tril(array: NamedArray, axis1: Axis, axis2: Axis, k=0) -> NamedArray ¤

Compute the lower triangular part of an array along two named axes.

triu(array: NamedArray, axis1: Axis, axis2: Axis, k=0) -> NamedArray ¤

Compute the upper triangular part of an array along two named axes.

where(condition: NamedOrNumeric | bool, x: NamedOrNumeric | None = None, y: NamedOrNumeric | None = None, fill_value: int | None = None, new_axis: Axis | None = None) -> NamedArray | tuple[NamedArray, ...] ¤

where(condition: NamedOrNumeric | bool, x: NamedOrNumeric, y: NamedOrNumeric) -> NamedArray
where(condition: NamedArray, *, fill_value: int, new_axis: Axis) -> tuple[NamedArray, ...]

Like jnp.where, but with named axes.

FFT¤

All FFT helpers accept an axis argument which may be a single axis, its name, or an ordered mapping from axes to output sizes. Passing a mapping dispatches to the n‑dimensional variants in :mod:jax.numpy.fft.

For example::

import jax.numpy as jnp
import haliax as hax

T = hax.Axis("time", 8)
signal = hax.arange(T, dtype=jnp.float32)

# operate along a single axis specified by name
hax.fft(signal, axis="time")

# resize by passing an Axis object
hax.fft(signal, axis=hax.Axis("time", 16))

X, Y = hax.make_axes(X=4, Y=6)
image = hax.arange((X, Y), dtype=jnp.float32)

# transform across several axes in order by passing a sequence
hax.fft(image, axis=("X", "Y"))

# selectively resize axes by providing a mapping
hax.fft(image, axis={"X": None, "Y": hax.Axis("Y", 10)})

# mappings can cover just a subset of axes when only partial resizing is needed
hax.fft(image, axis={"Y": 10})

fft ¤

Named wrappers around :mod:jax.numpy.fft.

These functions mirror the behaviour of their :mod:jax.numpy.fft counterparts while accepting named axes. Instead of separate fftn/fft2 variants we provide a single fft family of functions whose axis argument controls which axes are transformed.

The axis parameter can be one of:

  • None – operate on the last axis.
  • str – name of an existing axis in the input.
  • :class:~haliax.Axis – specifies both the axis to transform (by name) and the desired FFT length. The output axis is replaced by the provided Axis.
  • dict – mapping from axis selectors (names or Axis objects) to optional sizes. A value of None uses the existing axis length. The mapping order determines the order of transforms and dispatches to the n‑dimensional variants in :mod:jax.numpy.fft.
Example¤
X, Y = hax.make_axes(X=4, Y=6)
arr = hax.arange((X, Y))

# 1D transform along ``Y``
hax.fft(arr, axis="Y")

# 2D transform across both axes
hax.fft(arr, axis={"X": None, "Y": None})

# Resize the ``Y`` axis before transforming
hax.fft(arr, axis={"Y": Axis("Y", 8)})

Functions:

  • fft

    Named version of :func:jax.numpy.fft.fft.

  • ifft

    Named version of :func:jax.numpy.fft.ifft.

  • rfft

    Named version of :func:jax.numpy.fft.rfft.

  • irfft

    Named version of :func:jax.numpy.fft.irfft.

  • hfft

    Named version of :func:jax.numpy.fft.hfft.

  • ihfft

    Named version of :func:jax.numpy.fft.ihfft.

  • fftshift

    Named version of :func:jax.numpy.fft.fftshift.

  • ifftshift

    Named version of :func:jax.numpy.fft.ifftshift.

  • fftfreq

    Named version of :func:jax.numpy.fft.fftfreq.

  • rfftfreq

    Named version of :func:jax.numpy.fft.rfftfreq.

AxisSizeLike = int | Axis | None module-attribute ¤
AxisMapping = Mapping[AxisSelector, AxisSizeLike] module-attribute ¤
fft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.fft.

See module level documentation for the behaviour of the axis argument.

ifft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.ifft.

rfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.rfft.

irfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.irfft.

hfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.hfft.

Only a single axis is supported; passing a dictionary with more than one entry will raise an error.

ihfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.ihfft.

Only a single axis is supported; passing a dictionary with more than one entry will raise an error.

fftshift(x: NamedArray, axes: AxisSelection | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.fftshift.

ifftshift(x: NamedArray, axes: AxisSelection | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.ifftshift.

fftfreq(axis: Axis, d: float = 1.0) -> NamedArray ¤

Named version of :func:jax.numpy.fft.fftfreq.

rfftfreq(axis: Axis, d: float = 1.0) -> NamedArray ¤

Named version of :func:jax.numpy.fft.rfftfreq.

ifft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.ifft.

hfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.hfft.

Only a single axis is supported; passing a dictionary with more than one entry will raise an error.

ihfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.ihfft.

Only a single axis is supported; passing a dictionary with more than one entry will raise an error.

rfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.rfft.

irfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.irfft.

fftfreq(axis: Axis, d: float = 1.0) -> NamedArray ¤

Named version of :func:jax.numpy.fft.fftfreq.

rfftfreq(axis: Axis, d: float = 1.0) -> NamedArray ¤

Named version of :func:jax.numpy.fft.rfftfreq.

fftshift(x: NamedArray, axes: AxisSelection | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.fftshift.

ifftshift(x: NamedArray, axes: AxisSelection | None = None) -> NamedArray ¤

Named version of :func:jax.numpy.fft.ifftshift.

Named Array Reference¤

Most methods on haliax.NamedArray just call the corresponding haliax function with the array as the first argument, just as with Numpy. The exceptions are documented here:

NamedArray(array: jnp.ndarray, axes: AxisSpec) dataclass ¤

Methods:

Attributes:

  • array (ndarray) –
  • axes (tuple[Axis, ...]) –
  • shape (dict[str, int]) –
  • dtype

    The dtype of the underlying array

  • ndim

    The number of dimensions of the underlying array

  • size

    The number of elements in the underlying array

  • nbytes

    The number of bytes in the underlying array

  • at ('_NamedIndexUpdateHelper') –
        Named analog of [jax's at method](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html).
    
  • imag ('NamedArray') –
  • real ('NamedArray') –
array: jnp.ndarray instance-attribute ¤
axes: tuple[Axis, ...] instance-attribute ¤
shape: dict[str, int] cached property ¤
dtype = property(lambda self: self.array.dtype) class-attribute instance-attribute ¤

The dtype of the underlying array

ndim = property(lambda self: self.array.ndim) class-attribute instance-attribute ¤

The number of dimensions of the underlying array

size = property(lambda self: self.array.size) class-attribute instance-attribute ¤

The number of elements in the underlying array

nbytes = property(lambda self: self.array.nbytes) class-attribute instance-attribute ¤

The number of bytes in the underlying array

at: '_NamedIndexUpdateHelper' property ¤
    Named analog of [jax's at method](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html).
    The at[...].set(...) pattern is a functional way to update elements of an array.

    This is just the named version, using the same indexing syntax we use for slicing and
    the JAX syntax for at etc.

    Docs from the JAX docs:

    The at property provides a functionally pure equivalent of in-place array modifications.

    In particular:
Alternate syntax Equivalent In-place expression
x = x.at[idx].set(y) x[idx] = y
x[idx] = y x = x.at[idx].set(y)
x = x.at[idx].add(y) x[idx] += y
x = x.at[idx].multiply(y) x[idx] *= y
x = x.at[idx].divide(y) x[idx] /= y
x = x.at[idx].power(y) x[idx] **= y
x = x.at[idx].min(y) x[idx] = minimum(x[idx], y)
x = x.at[idx].max(y) x[idx] = maximum(x[idx], y)
x = x.at[idx].apply(ufunc) ufunc.at(x, idx)

x = x.at[idx].get()

x = x[idx]

None of the x.at expressions modify the original x; instead they return a modified copy of x. However, inside a jit() compiled function, expressions like x = x.at[idx].set(y) are guaranteed to be applied in-place.

Unlike NumPy in-place operations such as x[idx] += y, if multiple indices refer to the same location, all updates will be applied (NumPy would only apply the last update, rather than applying all updates.) The order in which conflicting updates are applied is implementation-defined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).

By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the mode parameter (see below).

    Returns:
imag: 'NamedArray' property ¤
real: 'NamedArray' property ¤
item() ¤

Returns the value of this NamedArray as a python scalar.

scalar() -> jnp.ndarray ¤

Returns a scalar array corresponding to the value of this NamedArray. Raises an error if the NamedArray is not scalar.

We sometimes use this to convert a NamedArray to a scalar for returning a loss or similar. Losses have to be jnp.ndarrays, not NamedArrays, so we need to convert them. item doesn't work inside jitted functions because it returns a python scalar.

You could just call array, but that's not as clear and doesn't assert.

tree_flatten() -> Any ¤
tree_unflatten(aux, tree: Any) -> Any classmethod ¤
has_axis(axis: AxisSelection) -> bool ¤

Returns true if the given axis is present in this NamedArray.

matches_axes(spec: NamedArrayAxesSpec) -> bool ¤

Check whether this NamedArray conforms to the given NamedArray type.

Parameters¤

spec : NamedArrayAxesSpec The specification to check against. It can be produced via the NamedArray[...] syntax or passed directly as a string or sequence of axis names.

axis_size(axis: AxisSelection) -> int | tuple[int, ...] ¤
axis_size(axis: AxisSelector) -> int
axis_size(axis: Sequence[AxisSelector]) -> tuple[int, ...]

Returns the size of the given axis, or a tuple of sizes if given multiple axes.

resolve_axis(axes: AxisSelection) -> AxisSpec ¤
resolve_axis(axis: AxisSelector) -> Axis
resolve_axis(axis: tuple[AxisSelector, ...]) -> tuple[Axis, ...]
resolve_axis(axis: PartialShapeDict) -> ShapeDict
resolve_axis(axes: AxisSelection) -> AxisSpec

Returns the axes corresponding to the given axis selection. That is, it returns the haliax.Axis values themselves, not just their names.

Raises a ValueError if any of the axes are not found.

axis_indices(axis: AxisSelection) -> int | None | tuple[int | None, ...] ¤
axis_indices(axis: AxisSelector) -> int | None
axis_indices(axis: Sequence[AxisSelector]) -> tuple[int | None, ...]
axis_indices(axis: AxisSelection) -> tuple[int | None, ...]

For a single axis, returns an int corresponding to the index of the axis. For multiple axes, returns a tuple of ints corresponding to the indices of the axes.

If the axis is not present, returns None for that position

rearrange(*args, **kwargs) -> 'NamedArray' ¤
rearrange(axes: Sequence[AxisSelector | EllipsisType]) -> 'NamedArray'
rearrange(expression: str, **bindings: AxisSelector | int) -> 'NamedArray'

See haliax.rearrange for details.

broadcast_to(axes: AxisSpec) -> 'NamedArray' ¤
broadcast_axis(axis: AxisSpec) -> 'NamedArray' ¤
split(axis: AxisSelector, new_axes: Sequence[Axis]) -> Sequence['NamedArray'] ¤
flatten_axes(old_axes: AxisSelection, new_axis: AxisSelector) -> 'NamedArray' ¤
unflatten_axis(axis: AxisSelector, new_axes: AxisSpec) -> 'NamedArray' ¤
ravel(new_axis_name: AxisSelector) -> 'NamedArray' ¤
flatten(new_axis_name: AxisSelector) -> 'NamedArray' ¤
unbind(axis: AxisSelector) -> Sequence['NamedArray'] ¤
rename(renames: Mapping[AxisSelector, AxisSelector]) -> 'NamedArray' ¤
slice(*args, **kwargs) -> 'NamedArray' ¤
slice(axis: AxisSelector, new_axis: AxisSelector | None = None, start: int = 0, length: int | None = None) -> 'NamedArray'
slice(start: Mapping[AxisSelector, int], length: Mapping[AxisSelector, int | Axis]) -> 'NamedArray'
updated_slice(start: Mapping[AxisSelector, int | 'NamedArray'], update: 'NamedArray') -> 'NamedArray' ¤
take(axis: AxisSelector, index: int | 'NamedArray') -> 'NamedArray' ¤
all(axis: AxisSelection | None = None, *, where: 'NamedArray' | None = None) -> 'NamedArray' ¤
any(axis: AxisSelection | None = None, *, where: 'NamedArray' | None = None) -> 'NamedArray' ¤
argmax(axis: AxisSelector | None = None) -> 'NamedArray' ¤
argmin(axis: AxisSelector | None) -> 'NamedArray' ¤
argsort(axis: AxisSelector | None, *, stable: bool = False) -> 'NamedArray' ¤
astype(dtype) -> 'NamedArray' ¤
clip(a_min=None, a_max=None) -> Any ¤
conj() -> 'NamedArray' ¤
conjugate() -> 'NamedArray' ¤
copy() -> 'NamedArray' ¤
cumprod(axis: AxisSelector, *, dtype=None) -> 'NamedArray' ¤
cumsum(axis: AxisSelector, *, dtype=None) -> 'NamedArray' ¤
dot(*args, **kwargs) -> 'NamedArray' ¤
dot(axis: AxisSelection | None, *b, precision: PrecisionLike = None, dot_general=jax.lax.dot_general) -> 'NamedArray'
dot(*args: 'NamedArray', axis: AxisSelection | None, precision: PrecisionLike = None, dot_general=jax.lax.dot_general) -> 'NamedArray'
max(axis: AxisSelection | None = None, *, where=None) -> 'NamedArray' ¤
mean(axis: AxisSelection | None = None, *, dtype=None, where: 'NamedArray' | None = None) -> 'NamedArray' ¤
min(axis: AxisSelection | None = None, *, where: 'NamedArray' | None = None) -> 'NamedArray' ¤
prod(axis: AxisSelection | None = None, *, dtype=None, where: 'NamedArray' | None = None) -> 'NamedArray' ¤
product(axis: AxisSelection | None = None, *, dtype=None, where: 'NamedArray' | None = None) -> 'NamedArray' ¤
ptp(axis: AxisSelection | None = None) -> 'NamedArray' ¤
round(decimals=0) -> 'NamedArray' ¤
sort(axis: AxisSelector) -> Any ¤
std(axis: AxisSelection | None = None, *, dtype=None, ddof=0, where: 'NamedArray' | None = None) -> 'NamedArray' ¤
sum(axis: AxisSelection | None = None, *, dtype=None, where: 'NamedArray' | None = None) -> 'NamedArray' ¤
tobytes(order='C') -> Any ¤
tolist() -> Any ¤
trace(axis1: AxisSelector, axis2: AxisSelector, offset=0, dtype=None) -> 'NamedArray' ¤
var(axis: AxisSelection | None = None, dtype=None, ddof=0, *, where: 'NamedArray' | None = None) -> 'NamedArray' ¤

Partitioning API¤

See also the section on Partitioning.

axis_mapping(mapping: ResourceMapping, *, merge: bool = False, **kwargs) ¤

Context manager for setting the global resource mapping

shard(x: T, mapping: ResourceMapping | None = None, mesh: Mesh | None = None) -> T ¤

Shard a PyTree using the provided axis mapping. NamedArrays in the PyTree are sharded using the axis mapping. Other arrays (i.e. plain JAX arrays) are left alone.

This is basically a fancy wrapper around with_sharding_constraint that uses the axis mapping to determine the sharding.

named_jit(fn: Callable[Args, R] | None = None, axis_resources: ResourceMapping | None = None, *, in_axis_resources: ResourceMapping | None = None, out_axis_resources: ResourceMapping | None = None, donate_args: PyTree | None = None, donate_kwargs: PyTree | None = None, **pjit_args) -> WrappedCallable[Args, R] | typing.Callable[[Callable[Args, R]], WrappedCallable[Args, R]] ¤

named_jit(fn: Callable[Args, R], axis_resources: ResourceMapping | None = None, *, in_axis_resources: ResourceMapping | None = None, out_axis_resources: ResourceMapping | None = None, donate_args: PyTree | None = None, donate_kwargs: PyTree | None = None, keep_unused: bool = False, backend: str | None = None, inline: bool | None = None) -> WrappedCallable[Args, R]
named_jit(*, axis_resources: ResourceMapping | None = None, in_axis_resources: ResourceMapping | None = None, out_axis_resources: ResourceMapping | None = None, donate_args: PyTree | None = None, donate_kwargs: PyTree | None = None, keep_unused: bool = False, backend: str | None = None, inline: bool | None = None) -> typing.Callable[[Callable[Args, R]], WrappedCallable[Args, R]]

A version of pjit that uses NamedArrays and the provided resource mapping to infer resource partitions for sharded computation for.

axis_resources will be used for a context-specific resource mapping when the function is invoked. In addition, if in_axis_resources is not provided, the arguments' own (pre-existing) shardings will be used as the in_axis_resources. If out_axis_resources is not provided, axis_resources will be used as the out_axis_resources.

If no resource mapping is provided, this function attempts to use the context resource mapping.

Functionally this is very similar to something like:

This function can be used as a decorator or as a function.

 def wrapped_fn(arg):
    result = fn(arg)
    return hax.shard(result, out_axis_resources)

 arg = hax.shard(arg, in_axis_resources)
 with hax.axis_mapping(axis_resources):
    result = jax.jit(wrapped_fn, **pjit_args)(arg)
return result

Parameters:

  • fn ¤
    (Callable, default: None ) –

    The function to be jit'd.

  • axis_resources ¤
    (ResourceMapping, default: None ) –

    A mapping from logical axis names to physical axis names use for the context-specific resource mapping.

  • in_axis_resources ¤
    (ResourceMapping, default: None ) –

    A mapping from logical axis names to physical axis names for arguments. If not passed, it uses the argument's own shardings.

  • out_axis_resources ¤
    (ResourceMapping, default: None ) –

    A mapping from logical axis names to physical axis names for the result, defaults to axis_resources.

  • donate_args ¤
    (PyTree, default: None ) –

    A PyTree of booleans or function leaf->bool, indicating if the arguments should be donated to the computation.

  • donate_kwargs ¤
    (PyTree, default: None ) –

    A PyTree of booleans or function leaf->bool, indication if the keyword arguments should be donated to the computation.

Returns:

  • WrappedCallable[Args, R] | Callable[[Callable[Args, R]], WrappedCallable[Args, R]]

    A jit'd version of the function.

fsdp(*args, **kwargs) ¤

fsdp(fn: F, parameter_mapping: ResourceMapping, compute_mapping: ResourceMapping) -> F
fsdp(parameter_mapping: ResourceMapping, compute_mapping: ResourceMapping) -> typing.Callable[[F], F]

A convenience wrapper around named_jit / pjit to encode the FSDP pattern. It's basically equivalent to this:

@named_jit(in_axis_resources=parameter_mapping, out_axis_resources=parameter_mapping, axis_resources=compute_mapping)
def f(*args, **kwargs):
    return fn(*args, **kwargs)

This function can be used as a decorator or as a function.

round_axis_for_partitioning(axis: Axis, mapping: ResourceMapping | None = None) -> Axis ¤

Round an axis so that it's divisible by the size of the partition it's on

physical_axis_name(axis: AxisSelector, mapping: ResourceMapping | None = None) -> PhysicalAxisSpec | None ¤

Get the physical axis name for a logical axis from the mapping. Returns none if the axis is not mapped.

physical_axis_size(axis: AxisSelector, mapping: ResourceMapping | None = None) -> int | None ¤

Get the physical axis size for a logical axis. This is the product of the size of all physical axes that this logical axis is mapped to.

sharding_for_axis(axis: AxisSelection, mapping: ResourceMapping | None = None, mesh: MeshLike | None = None) -> NamedSharding ¤

Get the sharding for a single axis

Gradient Checkpointing¤

Haliax mainly just defers to JAX and equinox.filter_checkpoint for gradient checkpointing. However, we provide a few utilities to make it easier to use.

See also haliax.nn.ScanCheckpointPolicy.

tree_checkpoint_name(x: T, name: str) -> T ¤

Checkpoint a tree of arrays with a given name. This is useful for gradient checkpointing. This is equivalent to calling [jax.ad_checkpoint.checkpoint_name][] except that it works for any PyTree, not just arrays.

See Also

Old API¤

These functions are being deprecated and will be removed in a future release.

shard_with_axis_mapping(x: T, mapping: ResourceMapping, mesh: Mesh | None = None) -> T ¤

auto_sharded(x: T, mesh: Optional[Mesh] = None) -> T ¤

Shard a PyTree using the global axis mapping. NamedArrays in the PyTree are sharded using the axis mapping and the names in the tree.

If there is no axis mapping, the global axis mapping, this function is a no-op.