Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor parallel distributed strategy without using deepspeed #280

Merged
merged 2 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,14 @@ def __call__(self, parser, namespace, values, option_string=None):
action="store_true",
help="Run the inference with dataset for specified --n_iterations(default:5)",
)
parser.add_argument(
"--distributed_strategy",
type=str,
choices=["tp", "none"], # Add other strategies as needed
default="none",
help="Run multi card with the specified distributed strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.",
)

args = parser.parse_args()

if args.torch_compile:
Expand Down
99 changes: 98 additions & 1 deletion examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,102 @@ def setup_model(args, model_dtype, model_kwargs, logger):
# assistant_model = get_torch_compiled_model(assistant_model)
return model, assistant_model

def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger):

from optimum.habana.distributed import serialization
from typing import Any, MutableMapping

from optimum.habana.distributed import tp_wrapping
from optimum.habana.distributed.strategy import DistributedStrategy
from torch import nn

class TensorParallelStrategy(DistributedStrategy):
def __init__(self, group=None, from_meta=False):
super().__init__(from_meta)
assert torch.distributed.is_initialized(), "must initialize a process group"
self.group = group if group is not None else torch.distributed.GroupMember.WORLD

def distribute_module(
self, module: nn.Module, final_layers: bool = False
) -> nn.Module:
return tp_wrapping.apply_tp(module, self.group)

def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module:
return tp_wrapping.apply_tp(block, layer, self.group)

def __getstate__(self):
state = self.__dict__.copy()
state['group'] = None # Remove ProcessGroup from state
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.group = None # Restore to default state or reinitialize

logger.info("Multi-device run.")

assert args.assistant_model is None, "Assistant model must be None"

from torch import distributed as dist
if args.device == 'hpu':
import habana_frameworks.torch.distributed.hccl
dist.init_process_group(backend='hccl')
else:
dist.init_process_group()

torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
logger.info("Creating Model")
config = AutoConfig.from_pretrained(args.model_name_or_path,torch_dtype=model_dtype, **model_kwargs)
model_kwargs={}
model_kwargs["distributed_strategy"] = TensorParallelStrategy()
model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs)

initial_device = torch.device("cpu")
source="hf"
checkpoint_sharding=None
lazy_sd: MutableMapping[str, Any] = {}
logger.info("Loading Checkpoints")
lazy_sd = serialization.load_state_dict(
args.model_name_or_path,
source=source,
distributed_strategy=args.distributed_strategy,
checkpoint_sharding=None,
initial_device=initial_device,
rank=args.global_rank,
world_size=args.world_size,
)
architecture="llama"
if len(lazy_sd):
serialization.load_state_dict_into_model(
model,
lazy_sd,
architecture,
source,
args.distributed_strategy,
checkpoint_sharding,
initial_device,
args.local_rank,
args.world_size,
)

if args.quant_config:
model = setup_quantization(model, args)

model = model.eval().to(args.device)

if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

if check_habana_frameworks_version("1.13.0") and model.config.model_type == "falcon":
model = wrap_in_hpu_graph(model, hash_with_views=False)
else:
model = wrap_in_hpu_graph(model)

if args.torch_compile and model.config.model_type == "llama":
model = get_torch_compiled_model(model)

return model, args.assistant_model


def setup_distributed_model(args, model_dtype, model_kwargs, logger):
import deepspeed
Expand Down Expand Up @@ -548,7 +644,8 @@ def initialize_model(args, logger):
model, assistant_model = (
setup_model(args, model_dtype, model_kwargs, logger)
if not use_deepspeed
else setup_distributed_model(args, model_dtype, model_kwargs, logger)
else setup_distributed_model(args, model_dtype, model_kwargs, logger) if not args.distributed_strategy == "tp"
else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger)
)
tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model)
generation_config = setup_generation_config(args, model, assistant_model, tokenizer)
Expand Down
26 changes: 26 additions & 0 deletions optimum/habana/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,28 @@
from .distributed_runner import DistributedRunner
from .fast_ddp import all_reduce_gradients
import os
import torch

def rank_and_world(group=None):
"""
Returns (rank, world_size) from the optionally-specified group, otherwise
from the default group, or if non-distributed just returns (0, 1)
"""
if torch.distributed.is_initialized() and group is None:
group = torch.distributed.GroupMember.WORLD

if group is None:
world_size = 1
rank = 0
else:
world_size = group.size()
rank = group.rank()

return rank, world_size


_LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0))


def local_rank():
return _LOCAL_RANK
Loading