Skip to content

Commit

Permalink
Merge branch 'main' into fix-broadcast-gpu-oom
Browse files Browse the repository at this point in the history
  • Loading branch information
haolin-nju committed Jan 6, 2025
2 parents 4526338 + c93c285 commit b4d9ee1
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 9 deletions.
15 changes: 15 additions & 0 deletions chatlearn/models/megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""megatron"""
5 changes: 4 additions & 1 deletion chatlearn/models/vllm/hooks/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models import llama
from vllm.model_executor.models import qwen2
from vllm.model_executor.models import qwen2, qwen2_moe

from chatlearn.utils.vllm_import_helper import LlamaForCausalLM
from chatlearn.utils.vllm_import_helper import QWenLMHeadModel
from chatlearn.utils.vllm_import_helper import Qwen2ForCausalLM
from chatlearn.utils.vllm_import_helper import Qwen2MoeForCausalLM
from chatlearn.utils.vllm_import_helper import get_model_architecture
from chatlearn.utils.utils import get_use_legacy_models

Expand Down Expand Up @@ -89,6 +90,8 @@ def load_model(self, *, model_config,
self.load_config.model_loader_extra_config["load"] is not None:
qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict
qwen2.Qwen2ForCausalLM.load_weights = load_weights
qwen2_moe.Qwen2MoeForCausalLM.load_state_dict = load_state_dict
qwen2_moe.Qwen2MoeForCausalLM.load_weights = load_weights
llama.LlamaForCausalLM.load_state_dict = load_state_dict
llama.LlamaForCausalLM.load_weights = load_weights
model.load_weights(self.load_config.model_loader_extra_config)
Expand Down
53 changes: 51 additions & 2 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, *args, **kwargs):

self.tokenizer = None
self._model = None
self.set_vllm_pp_layer_partition()

def add_extra_args(self, parser):
"""
Expand Down Expand Up @@ -127,12 +128,60 @@ def setup_vllm(self, workers):
disable_log_requests=self.model_args.get("disable_log_requests", True),
disable_log_stats=self.model_args.get("disable_log_stats", True),
trust_remote_code=True,
# TODO(jiangle.jl): support non-eager mode.
enforce_eager=True,
enforce_eager=self.model_args.get("enforce_eager", False),
disable_custom_all_reduce=True,
distributed_executor_backend="ray")
self.tokenizer = self.llm.llm_engine.tokenizer

def set_vllm_pp_layer_partition(self):
pipeline_world_size = self.module_args.pipeline_model_parallel_size
num_layers = self.model_args.get("num_layers")
remainder = num_layers % pipeline_world_size

if not self.model_args.get("allow_padding_num_layers", None):
assert remainder == 0, \
f"expect num_layers % pipeline_model_size == 0 when VLLM_PP_LAYER_PARTITION is not set. \
while num_layers = {num_layers} pipeline_model_size = {pipeline_world_size}"
return

if remainder > 0:
assert not self.model_args.get("standalone_embedding_stage", False), \
"not support standalone embedding stage if allow_padding_num_layers is true"
# pad num_layers to make num_layers % pipeline_model_parallel_size == 0
num_layers_with_padding = num_layers - remainder + pipeline_world_size
else:
num_layers_with_padding = num_layers
num_layers_without_padding = num_layers
num_layers = num_layers_with_padding
num_layers_per_stage_with_padding = (
num_layers // pipeline_world_size)

# Each stage gets a contiguous set of layers.
if self.model_args.get("pipeline_layers", None) is not None:
rank_sizes = self.model_args.get("pipeline_layers", None)
assert isinstance(rank_sizes, list) and all(isinstance(ele, int) for ele in rank_sizes), \
f"pipeline_layers expected to be list, and num layer of each stage to be integer, while {rank_sizes}."
else:
rank_sizes = [num_layers_per_stage_with_padding] * pipeline_world_size
num_padding = num_layers - num_layers_without_padding
if num_padding > 0:
assert num_padding == 2, \
"Support num_padding_lsyers == 2 when applies inbalanced pp. Please set `args.pipeline_layers` for VLLMModule."

for _index in range(-1, num_padding - 1):
rank_sizes[_index] -= 1
assert len(rank_sizes) == pipeline_world_size

# set env variable VLLM_PP_LAYER_PARTITION
vllm_pp_layer_partition = ",".join([str(ele) for ele in rank_sizes])
if os.getenv("VLLM_PP_LAYER_PARTITION", None) is not None:
env_vllm_pp_layer_partition = os.getenv("VLLM_PP_LAYER_PARTITION", None)
if vllm_pp_layer_partition != env_vllm_pp_layer_partition:
self._logger.warning(
f"expect VLLM_PP_LAYER_PARTITION to be {vllm_pp_layer_partition}, while {env_vllm_pp_layer_partition}")
os.environ["VLLM_PP_LAYER_PARTITION"] = vllm_pp_layer_partition
self._logger.info(f"Set VLLM_PP_LAYER_PARTITION={vllm_pp_layer_partition}")

def _get_sampling_params(self, is_eval):
temperature = 0.0
if not self.model_args.get("use_beam_search"):
Expand Down
16 changes: 15 additions & 1 deletion chatlearn/runtime/dist_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,27 @@ def call_vllm_engine_remote_funcs(self, func_name, *args, **kwargs):
results.append(res)
return results

def call_vllm_engine_and_workers_remote_funcs(self, func_name, *args, **kwargs):
"""
Call remote functions for vllm_engine + workers.
"""
results = []
for actor in self.all_actors:
res = self.call_actor_remote_func(actor, func_name, *args, **kwargs)
results.append(res)
res = self.call_actor_remote_func(self.vllm_engine, func_name, *args, **kwargs)
results.append(res)
return results

def add_remote_func(self):
for func_name, _ in inspect.getmembers(self.master):
# ray.actor.ActorMethod
if func_name.startswith('_') or func_name in ["peak_memory"]:
continue
if func_name in ["timer_summary", "model_setup"]:
if func_name in ["timer_summary"]:
dist_call = partial(self.call_vllm_engine_remote_funcs, func_name)
elif func_name in ["model_setup"]:
dist_call = partial(self.call_vllm_engine_and_workers_remote_funcs, func_name)
else: # needed to check for other call_funs.
dist_call = partial(self.call_remote_funcs, func_name)
setattr(self, func_name, dist_call)
Expand Down
6 changes: 4 additions & 2 deletions chatlearn/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,10 @@ def learn(self):
self.timers("sync_parameters").start()
self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync)
self.timers("sync_parameters").stop()
logger.info(f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])}")
logger.info(f"{LOG_START} " + get_full_proc_memory_info('After first param sync'))
logger.info(
f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} " \
+ get_full_proc_memory_info('After first param sync')
)
self._data_loader = data_loader
for episode_id in range(self._start_episode, self.runtime_args.num_episode):
if self.runtime_args.nsys:
Expand Down
11 changes: 8 additions & 3 deletions chatlearn/synchronizer/parameter_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,14 +994,18 @@ def validate_sync_results_parallel(self, actor_mappings_list:List, requires_grad
else:
execute_in_parallel(self.validate_sync_results, args)

def _calculate_max_workers(self, sorted_send_actors, actor_mapping):
def _calculate_max_workers(self, sorted_send_actors, actor_mappings=None):
max_workers = get_args().runtime_args.param_sync_max_workers
if max_workers is None:
max_workers = max(self.src_model.total_gpu // 8, 1)
if max_workers == -1:
if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST:
max_workers = len(sorted_send_actors)
else:
assert actor_mappings is not None, (
"actor_mappings should not be None when max_workers is -1 and "
"communication type for parameter synchronization is not broadcast."
)
max_workers = len(sorted_send_actors) * len(actor_mappings[sorted_send_actors[0]])
return max_workers

Expand Down Expand Up @@ -1387,19 +1391,20 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False):
if self.concurrent_comm:
assert self.dst_model.use_vllm_backend

max_workers = self._calculate_max_workers(self.send_actors_to_regroup_routed_experts)
if self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER:
# allgather routed experts only
self.sync_allgather_multi_threads(
[self.send_actors_to_regroup_routed_experts],
max_workers=1,
max_workers=max_workers,
requires_grad=requires_grad,
group_name=self.group_name + "_allgather",
filter_fn=self.routed_experts_filter)
elif self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL:
# alltoall routed experts only
self.sync_alltoall_multi_threads(
[self.send_actors_to_regroup_routed_experts],
max_workers=1,
max_workers=max_workers,
requires_grad=requires_grad,
filter_fn=self.routed_experts_filter)

Expand Down

0 comments on commit b4d9ee1

Please sign in to comment.