From ba4abb437d40bc348817fade56a89ce2f00ba8a6 Mon Sep 17 00:00:00 2001 From: Christoph Stumpf Date: Tue, 19 Dec 2023 18:38:23 +0100 Subject: [PATCH] Fix activation checkpointing - Create new function `activation_checkpoint_wrapper` to convert module outputs to be compatible with fairscale activation checkpoint wrapper - Use `static_graph` (see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) in img_clf training to allow training with activation checkpointing which otherwise fails with an error --- examples/training/img_clf/train.py | 3 +-- examples/training/img_clf/train.sh | 1 + perceiver/model/core/modules.py | 35 ++++++++++++++++++++++++++---- 3 files changed, 33 insertions(+), 6 deletions(-) mode change 100644 => 100755 examples/training/img_clf/train.sh diff --git a/examples/training/img_clf/train.py b/examples/training/img_clf/train.py index 507cd64..5b0ab8f 100644 --- a/examples/training/img_clf/train.py +++ b/examples/training/img_clf/train.py @@ -47,7 +47,6 @@ def configure_optimizers(self): num_latent_channels=128, ) - if __name__ == "__main__": lit_model = LitImageClassifier.create(config) @@ -55,7 +54,7 @@ def configure_optimizers(self): accelerator="gpu", devices=2, max_epochs=30, - strategy=DDPStrategy(find_unused_parameters=False), + strategy=DDPStrategy(find_unused_parameters=False, static_graph=True), logger=TensorBoardLogger(save_dir="logs", name="img_clf"), ) diff --git a/examples/training/img_clf/train.sh b/examples/training/img_clf/train.sh old mode 100644 new mode 100755 index 74d2a97..ac8ffba --- a/examples/training/img_clf/train.sh +++ b/examples/training/img_clf/train.sh @@ -19,6 +19,7 @@ python -m perceiver.scripts.vision.image_classifier fit \ --trainer.accelerator=gpu \ --trainer.devices=2 \ --trainer.max_epochs=30 \ + --trainer.strategy=ddp_static_graph \ --trainer.logger=TensorBoardLogger \ --trainer.logger.save_dir=logs \ --trainer.logger.name=img_clf diff --git a/perceiver/model/core/modules.py b/perceiver/model/core/modules.py index 431546f..f2062e0 100644 --- a/perceiver/model/core/modules.py +++ b/perceiver/model/core/modules.py @@ -406,7 +406,7 @@ def __init__( ] if activation_checkpointing: - layers = [checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) for layer in layers] + layers = [activation_checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) for layer in layers] self.num_rotary_layers = num_rotary_layers super().__init__(*layers) @@ -543,7 +543,8 @@ def cross_attn(): residual_dropout=residual_dropout, ) return ( - checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) if activation_checkpointing else layer + activation_checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) + if activation_checkpointing else layer ) def self_attn(): @@ -659,7 +660,7 @@ def __init__( ) if activation_checkpointing: - cross_attn = checkpoint_wrapper(cross_attn, offload_to_cpu=activation_offloading) + cross_attn = activation_checkpoint_wrapper(cross_attn, offload_to_cpu=activation_offloading) self.cross_attn = cross_attn self._init_parameters(init_scale) @@ -738,7 +739,8 @@ def cross_attn(): mlp_bias=False, ) return ( - checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) if activation_checkpointing else layer + activation_checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) + if activation_checkpointing else layer ) def self_attn(): @@ -926,3 +928,28 @@ def forward( output.logits = self.output_adapter(output.last_hidden_state, txt_embedding=self.input_adapter.txt_embedding) return output + + +def activation_checkpoint_wrapper(module: AbstractAttentionLayer, offload_to_cpu: bool = False): + abstract_attention_layer_original_forward = AbstractAttentionLayer.forward + + def _abstract_attention_layer_patched_forward(self, *args, **kwargs): + output = abstract_attention_layer_original_forward(self, *args, **kwargs) + if hasattr(self, "_activation_checkpointing_enabled") and self.training and isinstance(output, ModuleOutput): + return output.last_hidden_state + return output + + AbstractAttentionLayer.forward = _abstract_attention_layer_patched_forward + + module = checkpoint_wrapper(module, offload_to_cpu=offload_to_cpu) + module._activation_checkpointing_enabled = True + module_original_forward = module.forward + + def _module_patched_forward(*args, **kwargs): + output = module_original_forward(*args, **kwargs) + if isinstance(output, ModuleOutput): + return output + return ModuleOutput(last_hidden_state=output, kv_cache=None) + + module.forward = _module_patched_forward + return module