Skip to content

Commit

Permalink
Improve JSA library README, describing the basic principles of the li…
Browse files Browse the repository at this point in the history
…brary. (#41)

Mini-example of a JAX training loop using `autoscale`.
  • Loading branch information
balancap authored Nov 29, 2023
1 parent 4c656ac commit 9483f5b
Showing 1 changed file with 52 additions and 3 deletions.
55 changes: 52 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,66 @@
# JAX Scaled Arithmetics

JAX Scaled Arithmetics is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of
**JAX Scaled Arithmetics** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of
deep neural networks in low precision (BF16, FP16, FP8).

* [Draft JSA design document](docs/design.md);
Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Usually, these works have focused on ad-hoc approaches around scaling of matmuls (and sometimes reduction operations). The JSA library is adopting a more systematic approach by transforming the full computational graph into a `ScaledArray` graph, i.e. every operation taking `ScaledArray` inputs and returning `ScaledArray`, where the latter is a simple datastructure:
```python
@dataclass
class ScaledArray:
data: Array
scale: Array

def to_array(self) -> Array:
return data * scale
```

A typical JAX training loop requires just a few modifications to take advantage of `autoscale`:
```python
import jax_scaled_arithmetics as jsa

params = jsa.as_scaled_array(params)

@jit
@jsa.autoscale
def update(params, batch):
grads = grad(loss)(params, batch)
return opt_update(params, grads)

for batch in batches:
batch = jsa.as_scaled_array(batch)
params = update(params, batch)
```
In other words: model parameters and micro-batch are converted to `ScaledArray` objects, and the decorator `jsa.autoscale` properly transforms the graph into a scaled arithmetics graph (see the [MNIST examples](./experiments/mnist/) for more details).

There are multiple benefits to this systematic approach:

* The model definition is unchanged (i.e. compared to unit scaling);
* The dynamic rescaling logic can be moved to optimizer update phase, simplifying the model definition and state;
* Clean implementation, as a JAX interpreter, similarly to `grad`, `vmap`, ...
* Generalize to different quantization paradigms: `int8` per channel, `MX` block scaling, per tensor scaling;
* FP16 training is more stable?
* FP8 support out of the box?


## Installation

Local git repository install:
JSA library can be easily installed in Python virtual environnment:
```bash
git clone git@github.com:graphcore-research/jax-scaled-arithmetics.git
pip install -e ./
```
The main dependencies are `numpy`, `jax` and `chex` libraries.

**Note:** it is compatible with [experimental JAX on IPU](https://github.com/graphcore-research/jax-experimental), which can be installed in a Graphcore Poplar Python environnment:
```bash
pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html
```

## Documentation

* [Draft Scaled Arithmetics design document](docs/design.md);

## Development

Running `pre-commit` and `pytest`:
```bash
Expand Down

0 comments on commit 9483f5b

Please sign in to comment.