The perceiver-io
library provides three levels of interfaces for model construction and training:
- PyTorch backend model API: defines concrete Perceiver IO and Perceiver AR model and configuration classes. These are
subclasses of
PerceiverIO
/PerceiverIOConfig
andPerceiverAR
/PerceiverARConfig
, respectively. Model implementation details are described in building blocks. - PyTorch Lightning model API: defines wrappers for backend models to support training with the PyTorch Lightning Trainer.
- PyTorch Lightning model CLI: binds the PyTorch Lightning model API to the command line via the Lightning CLI.
The following subsections provide examples of Perceiver IO and Perceiver AR model construction and training. The same pattern applies to all other models of this library.
The 🤗 interface of the perceiver-io
library is not covered here. See library design and the
inference examples
notebook on how to use the 🤗 interface for inference.
This section demonstrates the construction of the Perceiver IO language model specified in Section 4 (Table 1) and Appendix F (Table 11) of the Perceiver IO paper (UTF-8 bytes tokenization, vocabulary size of 262, 201M parameters).
This language model is configured with classes MaskedLanguageModelConfig
, TextEncoderConfig
and TextDecoderConfig
and constructed with the MaskedLanguageModel
class. TextEncoderConfig
covers the configuration of the generic encoder
and its task-specific input adapter. TextDecoderConfig
covers the configuration of the generic decoder and its
task-specific output adapter (see also perceiver/model/text/mlm/backend.py).
from perceiver.model.text.mlm import MaskedLanguageModel, MaskedLanguageModelConfig, TextEncoderConfig, TextDecoderConfig
vocab_size = 262 # E
max_seq_len = 2048 # M, O
num_latents = 256 # N
num_latent_channels = 1280 # D
num_input_channels = 768 # C, F (weight tying)
num_qk_channels = 256
encoder_config = TextEncoderConfig(
vocab_size=vocab_size,
max_seq_len=max_seq_len,
num_input_channels=num_input_channels,
num_cross_attention_qk_channels=num_qk_channels,
num_cross_attention_v_channels=num_latent_channels,
num_cross_attention_heads=8,
num_self_attention_qk_channels=num_qk_channels,
num_self_attention_v_channels=num_latent_channels,
num_self_attention_heads=8,
num_self_attention_layers_per_block=26,
num_self_attention_blocks=1,
dropout=0.1,
)
decoder_config = TextDecoderConfig(
vocab_size=vocab_size,
max_seq_len=max_seq_len,
num_cross_attention_qk_channels=num_qk_channels,
num_cross_attention_v_channels=num_input_channels,
num_cross_attention_heads=8,
cross_attention_residual=False,
dropout=0.1,
)
config = MaskedLanguageModelConfig(
encoder_config,
decoder_config,
num_latents=num_latents,
num_latent_channels=num_latent_channels,
)
# PyTorch model
backend_model = MaskedLanguageModel(config)
The same language model wrapped into a PyTorch Lightning module can be created with the LitMaskedLanguageModel
class
and the config
object defined previously.
from perceiver.model.text.mlm import LitMaskedLanguageModel
config = ...
# PyTorch Lightning model
lit_model = LitMaskedLanguageModel.create(config)
# Wrapped backend model
backend_model = lit_model.backend_model
LitMaskedLanguageModel
and MaskedLanguageModelConfig
are designed for command-line binding with the Lightning CLI.
A training script for LitMaskedLanguageModel
can be implemented as follows (see perceiver/scripts/text/mlm.py for
further details):
# File mlm.py
from pytorch_lightning.cli import (
LightningArgumentParser,
LightningCLI
)
# Data modules must be imported in order
# to be configurable on the command line.
from perceiver.data.text import WikipediaDataModule
from perceiver.model.text.mlm import LitMaskedLanguageModel
class CLI(LightningCLI):
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
super().add_arguments_to_parser(parser)
# Ensure that the data module and model share the same values for vocab_size and max_seq_len.
parser.link_arguments("data.vocab_size", "model.encoder.vocab_size", apply_on="instantiate")
parser.link_arguments("data.vocab_size", "model.decoder.vocab_size", apply_on="instantiate")
parser.link_arguments("data.max_seq_len", "model.encoder.max_seq_len", apply_on="instantiate")
parser.link_arguments("data.max_seq_len", "model.decoder.max_seq_len", apply_on="instantiate")
# Define model configuration defaults
# (can be overridden on the command line)
parser.set_defaults(
{
"model.num_latents": 256,
"model.num_latent_channels": 1280,
"model.encoder.dropout": 0.1,
"model.decoder.dropout": 0.1,
# further model configuration defaults ...
}
)
if __name__ == "__main__":
CLI(model_class=LitMaskedLanguageModel)
Training a LitMaskedLanguageModel
from scratch with the Wikipedia dataset can then be started e.g. with:
python mlm.py fit \
--model.encoder.dropout=0.0 \
--model.decoder.dropout=0.0 \
--data=WikipediaDataModule \
--data.tokenizer=deepmind/language-perceiver \
--data.task=mlm \
--data.max_seq_len=2048
--data.batch_size=128 \
--optimizer=Lamb \
--optimizer.lr=1e-3 \
--trainer.accelerator=gpu \
--trainer.devices=-1 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=mlm
If you additionally use the --model.params=krasserm/perceiver-io-mlm
command line option then masked language
modeling starts from the official pretrained model instead of a randomly
initialized model. In this case you should use another dataset because the official model has already been pretrained
on Wikipedia (and other datasets).
The structure of the --model.*
command line options is determined by the structure of the configuration classes
MaskedLanguageModelConfig
, TextEncoderConfig
and TextDecoderConfig
. Defaults defined in perceiver/scripts/text/mlm.py
can be overridden on the command line.
The following subsections demonstrate the construction of a small Perceiver AR language model (UTF-8 bytes tokenization, vocabulary size of 262, 30.6M parameters).
CausalLanguageModel
inherits from PerceiverAR
and is configured with CausalLanguageModelConfig
. See
perceiver/model/text/clm/backend.py for further details.
from perceiver.model.text.clm import CausalLanguageModel, CausalLanguageModelConfig
config = CausalLanguageModelConfig(
vocab_size=262,
max_seq_len=4096,
max_latents=512,
num_channels=512,
num_self_attention_layers=8,
cross_attention_dropout=0.5,
)
# PyTorch model
backend_model = CausalLanguageModel(config)
The same language model wrapped into a PyTorch Lightning module can be created with the LitCausalLanguageModel
class
and the config
object defined previously.
from perceiver.model.text.clm import LitCausalLanguageModel
config = ...
# PyTorch Lightning model
lit_model = LitCausalLanguageModel.create(config)
# Wrapped backend model
backend_model = lit_model.backend_model
LitCausalLanguageModel
is designed for command-line binding with the Lightning CLI.
A training script for LitCausalLanguageModel
can be implemented as follows (see perceiver/scripts/text/clm.py
for further details):
# File clm.py
from pytorch_lightning.cli import (
LightningArgumentParser,
LightningCLI
)
# Data modules must be imported in order
# to be configurable on the command line.
from perceiver.data.text import WikiTextDataModule
from perceiver.model.text.clm import LitCausalLanguageModel
class CLI(LightningCLI):
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
super().add_arguments_to_parser(parser)
parser.link_arguments("data.max_seq_len", "model.max_seq_len", apply_on="instantiate")
parser.link_arguments("data.vocab_size", "model.vocab_size", apply_on="instantiate")
parser.set_defaults(
{
"model.max_latents": 512,
"model.num_channels": 512,
"model.num_self_attention_layers": 8,
"model.cross_attention_dropout": 0.5,
"model.post_attention_dropout": 0.0,
}
)
if __name__ == "__main__":
CLI(LitCausalLanguageModel)
Training a LitCausalLanguageModel
from scratch on the WikiText-103-raw dataset can then be started e.g. with:
python clm.py fit \
--model.cross_attention_dropout=0.6 \
--data=WikiTextDataModule \
--data.tokenizer=deepmind/language-perceiver \
--data.padding_side=left \
--data.max_seq_len=4096 \
--data.batch_size=24 \
--data.task=clm \
--optimizer=Adam \
--optimizer.lr=2e-4 \
--trainer.accelerator=gpu \
--trainer.devices=-1 \
--trainer.max_epochs=5 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=clm