Skip to content

Commit

Permalink
[not for land] TE experiments
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.use_te
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Oct 29, 2024
1 parent 7310abe commit bb7aeec
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 6 deletions.
32 changes: 32 additions & 0 deletions test/test_te.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.nn as nn

# path hack, TODO remove
import sys
# sys.path.insert(0, '/home/vasiliy/local/torchtitan/torchtitan')
import torchtitan.te_utils as te_utils

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)

def test():
# for now, single GPU smoke test of TE fp8

x = torch.randn(32, 32, device='cuda')

m = nn.Sequential(nn.Linear(32, 32)).cuda()
te_utils.swap_linear_to_te_linear(m)
print(m)

with maybe_te_float8_ctx:
y = m(x)
y.sum().backward()

print('done')

if __name__ == '__main__':
test()
14 changes: 14 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,20 @@ def __init__(self):
action="store_true",
help="Whether to compile the model",
)
self.parser.add_argument(
"--training.use_te",
action="store_true",
help="""
If true, uses TransformerEngine (not for land)
""",
)
self.parser.add_argument(
"--training.use_te_float8",
action="store_true",
help="""
If true, enables TransformerEngine's float8 integration (not for land)
""",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
Expand Down
69 changes: 69 additions & 0 deletions torchtitan/te_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Utilities for testing TransformerEngine
Note: I attempted to hack in DTensor-based TP/SP to te.Linear in the
link below, and gave up for now as it seemed to be a lot of remaining work.
We can power through that if needed later.
* https://gist.github.com/vkuzo/64d5362b63dd6c76410464e020d9a35f
Note: I looked into using te.LayerNormLinear, and that would require changing
how Attention and FFN are defined in torchtitan to use a single gemm for
attn.kqv and ffn.w1_w3. Punting for now but we can do this later if needed.
Note: PyTorch's checkpointing does not work with TE float8, fails with
* https://gist.github.com/vkuzo/54c76c16d6a38610a1d78f4de07a71e7
TE does have a `transformer_engine.pytorch.checkpoint` function, but
unclear where the code for that lives. For now, we have to use
`--activation_checkpoint.mode none`.
Note: using `--activation_checkpoint.mode none` leads to poor TE performance as
the memory usage is close to my GPU limits, the
`WARNING - 164 CUDA memory allocation retries` from the logs seems relevant.
Full logs: https://gist.github.com/vkuzo/0d6ebac2df3f7c90464da1e16d75d24c
Need to decrease memory usage (either by using a smaller model or decreasing
sequence_length) to train with TE without issues.
"""

import contextlib
import os

# required for current build to work with fp8 on devgpu003.cco3
# context: https://github.com/NVIDIA/TransformerEngine/pull/575
# error stack trace if not enabled: https://gist.github.com/vkuzo/8e78282f4a986961753fba25249fdf77
# os.environ["NVTE_UNFUSED_FP8_UPDATE"] = "1"

import torch

# import transformer_engine as te
import transformer_engine.pytorch as te

from transformer_engine.common.recipe import Format, DelayedScaling
te_fp8_format = Format.HYBRID
te_fp8_recipe = DelayedScaling(fp8_format=te_fp8_format, amax_history_len=16, amax_compute_algo="max")

def swap_linear_to_te_linear(model, fqn=''):
for name, child in model.named_children():
new_fqn = f"{fqn}.{name}"
if isinstance(child, torch.nn.Linear):
te_linear = te.Linear(child.in_features, child.out_features, bias=child.bias is not None)
te_linear.weight = child.weight
te_linear.bias = child.bias
setattr(model, name, te_linear)
else:
swap_linear_to_te_linear(child, new_fqn)

def get_maybe_fp8_autocast(job_config):
# not for land - set up TransformerEngine fp8 autocast
# Note: te.fp8_autocast has to be created at every training iteration.
# If we try to create it once and reuse, we get this error:
# https://gist.github.com/vkuzo/d9840328c8bdc2901b8d04aa570ecb5b
maybe_te_float8_ctx = contextlib.nullcontext()
if job_config.training.use_te and job_config.training.use_te_float8:
maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=te_fp8_recipe)
return maybe_te_float8_ctx
23 changes: 17 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchtitan.datasets import build_hf_data_loader, build_tokenizer
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
import torchtitan.te_utils as te_utils
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
Expand Down Expand Up @@ -117,6 +118,11 @@ def main(job_config: JobConfig):
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(model)

# not for land - set up TransformerEngine
if job_config.training.use_te:
te_utils.swap_linear_to_te_linear(model)
print(model)

# log model size
model_param_count = utils.get_num_params(model)
num_flop_per_token = utils.get_num_flop_per_token(
Expand Down Expand Up @@ -284,7 +290,11 @@ def loss_fn(pred, labels):
else None
)

# not for land - set up TransformerEngine fp8 autocast
maybe_te_float8_ctx = te_utils.get_maybe_fp8_autocast(job_config)

if parallel_dims.pp_enabled:
assert not job_config.training.use_te, "unsupported"
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

Expand All @@ -306,12 +316,13 @@ def loss_fn(pred, labels):
else:
# Non-PP forward / backward
with train_context(optional_context_parallel_ctx):
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
with maybe_te_float8_ctx:
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

# clip gradients
for m in model_parts:
Expand Down

0 comments on commit bb7aeec

Please sign in to comment.