Skip to content

Commit

Permalink
Merge branch 'master' into loadams/fix-torch-issues
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 16, 2025
2 parents b4066f5 + 05eaf3d commit 95b453e
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 21 deletions.
2 changes: 1 addition & 1 deletion accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_accelerator():
if accelerator_name is None:
# borrow this log from PR#5084
if accel_logger is not None:
accel_logger.warn(
accel_logger.warning(
"Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.")
# cpu added as catch-all when accelerator detection fails
accelerator_name = "cpu"
Expand Down
9 changes: 9 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
values for :any:`DeepSpeedMoEConfig`.
"""

keep_module_on_host: bool = False
"""
When loading checkpoints to model parameters, they are moved to the device. In very large models
this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on
host and not move them directly to the device (giving an option to quantize checkpoint data before
moving it to the device for example).
Set only for models with injection policies and auto TP.
"""

quant: QuantizationConfig = {}
"""
NOTE: only works for int8 dtype.
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(self, model, config):
is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta'
if is_meta_device:
self.module.to_empty(device=device)
else:
elif not config.keep_module_on_host:
self.module.to(device)

if config.tensor_parallel.tp_size > 1:
Expand Down
33 changes: 22 additions & 11 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


def move(tensor, device):
def move(tensor, device, copy=True):
if tensor.is_meta:
return torch.empty_like(tensor, device=device)
else:
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
return tensor.to(device, copy=True)
return tensor.to(device, copy=copy)


class ReplaceWithTensorSlicing:
Expand Down Expand Up @@ -189,7 +189,14 @@ def load(module, state_dict, prefix, mp_group=None):

class AutoTP():

def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl):
def __init__(self,
module,
all_reduce_linears,
prefix,
state_dict,
linear_layer_setting,
orig_layer_impl,
keep_module_on_host=False):
self.module = module
self.all_reduce_linears = all_reduce_linears
self.prefix = prefix
Expand All @@ -201,6 +208,7 @@ def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_
self.orig_layer_impl = orig_layer_impl
self.linear_policies = None
self.conv_linear_layer = False
self.keep_module_on_host = keep_module_on_host

def in_module_list(module, module_list):
for item in module_list:
Expand Down Expand Up @@ -331,6 +339,10 @@ def set_tensor_parallel_config(self, mp_size, mp_group):
def _replace(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
return_new_copy = not self.keep_module_on_host
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
Expand Down Expand Up @@ -368,18 +380,17 @@ def _replace(self, child, name, conv_linear_layer):
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
move(child.bias,
get_accelerator().current_device_name())), self.mp_group)
move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group)
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
else:

# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
Expand All @@ -392,22 +403,22 @@ def _replace(self, child, name, conv_linear_layer):
#The copy is a regular copy, The shape of dst and src is the same
data_dc = move(
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())
device_name, return_new_copy)

bias_data_dc = None if child.bias is None else move(
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())
device_name, return_new_copy)
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name())
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
else:
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
#mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)

# 1. Create AutoTP object
_autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl)
_autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl,
config.keep_module_on_host)

# 2. Set the tensor parallelism config
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec

tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
if self.mpu is None:
logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
logger.warning("MPU is not provided, setting tp size to 1 in checkpoint loading.")
tp_world_size = 1
else:
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3120,7 +3120,7 @@ def _get_all_zero_checkpoints(self, load_dir, tag):
if bf16_mode is not self.bfloat16_enabled():
checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16
engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16
logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
logger.warning(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names)

return None
Expand Down Expand Up @@ -3276,7 +3276,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa

local_expert_id = None
if not m:
logger.warn(f'No expert found in key {key}.')
logger.warning(f'No expert found in key {key}.')
else:
local_expert_id = m.group(1)

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, l
def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration):
if 'betas' not in optimizer.defaults:
optimizer_name = type(optimizer).__name__
logger.warn(
logger.warning(
f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
)
self.cycle_momentum = False
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def _configure_moe_settings(self):
assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
# NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion
if not self.partition_gradients and not self.contiguous_gradients:
logger.warn(
logger.warning(
"ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.")
assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"

Expand Down
21 changes: 19 additions & 2 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dty


@pytest.mark.seq_inference
@pytest.mark.parametrize('keep_module_on_host', [True, False])
@pytest.mark.parametrize(
"model_w_task",
[("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")],
Expand All @@ -570,6 +571,7 @@ def test(
inf_kwargs,
assert_fn,
dtype,
keep_module_on_host,
):
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
if invalid_test_msg:
Expand All @@ -592,13 +594,20 @@ def test(
framework="pt")
bs_output = pipe(query, **inf_kwargs)

pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=world_size,
dtype=dtype,
keep_module_on_host=keep_module_on_host)
ds_output = pipe(query, **inf_kwargs)

print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)

if keep_module_on_host:
for name, param in model.named_parameters():
assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"

@pytest.mark.world_size(3)
def test_odd_world_size(
self,
Expand All @@ -607,6 +616,7 @@ def test_odd_world_size(
inf_kwargs,
assert_fn,
dtype,
keep_module_on_host,
):
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
if invalid_test_msg:
Expand All @@ -624,13 +634,20 @@ def test_odd_world_size(
framework="pt")
bs_output = pipe(query, **inf_kwargs)

pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=world_size,
dtype=dtype,
keep_module_on_host=keep_module_on_host)
ds_output = pipe(query, **inf_kwargs)

print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)

if keep_module_on_host:
for name, param in model.named_parameters():
assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"


@pytest.mark.nightly
@pytest.mark.parametrize(
Expand Down

0 comments on commit 95b453e

Please sign in to comment.