Official Implementation of Latent Attention for Linear Time Transformers
Starting from scratch (create virtual environment, install dependencies, activate venv and set path variables):
$ source ./scripts/bootstrap
If venv is already created (just activate venv and set path variables):
$ source ./scripts/init
Download data
$ scripts/download_data
To run an experiment just run the associated bash file.
bash experiments/bash/run_lm.sh
You can also run it from any directory
bash run_lm.sh
The code assumes you have a local directory data in the root with the input and where the output will be saved.
Some datasets are downloaded automatically: Like OpenWebTxt and wiki103,
while others need downloading: ./scripts/download_lra
This is a quick guide to run enwik8. To setup the software:
- Run
source scripts/boostrap.
source scripts/init
- Install requirements. We did not properly set the requirements.txt since instalation of jax and torch depends on GPU. For now:
a. install jax: https://jax.readthedocs.io/en/latest/installation.html
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
b. Install torch:pip install torch
c. Reinstall jax:pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- This is a bit buggy since jax install its own cuda kernels, that is why we need to re-run after torch installation.
d. Install all other libraries:pip install transformers datasets wandb tqdm flax
- This is a bit buggy since jax install its own cuda kernels, that is why we need to re-run after torch installation.
- To run enwik8 do:
a. Create a folder called data in root. b. Create a subfolder named data/logs_latte
c. Download data by runing:bash scripts/download_chr_data.sh
d. runbash experiments/bash/run_lm_enwik.sh
e. You can configure steps/batch size, etc inexperiments/config/lm_scale_enwik.yaml