Skip to content

Commit

Permalink
Tensor parallel distributed strategy without using deepspeed (#280) (#…
Browse files Browse the repository at this point in the history
…299)

* TP reference -  ibm foundation-model-stack

* Code cleanup -removed unused code

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
  • Loading branch information
kalyanjk and kalyanjkk authored Jul 15, 2024
1 parent f06b27a commit 32c86d3
Show file tree
Hide file tree
Showing 9 changed files with 1,105 additions and 4 deletions.
13 changes: 13 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,19 @@ def __call__(self, parser, namespace, values, option_string=None):
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
parser.add_argument(
"--run_partial_dataset",
action="store_true",
help="Run the inference with dataset for specified --n_iterations(default:5)",
)
parser.add_argument(
"--distributed_strategy",
type=str,
choices=["tp", "none"], # Add other strategies as needed
default="none",
help="Run multi card with the specified distributed strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.",
)

args = parser.parse_args()

if args.torch_compile:
Expand Down
99 changes: 98 additions & 1 deletion examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,102 @@ def setup_model(args, model_dtype, model_kwargs, logger):
# assistant_model = get_torch_compiled_model(assistant_model)
return model, assistant_model

def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger):

from optimum.habana.distributed import serialization
from typing import Any, MutableMapping

from optimum.habana.distributed import tp_wrapping
from optimum.habana.distributed.strategy import DistributedStrategy
from torch import nn

class TensorParallelStrategy(DistributedStrategy):
def __init__(self, group=None, from_meta=False):
super().__init__(from_meta)
assert torch.distributed.is_initialized(), "must initialize a process group"
self.group = group if group is not None else torch.distributed.GroupMember.WORLD

def distribute_module(
self, module: nn.Module, final_layers: bool = False
) -> nn.Module:
return tp_wrapping.apply_tp(module, self.group)

def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module:
return tp_wrapping.apply_tp(block, layer, self.group)

def __getstate__(self):
state = self.__dict__.copy()
state['group'] = None # Remove ProcessGroup from state
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.group = None # Restore to default state or reinitialize

logger.info("Multi-device run.")

assert args.assistant_model is None, "Assistant model must be None"

from torch import distributed as dist
if args.device == 'hpu':
import habana_frameworks.torch.distributed.hccl
dist.init_process_group(backend='hccl')
else:
dist.init_process_group()

torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
logger.info("Creating Model")
config = AutoConfig.from_pretrained(args.model_name_or_path,torch_dtype=model_dtype, **model_kwargs)
model_kwargs={}
model_kwargs["distributed_strategy"] = TensorParallelStrategy()
model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs)

initial_device = torch.device("cpu")
source="hf"
checkpoint_sharding=None
lazy_sd: MutableMapping[str, Any] = {}
logger.info("Loading Checkpoints")
lazy_sd = serialization.load_state_dict(
args.model_name_or_path,
source=source,
distributed_strategy=args.distributed_strategy,
checkpoint_sharding=None,
initial_device=initial_device,
rank=args.global_rank,
world_size=args.world_size,
)
architecture="llama"
if len(lazy_sd):
serialization.load_state_dict_into_model(
model,
lazy_sd,
architecture,
source,
args.distributed_strategy,
checkpoint_sharding,
initial_device,
args.local_rank,
args.world_size,
)

if args.quant_config:
model = setup_quantization(model, args)

model = model.eval().to(args.device)

if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

if check_habana_frameworks_version("1.13.0") and model.config.model_type == "falcon":
model = wrap_in_hpu_graph(model, hash_with_views=False)
else:
model = wrap_in_hpu_graph(model)

if args.torch_compile and model.config.model_type == "llama":
model = get_torch_compiled_model(model)

return model, args.assistant_model


def setup_distributed_model(args, model_dtype, model_kwargs, logger):
import deepspeed
Expand Down Expand Up @@ -548,7 +644,8 @@ def initialize_model(args, logger):
model, assistant_model = (
setup_model(args, model_dtype, model_kwargs, logger)
if not use_deepspeed
else setup_distributed_model(args, model_dtype, model_kwargs, logger)
else setup_distributed_model(args, model_dtype, model_kwargs, logger) if not args.distributed_strategy == "tp"
else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger)
)
tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model)
generation_config = setup_generation_config(args, model, assistant_model, tokenizer)
Expand Down
26 changes: 26 additions & 0 deletions optimum/habana/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,28 @@
from .distributed_runner import DistributedRunner
from .fast_ddp import all_reduce_gradients
import os
import torch

def rank_and_world(group=None):
"""
Returns (rank, world_size) from the optionally-specified group, otherwise
from the default group, or if non-distributed just returns (0, 1)
"""
if torch.distributed.is_initialized() and group is None:
group = torch.distributed.GroupMember.WORLD

if group is None:
world_size = 1
rank = 0
else:
world_size = group.size()
rank = group.rank()

return rank, world_size


_LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0))


def local_rank():
return _LOCAL_RANK
Loading

0 comments on commit 32c86d3

Please sign in to comment.