Skip to content

Commit

Permalink
Merge branch 'master' into yizhou/support_fdpt
Browse files Browse the repository at this point in the history
  • Loading branch information
YizhouZ authored Jan 16, 2025
2 parents 79af547 + 05eaf3d commit 16a016f
Show file tree
Hide file tree
Showing 45 changed files with 494 additions and 343 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/hpu-gaudi2-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/hpu-gaudi2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/nv-a6000.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ jobs:
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install transformers
run: |
git clone --depth=1 https://github.com/huggingface/transformers
git clone https://github.com/huggingface/transformers
cd transformers
git checkout v4.47.1
git rev-parse --short HEAD
python -m pip install .
- name: Install deepspeed
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nv-ds-chat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Install deepspeed
run: |
pip install transformers==4.45.2
pip install transformers
pip install .[dev]
ds_report
Expand Down
4 changes: 4 additions & 0 deletions SECURITY.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ We prefer all communications to be in English.
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).

<!-- END MICROSOFT SECURITY.MD BLOCK -->

---

Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models.
2 changes: 1 addition & 1 deletion accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_accelerator():
if accelerator_name is None:
# borrow this log from PR#5084
if accel_logger is not None:
accel_logger.warn(
accel_logger.warning(
"Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.")
# cpu added as catch-all when accelerator detection fails
accelerator_name = "cpu"
Expand Down
2 changes: 1 addition & 1 deletion blogs/windows/08-2024/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Regardless of the installation choice, you can check that the installation was s
We use an image classification model, CIFAR10, and a language model, BERT, to demonstrate pretraining on Windows with DeepSpeed.

## Pretraining CIFAR10
The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py deepspeed`. The final output should look something like this:
The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py --deepspeed`. The final output should look something like this:
<div align="center">
<img src="./media/cifar10_training.png" style="width:6.5in;height:3.42153in" />
</div>
Expand Down
9 changes: 9 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
values for :any:`DeepSpeedMoEConfig`.
"""

keep_module_on_host: bool = False
"""
When loading checkpoints to model parameters, they are moved to the device. In very large models
this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on
host and not move them directly to the device (giving an option to quantize checkpoint data before
moving it to the device for example).
Set only for models with injection policies and auto TP.
"""

quant: QuantizationConfig = {}
"""
NOTE: only works for int8 dtype.
Expand Down
26 changes: 1 addition & 25 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def __init__(self, model, config):
self.mp_group = config.tensor_parallel.tp_group
self.mpu = config.tensor_parallel.mpu

#self._validate_args(self.mpu, config.replace_with_kernel_inject)
self.quantize_merge_count = 1
self.quantization_scales = None

Expand Down Expand Up @@ -170,7 +169,7 @@ def __init__(self, model, config):
is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta'
if is_meta_device:
self.module.to_empty(device=device)
else:
elif not config.keep_module_on_host:
self.module.to(device)

if config.tensor_parallel.tp_size > 1:
Expand Down Expand Up @@ -300,29 +299,6 @@ def _init_quantization_setting(self, quantization_setting):
f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
f"quantize_groups = {self.quantize_groups}", [0])

# TODO: remove this function and add this functionality to pydantic config checking
def _validate_args(self, mpu, replace_with_kernel_inject):
# TODO: to support SD pipeline we need to avoid this check for now
if replace_with_kernel_inject and not isinstance(self.module, Module):
raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1:
raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}")

if mpu:
methods = ["get_model_parallel_group", "get_data_parallel_group"]
for method in methods:
if not hasattr(mpu, method):
raise ValueError(f"mpu is missing {method}")
if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")

supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16]
if self._config.dtype not in supported_dtypes:
raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")

if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}")

def load_model_with_checkpoint(self, r_module):
self.mp_replace = ReplaceWithTensorSlicing(
mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
Expand Down
48 changes: 32 additions & 16 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


def move(tensor, device):
def move(tensor, device, copy=True):
if tensor.is_meta:
return torch.empty_like(tensor, device=device)
else:
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
return tensor.to(device, copy=True)
return tensor.to(device, copy=copy)


class ReplaceWithTensorSlicing:
Expand Down Expand Up @@ -134,7 +134,8 @@ def is_load_module(module):
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm"
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm",
"DeepseekV2RMSNorm", "DeepseekV2YarnRotaryEmbedding", "MoEGate"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down Expand Up @@ -188,7 +189,14 @@ def load(module, state_dict, prefix, mp_group=None):

class AutoTP():

def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl):
def __init__(self,
module,
all_reduce_linears,
prefix,
state_dict,
linear_layer_setting,
orig_layer_impl,
keep_module_on_host=False):
self.module = module
self.all_reduce_linears = all_reduce_linears
self.prefix = prefix
Expand All @@ -200,6 +208,7 @@ def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_
self.orig_layer_impl = orig_layer_impl
self.linear_policies = None
self.conv_linear_layer = False
self.keep_module_on_host = keep_module_on_host

def in_module_list(module, module_list):
for item in module_list:
Expand Down Expand Up @@ -330,11 +339,15 @@ def set_tensor_parallel_config(self, mp_size, mp_group):
def _replace(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
return_new_copy = not self.keep_module_on_host
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate" or (('mlp.shared_expert_gate' == name or 'mlp.gate' == name)
and 'qwen2_moe' in str(type(self.module))):
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
if "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
return child
# For Yuan model
if 'Yuan' in str(self.module):
Expand All @@ -350,11 +363,15 @@ def _replace(self, child, name, conv_linear_layer):
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MoE MLP model, e.g., deepseek and jamba
down_proj = False
if 'down_proj' in name:
down_proj = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear:
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

Expand All @@ -363,18 +380,17 @@ def _replace(self, child, name, conv_linear_layer):
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
move(child.bias,
get_accelerator().current_device_name())), self.mp_group)
move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group)
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
else:

# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
Expand All @@ -387,22 +403,22 @@ def _replace(self, child, name, conv_linear_layer):
#The copy is a regular copy, The shape of dst and src is the same
data_dc = move(
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())
device_name, return_new_copy)

bias_data_dc = None if child.bias is None else move(
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())
device_name, return_new_copy)
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name())
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
else:
Expand Down
12 changes: 12 additions & 0 deletions deepspeed/module_inject/containers/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@
class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer):

def __init__(self, **kwargs):
# Check transformers version, error if > 4.43.4 (breaks at 4.44.0)
from importlib.metadata import version
v_transformers = version('transformers')
vers = v_transformers.split('.')
major = int(vers[0])
minor = int(vers[1])
if major > 4 or (major == 4 and minor > 43):
import sys
sys.exit(
f"Transformers version {v_transformers} exceeds version 4.43.4! After transformers version 4.43.4, BLOOM inference with DeepSpeed is no longer supported."
)

super().__init__(**kwargs)

# All model specific things should be defined here instead of the base class.
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(self, weight_shape=None, weight=None, bias=None):
self.offset = 2
super().__init__(weight_shape, weight=weight)

def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()

Expand Down
17 changes: 9 additions & 8 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
#mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)

# 1. Create AutoTP object
_autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl)
_autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl,
config.keep_module_on_host)

# 2. Set the tensor parallelism config
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
Expand Down Expand Up @@ -342,13 +343,11 @@ def set_lm_head(module):
module.lm_head, "weight") and module.lm_head.weight.is_meta:
module.lm_head.weight = embedding_weight
# enable tensor parallel for the last linear
if hasattr(module, "lm_head") and hasattr(module.lm_head,
"weight") and not module.lm_head.weight.is_meta and isinstance(
module.lm_head, torch.nn.Linear):
if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and isinstance(
module.lm_head, torch.nn.Linear):
module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
elif hasattr(module, "embed_out") and hasattr(module.embed_out,
"weight") and not module.embed_out.weight.is_meta and isinstance(
module.embed_out, torch.nn.Linear):
elif hasattr(module, "embed_out") and hasattr(module.embed_out, "weight") and isinstance(
module.embed_out, torch.nn.Linear):
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
elif hasattr(module, "language_model") and hasattr(module.language_model, "lm_head"):
module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head")
Expand Down Expand Up @@ -389,7 +388,6 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
checkpoint=checkpoint_file)
pbar.update(1)
gc.collect()
replaced_module = set_lm_head(replaced_module)
# conv2d tp module replace
# Now is for yuan model. Add model list and conv policy to decide whether to replace conv.
if 'Yuan' in str(replaced_module):
Expand All @@ -399,6 +397,9 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
orig_class=orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple)
# AutoTP default set lm_head tp
if not config.replace_with_kernel_inject:
replaced_module = set_lm_head(replaced_module)

quantizer = GroupQuantizer(q_int8=quantize)
world_size = dist.get_world_size() if dist.is_initialized() else 1
Expand Down
7 changes: 6 additions & 1 deletion deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ def get_num_attention_heads():
def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
last_linear = ["lm_head", "embed_out"]
# MoE MLP layer use near even division will get better perf.
moe_mlp_layer = ["gate_proj", "up_proj", "down_proj", "w1", "w2", "w3"]
not_moe_mlp_layer = True
if name != None and any(s in str(name) for s in moe_mlp_layer):
not_moe_mlp_layer = False
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division
if rank == None:
rank = dist.get_rank()
if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str(
name) not in last_linear:
name) not in last_linear and not_moe_mlp_layer:
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
Expand Down
Loading

0 comments on commit 16a016f

Please sign in to comment.