Main API Reference¤
In general, the API is designed to be similar to JAX's version of NumPy's API, with the main difference being that we use names (either strings or haliax.Axis objects) to specify axes instead of integers. This shows up for creating arrays (see haliax.zeros and haliax.ones) as well as things like reductions (see haliax.sum and haliax.mean).
PyTree Helpers¤
PyTrees are the lingua franca for composing state in JAX ecosystems. Haliax provides drop-in replacements for the
jax.tree helpers that are aware of NamedArray semantics. They preserve axis metadata across
transformations while interoperating with standard JAX containers, so you can use them anywhere you would have reached
for JAX's versions.
Use these helpers whenever you need to map, flatten, or rebuild PyTrees that might include NamedArray instances:
haliax.tree.mapmirrorsjax.tree.mapbut forwards to Haliax's [haliax.tree_util.tree_map][] so axis names remain intact.haliax.tree.scan_aware_mapdescends intohaliax.nn.Stackedmodules so that each layer is transformed individually, effectively treating them as if they were unrolled when applying [haliax.tree_util.scan_aware_tree_map][].haliax.tree.flatten/haliax.tree.unflattenmatch the familiar flattening API while handlingNamedArraypayloads safely.haliax.tree.leavesandhaliax.tree.structureprovide direct access to the leaves and PyTree structure.
All of these helpers accept the same is_leaf hook you might already use with JAX's utilities. They should be the first
tools you reach for when you need deterministic tree transforms that understand named axes.
map(fn: Callable[..., T], tree: Any, *rest: Any, is_leaf: Callable[[Any], bool] | None = None) -> Any
¤
Alias for :func:haliax.tree_util.tree_map matching :func:jax.tree.map.
scan_aware_map(fn: Callable[..., T], tree: Any, *rest: Any, is_leaf: Callable[[Any], bool] | None = None) -> Any
¤
Alias for :func:haliax.tree_util.scan_aware_tree_map with :mod:jax.tree style naming.
flatten(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> tuple[Sequence[Any], Any]
¤
Alias for :func:haliax.tree_util.tree_flatten matching :func:jax.tree.flatten.
unflatten(treedef: Any, leaves: Iterable[Any]) -> Any
¤
Alias for :func:haliax.tree_util.tree_unflatten matching :func:jax.tree.unflatten.
leaves(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> Sequence[Any]
¤
Alias for :func:haliax.tree_util.tree_leaves matching :func:jax.tree.leaves.
structure(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> Any
¤
Alias for :func:haliax.tree_util.tree_structure matching :func:jax.tree.structure.
Axis Types¤
If you already speak NumPy or jax.numpy, think of Haliax as swapping positional axes (axis=0) for named axes
(axis="batch"). The type hints in this section describe the different ways those named axes can be provided to the API.
They appear throughout the documentation and in signatures so that you can quickly tell which forms are accepted.
| Name | Accepts | When to use it | Example |
|---|---|---|---|
Axis |
Axis(name: str, size: int) |
Define a named dimension with an explicit size | Batch = Axis("batch", 32) |
AxisSelector |
Axis or str |
Refer to an existing axis whose size can be inferred from the arrays you pass in | x.sum(axis="batch") |
AxisSpec |
dict[str, int], Axis, or a sequence of Axis objects |
Create or reshape arrays when the axis sizes must be provided | hax.zeros((Batch, Feature)) |
AxisSelection |
dict[str, int | None], AxisSpec, or a sequence of AxisSelector values |
Work with one or more existing axes (reductions, indexing helpers, flattening, …) | x.sum(axis=("batch", Feature)) |
Axis¤
An Axis is the fundamental building block: it is a tiny dataclass that stores a name and a size. You can
construct one directly or use haliax.make_axes to generate several at a time.
import haliax as hax
from haliax import Axis
Batch = Axis("batch", 32)
Feature = Axis("feature", 128)
x = hax.ones((Batch, Feature))
print(Batch.name, Batch.size)
Using Axis objects keeps array creation explicit and gives reusable handles you can share between different tensors.
Equality compares both the name and size so you get guardrails when wiring pieces together.
AxisSelector¤
An AxisSelector accepts either an Axis object or just the axis name as a string. It is used
whenever a function can read the axis size from one of its arguments. This mirrors how NumPy lets you pass axis=0 when
reducing an array:
total = x.sum(axis=Batch) # using the Axis handle
same_total = x.sum(axis="batch") # using only the name
Strings are convenient when you only care about the name, but Axis objects still work so you can keep using the handles
you created earlier. If an axis with that name is missing, Haliax raises a ValueError.
AxisSpec¤
An AxisSpec is used when Haliax needs full size information to create or reshape an array. You can
provide a shape dictionary (sometimes called a "shape dict") that maps names to sizes, or a sequence of Axis objects:
shape = {"batch": 32, "feature": 128}
y = hax.zeros(shape) # using a shape dict
z = hax.zeros((Batch, Feature)) # using the Axis objects directly
Both forms describe the same layout. Python dictionaries preserve insertion order, so the ordering in a shape dict matches
the order that axes appear in the array. Sequences must contain Axis objects (not plain strings) because Haliax cannot
otherwise know the axis sizes.
AxisSelection¤
AxisSelection generalizes the previous aliases so you can talk about several axes at once. It
shows up in reductions, indexing helpers, axis-mapping utilities, and anywhere you might have written axis=(0, 1) in
NumPy. You may supply:
- a sequence mixing
Axisobjects and strings, e.g.("batch", Feature)when reducing two axes, - an
AxisSpec, which is handy when you already have a tuple ofAxisobjects, or - a "partial shape dict" where the values are either sizes or
Noneto indicate "any size". Dictionaries are useful when you only care about a subset of axes or want to assert a particular size.
# Reduce over two axes using a tuple of selectors.
scalar = x.sum(axis=("batch", Feature))
# Ask for the axes by name and optionally pin sizes.
x.resolve_axis({"batch": None, "feature": None}) # returns {"batch": 32, "feature": 128}
from haliax.axis import selects_axis
assert selects_axis((Batch, "feature"), {"batch": None, "feature": 128})
Occasionally, an axis size can be inferred in some circumstances but not others. When this happens we still use
AxisSelector (or AxisSelection for multiple axes) but document the behavior in the docstring. A RuntimeError will be
raised if the size cannot be inferred.
Axis(name: str, size: int)
dataclass
¤
AxisSelector = Axis | str
module-attribute
¤
AxisSelector is a type that can be used to select a single axis from an array. str or Axis
AxisSpec = Axis | Sequence[Axis] | ShapeDict
module-attribute
¤
AxisSpec is a type that can be used to specify the axes of an array, usually for creation or adding a new axis whose size can't be determined another way. Axis or sequence of Axis
AxisSelection = AxisSelector | Sequence[AxisSelector] | PartialShapeDict
module-attribute
¤
AxisSelection is a type that can be used to select multiple axes from an array. str, Axis, or sequence of mixed str and Axis
Axis Manipulation¤
make_axes(**kwargs: int) -> tuple[Axis, ...]
¤
Convenience function for creating a tuple of Axis objects.
Example:
X, Y = axes(X=10, Y=20)
axis_name(ax: AxisSelection) -> str | tuple[str, ...]
¤
axis_name(ax: AxisSelector) -> str
axis_name(ax: Sequence[AxisSelector]) -> tuple[str, ...]
Returns the name of the axis. If ax is a string, returns ax. If ax is an Axis, returns ax.name
concat_axes(a1, a2)
¤
concat_axes(a1: ShapeDict, a2: AxisSpec) -> ShapeDict
concat_axes(a1: Sequence[Axis], a2: AxisSpec) -> tuple[Axis, ...]
concat_axes(a1: AxisSpec, a2: AxisSpec) -> AxisSpec
concat_axes(a1: AxisSelection, a2: AxisSelection) -> AxisSelection
Concatenates two AxisSpecs. Raises ValueError if any axis is present in both specs
union_axes(a1: AxisSelection, a2: AxisSelection) -> AxisSelection
¤
union_axes(a1: ShapeDict, a2: AxisSpec) -> ShapeDict
union_axes(a1: AxisSpec, a2: ShapeDict) -> ShapeDict
union_axes(a1: AxisSpec, a2: AxisSpec) -> AxisSpec
union_axes(a1: AxisSelection, a2: AxisSelection) -> AxisSelection
Similar to concat_axes, but allows axes to be specified multiple times. The resulting AxisSpec will have the order of just concatenating each axis spec, but with any duplicate axes removed.
Raises if any axis is present in both specs with different sizes
intersect_axes(ax1: AxisSelection, ax2: AxisSelection) -> AxisSelection
¤
intersect_axes(ax1: ShapeDict, ax2: AxisSelection) -> ShapeDict
intersect_axes(ax1: tuple[AxisSelector, ...], ax2: AxisSpec) -> tuple[Axis, ...]
intersect_axes(ax1: tuple[AxisSelector, ...], ax2: AxisSelection) -> tuple[AxisSelector, ...]
intersect_axes(ax1: AxisSpec, ax2: AxisSelection) -> AxisSpec
Returns a tuple of axes that are present in both ax1 and ax2. The returned order is the same as ax1.
eliminate_axes(axis_spec: AxisSelection, to_remove: AxisSelection) -> AxisSelection
¤
eliminate_axes(axis_spec: Axis | Sequence[Axis], axes: AxisSelection) -> tuple[Axis, ...]
eliminate_axes(axis_spec: ShapeDict, axes: AxisSelection) -> ShapeDict
eliminate_axes(axis_spec: AxisSelection, axes: AxisSelection) -> AxisSelection
eliminate_axes(axis_spec: PartialShapeDict, axes: AxisSelection) -> PartialShapeDict
Returns a new axis spec that is the same as the original, but without any axes in axes. Raises if any axis in to_remove is not present in axis_spec
without_axes(axis_spec: AxisSelection, to_remove: AxisSelection, allow_mismatched_sizes=False) -> AxisSelection
¤
without_axes(axis_spec: ShapeDict, to_remove: AxisSelection, allow_mismatched_sizes=False) -> ShapeDict
without_axes(axis_spec: Sequence[Axis], to_remove: AxisSelection, allow_mismatched_sizes=False) -> tuple[Axis, ...]
without_axes(axis_spec: AxisSpec, to_remove: AxisSelection, allow_mismatched_sizes=False) -> AxisSpec
without_axes(axis_spec: AxisSelection, to_remove: AxisSelection, allow_mismatched_sizes=False) -> AxisSelection
without_axes(axis_spec: Sequence[AxisSelector], to_remove: AxisSelection, allow_mismatched_sizes=False) -> tuple[AxisSpec, ...]
without_axes(axis_spec: PartialShapeDict, to_remove: AxisSelection, allow_mismatched_sizes=False) -> PartialShapeDict
As eliminate_axes, but does not raise if any axis in to_remove is not present in axis_spec.
However, this does raise if any axis in to_remove is present in axis_spec with a different size.
selects_axis(selector: AxisSelection, selected: AxisSelection) -> bool
¤
Returns true if the selector has every axis in selected and, if dims are given, that they match
is_axis_compatible(ax1: AxisSelector, ax2: AxisSelector)
¤
Returns true if the two axes are compatible, meaning they have the same name and, if both are Axis, the same size
Array Creation¤
named(a, axis: AxisSelection) -> NamedArray
¤
Creates a NamedArray from a numpy array and a list of axes.
zeros(shape: AxisSpec, dtype: DTypeLike | None = None) -> NamedArray
¤
Creates a NamedArray with all elements set to 0
ones(shape: AxisSpec, dtype: DTypeLike | None = None) -> NamedArray
¤
Creates a NamedArray with all elements set to 1
full(shape: AxisSpec, fill_value: T, dtype: DTypeLike | None = None) -> NamedArray
¤
Creates a NamedArray with all elements set to fill_value
zeros_like(a: NamedArray, dtype=None) -> NamedArray
¤
Creates a NamedArray with all elements set to 0
ones_like(a: NamedArray, dtype=None) -> NamedArray
¤
Creates a NamedArray with all elements set to 1
full_like(a: NamedArray, fill_value: T, dtype: DTypeLike | None = None) -> NamedArray
¤
Creates a NamedArray with all elements set to fill_value
arange(axis: AxisSpec, *, start=0, step=1, dtype: DTypeLike | None = None) -> NamedArray
¤
Version of jnp.arange that returns a NamedArray.
This version differs from jnp.arange (beyond the obvious NamedArray) in two ways:
1) It can work with a start that is a tracer (i.e. a JAX expression), whereas jax arange is not able to handle tracers. 2) Axis can be more than one axis, in which case it's equivalent to arange of the product of sizes, followed by reshape.
Examples
X, Y = hax.make_axes(X=3, Y=4)
# Create a NamedArray along a single axis
arr = hax.arange(X) # equivalent to jnp.arange(0, 3, 1)
# 2D
arr = hax.arange((X, Y)) # equivalent to jnp.arange(0, 12, 1).reshape(3, 4)
linspace(axis: AxisSelector, *, start: float, stop: float, endpoint: bool = True, dtype: DTypeLike | None = None) -> NamedArray
¤
Version of jnp.linspace that returns a NamedArray.
If axis is a string, the default number of samples (50, per numpy) will be used.
logspace(axis: AxisSelector, *, start: float, stop: float, endpoint: bool = True, base: float = 10.0, dtype: DTypeLike | None = None) -> NamedArray
¤
Version of jnp.logspace that returns a NamedArray.
If axis is a string, the default number of samples (50, per numpy) will be used.
geomspace(axis: AxisSelector, *, start: float, stop: float, endpoint: bool = True, dtype: DTypeLike | None = None) -> NamedArray
¤
Version of jnp.geomspace that returns a NamedArray.
If axis is a string, the default number of samples (50, per numpy) will be used.
Combining Arrays¤
concatenate(axis: AxisSelector, arrays: Sequence[NamedArray]) -> NamedArray
¤
Version of jax.numpy.concatenate that returns a NamedArray. The returns array will have the same axis names in the same order as the first, with the selected axis extended by the sum of the sizes of the selected axes in the concatenated arrays.
stack(axis: AxisSelector, arrays: Sequence[NamedArray]) -> NamedArray
¤
Version of jax.numpy.stack that returns a NamedArray
(We don't include hstack or vstack because they are subsumed by stack.)
Array Manipulation¤
Broadcasting¤
See also the section on Broadcasting.
broadcast_axis(a: NamedArray, axis: AxisSpec) -> NamedArray
¤
Broadcasts a, ensuring that it has all the axes in axis.
broadcast_axis is an alias for broadcast_to(a, axis, enforce_no_extra_axes=False, ensure_order=True)
You typically use this function when you want to broadcast an array to a common set of axes.
broadcast_to(a: NamedOrNumeric, axes: AxisSpec, ensure_order: bool = True, enforce_no_extra_axes: bool = True) -> NamedArray
¤
Broadcast a so that it has the given axes.
If ensure_order is True (default) then the returned array's axes are
arranged in the same order as axes. Otherwise existing axes may remain in
their current order, though they may still be moved to the front if new axes
are added.
If enforce_no_extra_axes is True and a has axes that are not in
axes then a ValueError is raised.
Slicing¤
See also the section on Indexing and Slicing.
index(array: NamedArray, slices: Mapping[AxisSelector, NamedIndex]) -> NamedArray
¤
Selects elements from an array along an axis via index or another named array.
This function is typically invoked using array[...] syntax. For instance,
you might use array[{"batch": slice(0, 10)}] or array["batch", 0:10] to select the first 10 elements
of the 'batch' axis.
When indexing with a dslice the slice is gathered starting at the given
start for size elements. Values read past the end of the array are
filled with the fill_value (defaults to 0), and writes outside the
bounds are dropped.
See Also
- haliax.NamedArray.at for a functional equivalent of in-place array modifications.
Returns:
-
NamedArray–NamedArray or jnp.ndarray: A NamedArray is returned if there are any axes remaining after selection,
-
NamedArray–otherwise a scalar (0-dimensional) jnp.ndarray is returned if all axes are indexed out.
slice(array: NamedArray, *args, **kwargs) -> NamedArray
¤
slice(array: NamedArray, axis: AxisSelector, new_axis: AxisSelector | None = None, start: int = 0, length: int | None = None) -> NamedArray
slice(array: NamedArray, start: Mapping[AxisSelector, IntScalar], length: Mapping[AxisSelector, int] | None = None) -> NamedArray
Slices the array along the specified axis or axes, replacing them with new axes (or a shortened version of the old one)
This method has two signatures:
slice(array, axis, new_axis=None, start=0, length=None)slice(array, start: Mapping[AxisSelector, IntScalar], length: Mapping[AxisSelector, int])
They both do similar things. The former slices an array along a single axis, replacing it with a new axis. The latter slices an array along multiple axes, replacing them with new axes.
take(array: NamedArray, axis: AxisSelector, index: int | NamedArray) -> NamedArray
¤
Selects elements from an array along an axis, by an index or by another named array
if index is a NamedArray, then those axes are added to the output array
updated_slice(array: NamedArray, start: Mapping[AxisSelector, int | jnp.ndarray | NamedArray], update: NamedArray) -> NamedArray
¤
Updates a slice of an array with another array.
Parameters:
-
(array¤NamedArray) –The array to update.
-
(start¤Mapping[AxisSelector, int | ndarray]) –The starting index of each axis to update.
-
(update¤NamedArray) –The array to update with.
Returns:
-
NamedArray(NamedArray) –The updated array.
Mutable References¤
See also the section on Mutable References.
JAX provides [jax.Ref][], a mutable array reference that can be read or written in place while remaining compatible
with transformations such as jax.jit or jax.grad. Haliax mirrors that API with haliax.NamedRef,
which carries axis metadata so you can keep using named indexing when plumbing state through your programs.
You introduce a new reference with haliax.new_ref. The returned object behaves much like a
NamedArray for indexing purposes: ref[{"batch": 0}] reads a slice, and assignments like
ref[{"token": slice(1, 3)}] = update perform in-place updates on the underlying buffer. If you need to stage part of a
reference for repeated use, call NamedRef.slice to create a slice ref. Slice refs remember a
partial indexing expression so you only supply the remaining axes during reads or writes:
Cache = hax.Axis("layers", 24)
Head = hax.Axis("head", 8)
cache = hax.zeros((Cache, Head))
cache_ref = hax.new_ref(cache)
# Pin the layer axis once so subsequent lookups only specify the head coordinate.
layer_ref = cache_ref.slice({"layers": slice(4, 8)})
layer_ref[{"layers": 0, "head": 3}] = 1.0 # updates layer 4, head 3 in the original cache
When you are done mutating a reference, call haliax.freeze to invalidate it and recover a final
NamedArray snapshot. You can also perform atomic-style updates with haliax.swap (or the functional
helpers under haliax.ref).
NamedRef(_ref: jax.Ref, _axes: tuple[Axis, ...], _prefix: tuple[Any, ...])
dataclass
¤
Named wrapper around :class:jax.Ref that preserves axis metadata.
Methods:
-
value–Materialize this reference view as a
NamedArray. -
unsliced–Return a view of the original reference without staged selectors.
-
resolve_axis–Resolve an axis selector to the corresponding axis in the current view.
-
slice–Return a new view with the provided selector staged for future operations.
-
unsafe_buffer_pointer–
Attributes:
-
dtype–Return the dtype of the underlying reference.
-
axes(tuple[Axis, ...]) –Axes visible from this view after applying staged selectors.
-
shape(Mapping[str, int]) –Mapping from axis name to size for the current view.
-
named_shape(Mapping[str, int]) – -
ndim(int) –Number of axes in the current view.
dtype
property
¤
Return the dtype of the underlying reference.
axes: tuple[Axis, ...]
property
¤
Axes visible from this view after applying staged selectors.
shape: Mapping[str, int]
property
¤
Mapping from axis name to size for the current view.
named_shape: Mapping[str, int]
property
¤
ndim: int
property
¤
Number of axes in the current view.
value() -> NamedArray
¤
Materialize this reference view as a NamedArray.
unsliced() -> 'NamedRef'
¤
Return a view of the original reference without staged selectors.
resolve_axis(axis: AxisSelector) -> Axis
¤
Resolve an axis selector to the corresponding axis in the current view.
slice(selector: Mapping[AxisSelector, Any]) -> 'NamedRef'
¤
Return a new view with the provided selector staged for future operations.
unsafe_buffer_pointer()
¤
new_ref(value: NamedArray) -> NamedRef
¤
Construct a NamedRef from a NamedArray.
freeze(ref: NamedRef) -> NamedArray
¤
Freeze the reference and return its current contents.
swap(ref: NamedRef, idx: SliceSpec | EllipsisType, value: NamedOrNumeric) -> NamedArray
¤
Swap the value at idx, returning the previous contents as a NamedArray.
Dynamic Slicing¤
dslice(start: int, length: int | Axis)
¤
Bases: Module
Dynamic slice, comprising a (start, length) pair. Also aliased as ds.
NumPy-style slices like a[i:i+16] don't work inside :func:jax.jit, because
JAX requires slice bounds to be static. dslice works around this by
separating the dynamic start from the static size so that you can
write a[dslice(i, 16)] or simply a[ds(i, 16)].
When used in indexing or at updates, dslice behaves like a gather of
size elements starting at start. Reads beyond the end of the array are
filled with a value (0 by default) and writes outside the array bounds are
dropped, matching JAX's default scatter/gather semantics.
This class's name is taken from :mod:jax.experimental.pallas.
start:
length:
Methods:
Attributes:
dblock(idx: int, size: int) -> dslice
¤
Returns a dslice that selects a single block of size size starting at idx
ds: typing.TypeAlias = dslice
module-attribute
¤
Shape Manipulation¤
flatten(array: NamedArray, new_axis_name: AxisSelector) -> NamedArray
¤
Returns a flattened view of the array, with all axes merged into one. Aliax for haliax.ravel
flatten_axes(*args, **kwargs) -> NamedArray | AxisSpec
¤
flatten_axes(axis: Axis, old_axes: Axis, new_axis: AxisSelector) -> Axis
flatten_axes(axis: AxisSpec, new_axis: AxisSelector) -> Axis
flatten_axes(axis: AxisSpec, old_axes: AxisSelection, new_axis: AxisSelector) -> AxisSpec
flatten_axes(array: NamedArray, old_axes: AxisSelection, new_axis: AxisSelector) -> NamedArray
Merge a sequence of axes into a single axis. The new axis must have the same size as the product of the old axes.
The new axis is always inserted starting at the index of the first old axis in the underlying array.
This function can be used in two ways:
flatten_axes(array, old_axes, new_axis): merge the old axes of the array into a new axisflatten_axes(axes, old_axes, new_axis): merge the old axes into a new axis
rearrange(array: NamedArray, *args, **kwargs) -> NamedArray
¤
rearrange(array: NamedArray, axes: Sequence[AxisSelector | EllipsisType]) -> NamedArray
rearrange(array: NamedArray, expression: str, **bindings: AxisSelector | int) -> NamedArray
Rearrange a tensor according to an einops-style haliax rearrangement string or a sequence of axes. See full documentation here: Rearrange
The sequence form of rearrange rearranges an array so that its underlying storage conforms to axes.
axes may include up to 1 ellipsis, indicating that the remaining axes should be
permuted in the same order as the array's axes.
For example, if array has axes (a, b, c, d, e, f) and axes is (e, f, a, ..., c), then the output array will have axes (e, f, a, b, c, d).
The string form of rearrange works similarly to einops.rearrange, but also supports named axes and unordered matching.
The string form of rearrange comes in two forms:
- Ordered strings are like einops strings, with the only significant difference that flattened axes are named with
a colon, e.g.
B H W C -> B (E: H W C). - Unordered strings match axes by name and are marked with the addition of curly braces,
e.g.
{H W C} -> ... C H Wor{H W C} -> ... (E: H W C)
As with einops, you can provide axis sizes to unflatten axes. For instance, to turn an image patches, `hax.rearrange(x, '{ B (H: w1 H) (W: w1 W)} -> (B: B h1 w1) H W ...', H=32, W=32) will turn a batch of images into a batch of image patches. Bindings can also be haliax.Axis objects, or strings that will be used as the actual name of the resulting axis.
Examples:
>>> import haliax as hax
>>> import jax.random as jrandom
>>> B, H, W, C = hax.Axis("B", 8), hax.Axis("H", 32), hax.Axis("W", 32), hax.Axis("C", 3)
>>> x = hax.random.normal( (B, H, W, C))
>>> # Sequence-based rearrange
>>> hax.rearrange(x, (C, B, H, W))
>>> hax.rearrange(x, (C, ...)) # ellipsis means "keep the rest of the axes in the same order"
>>> # String-based rearrange
>>> # permute the axes
>>> hax.rearrange(x, "B H W C -> C B H W")
>>> # flatten the image (note the assignment of a new name to the flattened axis)
>>> hax.rearrange(x, "B H W C -> B (E: H W C)")
>>> # turn the image into patches
>>> hax.rearrange(x, "{ B (H: h1 H) (W: w1 W) C } -> (B: B h1 w1) (E: H W C) ...", H=2, W=2)
>>> # names can be longer than one character
>>> hax.rearrange(x, "{ B (H: h1 H) (W: w1 W) C } -> (B: B h1 w1) (embed: H W C) ...", H=2, W=2)
unbind(array: NamedArray, axis: AxisSelector) -> list[NamedArray]
¤
Unbind an array along an axis, returning a list of NamedArrays, one for each position on that axis. Analogous to torch.unbind or np.rollaxis
unflatten_axis(array: NamedArray, axis: AxisSelector, new_axes: AxisSpec) -> NamedArray
¤
Split an axis into a sequence of axes. The old axis must have the same size as the product of the new axes.
split(a: NamedArray, axis: AxisSelector, new_axes: Sequence[Axis]) -> Sequence[NamedArray]
¤
Splits an array along an axis into multiple arrays, one for each element of new_axes.
Parameters:
-
(a¤NamedArray) –the array to split
-
(axis¤AxisSelector) –the axis to split along
-
(new_axes¤Sequence[Axis]) –the axes to split into. Must have the same total length as the axis being split.
ravel(array: NamedArray, new_axis_name: AxisSelector) -> NamedArray
¤
Returns a flattened view of the array, with all axes merged into one
Operations¤
Binary and unary operations are all more or less directly from JAX's NumPy API. The only difference is they operate on named arrays instead.
Matrix Multiplication¤
See also the page on Matrix Multiplication as well as the cheat sheet section.
dot(*arrays, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, out_axes: PartialAxisSpec | None = None, dot_general=jax.lax.dot_general, **kwargs) -> NamedArray
¤
dot(axis: AxisSelection | None, *arrays: NamedArray, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, out_axes: PartialAxisSpec | None = ..., dot_general=jax.lax.dot_general) -> NamedArray
dot(*arrays: NamedArray, axis: AxisSelection | None, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, out_axes: PartialAxisSpec | None = ..., dot_general=jax.lax.dot_general) -> NamedArray
Returns the tensor product of two NamedArrays. The axes axis are contracted over,
and any other axes that are shared between the arrays are batched over. Non-contracted Axes in one
that are not in the other are preserved.
Note that if axis is None, the result will be a scalar, not a NamedArray. The semantics of axis=None are
similar to e.g. how sum and other reduction functions work in numpy. If axis=(), then the result will be
an "outer product" of the arrays, i.e. a tensor with shape equal to the concatenation of the shapes of the arrays.
By default, the order of output axes is determined by the order of the input axes, such that each output axis occurs in the same order as it first occurs in the concatenation of the input axes.
If out_axes is provided, the output will be transposed to match the provided axes. out_axes may be a partial
specification of the output axes (using ellipses), in which case the output will be rearranged to be consistent
with the partial specification. For example, if out_axes=(..., Height, Width) and the output axes are
(Width, Height, Depth), the output will be transposed to (Depth, Height, Width). Multiple ellipses
are supported, in which case axes will be inserted according to a greedy heuristic that prefers to place
unconstrained axes as soon as all prior axes in the "natural" order are covered.
Parameters:
-
(*arrays¤NamedArray, default:()) –The arrays to contract.
-
(axis¤AxisSelection) –The axes to contract over.
-
(precision¤PrecisionLike, default:None) –The precision to use. Defaults to None. This argument is passed to
jax.numpy.einsum, which in turn passes it to jax.lax.dot_general. -
(preferred_element_type¤DTypeLike, default:None) –The preferred element type of the result. Defaults to None. This argument is passed to
jax.numpy.einsum. -
(out_axes¤PartialAxisSpec | None, default:None) –a potentially partial specification of the output axes. If provided, the output will be transposed to match the provided axes. Defaults to None.
Returns:
-
NamedArray(NamedArray) –The result of the contraction.
einsum(equation: str, *arrays: NamedArray, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: DotGeneralOp = jax.lax.dot_general, **axis_aliases: AxisSelector) -> NamedArray
¤
Compute the tensor contraction of the input arrays according to Haliax's named variant of the Einstein summation convention.
Examples:
>>> # normal einsum
>>> import haliax as hax
>>> H = hax.Axis("H", 32)
>>> W = hax.Axis("W", 32)
>>> D = hax.Axis("D", 64)
>>> a = hax.zeros((H, W, D))
>>> b = hax.zeros((D, W, H))
>>> hax.einsum("h w d, d w h -> h w", a, b)
>>> # named einsum
>>> hax.einsum("{H W D} -> H W", a, b)
>>> hax.einsum("{D} -> ", a, b) # same as the previous example
>>> hax.einsum("-> H W", a, b) # same as the first example
>>> # axis aliases, useful for generic code
>>> hax.einsum("{x y} -> y", a, b, x=H, y=W)
Parameters:
-
(equation¤str) –The einsum equation.
-
(arrays¤NamedArray, default:()) –The input arrays.
-
(precision¤PrecisionLike, default:None) –The precision of the computation.
-
(preferred_element_type¤DTypeLike | None, default:None) –The preferred element type of the computation.
-
(_dot_general¤DotGeneralOp, default:dot_general) –The dot_general function to use.
-
(axis_aliases¤AxisSelector, default:{}) –The axis aliases to use.
Returns:
-
NamedArray–The result of the einsum.
Reductions¤
Reduction operations are things like haliax.sum and haliax.mean that reduce an array along one or more axes. Except for haliax.argmin and haliax.argmax, they all have the form:
def sum(x, axis: Optional[AxisSelection] = None, where: Optional[NamedArray] = None) -> haliax.NamedArray:
...
with the behavior closely following that of JAX's NumPy API. The axis argument can
be a single axis (or axis name), a tuple of axes, or None to reduce all axes. The where argument is a boolean array
that specifies which elements to include in the reduction. It must be broadcastable to the input array, using
Haliax's broadcasting rules.
The result of a reduction operation is always haliax.NamedArray with the reduced axes removed. If you reduce all axes, the result is a NamedArray with 0 axes, i.e. a scalar. You can convert it to a jax.numpy.ndarray with haliax.NamedArray.scalar, or just haliax.NamedArray.array.
all(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray
¤
Named version of jax.numpy.all.
amax(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray
¤
Aliax for max. See max for details.
amin(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray
¤
Aliax for min. See min for details.
any(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray
¤
True if any elements along a given axis or axes are True. If axis is None, any elements are True.
argmax(array: NamedArray, axis: AxisSelector | None) -> NamedArray
¤
argmin(array: NamedArray, axis: AxisSelector | None) -> NamedArray
¤
max(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray
¤
mean(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray
¤
min(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray
¤
nanargmax(array: NamedArray, axis: AxisSelector | None = None) -> NamedArray
¤
nanargmin(array: NamedArray, axis: AxisSelector | None = None) -> NamedArray
¤
nanmax(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray
¤
nanmean(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray
¤
nanmin(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray
¤
nanprod(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray
¤
nanstd(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, ddof: int = 0, dtype: DTypeLike | None = None) -> NamedArray
¤
nansum(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray
¤
nanvar(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, ddof: int = 0, dtype: DTypeLike | None = None) -> NamedArray
¤
prod(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray
¤
ptp(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None) -> NamedArray
¤
std(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, ddof: int = 0, dtype: DTypeLike | None = None) -> NamedArray
¤
sum(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, dtype: DTypeLike | None = None) -> NamedArray
¤
var(array: NamedArray, axis: AxisSelection | None = None, *, where: NamedArray | None = None, ddof: int = 0, dtype: DTypeLike | None = None) -> NamedArray
¤
Axis-wise Operations¤
Axis-wise operations are things like haliax.cumsum and haliax.sort that operate on a single axis of an array but don't reduce it.
cumsum(a: NamedArray, axis: AxisSelector, *, dtype: DTypeLike | None = None) -> NamedArray
¤
Named version of jax.numpy.cumsum
cumprod(a: NamedArray, axis: AxisSelector, dtype: DTypeLike | None = None) -> NamedArray
¤
Named version of jax.numpy.cumprod
nancumprod(a: NamedArray, axis: AxisSelector, dtype: DTypeLike | None = None) -> NamedArray
¤
Named version of jax.numpy.nancumprod
nancumsum(a: NamedArray, axis: AxisSelector, *, dtype: DTypeLike | None = None) -> NamedArray
¤
Named version of jax.numpy.nancumsum
sort(a: NamedArray, axis: AxisSelector) -> NamedArray
¤
Named version of jax.numpy.sort
argsort(a: NamedArray, axis: AxisSelector | None, *, stable: bool = False) -> NamedArray
¤
Named version of jax.numpy.argsort.
If axis is None, the returned array will be a 1D array of indices that would sort the flattened array,
identical to jax.numpy.argsort(a.array).
Parameters:
Unary Operations¤
The A in these operations means haliax.NamedArray, a Scalar, or jax.numpy.ndarray.
These are all more or less directly from JAX's NumPy API.
abs(a: A) -> A
¤
absolute(a: A) -> A
¤
angle(a: A) -> A
¤
arccos(a: A) -> A
¤
arccosh(a: A) -> A
¤
arcsin(a: A) -> A
¤
arcsinh(a: A) -> A
¤
arctan(a: A) -> A
¤
arctanh(a: A) -> A
¤
around(a: A) -> A
¤
bitwise_count(a: A) -> A
¤
bitwise_invert(a: A) -> A
¤
bitwise_not(a: A) -> A
¤
cbrt(a: A) -> A
¤
ceil(a: A) -> A
¤
conj(a: A) -> A
¤
conjugate(a: A) -> A
¤
copy(a: A) -> A
¤
cos(a: A) -> A
¤
cosh(a: A) -> A
¤
deg2rad(a: A) -> A
¤
degrees(a: A) -> A
¤
exp(a: A) -> A
¤
exp2(a: A) -> A
¤
expm1(a: A) -> A
¤
fabs(a: A) -> A
¤
fix(a: A) -> A
¤
floor(a: A) -> A
¤
frexp(a: A) -> A
¤
i0(a: A) -> A
¤
imag(a: A) -> A
¤
invert(a: A) -> A
¤
iscomplex(a: A) -> A
¤
isfinite(a: A) -> A
¤
isinf(a: A) -> A
¤
isnan(a: A) -> A
¤
isneginf(a: A) -> A
¤
isposinf(a: A) -> A
¤
isreal(a: A) -> A
¤
log(a: A) -> A
¤
log10(a: A) -> A
¤
log1p(a: A) -> A
¤
log2(a: A) -> A
¤
logical_not(a: A) -> A
¤
ndim(a: A) -> A
¤
negative(a: A) -> A
¤
positive(a: A) -> A
¤
rad2deg(a: A) -> A
¤
radians(a: A) -> A
¤
real(a: A) -> A
¤
reciprocal(a: A) -> A
¤
rint(a: A) -> A
¤
round(a: A, decimals: int = 0) -> A
¤
rsqrt(a: A) -> A
¤
sign(a: A) -> A
¤
signbit(a: A) -> A
¤
sin(a: A) -> A
¤
sinc(a: A) -> A
¤
sinh(a: A) -> A
¤
square(a: A) -> A
¤
sqrt(a: A) -> A
¤
tan(a: A) -> A
¤
tanh(a: A) -> A
¤
trunc(a: A) -> A
¤
Binary Operations¤
add(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.add
arctan2(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.arctan2
bitwise_and(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.bitwise_and
bitwise_left_shift(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.bitwise_left_shift
bitwise_or(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.bitwise_or
bitwise_right_shift(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.bitwise_right_shift
bitwise_xor(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.bitwise_xor
divide(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.divide
divmod(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.divmod
equal(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.equal
float_power(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.float_power
floor_divide(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.floor_divide
fmax(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.fmax
fmin(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.fmin
fmod(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.fmod
greater(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.greater
greater_equal(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.greater_equal
hypot(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.hypot
left_shift(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.left_shift
less(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.less
less_equal(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.less_equal
logaddexp(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.logaddexp
logaddexp2(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
logical_and(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
logical_or(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
logical_xor(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
maximum(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
minimum(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
mod(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
multiply(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
nextafter(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
not_equal(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
power(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
remainder(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
right_shift(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
subtract(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
true_divide(x1: NamedOrNumeric, x2: NamedOrNumeric) -> NamedOrNumeric
¤
Polynomial Operations¤
poly
¤
Polynomial helpers for :mod:haliax.
This module provides NamedArray-aware wrappers around :mod:jax.numpy's
polynomial utilities.
Functions:
-
poly–Named version of jax.numpy.poly.
-
polyadd–Named version of jax.numpy.polyadd.
-
polysub–Named version of jax.numpy.polysub.
-
polymul–Named version of jax.numpy.polymul.
-
polydiv–Named version of jax.numpy.polydiv.
-
polyint–Named version of jax.numpy.polyint.
-
polyder–Named version of jax.numpy.polyder.
-
polyval–Named version of jax.numpy.polyval.
-
polyfit–Named version of jax.numpy.polyfit.
-
roots–Named version of jax.numpy.roots.
-
trim_zeros–Named version of jax.numpy.trim_zeros.
-
vander–Named version of jax.numpy.vander.
Attributes:
-
DEFAULT_POLY_AXIS_NAME–Default name used for polynomial coefficient axes when none is provided.
DEFAULT_POLY_AXIS_NAME = 'degree'
module-attribute
¤
Default name used for polynomial coefficient axes when none is provided.
poly(seq_of_zeros: NamedArray | ArrayLike) -> NamedArray
¤
Named version of jax.numpy.poly.
If seq_of_zeros is not a haliax.NamedArray, the returned coefficient axis
is named degree.
polyadd(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray
¤
Named version of jax.numpy.polyadd.
If neither input is a haliax.NamedArray, the coefficient axis is named
degree; otherwise the axis from the NamedArray input is reused (resized as
needed).
polysub(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray
¤
Named version of jax.numpy.polysub.
If neither input is a haliax.NamedArray, the coefficient axis is named
degree; otherwise the axis from the NamedArray input is reused (resized as
needed).
polymul(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray
¤
Named version of jax.numpy.polymul.
If neither input is a haliax.NamedArray, the coefficient axis is named
degree; otherwise the axis from the NamedArray input is reused (resized as
needed).
polydiv(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> tuple[NamedArray, NamedArray]
¤
Named version of jax.numpy.polydiv.
The quotient and remainder reuse the coefficient axis from the NamedArray input
when available; otherwise their coefficient axes are named degree.
polyint(p: NamedArray | ArrayLike, m: int = 1, k: ArrayLike | NamedArray | None = None) -> NamedArray
¤
Named version of jax.numpy.polyint.
If p is not a haliax.NamedArray, the integrated polynomial uses a
coefficient axis named degree.
polyder(p: NamedArray | ArrayLike, m: int = 1) -> NamedArray
¤
Named version of jax.numpy.polyder.
If p is not a haliax.NamedArray, the differentiated polynomial uses a
coefficient axis named degree.
polyval(p: NamedArray | ArrayLike, x: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.polyval.
When x is a haliax.NamedArray, the returned array reuses x's axes.
Otherwise a regular :mod:jax.numpy array is returned.
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = None, full: bool = False, w: NamedArray | ArrayLike | None = None, cov: bool = False) -> NamedArray | tuple
¤
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[False] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[False] = ...) -> NamedArray
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[True] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[False] = ...) -> tuple[NamedArray, Array, Array, Array, Array]
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[False] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[True] = ...) -> tuple[NamedArray, NamedArray]
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[True] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[True] = ...) -> tuple[NamedArray, Array, Array, Array, Array]
Named version of jax.numpy.polyfit.
If neither x nor y is a haliax.NamedArray, the fitted coefficients
use a coefficient axis named degree; otherwise the axis from the NamedArray
input is reused. When cov is True, the returned covariance matrix is
wrapped in a haliax.NamedArray whose row axis matches the coefficient
axis and whose column axis uses the same name with a "_cov" suffix.
roots(p: NamedArray | ArrayLike) -> NamedArray
¤
Named version of jax.numpy.roots.
If p is not a haliax.NamedArray, the root axis is named degree.
trim_zeros(f: NamedArray | ArrayLike, trim: str = 'fb') -> NamedArray
¤
Named version of jax.numpy.trim_zeros.
If f is not a haliax.NamedArray, the trimmed coefficient axis is named
degree.
vander(x: NamedArray, degree: AxisSelector) -> NamedArray
¤
Named version of jax.numpy.vander.
Parameters:
-
(x¤NamedArray) –Input array of shape
(n,). -
(degree¤AxisSelector) –Axis for the polynomial degree in the output. If a string is provided, an axis with that name and size
nis created.
Returns:
-
NamedArray–Vandermonde matrix with row axis from
xand the provided degree axis.
polyadd(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray
¤
Named version of jax.numpy.polyadd.
If neither input is a haliax.NamedArray, the coefficient axis is named
degree; otherwise the axis from the NamedArray input is reused (resized as
needed).
polysub(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray
¤
Named version of jax.numpy.polysub.
If neither input is a haliax.NamedArray, the coefficient axis is named
degree; otherwise the axis from the NamedArray input is reused (resized as
needed).
polymul(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> NamedArray
¤
Named version of jax.numpy.polymul.
If neither input is a haliax.NamedArray, the coefficient axis is named
degree; otherwise the axis from the NamedArray input is reused (resized as
needed).
polydiv(p1: NamedArray | ArrayLike, p2: NamedArray | ArrayLike) -> tuple[NamedArray, NamedArray]
¤
Named version of jax.numpy.polydiv.
The quotient and remainder reuse the coefficient axis from the NamedArray input
when available; otherwise their coefficient axes are named degree.
polyint(p: NamedArray | ArrayLike, m: int = 1, k: ArrayLike | NamedArray | None = None) -> NamedArray
¤
Named version of jax.numpy.polyint.
If p is not a haliax.NamedArray, the integrated polynomial uses a
coefficient axis named degree.
polyder(p: NamedArray | ArrayLike, m: int = 1) -> NamedArray
¤
Named version of jax.numpy.polyder.
If p is not a haliax.NamedArray, the differentiated polynomial uses a
coefficient axis named degree.
polyval(p: NamedArray | ArrayLike, x: NamedOrNumeric) -> NamedOrNumeric
¤
Named version of jax.numpy.polyval.
When x is a haliax.NamedArray, the returned array reuses x's axes.
Otherwise a regular :mod:jax.numpy array is returned.
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = None, full: bool = False, w: NamedArray | ArrayLike | None = None, cov: bool = False) -> NamedArray | tuple
¤
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[False] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[False] = ...) -> NamedArray
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[True] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[False] = ...) -> tuple[NamedArray, Array, Array, Array, Array]
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[False] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[True] = ...) -> tuple[NamedArray, NamedArray]
polyfit(x: NamedArray | ArrayLike, y: NamedArray | ArrayLike, deg: int, rcond: ArrayLike | None = ..., full: Literal[True] = ..., w: NamedArray | ArrayLike | None = ..., cov: Literal[True] = ...) -> tuple[NamedArray, Array, Array, Array, Array]
Named version of jax.numpy.polyfit.
If neither x nor y is a haliax.NamedArray, the fitted coefficients
use a coefficient axis named degree; otherwise the axis from the NamedArray
input is reused. When cov is True, the returned covariance matrix is
wrapped in a haliax.NamedArray whose row axis matches the coefficient
axis and whose column axis uses the same name with a "_cov" suffix.
roots(p: NamedArray | ArrayLike) -> NamedArray
¤
Named version of jax.numpy.roots.
If p is not a haliax.NamedArray, the root axis is named degree.
trim_zeros(f: NamedArray | ArrayLike, trim: str = 'fb') -> NamedArray
¤
Named version of jax.numpy.trim_zeros.
If f is not a haliax.NamedArray, the trimmed coefficient axis is named
degree.
vander(x: NamedArray, degree: AxisSelector) -> NamedArray
¤
Named version of jax.numpy.vander.
Parameters:
-
(x¤NamedArray) –Input array of shape
(n,). -
(degree¤AxisSelector) –Axis for the polynomial degree in the output. If a string is provided, an axis with that name and size
nis created.
Returns:
-
NamedArray–Vandermonde matrix with row axis from
xand the provided degree axis.
Other Operations¤
bincount(x: NamedArray, Counts: Axis, *, weights: NamedArray | ArrayLike | None = None, minlength: int = 0) -> NamedArray
¤
Named version of jax.numpy.bincount.
The output axis is specified by Counts.
clip(array: NamedOrNumeric, a_min: NamedOrNumeric, a_max: NamedOrNumeric) -> NamedArray
¤
Like jnp.clip, but with named axes. This version currently only accepts the three argument form.
packbits(a: NamedArray, axis: AxisSelector, *, bitorder: str = 'big') -> NamedArray
¤
Named version of jax.numpy.packbits.
unpackbits(a: NamedArray, axis: AxisSelector, *, count: int | None = None, bitorder: str = 'big') -> NamedArray
¤
Named version of jax.numpy.unpackbits.
isclose(a: NamedArray, b: NamedArray, rtol=1e-05, atol=1e-08, equal_nan=False) -> NamedArray
¤
Returns a boolean array where two arrays are element-wise equal within a tolerance.
allclose(a: NamedArray, b: NamedArray, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool
¤
Returns True if two arrays are element-wise equal within a tolerance.
array_equal(a: NamedArray, b: NamedArray) -> bool
¤
Returns True if two arrays have the same shape and elements.
array_equiv(a: NamedArray, b: NamedArray) -> bool
¤
Returns True if two arrays are shape-consistent and equal.
pad(array: NamedArray, pad_width: Mapping[AxisSelector, tuple[int, int]], *, mode: str = 'constant', constant_values: NamedOrNumeric = 0, **kwargs) -> NamedArray
¤
Version of jax.numpy.pad that works with NamedArray.
pad_width should be a mapping from axis (or axis name) to a (before, after)
tuple specifying how much padding to add on each side of that axis. Any axis
not present in pad_width will not be padded.
searchsorted(a: NamedArray, v: NamedArray | ArrayLike, *, side: str = 'left', sorter: NamedArray | ArrayLike | None = None, method: str = 'scan') -> NamedArray
¤
Named version of jax.numpy.searchsorted.
a and sorter (if provided) must be one-dimensional.
The returned array has the same axes as v.
top_k(arr: NamedArray, axis: AxisSelector, k: int, new_axis: AxisSelector | None = None) -> tuple[NamedArray, NamedArray]
¤
Select the top k elements along the given axis. Args: arr (NamedArray): array to select from axis (AxisSelector): axis to select from k (int): number of elements to select new_axis (AxisSelector | None): new axis name, if none, the original axis will be resized to k
Returns:
-
NamedArray(NamedArray) –array with the top k elements along the given axis
-
NamedArray(NamedArray) –array with the top k elements' indices along the given axis
nonzero(array: NamedArray, *, size: Axis, fill_value: int = 0) -> tuple[NamedArray, ...]
¤
Like :func:jax.numpy.nonzero, but with named axes.
Parameters:
-
(array¤NamedArray) –The input array to test for nonzero values. Must be a :class:
NamedArray. -
(size¤Axis) –Axis specifying the size of the output axis. This is required because JAX requires the size of the result at tracing time.
-
(fill_value¤int, default:0) –Value used to fill the output when fewer than
sizeelements are nonzero. Defaults to0.
Returns:
-
NamedArray–A tuple of :class:
NamedArrayobjects, one for each axis ofarray. Each -
...–returned array has
sizeas its only axis and contains the indices of the -
tuple[NamedArray, ...]–nonzero elements along the corresponding input axis.
trace(array: NamedArray, axis1: AxisSelector, axis2: AxisSelector, offset=0, dtype=None) -> NamedArray
¤
Compute the trace of an array along two named axes.
tril(array: NamedArray, axis1: Axis, axis2: Axis, k=0) -> NamedArray
¤
Compute the lower triangular part of an array along two named axes.
triu(array: NamedArray, axis1: Axis, axis2: Axis, k=0) -> NamedArray
¤
Compute the upper triangular part of an array along two named axes.
where(condition: NamedOrNumeric | bool, x: NamedOrNumeric | None = None, y: NamedOrNumeric | None = None, fill_value: int | None = None, new_axis: Axis | None = None) -> NamedArray | tuple[NamedArray, ...]
¤
where(condition: NamedOrNumeric | bool, x: NamedOrNumeric, y: NamedOrNumeric) -> NamedArray
where(condition: NamedArray, *, fill_value: int, new_axis: Axis) -> tuple[NamedArray, ...]
Like jnp.where, but with named axes.
FFT¤
All FFT helpers accept an axis argument which may be a single axis, its
name, or an ordered mapping from axes to output sizes. Passing a mapping
dispatches to the n‑dimensional variants in :mod:jax.numpy.fft.
For example::
import jax.numpy as jnp
import haliax as hax
T = hax.Axis("time", 8)
signal = hax.arange(T, dtype=jnp.float32)
# operate along a single axis specified by name
hax.fft(signal, axis="time")
# resize by passing an Axis object
hax.fft(signal, axis=hax.Axis("time", 16))
X, Y = hax.make_axes(X=4, Y=6)
image = hax.arange((X, Y), dtype=jnp.float32)
# transform across several axes in order by passing a sequence
hax.fft(image, axis=("X", "Y"))
# selectively resize axes by providing a mapping
hax.fft(image, axis={"X": None, "Y": hax.Axis("Y", 10)})
# mappings can cover just a subset of axes when only partial resizing is needed
hax.fft(image, axis={"Y": 10})
fft
¤
Named wrappers around :mod:jax.numpy.fft.
These functions mirror the behaviour of their :mod:jax.numpy.fft counterparts
while accepting named axes. Instead of separate fftn/fft2 variants we
provide a single fft family of functions whose axis argument controls
which axes are transformed.
The axis parameter can be one of:
None– operate on the last axis.str– name of an existing axis in the input.- :class:
~haliax.Axis– specifies both the axis to transform (by name) and the desired FFT length. The output axis is replaced by the providedAxis. dict– mapping from axis selectors (names orAxisobjects) to optional sizes. A value ofNoneuses the existing axis length. The mapping order determines the order of transforms and dispatches to then‑dimensional variants in :mod:jax.numpy.fft.
Example¤
X, Y = hax.make_axes(X=4, Y=6)
arr = hax.arange((X, Y))
# 1D transform along ``Y``
hax.fft(arr, axis="Y")
# 2D transform across both axes
hax.fft(arr, axis={"X": None, "Y": None})
# Resize the ``Y`` axis before transforming
hax.fft(arr, axis={"Y": Axis("Y", 8)})
Functions:
-
fft–Named version of :func:
jax.numpy.fft.fft. -
ifft–Named version of :func:
jax.numpy.fft.ifft. -
rfft–Named version of :func:
jax.numpy.fft.rfft. -
irfft–Named version of :func:
jax.numpy.fft.irfft. -
hfft–Named version of :func:
jax.numpy.fft.hfft. -
ihfft–Named version of :func:
jax.numpy.fft.ihfft. -
fftshift–Named version of :func:
jax.numpy.fft.fftshift. -
ifftshift–Named version of :func:
jax.numpy.fft.ifftshift. -
fftfreq–Named version of :func:
jax.numpy.fft.fftfreq. -
rfftfreq–Named version of :func:
jax.numpy.fft.rfftfreq.
AxisSizeLike = int | Axis | None
module-attribute
¤
AxisMapping = Mapping[AxisSelector, AxisSizeLike]
module-attribute
¤
fft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.fft.
See module level documentation for the behaviour of the axis argument.
ifft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.ifft.
rfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.rfft.
irfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.irfft.
hfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.hfft.
Only a single axis is supported; passing a dictionary with more than one entry will raise an error.
ihfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.ihfft.
Only a single axis is supported; passing a dictionary with more than one entry will raise an error.
fftshift(x: NamedArray, axes: AxisSelection | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.fftshift.
ifftshift(x: NamedArray, axes: AxisSelection | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.ifftshift.
fftfreq(axis: Axis, d: float = 1.0) -> NamedArray
¤
Named version of :func:jax.numpy.fft.fftfreq.
rfftfreq(axis: Axis, d: float = 1.0) -> NamedArray
¤
Named version of :func:jax.numpy.fft.rfftfreq.
ifft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.ifft.
hfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.hfft.
Only a single axis is supported; passing a dictionary with more than one entry will raise an error.
ihfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.ihfft.
Only a single axis is supported; passing a dictionary with more than one entry will raise an error.
rfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.rfft.
irfft(a: NamedArray, axis: AxisSelector | Sequence[AxisSelector] | AxisMapping | None = None, norm: str | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.irfft.
fftfreq(axis: Axis, d: float = 1.0) -> NamedArray
¤
Named version of :func:jax.numpy.fft.fftfreq.
rfftfreq(axis: Axis, d: float = 1.0) -> NamedArray
¤
Named version of :func:jax.numpy.fft.rfftfreq.
fftshift(x: NamedArray, axes: AxisSelection | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.fftshift.
ifftshift(x: NamedArray, axes: AxisSelection | None = None) -> NamedArray
¤
Named version of :func:jax.numpy.fft.ifftshift.
Named Array Reference¤
Most methods on haliax.NamedArray just call the corresponding haliax function with the array as the first argument,
just as with Numpy.
The exceptions are documented here:
NamedArray(array: jnp.ndarray, axes: AxisSpec)
dataclass
¤
Methods:
-
item–Returns the value of this NamedArray as a python scalar.
-
scalar–Returns a scalar array corresponding to the value of this NamedArray.
-
tree_flatten– -
tree_unflatten– -
has_axis–Returns true if the given axis is present in this NamedArray.
-
matches_axes–Check whether this NamedArray conforms to the given
NamedArraytype. -
axis_size–Returns the size of the given axis, or a tuple of sizes if given multiple axes.
-
resolve_axis–Returns the axes corresponding to the given axis selection.
-
axis_indices–For a single axis, returns an int corresponding to the index of the axis.
-
rearrange–See haliax.rearrange for details.
-
broadcast_to– -
broadcast_axis– -
split– -
flatten_axes– -
unflatten_axis– -
ravel– -
flatten– -
unbind– -
rename– -
slice– -
updated_slice– -
take– -
all– -
any– -
argmax– -
argmin– -
argsort– -
astype– -
clip– -
conj– -
conjugate– -
copy– -
cumprod– -
cumsum– -
dot– -
max– -
mean– -
min– -
prod– -
product– -
ptp– -
round– -
sort– -
std– -
sum– -
tobytes– -
tolist– -
trace– -
var–
Attributes:
-
array(ndarray) – -
axes(tuple[Axis, ...]) – -
shape(dict[str, int]) – -
dtype–The dtype of the underlying array
-
ndim–The number of dimensions of the underlying array
-
size–The number of elements in the underlying array
-
nbytes–The number of bytes in the underlying array
-
at('_NamedIndexUpdateHelper') –Named analog of [jax's at method](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html). -
imag('NamedArray') – -
real('NamedArray') –
array: jnp.ndarray
instance-attribute
¤
axes: tuple[Axis, ...]
instance-attribute
¤
shape: dict[str, int]
cached
property
¤
dtype = property(lambda self: self.array.dtype)
class-attribute
instance-attribute
¤
The dtype of the underlying array
ndim = property(lambda self: self.array.ndim)
class-attribute
instance-attribute
¤
The number of dimensions of the underlying array
size = property(lambda self: self.array.size)
class-attribute
instance-attribute
¤
The number of elements in the underlying array
nbytes = property(lambda self: self.array.nbytes)
class-attribute
instance-attribute
¤
The number of bytes in the underlying array
at: '_NamedIndexUpdateHelper'
property
¤
Named analog of [jax's at method](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html).
The at[...].set(...) pattern is a functional way to update elements of an array.
This is just the named version, using the same indexing syntax we use for slicing and
the JAX syntax for at etc.
Docs from the JAX docs:
The at property provides a functionally pure equivalent of in-place array modifications.
In particular:
| Alternate syntax | Equivalent In-place expression |
|---|---|
x = x.at[idx].set(y) |
x[idx] = y |
x[idx] = y |
x = x.at[idx].set(y) |
x = x.at[idx].add(y) |
x[idx] += y |
x = x.at[idx].multiply(y) |
x[idx] *= y |
x = x.at[idx].divide(y) |
x[idx] /= y |
x = x.at[idx].power(y) |
x[idx] **= y |
x = x.at[idx].min(y) |
x[idx] = minimum(x[idx], y) |
x = x.at[idx].max(y) |
x[idx] = maximum(x[idx], y) |
x = x.at[idx].apply(ufunc) |
ufunc.at(x, idx) |
x = x.at[idx].get()
x = x[idx]
None of the x.at expressions modify the original x; instead they return a modified copy of x. However, inside a jit() compiled function, expressions like x = x.at[idx].set(y) are guaranteed to be applied in-place.
Unlike NumPy in-place operations such as x[idx] += y, if multiple indices refer to the same location, all updates will be applied (NumPy would only apply the last update, rather than applying all updates.) The order in which conflicting updates are applied is implementation-defined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).
By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the mode parameter (see below).
Returns:
imag: 'NamedArray'
property
¤
real: 'NamedArray'
property
¤
item()
¤
Returns the value of this NamedArray as a python scalar.
scalar() -> jnp.ndarray
¤
Returns a scalar array corresponding to the value of this NamedArray. Raises an error if the NamedArray is not scalar.
We sometimes use this to convert a NamedArray to a scalar for returning a loss or similar. Losses have to be jnp.ndarrays, not NamedArrays, so we need to convert them. item doesn't work inside jitted functions because it returns a python scalar.
You could just call array, but that's not as clear and doesn't assert.
tree_flatten() -> Any
¤
tree_unflatten(aux, tree: Any) -> Any
classmethod
¤
has_axis(axis: AxisSelection) -> bool
¤
Returns true if the given axis is present in this NamedArray.
matches_axes(spec: NamedArrayAxesSpec) -> bool
¤
Check whether this NamedArray conforms to the given NamedArray type.
Parameters¤
spec : NamedArrayAxesSpec
The specification to check against. It can be produced via the
NamedArray[...] syntax or passed directly as a string or
sequence of axis names.
axis_size(axis: AxisSelection) -> int | tuple[int, ...]
¤
axis_size(axis: AxisSelector) -> int
axis_size(axis: Sequence[AxisSelector]) -> tuple[int, ...]
Returns the size of the given axis, or a tuple of sizes if given multiple axes.
resolve_axis(axes: AxisSelection) -> AxisSpec
¤
resolve_axis(axis: AxisSelector) -> Axis
resolve_axis(axis: tuple[AxisSelector, ...]) -> tuple[Axis, ...]
resolve_axis(axis: PartialShapeDict) -> ShapeDict
resolve_axis(axes: AxisSelection) -> AxisSpec
Returns the axes corresponding to the given axis selection. That is, it returns the haliax.Axis values themselves, not just their names.
Raises a ValueError if any of the axes are not found.
axis_indices(axis: AxisSelection) -> int | None | tuple[int | None, ...]
¤
axis_indices(axis: AxisSelector) -> int | None
axis_indices(axis: Sequence[AxisSelector]) -> tuple[int | None, ...]
axis_indices(axis: AxisSelection) -> tuple[int | None, ...]
For a single axis, returns an int corresponding to the index of the axis. For multiple axes, returns a tuple of ints corresponding to the indices of the axes.
If the axis is not present, returns None for that position
rearrange(*args, **kwargs) -> 'NamedArray'
¤
rearrange(axes: Sequence[AxisSelector | EllipsisType]) -> 'NamedArray'
rearrange(expression: str, **bindings: AxisSelector | int) -> 'NamedArray'
See haliax.rearrange for details.
broadcast_to(axes: AxisSpec) -> 'NamedArray'
¤
broadcast_axis(axis: AxisSpec) -> 'NamedArray'
¤
split(axis: AxisSelector, new_axes: Sequence[Axis]) -> Sequence['NamedArray']
¤
flatten_axes(old_axes: AxisSelection, new_axis: AxisSelector) -> 'NamedArray'
¤
unflatten_axis(axis: AxisSelector, new_axes: AxisSpec) -> 'NamedArray'
¤
ravel(new_axis_name: AxisSelector) -> 'NamedArray'
¤
flatten(new_axis_name: AxisSelector) -> 'NamedArray'
¤
unbind(axis: AxisSelector) -> Sequence['NamedArray']
¤
rename(renames: Mapping[AxisSelector, AxisSelector]) -> 'NamedArray'
¤
slice(*args, **kwargs) -> 'NamedArray'
¤
slice(axis: AxisSelector, new_axis: AxisSelector | None = None, start: int = 0, length: int | None = None) -> 'NamedArray'
slice(start: Mapping[AxisSelector, int], length: Mapping[AxisSelector, int | Axis]) -> 'NamedArray'
updated_slice(start: Mapping[AxisSelector, int | 'NamedArray'], update: 'NamedArray') -> 'NamedArray'
¤
take(axis: AxisSelector, index: int | 'NamedArray') -> 'NamedArray'
¤
all(axis: AxisSelection | None = None, *, where: 'NamedArray' | None = None) -> 'NamedArray'
¤
any(axis: AxisSelection | None = None, *, where: 'NamedArray' | None = None) -> 'NamedArray'
¤
argmax(axis: AxisSelector | None = None) -> 'NamedArray'
¤
argmin(axis: AxisSelector | None) -> 'NamedArray'
¤
argsort(axis: AxisSelector | None, *, stable: bool = False) -> 'NamedArray'
¤
astype(dtype) -> 'NamedArray'
¤
clip(a_min=None, a_max=None) -> Any
¤
conj() -> 'NamedArray'
¤
conjugate() -> 'NamedArray'
¤
copy() -> 'NamedArray'
¤
cumprod(axis: AxisSelector, *, dtype=None) -> 'NamedArray'
¤
cumsum(axis: AxisSelector, *, dtype=None) -> 'NamedArray'
¤
dot(*args, **kwargs) -> 'NamedArray'
¤
dot(axis: AxisSelection | None, *b, precision: PrecisionLike = None, dot_general=jax.lax.dot_general) -> 'NamedArray'
dot(*args: 'NamedArray', axis: AxisSelection | None, precision: PrecisionLike = None, dot_general=jax.lax.dot_general) -> 'NamedArray'
max(axis: AxisSelection | None = None, *, where=None) -> 'NamedArray'
¤
mean(axis: AxisSelection | None = None, *, dtype=None, where: 'NamedArray' | None = None) -> 'NamedArray'
¤
min(axis: AxisSelection | None = None, *, where: 'NamedArray' | None = None) -> 'NamedArray'
¤
prod(axis: AxisSelection | None = None, *, dtype=None, where: 'NamedArray' | None = None) -> 'NamedArray'
¤
product(axis: AxisSelection | None = None, *, dtype=None, where: 'NamedArray' | None = None) -> 'NamedArray'
¤
ptp(axis: AxisSelection | None = None) -> 'NamedArray'
¤
round(decimals=0) -> 'NamedArray'
¤
sort(axis: AxisSelector) -> Any
¤
std(axis: AxisSelection | None = None, *, dtype=None, ddof=0, where: 'NamedArray' | None = None) -> 'NamedArray'
¤
sum(axis: AxisSelection | None = None, *, dtype=None, where: 'NamedArray' | None = None) -> 'NamedArray'
¤
tobytes(order='C') -> Any
¤
tolist() -> Any
¤
trace(axis1: AxisSelector, axis2: AxisSelector, offset=0, dtype=None) -> 'NamedArray'
¤
var(axis: AxisSelection | None = None, dtype=None, ddof=0, *, where: 'NamedArray' | None = None) -> 'NamedArray'
¤
Partitioning API¤
See also the section on Partitioning.
axis_mapping(mapping: ResourceMapping, *, merge: bool = False, **kwargs)
¤
Context manager for setting the global resource mapping
shard(x: T, mapping: ResourceMapping | None = None, mesh: Mesh | None = None) -> T
¤
Shard a PyTree using the provided axis mapping. NamedArrays in the PyTree are sharded using the axis mapping. Other arrays (i.e. plain JAX arrays) are left alone.
This is basically a fancy wrapper around with_sharding_constraint that uses the axis mapping to determine
the sharding.
named_jit(fn: Callable[Args, R] | None = None, axis_resources: ResourceMapping | None = None, *, in_axis_resources: ResourceMapping | None = None, out_axis_resources: ResourceMapping | None = None, donate_args: PyTree | None = None, donate_kwargs: PyTree | None = None, **pjit_args) -> WrappedCallable[Args, R] | typing.Callable[[Callable[Args, R]], WrappedCallable[Args, R]]
¤
named_jit(fn: Callable[Args, R], axis_resources: ResourceMapping | None = None, *, in_axis_resources: ResourceMapping | None = None, out_axis_resources: ResourceMapping | None = None, donate_args: PyTree | None = None, donate_kwargs: PyTree | None = None, keep_unused: bool = False, backend: str | None = None, inline: bool | None = None) -> WrappedCallable[Args, R]
named_jit(*, axis_resources: ResourceMapping | None = None, in_axis_resources: ResourceMapping | None = None, out_axis_resources: ResourceMapping | None = None, donate_args: PyTree | None = None, donate_kwargs: PyTree | None = None, keep_unused: bool = False, backend: str | None = None, inline: bool | None = None) -> typing.Callable[[Callable[Args, R]], WrappedCallable[Args, R]]
A version of pjit that uses NamedArrays and the provided resource mapping to infer resource partitions for sharded computation for.
axis_resources will be used for a context-specific resource mapping when the function is invoked.
In addition, if in_axis_resources is not provided, the arguments' own (pre-existing) shardings will be used as the in_axis_resources.
If out_axis_resources is not provided, axis_resources will be used as the out_axis_resources.
If no resource mapping is provided, this function attempts to use the context resource mapping.
Functionally this is very similar to something like:
This function can be used as a decorator or as a function.
def wrapped_fn(arg):
result = fn(arg)
return hax.shard(result, out_axis_resources)
arg = hax.shard(arg, in_axis_resources)
with hax.axis_mapping(axis_resources):
result = jax.jit(wrapped_fn, **pjit_args)(arg)
return result
Parameters:
-
(fn¤Callable, default:None) –The function to be jit'd.
-
(axis_resources¤ResourceMapping, default:None) –A mapping from logical axis names to physical axis names use for the context-specific resource mapping.
-
(in_axis_resources¤ResourceMapping, default:None) –A mapping from logical axis names to physical axis names for arguments. If not passed, it uses the argument's own shardings.
-
(out_axis_resources¤ResourceMapping, default:None) –A mapping from logical axis names to physical axis names for the result, defaults to axis_resources.
-
(donate_args¤PyTree, default:None) –A PyTree of booleans or function leaf->bool, indicating if the arguments should be donated to the computation.
-
(donate_kwargs¤PyTree, default:None) –A PyTree of booleans or function leaf->bool, indication if the keyword arguments should be donated to the computation.
Returns:
fsdp(*args, **kwargs)
¤
fsdp(fn: F, parameter_mapping: ResourceMapping, compute_mapping: ResourceMapping) -> F
fsdp(parameter_mapping: ResourceMapping, compute_mapping: ResourceMapping) -> typing.Callable[[F], F]
A convenience wrapper around named_jit / pjit to encode the FSDP pattern. It's basically equivalent to this:
@named_jit(in_axis_resources=parameter_mapping, out_axis_resources=parameter_mapping, axis_resources=compute_mapping)
def f(*args, **kwargs):
return fn(*args, **kwargs)
This function can be used as a decorator or as a function.
round_axis_for_partitioning(axis: Axis, mapping: ResourceMapping | None = None) -> Axis
¤
Round an axis so that it's divisible by the size of the partition it's on
physical_axis_name(axis: AxisSelector, mapping: ResourceMapping | None = None) -> PhysicalAxisSpec | None
¤
Get the physical axis name for a logical axis from the mapping. Returns none if the axis is not mapped.
physical_axis_size(axis: AxisSelector, mapping: ResourceMapping | None = None) -> int | None
¤
Get the physical axis size for a logical axis. This is the product of the size of all physical axes that this logical axis is mapped to.
sharding_for_axis(axis: AxisSelection, mapping: ResourceMapping | None = None, mesh: MeshLike | None = None) -> NamedSharding
¤
Get the sharding for a single axis
Gradient Checkpointing¤
Haliax mainly just defers to JAX and equinox.filter_checkpoint for gradient checkpointing. However, we provide a few utilities to make it easier to use.
See also haliax.nn.ScanCheckpointPolicy.
tree_checkpoint_name(x: T, name: str) -> T
¤
Checkpoint a tree of arrays with a given name. This is useful for gradient checkpointing. This is equivalent to calling [jax.ad_checkpoint.checkpoint_name][] except that it works for any PyTree, not just arrays.
See Also
- [jax.ad_checkpoint.checkpoint_name][]
- haliax.nn.ScanCheckpointPolicy
Old API¤
These functions are being deprecated and will be removed in a future release.
shard_with_axis_mapping(x: T, mapping: ResourceMapping, mesh: Mesh | None = None) -> T
¤
auto_sharded(x: T, mesh: Optional[Mesh] = None) -> T
¤
Shard a PyTree using the global axis mapping. NamedArrays in the PyTree are sharded using the axis mapping and the names in the tree.
If there is no axis mapping, the global axis mapping, this function is a no-op.