Skip to content

Commit

Permalink
Add new model: simple_gpt_tp_manual (#1969)
Browse files Browse the repository at this point in the history
Summary:
Similar to simple_gpt, but instead of using the DTensor API to apply Tensor Parallelism (TP), we use the manual weights sharding implementation and directly functional collectives. 2 main reasons it is beneficial to add this:
1. DTensor + compile is not ready yet
2. DTensor has a CPU overhead, and adding this less overhead model will help us track the improvement/regression

Tests:

in benchmark/
python test.py -k "test_simple_gpt_manual_tp_"

in pytorch/
PYTHONPATH=benchmark/ python pytorch/benchmarks/dynamo/torchbench.py --float16 -dcuda --inference --backend=inductor --multiprocess --performance --only simple_gpt_tp_manual

Pull Request resolved: #1969

Reviewed By: xuzhao9

Differential Revision: D50130401

Pulled By: xmfan

fbshipit-source-id: cd4b5e543919024ff6c42c6fccfc0b12367d9bb2
  • Loading branch information
xmfan authored and facebook-github-bot committed Oct 10, 2023
1 parent fd7f14e commit cdd87f0
Show file tree
Hide file tree
Showing 6 changed files with 555 additions and 1 deletion.
2 changes: 1 addition & 1 deletion torchbenchmark/models/simple_gpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
tp_mesh_dim=0,
)

max_batch_size = self.DEFAULT_EVAL_BSIZE
max_batch_size = self.batch_size
self.model.setup_caches(
max_batch_size=max_batch_size, max_seq_length=self.model.config.block_size
)
Expand Down
81 changes: 81 additions & 0 deletions torchbenchmark/models/simple_gpt_tp_manual/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os

import torch
from torch.distributed._tensor import DeviceMesh
from torch.distributed.tensor.parallel import parallelize_module
from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel
from torchbenchmark.tasks import NLP

from ...util.model import BenchmarkModel
from .model import LLaMA
from .tp import apply_tp


class Model(BenchmarkModel):
task = NLP.GENERATION
DEFAULT_EVAL_BSIZE = 1

def validate_environment(self):
if not torch.cuda.is_available() or "cuda" not in self.device:
return NotImplementedError("Model requires CUDA")

if not torch.cuda.is_bf16_supported():
return NotImplementedError("Model requires BF16")

if not hasattr(self, "_world_size"):
return NotImplementedError("Model needs to be run via dynamo torchbench and be provided distributed parameters")

if self._world_size != torch.cuda.device_count():
return NotImplementedError(
f"DTensor and all local GPUs to be within the device mesh. {torch.cuda.device_count()} local GPUs, but only world size is only {self._world_size}"
)

return None

def __init__(self, test, device, batch_size=None, extra_args=[]):
super().__init__(
test=test,
device=device,
batch_size=batch_size,
extra_args=extra_args,
)

error = self.validate_environment()
if error:
raise error

# temporary workarounds
torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.inplace_buffers = False

model = LLaMA.from_name("7B")

print("Applying tensor parallel to model ...")
apply_tp(model, self._rank, self._world_size)

max_batch_size = self.batch_size
with torch.device(device):
model.setup_caches(
max_batch_size=max_batch_size, max_seq_length=model.config.block_size
)

self.model = model.to(device=device, dtype=torch.bfloat16)

prompt_size = 10
idx = torch.randint(
self.model.config.vocab_size,
(max_batch_size, prompt_size),
dtype=torch.int32,
device=device,
)
input_pos = torch.arange(prompt_size, device=device)
self.example_inputs = [idx, input_pos]

def get_module(self):
return self.model, self.example_inputs

def train(self):
raise NotImplementedError("Training not supported for this model")

def eval(self):
raise NotImplementedError("Model needs to be run via dynamo torchbench and be provided distributed parameters")
5 changes: 5 additions & 0 deletions torchbenchmark/models/simple_gpt_tp_manual/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
Loading

0 comments on commit cdd87f0

Please sign in to comment.