Skip to content

Commit

Permalink
Upgrade to PyTorch Lightning 1.7.x and PyTorch 1.12.x
Browse files Browse the repository at this point in the history
- closes #19
  • Loading branch information
krasserm committed Aug 31, 2022
1 parent a780397 commit ebed3e8
Show file tree
Hide file tree
Showing 18 changed files with 552 additions and 614 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime

WORKDIR /app

Expand Down
23 changes: 13 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ This library is a PyTorch and PyTorch Lightning implementation of

An introduction to the model interfaces provided by the library is given in [Interfaces](docs/interfaces.md). Further
implementation details are described in [Architecture](docs/architecture.md). The codebase was designed for easy
extension to new tasks and datasets. The integration with [PyTorch Lightning](https://www.pytorchlightning.ai/)
supports model training at any scale. The command line interface is implemented with the [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/1.6.5/common/lightning_cli.html).
extension to new tasks and datasets. The integration with [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/)
supports model training at any scale. The command line interface is implemented with the [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/stable/cli/lightning_cli.html).

Datasets used for model training are 🤗 [Datasets](https://huggingface.co/docs/datasets) wrapped into PyTorch Lightning
data modules (see [data](perceiver/data) package). Datasets are automatically downloaded, preprocessed and cached
Expand Down Expand Up @@ -114,7 +114,7 @@ python -m perceiver.scripts.image.classifier fit \

Here are some command line examples how to train Perceiver IO models with this library. If a model must be initialized
with parameters from a previous run, it references a checkpoint from that run with the `--model.params` option. You can
download these checkpoints [here](https://martin-krasser.com/perceiver/logs-update-6.zip) (2.3 GB) or create your own
download these checkpoints [here](https://martin-krasser.com/perceiver/logs-update-7.zip) (2.3 GB) or create your own
checkpoints by running the examples yourself. Training results are used in [Inference examples](notebooks/inference_examples.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/krasserm/perceiver-io/blob/main/notebooks/inference_examples.ipynb)

Expand Down Expand Up @@ -169,8 +169,8 @@ python -m perceiver.scripts.text.lm fit \
--trainer.devices=2 \
--trainer.strategy=ddp_sharded \
--trainer.log_every_n_steps=20 \
--trainer.logger.save_dir=logs \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=mlm
```

Expand Down Expand Up @@ -240,6 +240,7 @@ The validation accuracy of these two runs can be obtained with
python -m perceiver.scripts.text.classifier validate \
--config=logs/txt_clf_dec/version_1/config.yaml \
--model.encoder.params=null \
--trainer.devices=1 \
--ckpt_path="logs/txt_clf_dec/version_1/checkpoints/epoch=010-val_loss=0.212.ckpt"
```

Expand All @@ -248,7 +249,7 @@ python -m perceiver.scripts.text.classifier validate \
Validate metric DataLoader 0
──────────────────────────────────────────────────
val_acc 0.9162399768829346
val_loss 0.21216852962970734
val_loss 0.2121591567993164
──────────────────────────────────────────────────
```

Expand All @@ -258,15 +259,16 @@ and
python -m perceiver.scripts.text.classifier validate \
--config=logs/txt_clf_all/version_0/config.yaml \
--model.params=null \
--trainer.devices=1 \
--ckpt_path="logs/txt_clf_all/version_0/checkpoints/epoch=002-val_loss=0.156.ckpt"
```

```
──────────────────────────────────────────────────
Validate metric DataLoader 0
──────────────────────────────────────────────────
val_acc 0.9444400072097778
val_loss 0.15595446527004242
val_acc 0.9444000124931335
val_loss 0.15592406690120697
──────────────────────────────────────────────────
```

Expand Down Expand Up @@ -314,8 +316,8 @@ python -m perceiver.scripts.text.lm fit \
--trainer.accumulate_grad_batches=2 \
--trainer.val_check_interval=0.5 \
--trainer.log_every_n_steps=20 \
--trainer.logger.save_dir=logs \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=mlm_pre
```

Expand Down Expand Up @@ -356,14 +358,15 @@ The validation accuracy is 98.1%:
```shell
python -m perceiver.scripts.image.classifier validate \
--config=logs/img_clf/version_0/config.yaml \
--trainer.devices=1 \
--ckpt_path="logs/img_clf/version_0/checkpoints/epoch=015-val_loss=0.068.ckpt"
```

```
──────────────────────────────────────────────────
Validate metric DataLoader 0
──────────────────────────────────────────────────
val_acc 0.9807000160217285
val_loss 0.06775263696908951
val_acc 0.9805999994277954
val_loss 0.06774937361478806
──────────────────────────────────────────────────
```
4 changes: 2 additions & 2 deletions docs/docker.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Docker

A [Docker image for the latest release](https://github.com/krasserm/perceiver-io/pkgs/container/perceiver-io) is
available on GitHub Container registry. Training runs can be started with:
A [Docker image](https://github.com/krasserm/perceiver-io/pkgs/container/perceiver-io) with the `perceiver-io` library
installed is available on the GitHub Container registry. Training runs can be started with:

```shell
sudo docker run \
Expand Down
12 changes: 5 additions & 7 deletions docs/interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This library provides three model interfaces:
- *PyTorch Lightning model API*: defines wrappers for PyTorch models to support training with the
[PyTorch Lightning Trainer](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html).
- *PyTorch Lightning model CLI*: binds the PyTorch Lightning model API to the command line via the
[Lightning CLI](https://pytorch-lightning.readthedocs.io/en/1.6.5/common/lightning_cli.html).
[Lightning CLI](https://pytorch-lightning.readthedocs.io/en/stable/cli/lightning_cli.html).

The following subsections demonstrate the construction of the Perceiver IO language model (UTF-8 bytes tokenization,
vocabulary size of 262, 201M parameters) specified in Section 4 (Table 1) and Appendix F (Table 11) of the
Expand Down Expand Up @@ -101,7 +101,7 @@ model = lit_model.model

## PyTorch Lightning model CLI

`LitLanguageModel` and `PerceiverConfig` are designed for command-line binding with the [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/1.6.5/common/lightning_cli.html).
`LitLanguageModel` and `PerceiverConfig` are designed for command-line binding with the [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/stable/cli/lightning_cli.html).
A training script for `LitLanguageModel` can be implemented as follows (see [lm.py](../perceiver/scripts/text/lm.py) for
further details):

Expand All @@ -113,14 +113,12 @@ from pytorch_lightning.utilities.cli import (
LightningArgumentParser,
LightningCLI
)

# Data modules must be imported in order
# to be configurable on the command line.
from perceiver.data.text import WikipediaDataModule
from perceiver.model.text.language import LitLanguageModel

# Register Wikipedia data module so that
# it can be referenced on the command line
DATAMODULE_REGISTRY(WikipediaDataModule)
# Register further data modules if needed
# ...

class CLI(LightningCLI):
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
Expand Down
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ channels:
dependencies:
- python=3.7
- cudatoolkit=11.3
- pytorch=1.11.0
- torchvision=0.12.0
- pytorch=1.12
- torchvision=0.13
- pip>=22
Loading

0 comments on commit ebed3e8

Please sign in to comment.