Skip to content

Neural Networks¤

Modules¤

Haliax provides a small number of neural network modules that are compatible with Equinox, though they naturally all use haliax.NamedArray. (We welcome PRs for more modules! Nothing too exotic though.)

The most interesting of these modules is haliax.nn.Stacked, which allows you to create homogeneous "stacks" of the same module (e.g. transformer blocks), which is a common pattern in deep learning.

Linear¤

Embedding ¤

Bases: Module, ReparamEnabled

Methods:

  • init

    Initialize an Embedding module.

  • embed

    Args:

  • unembed

    Unembed the input embeddings back to the vocabulary space.

  • resize_embeddings

    Resize the embedding layer to a new size.

Attributes:

weight: NamedArray instance-attribute ¤
Vocab: Axis = eqx.field(static=True) class-attribute instance-attribute ¤
Embed: AxisSpec = eqx.field(static=True) class-attribute instance-attribute ¤
reparam: AbstractEmbeddingReparam property ¤
init(Vocab: Axis, Embed: AxisSpec, *, init_scale: float = 1, key, initializer_range: float | None = None, reparam_cls: type[AbstractEmbeddingReparam] = EmbeddingStandardParam) staticmethod ¤

Initialize an Embedding module.

An embedding module is a simple lookup table that maps integer indices to vectors or tensors. Weights are initialized with a truncated normal distribution with a standard deviation of init_scale / output_size.

Parameters:

  • Vocab ¤
    (Axis) –

    Size of the vocabulary

  • Embed ¤
    (AxisSpec) –

    Shape of the embedding vectors. May be a single axis or a full AxisSpec

  • init_scale ¤
    (float, default: 1 ) –

    Scale of the initialization

  • key ¤

    PRNG key

  • initializer_range ¤
    (float | None, default: None ) –

    Deprecated. Use init_scale instead.

embed(input_ids: NamedArray) ¤

Parameters:

  • input_ids ¤
    (NamedArray) –

    token IDs with shape > {Vocab}

unembed(input_embeds: NamedArray) ¤

Unembed the input embeddings back to the vocabulary space.

Equivalent to input_embeds.dot(self.weight, axis=self.Embed).

resize_embeddings(new_size: int, key: PRNGKeyArray | None = None) ¤

Resize the embedding layer to a new size. Args: new_size: New size of the vocabulary key: PRNG key for initialization of any new weights

Returns:

  • Embedding

    Resized embedding layer

Linear ¤

Bases: ModuleWithStateDictSerialization, ReparamEnabled

A named Linear layer. This module allows you to specify multiple named axes for both input and output, which is occasionally useful.

Methods:

Attributes:

weight: NamedArray instance-attribute ¤
bias: NamedArray | None instance-attribute ¤
In: AxisSpec = eqx.field(static=True) class-attribute instance-attribute ¤
Out: AxisSpec = eqx.field(static=True) class-attribute instance-attribute ¤
dot_general: DotGeneralOp = eqx.field(default_factory=(DotGeneralOp.default)) class-attribute instance-attribute ¤
reparam: AbstractLinearReparam property ¤
init(In: AxisSpec, Out: AxisSpec, *, key: PRNGKey, use_bias: bool = True, out_first: bool = True, dot_general: DotGeneralOp | None = None, init_scale: float = 1.0, reparam_cls: type[AbstractLinearReparam] = LinearStandardParam) -> Linear staticmethod ¤

Parameters:

  • In ¤
    (AxisSpec) –

    AxisSpec: The input axis spec

  • Out ¤
    (AxisSpec) –

    AxisSpec: The output axis spec

  • key ¤
    (PRNGKey) –

    PRNGKeyArray: The PRNG key to use for initialization

  • use_bias ¤
    (bool, default: True ) –

    bool: Whether to use a bias term

  • out_first ¤
    (bool, default: True ) –

    bool: Whether to put output axes first in the weight matrix. out_first is how PyTorch does it.

  • dot_general ¤
    (DotGeneralOp | None, default: None ) –

    Callable: The dot_general function to use. Defaults to jax.lax.dot_general.

  • init_scale ¤
    (float, default: 1.0 ) –

    float: The scale to use for initialization. We scale init by 1/sqrt(Input.size)*init_scale

flatten_for_export() -> Mod ¤
unflatten_from_export(template: Mod) -> Mod ¤
to_state_dict(prefix: Optional[str] = None) -> StateDict ¤
from_state_dict(state_dict: StateDict, prefix: Optional[str] = None) -> Mod ¤
input_reparam(use_mup: bool = True) -> type[AbstractLinearReparam] staticmethod ¤

Return the reparameterization class for an input linear layer.

hidden_reparam(use_mup: bool = True) -> type[AbstractLinearReparam] staticmethod ¤

Return the reparameterization class for a hidden linear layer.

output_reparam(use_mup: bool = True) -> type[AbstractLinearReparam] staticmethod ¤

Return the reparameterization class for an output linear layer.

Dropout¤

Dropout(pdrop: float = 0.5, broadcast_axes: AxisSpec | None = None, inference: bool = False) ¤

Bases: Module

Applies dropout.

Attributes:

  • pdrop (float) –

    The fraction of entries to set to zero.

  • broadcast_axes (AxisSpec | None) –

    The dimensions to broadcast the dropout mask over. If set, these axes will share the same mask

pdrop: float = pdrop class-attribute instance-attribute ¤
broadcast_axes: AxisSpec | None = broadcast_axes class-attribute instance-attribute ¤
inference: bool = inference class-attribute instance-attribute ¤
is_active property ¤

Returns True if dropout is active (and therefore needs a key), False otherwise.

Normalization¤

LayerNormBase ¤

Bases: ModuleWithStateDictSerialization

Methods:

Attributes:

axis: AxisSpec = eqx.field(static=True) class-attribute instance-attribute ¤
weight: NamedArray | None instance-attribute ¤
bias: NamedArray | None instance-attribute ¤
eps: float = eqx.field(default=1e-05, static=True) class-attribute instance-attribute ¤
dtype: jnp.dtype | None = eqx.field(default=None, static=True) class-attribute instance-attribute ¤
to_state_dict(prefix: str | None = None) -> StateDict ¤
from_state_dict(state_dict: StateDict, prefix: str | None = None) -> Mod ¤
init(axis: AxisSpec, eps: float = 1e-05, *, use_weight: bool = True, use_bias: bool = True, dtype: jnp.dtype | None = None) classmethod ¤
flatten_for_export() -> Mod ¤
unflatten_from_export(template: Mod) -> Mod ¤

LayerNorm ¤

Bases: LayerNormBase

Normalises the input along the specified axis (or axes), using the mean and variance of the input along that axis.

Methods:

Attributes:

axis: AxisSpec = eqx.field(static=True) class-attribute instance-attribute ¤
weight: NamedArray | None instance-attribute ¤
bias: NamedArray | None instance-attribute ¤
eps: float = eqx.field(default=1e-05, static=True) class-attribute instance-attribute ¤
dtype: jnp.dtype | None = eqx.field(default=None, static=True) class-attribute instance-attribute ¤
to_state_dict(prefix: str | None = None) -> StateDict ¤
from_state_dict(state_dict: StateDict, prefix: str | None = None) -> Mod ¤
flatten_for_export() -> Mod ¤
unflatten_from_export(template: Mod) -> Mod ¤
init(axis: AxisSpec, eps: float = 1e-05, *, use_weight: bool = True, use_bias: bool = True, dtype: jnp.dtype | None = None) classmethod ¤

RmsNorm ¤

Bases: LayerNormBase

Implements RMS normalization, which normalizes the input by dividing by the root mean square of the input.

Methods:

Attributes:

axis: AxisSpec = eqx.field(static=True) class-attribute instance-attribute ¤
weight: NamedArray | None instance-attribute ¤
bias: NamedArray | None instance-attribute ¤
eps: float = eqx.field(default=1e-05, static=True) class-attribute instance-attribute ¤
dtype: jnp.dtype | None = eqx.field(default=None, static=True) class-attribute instance-attribute ¤
to_state_dict(prefix: str | None = None) -> StateDict ¤
from_state_dict(state_dict: StateDict, prefix: str | None = None) -> Mod ¤
flatten_for_export() -> Mod ¤
unflatten_from_export(template: Mod) -> Mod ¤
init(axis: AxisSpec, eps: float = 1e-05, *, use_weight: bool = True, use_bias: bool = True, dtype: jnp.dtype | None = None) classmethod ¤

Meta¤

MLP ¤

Bases: Module

A multilayer perceptron (MLP) / feed-forward neural network (FFNN).

MLPs, with their stacked linear layers often with non-semantic axes for hidden dims, are not a particular strength of Haliax's design philosophy. Nonetheless, they are a useful tool for many tasks, and so we provide this module.

In Haliax, all axes must have names, and names must be unique within an array. We considered a few strategies for naming the axes of an MLP, and settled on the following: By default, we alternate hidden names between "mlp" and "mlp2". Input and output names must be specified, and are not repeated. This naming scheme is not perfect, but does mean that model parallelism works reasonably well.

NB: unlike Equinox's MLP, this MLP uses a static field for activation. If you want a learnable activation, you likely want a unique activation per layer, which neither version provides. Instead, you should use a haliax.nn.Stacked with a suitable block.

Methods:

Attributes:

activation: Callable = eqx.field(static=True) class-attribute instance-attribute ¤
layers: Sequence[Linear] instance-attribute ¤
In: AxisSpec property ¤
Out: AxisSpec property ¤
init(Input: AxisSpec, Output: AxisSpec, width: int | Axis, depth: int, activation: Callable = relu, *, out_first: bool = True, use_bias: bool = True, use_final_bias: bool = True, key: PRNGKeyArray, dot_general: DotGeneralOp | None = None, init_scale: float = 1.0) staticmethod ¤

Stacked¤

See the full documentation of Stacked.

Convolution¤

Unlike other frameworks, Haliax doesn't distinguish between 1D, 2D, and 3D, and general convolutions. Instead, we have a single haliax.nn.Conv module that can be used for all of these, depending on the number of axes provided. Similarly, for transposed convolutions, we have haliax.nn.ConvTranspose.

Conv ¤

Bases: _ConvBase

General N-dimensional convolution.

Methods:

Attributes:

Spatial: tuple[str | Axis, ...] = eqx.field(static=True) class-attribute instance-attribute ¤
In: Axis = eqx.field(static=True) class-attribute instance-attribute ¤
Out: Axis = eqx.field(static=True) class-attribute instance-attribute ¤
weight: NamedArray = eqx.field() class-attribute instance-attribute ¤
bias: NamedArray | None = eqx.field() class-attribute instance-attribute ¤
kernel_size: tuple[int, ...] = eqx.field(static=True) class-attribute instance-attribute ¤
stride: tuple[int, ...] = eqx.field(static=True) class-attribute instance-attribute ¤
padding: tuple[tuple[int, int], ...] = eqx.field(static=True) class-attribute instance-attribute ¤
dilation: tuple[int, ...] = eqx.field(static=True) class-attribute instance-attribute ¤
groups: int = eqx.field(static=True) class-attribute instance-attribute ¤
init(Spatial: AxisSelection, In: Axis, Out: Axis, kernel_size: int | Sequence[int], *, stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[tuple[int, int]] = 0, dilation: int | Sequence[int] = 1, groups: int = 1, use_bias: bool = True, key: PRNGKeyArray) staticmethod ¤

Parameters:

  • Spatial ¤
    (AxisSelection) –

    names of spatial dimensions

  • In ¤
    (Axis) –

    Axis of input channels

  • Out ¤
    (Axis) –

    Axis of output channels

  • kernel_size ¤
    (int | Sequence[int]) –

    The size of the convolutional kernel.

  • stride ¤
    (int | Sequence[int], default: 1 ) –

    The stride of the convolution. Can be a single number or a tuple

  • padding ¤
    (int | Sequence[int] | Sequence[tuple[int, int]], default: 0 ) –

    The amount of padding to apply before and after each spatial dimension.

  • dilation ¤
    (int | Sequence[int], default: 1 ) –

    The dilation of the convolution.

  • groups ¤
    (int, default: 1 ) –

    The number of groups to split the input channels into. Each group is convolved separately with its own kernel.

  • use_bias ¤
    (bool, default: True ) –

    Whether to add a bias after the convolution.

  • key ¤
    (PRNGKeyArray) –

    Random key

ConvTranspose ¤

Bases: _ConvBase

General N-dimensional transposed convolution.

Based on Equinox's ConvTranspose class

Methods:

Attributes:

Spatial: tuple[str | Axis, ...] = eqx.field(static=True) class-attribute instance-attribute ¤
In: Axis = eqx.field(static=True) class-attribute instance-attribute ¤
Out: Axis = eqx.field(static=True) class-attribute instance-attribute ¤
weight: NamedArray = eqx.field() class-attribute instance-attribute ¤
bias: NamedArray | None = eqx.field() class-attribute instance-attribute ¤
kernel_size: tuple[int, ...] = eqx.field(static=True) class-attribute instance-attribute ¤
stride: tuple[int, ...] = eqx.field(static=True) class-attribute instance-attribute ¤
padding: tuple[tuple[int, int], ...] = eqx.field(static=True) class-attribute instance-attribute ¤
output_padding: tuple[int, ...] = eqx.field(static=True) class-attribute instance-attribute ¤
dilation: tuple[int, ...] = eqx.field(static=True) class-attribute instance-attribute ¤
groups: int = eqx.field(static=True) class-attribute instance-attribute ¤
init(Spatial: AxisSelection, In: Axis, Out: Axis, kernel_size: int | Sequence[int], *, stride: int | Sequence[int] = 1, padding: int | Sequence[int] | Sequence[tuple[int, int]] = 0, output_padding: int | Sequence[int] = 0, dilation: int | Sequence[int] = 1, groups: int = 1, use_bias: bool = True, key: PRNGKeyArray) staticmethod ¤

Parameters:

  • Spatial ¤
    (AxisSelection) –

    Spatial dimensions

  • In ¤
    (Axis) –

    Input channels

  • Out ¤
    (Axis) –

    Output channels

  • kernel_size ¤
    (int | Sequence[int]) –

    Kernel size, can be a single number or a tuple

  • stride ¤
    (int | Sequence[int], default: 1 ) –

    Stride, can be a single number or a tuple

  • padding ¤
    (int | Sequence[int] | Sequence[tuple[int, int]], default: 0 ) –

    Padding, can be a single number or a tuple

  • output_padding ¤
    (int | Sequence[int], default: 0 ) –

    Output padding, can be a single number or a tuple

  • dilation ¤
    (int | Sequence[int], default: 1 ) –

    Dilation, can be a single number or a tuple

  • groups ¤
    (int, default: 1 ) –

    Number of groups to split the input channels into

  • use_bias ¤
    (bool, default: True ) –

    Whether to add on a bias after the convolution

  • key ¤
    (PRNGKeyArray) –

    Random key

Notes

Output padding is the amount of extra padding to add to the output. Because the output size is not uniquely determined by the input size for transposed convolutions.

Pooling¤

As with convolutions, we don't distinguish between 1D, 2D, and 3D pooling, and instead have a single pooling operation for each of the kinds of reductions:

max_pool(Window: AxisSpec, inputs: NamedArray, stride: int | tuple[int, ...] | None = None, padding: Padding = DEFAULT_PADDING, use_ceil: bool = False) -> NamedArray ¤

Max pooling.

Parameters:

  • Window ¤
    (AxisSpec) –

    the size of the window to pool over

  • inputs ¤
    (NamedArray) –

    input data with dimensions (batch, window dims..., features).

  • stride ¤
    (int | tuple[int, ...] | None, default: None ) –

    a sequence of n integers, representing the inter-window stride (default: (1, ..., 1)).

  • padding ¤
    (Padding, default: DEFAULT_PADDING ) –

    either the string 'SAME', the string 'VALID', or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.

Returns: The maximum value in each window slice.

mean_pool(Window: AxisSpec, inputs: NamedArray, stride: int | tuple[int, ...] | None = None, padding: Padding = DEFAULT_PADDING, *, use_ceil: bool = False, count_include_pad: bool = False) -> NamedArray ¤

Mean pooling.

Parameters:

  • Window ¤
    (AxisSpec) –

    the size of the window to pool over

  • inputs ¤
    (NamedArray) –

    input data with dimensions (batch, window dims..., features).

  • stride ¤
    (int | tuple[int, ...] | None, default: None ) –

    a sequence of n integers, representing the inter-window stride (default: (1, ..., 1)).

  • padding ¤
    (Padding, default: DEFAULT_PADDING ) –

    either the string 'SAME', the string 'VALID', or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.

Returns: The mean value in each window slice.

min_pool(Window: AxisSpec, inputs: NamedArray, stride: int | tuple[int, ...] | None = None, padding: Padding = DEFAULT_PADDING, use_ceil: bool = False) -> NamedArray ¤

Min pooling.

Parameters:

  • Window ¤
    (AxisSpec) –

    the size of the window to pool over

  • inputs ¤
    (NamedArray) –

    input data with dimensions (batch, window dims..., features).

  • stride ¤
    (int | tuple[int, ...] | None, default: None ) –

    a sequence of n integers, representing the inter-window stride (default: (1, ..., 1)).

  • padding ¤
    (Padding, default: DEFAULT_PADDING ) –

    either the string 'SAME', the string 'VALID', or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.

  • use_ceil ¤
    (bool, default: False ) –

    if True, will use ceil instead of floor to compute the output shape

Returns: The minimum value in each window slice.

Attention¤

We don't provide an explicit attention module, but we do provide an attention function and several related functions:

dot_product_attention(KPos: AxisSelection, Key: AxisSelector, query: NamedArray, key: NamedArray, value: NamedArray, mask: NamedArray | None = None, bias: NamedArray | None = None, attention_dtype: jnp.dtype | None = None, precision: PrecisionLike = None) -> NamedArray ¤

NamedArray version of dot product attention. This can be multi-headed or not.

:param KPos: Axis of key sequence length :param Key: Axis of head dimension :param query: NamedArray of shape {..., QPos, KeySize} :param key: NamedArray of shape {..., KPos, KeySize} :param value: NamedArray of shape {..., KPos, KeySize} :param mask: NamedArray | None broadcast compatible with (KeySize, QPos, KPos). Should be boolean :param bias: NamedArray | None broadcast compatible with (KeySize, QPos, KPos). Should be float :param attention_dtype: Optional dtype to use for attention :param precision: PrecisionLike for dot product. See precision argument to jax.lax.dot_general :return: NamedArray of shape (QPos, KeySize)

Mask and bias are given as separate arguments because they are often computed separately and have different shapes. For example, mask is frequently just a boolean array of shape (QPos, KPos), while bias is frequently a float array of shape (KeySize, QPos, KPos) or (KeySize, KPos)

dot_product_attention_weights(Key: AxisSelector, KPos: AxisSelection, query: NamedArray, key: NamedArray, mask: NamedArray | None = None, bias: NamedArray | None = None, attention_dtype: jnp.dtype | None = None, precision: PrecisionLike = None, scaling_factor: float | None = None) -> NamedArray ¤

NamedArray version of dot product attention. Computes the logits for the attention weights. Note that the "Pos" axis in query must be distinct from the "Pos" axis in key.

:param Key: Axis of head dimension :param KPos: Axis or axes that are attended to :param query: NamedArray of shape (QPos, KeySize) :param key: NamedArray of shape (KPos, KeySize) :param mask: NamedArray | None broadcast compatible with (KeySize, QPos, KPos). Should be boolean :param bias: NamedArray | None broadcast compatible with (KeySize, QPos, KPos). Should be float :param attention_dtype: Optional dtype to use for attention :param precision: PrecisionLike for dot product. See precision argument to jax.lax.dot_general :param scaling_factor: Optional float as scaling factor for attention score. Default to 1/sqrt(D) :return: NamedArray of shape (QPos, KPos)

Masks¤

causal_mask(QPos: Axis, KPos: Axis, q_start: int | NamedArray = 0, k_start: int | NamedArray = 0) -> NamedArray ¤

Creates a materialized causal mask for attention.

:param QPos: Axis of query sequence length :param KPos: Axis of key sequence length :return: NamedArray of shape (QPos, KPos)

prefix_lm_mask(QSeqLen: Axis, KSeqLen: Axis, prefix_len: int, q_start: int = 0, k_start: int = 0) -> NamedArray ¤

Mask for the PrefixLM objective: fully connected before prefix_len, then causal after.

combine_masks_and(mask1: NamedArray | None, mask2: NamedArray | None) -> NamedArray | None ¤

combine_masks_or(mask1: NamedArray | None, mask2: NamedArray | None) -> NamedArray | None ¤

forgetful_causal_mask(KPos: Axis, mask_prob: float, sample_prob: bool = True, *, key: PRNGKeyArray) -> NamedArray ¤

Forgetful Context Masking a la https://arxiv.org/abs/2210.13432. Randomly drops out positions from the key sequence. Reportedly better than normal attention dropout. Almost certainly faster.

You're always allowed to attend to the 0th position. (They say BOS token, but we don't always start with bos)

:param KPos: Axis of key sequence length :param mask_prob: Probability a position to mask :param sample_prob: If True, sample the prob between 0 and the provided prob (this is what the paper does)

Biases¤

mask_to_bias(mask: NamedArray, mask_value: float = -1000000000.0) -> NamedArray ¤

alibi_attention_bias(Heads: Axis, KPos: Axis, bias_max: float = 8, dtype=jnp.float32) -> NamedArray ¤

Creates an attention bias for alibi attention.

:param KPos: Axis of (key) sequence length :param Heads: Axis of heads :return: NamedArray of shape (Heads, KPos)

Functions¤

These functions wrap the equivalent in jax.nn:

relu(a: A) -> A ¤

relu6(a: A) -> A ¤

sigmoid(a: A) -> A ¤

softplus(a: A) -> A ¤

soft_sign(a: A) -> A ¤

silu(a: A) -> A ¤

swish(a: A) -> A ¤

log_sigmoid(a: A) -> A ¤

leaky_relu(a: A) -> A ¤

hard_sigmoid(a: A) -> A ¤

hard_silu(a: A) -> A ¤

hard_swish(a: A) -> A ¤

hard_tanh(a: A) -> A ¤

elu(a: A) -> A ¤

celu(a: A) -> A ¤

selu(a: A) -> A ¤

gelu(a: A, approximate: bool = True) -> A ¤

quick_gelu(x) ¤

glu(x: NamedArray, axis: Axis) -> NamedArray ¤

logsumexp(a: A, axis: AxisSelection | None = None) -> A ¤

log_softmax(a: A, axis: AxisSelection | None = None) -> A ¤

softmax(a: A, axis: AxisSelection | None = None) -> A ¤

standardize(x: NamedArray, axis: AxisSpec, *, mean: NamedArray | None = None, variance: NamedArray | None = None, epsilon: float = 1e-05, where: NamedArray | None = None) -> NamedArray ¤

Analogous to jax.nn.standardize, but with support for NamedArrays.

one_hot(x: NamedArray | int, class_axis: Axis, *, dtype=None) -> NamedArray ¤

Convert an integer to a one-hot vector. This is basically a generalization of jax.nn.one_hot for NamedArrays.

Parameters:

  • x ¤
    (NamedArray | int) –

    the integer or NamedArray of integers to convert

  • class_axis ¤
    (Axis) –

    the axis to convert to one-hot

  • dtype ¤

    the dtype of the result. If None, it will default to jax's default (currently float_)

Returns: a NamedArray with the same axes as x plus class_axis, with 1s in the appropriate places

Loss Functions¤

cross_entropy_loss(logits: NamedArray, Label: AxisSelector, targets: NamedArray, reduction: ReductionFunction | None | Unspecified = UNSPECIFIED, where: NamedArray | None = None, reduction_axis: AxisSelection | None = None) -> jnp.ndarray | NamedArray ¤

cross_entropy_loss(logits: NamedArray, Label: AxisSelector, targets: NamedArray, reduction: ReductionFunction | None | Unspecified = UNSPECIFIED, where: NamedArray | None = None, reduction_axis: None = None) -> jnp.ndarray | NamedArray
cross_entropy_loss(logits: NamedArray, Label: AxisSelector, targets: NamedArray, reduction: ReductionFunction | None | Unspecified = UNSPECIFIED, where: NamedArray | None = None, reduction_axis: AxisSelection = ...) -> NamedArray

cross_entropy_loss_and_log_normalizers(pred_y: NamedArray, Label: AxisSelector, target_y: NamedArray) -> tuple[NamedArray, NamedArray] ¤

Compute the cross entropy loss and log normalizers for a batch of predictions and targets.

:param pred_y: a NamedArray with the Label axis (and possibly others for e.g. batch and seq) containing the logits :param Label: the Label axis :param target_y: a NamedArray with the Label axis (and possibly others) containing the targets

:return: tuple of two named arrays, with "per position" losses and log normalizers