forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpipeline_llama.py
124 lines (104 loc) · 3.86 KB
/
pipeline_llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This file applies the PT-D pipeline parallelism to the Llama model.
import copy
from typing import Callable, Union
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage
from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.pipelining_utils import (
build_pipeline_schedule,
generate_split_points,
stage_ids_this_rank,
)
DeviceType = Union[int, str, torch.device]
def pipeline_llama(
model: nn.Module,
pp_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
loss_fn: Callable[..., torch.Tensor],
):
stages, models = pipeline_llama_manual_split(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
return pp_schedule, models
def pipeline_llama_manual_split(
whole_model: nn.Module,
pp_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
):
"""
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
parallelism.
"""
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
microbatches = (
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)
splits = (
job_config.experimental.pipeline_parallel_split_points
or generate_split_points(job_config, parallel_dims.pp, model_config)
)
def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
model = copy.deepcopy(whole_model)
if not is_first:
model.tok_embeddings = None
drop_layers = start_layer is not None
for name in list(model.layers.keys()):
# we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
if f"layers.{name}" == start_layer:
drop_layers = False
if f"layers.{name}" == stop_layer:
drop_layers = True
if drop_layers:
del model.layers[name]
if not is_last:
model.norm = None
model.output = None
stage = PipelineStage(
model,
stage_idx,
num_stages,
device,
group=pp_mesh.get_group("pp"),
)
return stage, model
num_stages = len(splits) + 1
stage_idx = pp_rank
stages = []
models = []
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"):
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
stage_idx,
start_layer,
stop_layer,
is_first=stage_idx == 0,
is_last=stage_idx == num_stages - 1,
)
logger.info(
f"PP rank {pp_rank} is building stage_idx {stage_idx}"
f" with start_layer {start_layer}, stop_layer {stop_layer}: model chunk \n{model_chunk}"
)
stages.append(stage)
models.append(model_chunk)
return stages, models