diff --git a/test/test_te.py b/test/test_te.py new file mode 100644 index 00000000..0feae2dc --- /dev/null +++ b/test/test_te.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +# path hack, TODO remove +import sys +# sys.path.insert(0, '/home/vasiliy/local/torchtitan/torchtitan') +import torchtitan.te_utils as te_utils + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +fp8_format = Format.HYBRID +fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") +maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) + +def test(): + # for now, single GPU smoke test of TE fp8 + + x = torch.randn(32, 32, device='cuda') + + m = nn.Sequential(nn.Linear(32, 32)).cuda() + te_utils.swap_linear_to_te_linear(m) + print(m) + + with maybe_te_float8_ctx: + y = m(x) + y.sum().backward() + + print('done') + +if __name__ == '__main__': + test() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index defc010e..00bbdaeb 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -363,6 +363,20 @@ def __init__(self): action="store_true", help="Whether to compile the model", ) + self.parser.add_argument( + "--training.use_te", + action="store_true", + help=""" + If true, uses TransformerEngine (not for land) + """, + ) + self.parser.add_argument( + "--training.use_te_float8", + action="store_true", + help=""" + If true, enables TransformerEngine's float8 integration (not for land) + """, + ) self.parser.add_argument( "--training.gc_freq", type=int, diff --git a/torchtitan/te_utils.py b/torchtitan/te_utils.py new file mode 100644 index 00000000..cb281bc3 --- /dev/null +++ b/torchtitan/te_utils.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for testing TransformerEngine + +Note: I attempted to hack in DTensor-based TP/SP to te.Linear in the +link below, and gave up for now as it seemed to be a lot of remaining work. +We can power through that if needed later. +* https://gist.github.com/vkuzo/64d5362b63dd6c76410464e020d9a35f + +Note: I looked into using te.LayerNormLinear, and that would require changing +how Attention and FFN are defined in torchtitan to use a single gemm for +attn.kqv and ffn.w1_w3. Punting for now but we can do this later if needed. + +Note: PyTorch's checkpointing does not work with TE float8, fails with +* https://gist.github.com/vkuzo/54c76c16d6a38610a1d78f4de07a71e7 +TE does have a `transformer_engine.pytorch.checkpoint` function, but +unclear where the code for that lives. For now, we have to use +`--activation_checkpoint.mode none`. + +Note: using `--activation_checkpoint.mode none` leads to poor TE performance as +the memory usage is close to my GPU limits, the +`WARNING - 164 CUDA memory allocation retries` from the logs seems relevant. +Full logs: https://gist.github.com/vkuzo/0d6ebac2df3f7c90464da1e16d75d24c +Need to decrease memory usage (either by using a smaller model or decreasing +sequence_length) to train with TE without issues. +""" + +import contextlib +import os + +# required for current build to work with fp8 on devgpu003.cco3 +# context: https://github.com/NVIDIA/TransformerEngine/pull/575 +# error stack trace if not enabled: https://gist.github.com/vkuzo/8e78282f4a986961753fba25249fdf77 +# os.environ["NVTE_UNFUSED_FP8_UPDATE"] = "1" + +import torch + +# import transformer_engine as te +import transformer_engine.pytorch as te + +from transformer_engine.common.recipe import Format, DelayedScaling +te_fp8_format = Format.HYBRID +te_fp8_recipe = DelayedScaling(fp8_format=te_fp8_format, amax_history_len=16, amax_compute_algo="max") + +def swap_linear_to_te_linear(model, fqn=''): + for name, child in model.named_children(): + new_fqn = f"{fqn}.{name}" + if isinstance(child, torch.nn.Linear): + te_linear = te.Linear(child.in_features, child.out_features, bias=child.bias is not None) + te_linear.weight = child.weight + te_linear.bias = child.bias + setattr(model, name, te_linear) + else: + swap_linear_to_te_linear(child, new_fqn) + +def get_maybe_fp8_autocast(job_config): + # not for land - set up TransformerEngine fp8 autocast + # Note: te.fp8_autocast has to be created at every training iteration. + # If we try to create it once and reuse, we get this error: + # https://gist.github.com/vkuzo/d9840328c8bdc2901b8d04aa570ecb5b + maybe_te_float8_ctx = contextlib.nullcontext() + if job_config.training.use_te and job_config.training.use_te_float8: + maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=te_fp8_recipe) + return maybe_te_float8_ctx diff --git a/train.py b/train.py index bc04dad0..c5fd0bc8 100644 --- a/train.py +++ b/train.py @@ -18,6 +18,7 @@ from torchtitan.datasets import build_hf_data_loader, build_tokenizer from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger +import torchtitan.te_utils as te_utils from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.optimizer import build_lr_schedulers, build_optimizers @@ -117,6 +118,11 @@ def main(job_config: JobConfig): # swap to Float8Linear based on float8 configs float8_handler.convert_to_float8_training(model) + # not for land - set up TransformerEngine + if job_config.training.use_te: + te_utils.swap_linear_to_te_linear(model) + print(model) + # log model size model_param_count = utils.get_num_params(model) num_flop_per_token = utils.get_num_flop_per_token( @@ -284,7 +290,11 @@ def loss_fn(pred, labels): else None ) + # not for land - set up TransformerEngine fp8 autocast + maybe_te_float8_ctx = te_utils.get_maybe_fp8_autocast(job_config) + if parallel_dims.pp_enabled: + assert not job_config.training.use_te, "unsupported" # Pipeline Parallel forward / backward inside step() call is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 @@ -306,12 +316,13 @@ def loss_fn(pred, labels): else: # Non-PP forward / backward with train_context(optional_context_parallel_ctx): - pred = model(input_ids) - loss = loss_fn(pred, labels) - # pred.shape=(bs, seq_len, vocab_size) - # need to free to before bwd to avoid peaking memory - del pred - loss.backward() + with maybe_te_float8_ctx: + pred = model(input_ids) + loss = loss_fn(pred, labels) + # pred.shape=(bs, seq_len, vocab_size) + # need to free to before bwd to avoid peaking memory + del pred + loss.backward() # clip gradients for m in model_parts: