Skip to content

Commit

Permalink
Merge pull request #36 from krasserm/wip-fsdp
Browse files Browse the repository at this point in the history
Perceiver AR enhancements
  • Loading branch information
krasserm authored Feb 21, 2023
2 parents 106a709 + 9d6a82b commit abacca4
Show file tree
Hide file tree
Showing 35 changed files with 9,786 additions and 8,439 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime

WORKDIR /app

RUN apt-get update
RUN apt-get install -y --no-install-recommends curl
RUN apt-get install -y --no-install-recommends curl build-essential

RUN curl -sSL https://install.python-poetry.org | POETRY_HOME=/opt/poetry python3 -

Expand Down
185 changes: 2 additions & 183 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ See [Docker image](docs/docker-image.md) for details.

## Documentation

- [Getting started](#getting-started)
- [Getting started](docs/getting-started.md)
- [Model construction](docs/model-construction.md)
- [Pretrained models](docs/pretrained-models.md)
- [Training examples](docs/training-examples.md)
- [Inference examples](examples/inference.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/krasserm/perceiver-io/blob/0.7.0/examples/inference.ipynb)
- [Inference examples](examples/inference.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/krasserm/perceiver-io/blob/0.8.0/examples/inference.ipynb)
- [Building blocks](docs/building-blocks.md)

## Articles
Expand All @@ -94,187 +94,6 @@ Articles referencing this repository:
- [Training compute-optimal Perceiver AR language models](https://krasserm.github.io/2023/01/23/scaling-perceiver-ar/)
- [A gentle introduction to Rotary Position Embedding](https://krasserm.github.io/2022/12/13/rotary-position-embedding/)

## Getting started

Here's a minimal example for autoregressive language modeling with Perceiver AR. A small language model (30.7M parameters)
is trained on the WikiText-103-raw dataset and then used to generate text from a prompt. Input text is tokenized into
raw UTF-8 bytes and the model also generates raw UTF-8 bytes.

The PyTorch model class (`CausalLanguageModel`) and the corresponding PyTorch Lightning wrapper class
(`LitCausalLanguageModel`) are defined in [perceiver/model/text/clm.py](perceiver/model/text/clm.py) (see also
[model construction](docs/model-construction.md) for further details). The PyTorch Lightning data module
(`WikiTextDataModule`) is defined in [perceiver/data/text/wikitext.py](perceiver/data/text/wikitext.py).

### Training

#### Command line

The script for training a `CausalLanguageModel` on the command line is [perceiver/scripts/text/clm.py](perceiver/scripts/text/clm.py).
The constructor signatures of `LitCausalLanguageModel` and `WikiTextDataModule` determine the available `--model.*` and
`--data.*` command line options. Command line options `--optimizer.*`, `--lr_scheduler.*` and `--trainer.*` configure
the optimizer, learning rate scheduler and the PyTorch Lightning [Trainer](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html),
respectively.

```shell
python -m perceiver.scripts.text.clm fit \
--model.num_latents=512 \
--model.num_channels=512 \
--model.num_self_attention_layers=8 \
--model.cross_attention_dropout=0.5 \
--data=WikiTextDataModule \
--data.tokenizer=deepmind/language-perceiver \
--data.add_special_tokens=false \
--data.max_seq_len=4096 \
--data.task=clm \
--data.batch_size=16 \
--optimizer=Adam \
--optimizer.lr=2e-4 \
--lr_scheduler.warmup_steps=200 \
--trainer.accelerator=gpu \
--trainer.devices=1 \
--trainer.max_epochs=5 \
--trainer.accumulate_grad_batches=4
```

Supported optimizers are those packaged with PyTorch and [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
The `--data.task=clm` option configures the data module to produce data compatible with causal language modeling (other
possible values are `mlm` for masked language modeling and `clf` for sequence classification). When running this command
for the first time, the WikiText dataset is downloaded and preprocessed. To download and preprocess the dataset prior
to training, run

```shell
python -m perceiver.scripts.text.preproc wikitext \
--tokenizer=deepmind/language-perceiver \
--add_special_tokens=false \
--max_seq_len=4096 \
--task=clm
```

which is usually faster.

#### Python code

Training on the command line uses the PyTorch Lightning `Trainer` under the hood. To run the `Trainer` directly from
a Python script, dynamically add a `configure_optimizers` method to `LitCausalLanguageModel`, create instances of
`LitCausalLanguageModel` and `WikiTextDataModule` and then call `trainer.fit()` with the model and data module as
arguments:

```python
from torch.optim import Adam

from perceiver.data.text import WikiTextDataModule, Task
from perceiver.model.text.clm import LitCausalLanguageModel, CausalLanguageModelConfig
from perceiver.scripts.lrs import ConstantWithWarmupLR

import pytorch_lightning as pl


def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=2e-4)
scheduler = ConstantWithWarmupLR(optimizer, warmup_steps=200)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler, "interval": "step", "frequency": 1},
}


# # Add configure_optimizers method to LitCausalLanguageModel (not hard-coded there)
setattr(LitCausalLanguageModel, "configure_optimizers", configure_optimizers),


if __name__ == '__main__':
data = WikiTextDataModule(
tokenizer="deepmind/language-perceiver",
add_special_tokens=False,
max_seq_len=4096,
task=Task.clm,
batch_size=16,
)

config = CausalLanguageModelConfig(
vocab_size=data.vocab_size,
max_seq_len=data.max_seq_len,
num_latents=512,
num_channels=512,
num_self_attention_layers=8,
cross_attention_dropout=0.5,
)

# Create Lightning module of CausalLanguageModel from configuration object
lit_model = LitCausalLanguageModel.create(config)

# Instantiate PyTorch Lightning Trainer
trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=5, accumulate_grad_batches=4)

# Train model (will also preprocess dataset if not already done yet)
trainer.fit(lit_model, datamodule=data)
```

The trained PyTorch model can be accessed with `lit_model.model`. If you prefer to use a custom training loop without
using the PyTorch Lightning Trainer, create a plain PyTorch model with `CausalLanguageModel.create(config=...)` and
train it directly as shown in the following simple example:

```python
from perceiver.model.text.clm import CausalLanguageModel

import torch
import torch.nn.functional as F
from torch.optim import Adam

data = ...
data.prepare_data()
data.setup()

model_config = ...
model = CausalLanguageModel(config=model_config)
model.train()

optim = Adam(model.parameters(), lr=2e-4)

# Simplified training loop compared to previous
# examples (no gradient accumulation, ...)
for epoch in range(5):
for labels_ids, input_ids, _ in data.train_dataloader():
logits = model(input_ids)
loss = F.cross_entropy(logits.permute(0, 2, 1), labels_ids[:, -model_config.num_latents:])
loss.backward()
optim.step()
optim.zero_grad()

# Save trained model
torch.save(model.state_dict(), "/path/to/model.pt")
```

### Inference

For generating text from a prompt via top-k sampling, `CausalLanguageModel` provides a `generate()` method. The following
example first loads a trained model from a checkpoint and then generates text from a short sample prompt. An interactive
demo is also available in the [Colab notebook](https://colab.research.google.com/github/krasserm/perceiver-io/blob/0.7.0/examples/inference.ipynb).

```python
from perceiver.data.text import TextPreprocessor
from perceiver.model.text.clm import LitCausalLanguageModel

# Load model from a checkpoint that has been written by the PyTorch Lightning Trainer
model = LitCausalLanguageModel.load_from_checkpoint("/path/to/checkpoint").model.eval()

# Alternatively, load the model's state_dict directly
#model = CausalLanguageModel(config=model_config).eval()
#model.load_state_dict(torch.load("/path/to/model.pt"))

# Create a text preprocessor
preproc = TextPreprocessor(tokenizer="deepmind/language-perceiver", max_seq_len=4096, add_special_tokens=False)

# Convert text to model input
prompt, _ = preproc.preprocess("A man was reading a book on a sunny day until he sudden")

# Continue prompt via top-k sampling where k = f(vocab_size, threshold)
generated = model.generate(num=512, prompt=prompt[None, ...], threshold=0.9)

# Decode model output using preprocessor's tokenizer
generated_text = preproc.tokenizer.decode(generated[0])
```

## Other implementations

- [Perceiver](https://paperswithcode.com/paper/perceiver-general-perception-with-iterative#code)
Expand Down
5 changes: 3 additions & 2 deletions docs/building-blocks.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,6 @@ Perceiver IO architecture (1 cross-attention layer, `L` self-attention blocks wi

[Perceiver AR](https://arxiv.org/abs/2202.07765) models are constructed from a generic `PerceiverAR` class and
task-specific `InputAdapter` and `OutputAdapter` subclasses. The implementation of Perceiver AR is similar to
that of a Perceiver IO encoder. Perceiver AR additionally uses [rotary position embeddings](https://arxiv.org/abs/2104.09864)
and causal cross- and self- attention masks. The current Perceiver AR implementation is still experimental.
that of a Perceiver IO encoder. In addition to absolute position embedding, Perceiver AR also uses [rotary position
embedding](https://krasserm.github.io/2022/12/13/rotary-position-embedding/) to encode relative position information.
It also uses causal cross- and self-attention masks.
10 changes: 8 additions & 2 deletions docs/dataset-preproc.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

Datasets used for model training are 🤗 [Datasets](https://huggingface.co/docs/datasets) wrapped into PyTorch Lightning
data modules (see [data](../perceiver/data) package). Datasets are automatically downloaded, preprocessed and cached
when their corresponding Lightning data module is loaded during training. For larger datasets however, like Wikipedia,
BookCorpus or ImageNet, for example, it is recommended to do this prior to training as described here.
when their corresponding Lightning data module is loaded during training. For larger datasets, like [Wikipedia](../perceiver/data/text/wikipedia.py)
or [BookCorpus](../perceiver/data/text/bookcorpus.py), it is recommended to do this prior to training as described in
the [next section](#text-datasets). The [C4](../perceiver/data/text/c4.py) dataset is streamed directly and doesn't need
preprocessing.

## Text datasets

Expand Down Expand Up @@ -71,6 +73,10 @@ to whatever you need for model training.
--add_special_tokens=false
```

- [C4](https://huggingface.co/datasets/c4) (`c4`), used in [training examples](training-examples.md):

Streaming dataset, no preprocessing needed.

## Image datasets

- [imagenet](https://huggingface.co/datasets/imagenet-1k):
Expand Down
Loading

0 comments on commit abacca4

Please sign in to comment.