Skip to content

Commit

Permalink
Improve JSA library README, describing the basic principles of the li…
Browse files Browse the repository at this point in the history
…brary.

Mini-example of a JAX training loop using `autoscale`.
  • Loading branch information
balancap committed Nov 29, 2023
1 parent 4c656ac commit 35f58bc
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
# JAX Scaled Arithmetics

JAX Scaled Arithmetics is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of
**JAX Scaled Arithmetics** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of
deep neural networks in low precision (BF16, FP16, FP8).

* [Draft JSA design document](docs/design.md);



## Installation

Local git repository install:
JSA library can be easily installed in Python virtual environnment:
```bash
git clone git@github.com:graphcore-research/jax-scaled-arithmetics.git
pip install -e ./
```
The main dependencies are `numpy`, `jax` and `chex` libraries.

**Note:** it is compatible with [experimental JAX on IPU](https://github.com/graphcore-research/jax-experimental), which can be installed in a Graphcore Poplar Python environnment:
```bash
pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html
```

## Documentation

* [Draft Scaled Arithmetics design document](docs/design.md);

## Development

Running `pre-commit` and `pytest`:
```bash
Expand Down

0 comments on commit 35f58bc

Please sign in to comment.