Quantized Training¤
Warning
FP8 and Int8 training in Haliax is currently experimental and may change in the future.
Haliax supports training with FP8 and int8. This is useful for training on hardware that is optimized for FP8 or Int8, such as the H100 (fp8) or A100s (int8) and TPU v5 and newer (int8).
TL;DR¤
Using FP8 with Haliax is actually pretty straightforward. To enable FP8, do this:
import haliax.quantization as haxq
# setup
module = haxq.quantize_linear_layers(module, haxq.QuantizationConfig(fp8=True))
# if using optax. This saves a tiny amount of memory so you can skip it if you want
_, trainable_module = haxq.partition_for_grad_overwrite(module)
opt_state = opt.initial_state(trainable_module)
# train step
grads = eqx.filter_grad(loss_fn)(module, data)
overwrite, grads = haxq.partition_for_grad_overwrite(grads)
updates, opt_state = opt.update(grads, opt_state, params=module) # or however you update your optimizer
module = haxq.apply_updates(module, updates, overwrite)
And train your model like normal.
Similarly, you can use Int8 by setting Int8=True in the QuantizationConfig object.
What is FP8?¤
FP8 refers to 8-bit floating point numbers. FP8 is a massively reduced precision compared to the 32-bit floating point numbers or 16-bit floating point numbers that are typically used in deep learning: there are only 256 possible values in FP8, compared to the (almost) 2^32 in 32-bit and 2^16 in 16-bit. However, FP8 is still useful for training deep learning models, especially on hardware that is optimized for FP8. In particular, it can massively accelerate training on hardware that is optimized for FP8: H100 has 2x FP8 FLOPS compared to FP16 FLOPS and almost 60x(!) compared to F32 FLOPS.
The FP8 in Haliax is currently designed to optimize throughput on FP8-enabled devices (currently H100) rather than to save memory. In particular, Haliax's FP8 support is not designed to quantize a model to FP8 for deployment, though this shouldn't be that hard to add for models that were trained using this functionality. We would be happy to accept contributions to add this functionality, and are happy to work with you to do so. In particular, adding this for models trained using Haliax's FP8 should be easy.
See this FP8 Primer for more information on FP8.
What is Int8?¤
Int8 refers to 8-bit integers. Int8 has the same number of bits as FP8, but the interpretation is different: instead of exponentially spaced numbers, Int8 has linearly spaced numbers.
In Haliax, we support Int8 training through Google's AQT library. AQT (for "Accurate Quantization Training") is a library that allows you to train models with quantization-aware training (QAT).
How to use FP8 or Int8 in Haliax¤
To use quantized training in Haliax, you need to do three things:
- Enable FP8 (or int8) for the layers you want
- Modify your training step to be compatible
Each of these is just a couple of lines of code.
import haliax as hax
import equinox as eqx
import jax
In = hax.Axis("In", 32)
Mid = hax.Axis("Mid", 128)
Out = hax.Axis("Out", 16)
Hidden = hax.Axis("Hidden", 64)
class MyModule(eqx.Module):
up_proj: hax.nn.Linear
down_proj: hax.nn.Linear
@staticmethod
def init(*, key):
super().__init__()
k_up, k_down = jax.random.split(key)
return MyModule(
up_proj=hax.nn.Linear.init(In, Mid, key=k_up),
down_proj=hax.nn.Linear.init(Mid, Out, key=k_down),
)
def __call__(self, x):
x = self.up_proj(x)
x = hax.nn.relu(x)
x = self.down_proj(x)
return x
module = MyModule.init(key=jax.random.PRNGKey(0))
# Enable FP8
module = hax.quantization.quantize_linear_layers(module, QuantizationConfig(fp8=True))
# Enable FP8 for a specific layer
from haliax.quantization import QuantizationConfig
config = QuantizationConfig(targets=["up_proj"], fp8=True)
module = hax.quantization.quantize_linear_layers(module, config)
# Train step
grads = eqx.filter_grad(loss_fn)(module, data)
overwrite, grads = haxq.partition_for_grad_overwrite(grads)
updates, opt_state = opt.update(grads, opt_state, params=module) # or however you update your optimizer
module = hax.quantization.apply_updates(module, updates, grads)
That's it! Just a few lines of code to enable FP8. The quantize_linear_layers function will transform your module to use
quantization-aware training for linear layers (or a subset if you want), and the combo of haliax.quantization.partition_for_grad_overwrite and haliax.quantization.apply_updates function will apply the updates to the module
in a way that is compatible with FP8.
How FP8 works¤
For an overview of the FP8, see the FP8 Primer. You don't need to understand it though. Haliax's FP8 integration is more or less plug and play, as shown above. The implementation of FP8 in Haliax is more or less a straightforward port (including some copy and paste) of the FP8 implementation in Flax.
FP8 in JAX (as well as INT8) is typically implemented using "dot_general injection", where you pass
a custom implementation of dot_general to functions and modules like haliax.dot and haliax.nn.Linear.
The dot_general for FP8 is implemented by scaling
the inputs, projecting the inputs to FP8, performing the computation in FP8, and then
scaling the result back to the original precision.
The subtle part of FP8 is that the scaling is a parameter that is trained based on a history of the inputs to the layer
(as well as gradients coming in from backward). This means that the FP8 dot_general needs to maintain state.
In Equinox, this means that the dot_general is actually a Module that packages together the state and the
computation. (Unlike equinox.nn.StatefulLayer which returns a state object you pass back into the module, the FP8 dot_general
module hijacks the gradient computation to update its state. This is necessary because the FP8 scaling factors
depend on the gradients.)
The way this happens is by "hijacking" the gradient computation. When you call eqx.filter_grad(loss_fn)(module, data),
you will get the gradient computation as normal, but you'll also get the updated state of the FP8 dot_general module.
This updated state needs to directly replace the state in the module (rather than be used for a gradient step), which is
why you need to use the haliax.quantization.partition_for_grad_overwrite
The FP8 dot_general module is implemented in haliax.quantization.Fp8DotGeneralOp. It's actually not that complicated:
1) It holds a scaling factor and history of maximum values for each of (lhs, rhs, output) and updates them based on the gradients. 2) When invoked, it scales the inputs, projects them to FP8, performs the computation, and scales the result back to the original precision. It remembers the maximum absolute value for each of the inputs. 3) For the gradients, it scales the gradients, projects them to FP8, does the backward computation, and scales the gradients back to the original precision. It remembers the maximum absolute value for the incoming gradient and stores it in the gradient.
How Int8 works¤
Int8 is in principle the same, though the details differ. AQT is a much more flexible library than the FP8 implementation, because it can be a bit more finicky. We use AQT directly, and we recommend you look at the AQT documentation for more information on how it works.
API Reference¤
Functions¤
quantize_linear_layers(tree: T, config: QuantizationConfig) -> T
¤
Converts a module tree to use FP8/INT8 quantization.
partition_for_grad_overwrite(grad: T) -> tuple[T, T]
¤
This function is used to partition the state of a module into two parts: one that will be overwritten by the gradient and one that will be updated by the gradient. This is used by equinox.apply_updates to determine which state should be updated and which should be overwritten. The usual pattern is something like:
```python
grads = jax.grad(loss_fn)(model)
overwrites, grads = partition_for_grad_overwrite(grads)
updates = optimizer.update(grads, params=model)
model = hax.quant.apply_updates(model, updates, overwrites)
```
apply_updates(tree, updates, overwrites)
¤
A jax.tree_util.tree_map-broadcasted version of
```python
if overwrite is not None:
return overwrite
if update is None:
return model
else:
return model + update
Interfaces¤
DotGeneralOp
¤
OverwriteWithGradient
¤
Bases: Module
Sometimes there is state that must be computed in the backward pass which we want to persist for subsequent passes. Typically, we see this with quantization, particularly FP8. This module is a marker that indicates to haliax.quantization.apply_updates that the gradient should be used to overwrite the state rather than added to it.
Typically this is used in conjunction with haliax.quantization.partition_for_grad_overwrite and the types are kinds of DotGeneralOp.
Modules¤
DefaultDotGeneralOp
¤
Bases: Module
The default dot_general function that is used by the Linear module. This is the
standard JAX jax.lax.dot_general function.
Notes
We could have used jax.lax.dot_general directly, but we use this class so that we don't
unnecessarily have functions as leaves in the module tree.
Methods:
-
init–
init()
staticmethod
¤
Fp8DotGeneralOp
¤
Bases: OverwriteWithGradient
Methods:
-
init–
Attributes:
-
input_scale(ndarray) – -
output_grad_scale(ndarray) – -
kernel_scale(ndarray) – -
input_amax_history(ndarray) – -
output_grad_amax_history(ndarray) – -
kernel_amax_history(ndarray) – -
compute_dtype(DTypeLike | None) –
input_scale: jnp.ndarray
instance-attribute
¤
output_grad_scale: jnp.ndarray
instance-attribute
¤
kernel_scale: jnp.ndarray
instance-attribute
¤
input_amax_history: jnp.ndarray
instance-attribute
¤
output_grad_amax_history: jnp.ndarray
instance-attribute
¤
kernel_amax_history: jnp.ndarray
instance-attribute
¤
compute_dtype: DTypeLike | None = eqx.field(static=True)
class-attribute
instance-attribute
¤
init(amax_history_length: int = 1024, compute_dtype: DTypeLike = None)
classmethod
¤
Int8DotGeneralOp
¤
Configuration¤
QuantizationConfig(targets: list[str] | str | None = None, amax_history_length: int = 1024, compute_dtype: DTypeLike = None, fp8: bool = False, int8: bool = False)
dataclass
¤
Attributes:
-
targets(list[str] | str | None) –If provided, only modules with names in this list will be quantized. If a single string, will be treated as a regex
-
amax_history_length(int) – -
compute_dtype(DTypeLike) – -
fp8(bool) – -
int8(bool) –
targets: list[str] | str | None = dataclasses.field(default=None)
class-attribute
instance-attribute
¤
If provided, only modules with names in this list will be quantized. If a single string, will be treated as a regex