Skip to content

Commit

Permalink
[MoE][PoC] Expert Parallel: tp and tp2ep
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Dec 12, 2024
1 parent 327b72b commit 50faa5a
Show file tree
Hide file tree
Showing 4 changed files with 414 additions and 1 deletion.
10 changes: 10 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,16 @@ def __init__(self):
default=1,
help="Context parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--experimental.expert_parallel_mode",
type=str,
default="none",
choices=["none", "tp", "tp2ep"],
help="""
Expert Parallel mode.
'tp2ep' would use the entire TP mesh to shard non-shared experts on the num_experts dimension.
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
327 changes: 327 additions & 0 deletions torchtitan/parallelisms/expert_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from functools import partial
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor.parallel import ParallelStyle
from torch.distributed.tensor.placement_types import Placement


# This is similar to PrepareModuleInput and PrepareModuleOutput,
# but applies them simultaneously.
class PrepareModuleInputOutput(ParallelStyle):
def __init__(
self,
*,
input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None,
desired_input_layouts: Optional[
Union[Placement, Tuple[Optional[Placement]]]
] = None,
input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
use_local_input: bool = True,
output_layouts: Union[Placement, Tuple[Placement]],
desired_output_layouts: Union[Placement, Tuple[Placement]],
use_local_output: bool = True,
):
# for input
self.input_layouts = (
(input_layouts,) if isinstance(input_layouts, Placement) else input_layouts
)
self.desired_input_layouts = (
(desired_input_layouts,)
if isinstance(desired_input_layouts, Placement)
else desired_input_layouts
)
self.use_local_input = use_local_input
if self.input_layouts is not None:
assert (
self.desired_input_layouts is not None
), "desired module inputs should not be None!"
assert len(self.input_layouts) == len(
self.desired_input_layouts
), "input_layouts and desired_input_layouts should have same length!"
self.with_kwargs = input_kwarg_layouts is not None
self.input_kwarg_layouts = input_kwarg_layouts or {}
self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
if self.with_kwargs:
assert len(self.input_kwarg_layouts) == len(
self.desired_input_kwarg_layouts
), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"

# for output
self.output_layouts = (
(output_layouts,)
if isinstance(output_layouts, Placement)
else output_layouts
)
self.desired_output_layouts = (
(desired_output_layouts,)
if isinstance(desired_output_layouts, Placement)
else desired_output_layouts
)
self.use_local_output = use_local_output
assert len(self.output_layouts) == len(
self.desired_output_layouts
), "output_layouts and desired_output_layouts should have same length!"

def _prepare_input_arg(
self,
input: Any,
mesh: DeviceMesh,
input_layout: Optional[Placement],
desired_layout: Optional[Placement],
):
if input_layout is not None:
if isinstance(input, DTensor):
# TODO: re-enable the check once we fix the compile path
# assert inp.placements[0] == input_layout
dt_inp = input
else:
assert isinstance(
input, torch.Tensor
), "expecting input to be a torch.Tensor!"
dt_inp = DTensor.from_local(
input, mesh, (input_layout,), run_check=False
)

if desired_layout is not None and input_layout != desired_layout:
dt_inp = dt_inp.redistribute(placements=(desired_layout,))

return dt_inp.to_local() if self.use_local_input else dt_inp
else:
return input

def _prepare_input_fn(self, inputs, device_mesh):
if self.input_layouts is None:
return inputs
prepared_inputs = []
if not isinstance(inputs, tuple):
inputs = (inputs,)
if len(inputs) != len(self.input_layouts):
raise ValueError("module inputs and input_layouts should have same length!")

assert (
self.desired_input_layouts is not None
), "desired module inputs should not be None!"
for inp, input_layout, desired_layout in zip(
inputs, self.input_layouts, self.desired_input_layouts
):
prepared_inputs.append(
self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)
)
return tuple(prepared_inputs)

def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
prepared_kwarg_inputs = {}
for kwarg_key in kwarg_inputs.keys():
kwarg_val = kwarg_inputs[kwarg_key]
input_layout = self.input_kwarg_layouts.get(kwarg_key)
desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key)

prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(
kwarg_val, device_mesh, input_layout, desired_input_layout
)

return (prepared_arg_inputs, prepared_kwarg_inputs)

def _prepare_out_fn(self, outputs, device_mesh):
prepared_outputs = []
if not isinstance(outputs, tuple):
outputs = (outputs,)
if len(outputs) != len(self.output_layouts):
raise ValueError(
"module outputs and output_layouts should have same length!"
)
for out, out_layout, desired_out_layout in zip(
outputs, self.output_layouts, self.desired_output_layouts
):
if out_layout is not None:
if isinstance(out, DTensor):
# TODO: re-enable the check once we fix the compile path
# assert out.placements[0] == out_layout
dt_out = out
else:
dt_out = DTensor.from_local(
out, device_mesh, (out_layout,), run_check=False
)

if out_layout != desired_out_layout:
dt_out = dt_out.redistribute(placements=(desired_out_layout,))
prepared_outputs.append(
dt_out.to_local() if self.use_local_output else dt_out
)
else:
prepared_outputs.append(out)
if len(prepared_outputs) == 1:
return prepared_outputs[0]
else:
return tuple(prepared_outputs)

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
# for input
if self.with_kwargs:
module.register_forward_pre_hook(
lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(
inputs, kwargs, device_mesh
),
with_kwargs=True,
) # type: ignore[misc]
else:
module.register_forward_pre_hook(
lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)
) # type: ignore[misc, call-arg]

# for output
module.register_forward_hook(
lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)
) # type: ignore[misc, call-arg]

return module


class TensorParallel(ParallelStyle):
def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True,
):
super().__init__()
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = (output_layouts or Partial(),)
self.desired_input_layouts = (Replicate(),)
self.use_local_output = use_local_output

@staticmethod
def _prepare_input_fn(
input_layouts, desired_input_layouts, mod, inputs, device_mesh
):
# TODO: figure out dynamo support for instance method and switch this to instance method

# annotate module input placements/sharding with input_layouts
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(
input_tensor, device_mesh, input_layouts, run_check=False
)

if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(
placements=desired_input_layouts, async_op=True
)
return input_tensor

def _partition_fn(self, name, module, device_mesh):
module.register_parameter(
"gate_proj",
nn.Parameter(distribute_tensor(module.gate_proj, device_mesh, [Shard(2)])),
) # Column-wise sharding
module.register_parameter(
"down_proj",
nn.Parameter(distribute_tensor(module.down_proj, device_mesh, [Shard(1)])),
) # Row-wise sharding
module.register_parameter(
"up_proj",
nn.Parameter(distribute_tensor(module.up_proj, device_mesh, [Shard(2)])),
) # Column-wise sharding

@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
# back to local tensor
return outputs.to_local() if use_local_output else outputs

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
self._partition_fn,
partial(
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
),
partial(
self._prepare_output_fn, self.output_layouts, self.use_local_output
),
)


class ExpertParallel(ParallelStyle):
def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True,
):
super().__init__()
self.input_layouts = (input_layouts or Shard(0),)
self.output_layouts = (output_layouts or Shard(0),)
self.desired_input_layouts = (Shard(0),)
self.use_local_output = use_local_output

@staticmethod
def _prepare_input_fn(
input_layouts, desired_input_layouts, mod, inputs, device_mesh
):
# TODO: figure out dynamo support for instance method and switch this to instance method

# annotate module input placements/sharding with input_layouts
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(
input_tensor, device_mesh, input_layouts, run_check=False
)

if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(
placements=desired_input_layouts, async_op=True
)
return input_tensor

def _partition_fn(self, name, module, device_mesh):
# shard on the expert dimension
for name, param in module.named_parameters(recurse=False):
dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)]))
module.register_parameter(name, dist_param)

@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
# back to local tensor
return outputs.to_local() if use_local_output else outputs

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
self._partition_fn,
partial(
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
),
partial(
self._prepare_output_fn, self.output_layouts, self.use_local_output
),
)
Loading

0 comments on commit 50faa5a

Please sign in to comment.