From b4ad394a8a4890ee222f2f5939d74cc941b1e0f6 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 26 Dec 2024 11:53:24 +0800 Subject: [PATCH] init process group for param sync. --- chatlearn/models/base_module.py | 54 ++++++++++-- .../models/vllm/hooks/ray_gpu_executor.py | 5 ++ chatlearn/models/vllm_module_v2.py | 84 +++++++++---------- chatlearn/runtime/dist_actor.py | 8 +- chatlearn/schedule/model_manager.py | 5 +- chatlearn/synchronizer/parameter_sync.py | 13 ++- .../configs/llama2/vllm_param_sync.yaml | 1 + .../megatron/configs/llama2/vllm_rlhf.yaml | 1 + .../megatron/models/vllm_policy_inference.py | 12 +-- .../tests/test_unbalanced_param_sync.sh | 2 +- 10 files changed, 116 insertions(+), 69 deletions(-) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index b09df943..225fee3c 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -39,6 +39,22 @@ from chatlearn.launcher import dlc_utils +from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetComputeRunningProcesses, nvmlShutdown + +def get_gpu_usage(): + nvmlInit() + device_count = nvmlDeviceGetCount() + gpu_usage = {} + for i in range(device_count): + handle = nvmlDeviceGetHandleByIndex(i) + processes = nvmlDeviceGetComputeRunningProcesses(handle) + for process in processes: + if process.pid == os.getpid(): + gpu_usage[i] = process.pid + nvmlShutdown() + return gpu_usage + + class BaseModule: """BaseModule is the base class for Base models. @@ -130,6 +146,7 @@ def __init__(self, name, args=None, replica_id=0): self._sync_buffer = defaultdict(list) self._expert_sync_buffer = {} self._synchronizer = None + self.col_groups = {} def get_sync_buffer(self): return self._sync_buffer @@ -485,7 +502,8 @@ def setup_collective_group(self, rank, world_size, backend, group_name): """ self._group_names.append(group_name) self._world_size = world_size - col.init_collective_group( + # breakpoint() + self.col_groups[group_name] = col.init_collective_group( world_size, rank, backend=backend, group_name=group_name) def _destroy_collective_group(self, group_name): @@ -617,6 +635,7 @@ def _set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_t if self._synchronizer is not None: params_to_sync_list = self._synchronizer.transform_parameters(params_to_sync_list) parameters_to_sync[pipe_stage] = params_to_sync_list + print(f"{self} params_to_sync_list: {len(params_to_sync_list)} {[name for name,_ in params_to_sync_list]}") return parameters_to_sync def set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to_sync=None): @@ -675,7 +694,8 @@ def get_parameter_shape(self, pipe_stage=0, parameters_to_sync=None): parameters_to_sync = self._parameters_to_sync parameters_shape = [] for name, param in parameters_to_sync[pipe_stage]: - if self._expert_sync_buffer and name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed: + if self._expert_sync_buffer and name in self._expert_sync_buffer and \ + self._synchronizer and self._synchronizer.is_parameter_changed: parameters_shape.append((name, self._expert_sync_buffer[name].shape)) else: parameters_shape.append((name, param.shape)) @@ -693,12 +713,13 @@ def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False, regroup=False): assert pipe_stage in self._parameters_to_sync and len(self._parameters_to_sync[pipe_stage]) > 0 for name0, param in self._parameters_to_sync[pipe_stage]: if name0 == name: - if name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed: + if name in self._expert_sync_buffer and self._synchronizer and \ + self._synchronizer.is_parameter_changed: param = self._expert_sync_buffer[name] regroup_routed_experts = True else: regroup_routed_experts = False - if regroup: + if regroup and self._synchronizer: param = self._synchronizer.regroup_params_to_sync( name, param.data, @@ -740,6 +761,7 @@ def send_recv_parameter(self, rank, group_name, func, pipe_stage=0): func(param, rank, group_name) def alltoall_routed_expert_parameter(self, pipe_stage=0): + assert self._synchronizer is not None for name, param in self._parameters_to_sync[pipe_stage]: param, state = self._synchronizer.alltoall_routed_experts( name, @@ -751,6 +773,7 @@ def alltoall_routed_expert_parameter(self, pipe_stage=0): self._expert_sync_buffer[name] = param def allgather_routed_expert_parameter(self, group_name, pipe_stage=0): + assert self._synchronizer is not None for name, param in self._parameters_to_sync[pipe_stage]: param, state = self._synchronizer.allgather_routed_experts( name, @@ -767,18 +790,39 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): :meta private: """ tensors = [] + # breakpoint() + print(f"debug from {src_rank} to {rank} self._parameters_to_sync[{pipe_stage}]: {len(self._parameters_to_sync[pipe_stage])} {self._parameters_to_sync[pipe_stage][0][0]}", flush=True) + print(f"{self} get_visible_gpus: {self.get_visible_gpus()}") for name, param in self._parameters_to_sync[pipe_stage]: - if self._expert_sync_buffer and name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed: + if self._expert_sync_buffer and name in self._expert_sync_buffer and \ + (self._synchronizer and self._synchronizer.is_parameter_changed): tensors.append(self._expert_sync_buffer[name]) else: tensors.append(param.data) + # if src_rank != rank: + # print(f"self.worker.device: {self.worker.device} current_device: {torch.cuda.current_device()} tensors[0].device: {tensors[0].device}") + + # breakpoint() + print(f"debug from {src_rank} to {rank} tensors: {len(tensors)} {[tensors[0].shape]}", flush=True) + print(f"debug get_gpu_usage {self}: {get_gpu_usage()}", flush=True) + tensor_device = tensors[0].device + print(f"debug tensor is on device {self}: {tensor_device}") + assert len(tensors) > 0 dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) tensor_changed = rank != src_rank + print(f"debug from {src_rank} to {rank} dense_buckets: {len(dense_buckets)} sparse_bucket: {len(sparse_bucket)}", flush=True) + # breakpoint() + print(f"self.col_groups: {self.col_groups}") for bucket in dense_buckets: + # for ele in bucket: + # breakpoint() + # col.broadcast(ele, src_rank, group_name) + # coalesced_comm_dense([ele], col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) + coalesced_comm_dense(bucket, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) for param in sparse_bucket: diff --git a/chatlearn/models/vllm/hooks/ray_gpu_executor.py b/chatlearn/models/vllm/hooks/ray_gpu_executor.py index 539d7b01..1c3a0307 100644 --- a/chatlearn/models/vllm/hooks/ray_gpu_executor.py +++ b/chatlearn/models/vllm/hooks/ray_gpu_executor.py @@ -200,6 +200,11 @@ def sort_by_driver_then_worker_ip(worker): print(f"debug 1111 self.workers 4: {[id(ele) for ele in self.workers]} {self.workers}") self._run_workers("init_device") + from chatlearn.utils import future + refs = [self.workers[rank].init_device.remote() for rank in range(len(self.workers))] + future.wait(refs) + + self._run_workers("load_model", # use_dummy_driver=True, max_concurrent_workers=self.parallel_config. diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index f1393cd0..284cd7fa 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -29,6 +29,8 @@ from chatlearn.utils.vllm_import_helper import get_pipeline_model_parallel_rank from chatlearn.utils.vllm_import_helper import TextTokensPrompt from .torch_module import TorchModule +from vllm.worker.worker import Worker +from chatlearn.utils.vllm_utils import initialize_vllm class VLLMModuleV2(TorchModule, RayWorkerWrapper): @@ -45,24 +47,49 @@ def __init__(self, *args, **kwargs): f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}" # TorchModule.__init__(self, *args) self.local_rank = 0 - os.environ['LOCAL_RANK'] = '0' + # os.environ['LOCAL_RANK'] = '0' if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() - # def __init__(self, *args, **kwargs): - # super().__init__(*args, **kwargs) - # self.local_rank = 0 - # os.environ['LOCAL_RANK'] = '0' - # if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: - # RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called - # os.environ['VLLM_HOST_IP'] = self.get_address() - # self.llm_engine = None self.tokenizer = None self._tp_rank = None self._pp_rank = None self._model = None + def add_extra_args(self, parser): + """ + Add extra arguments for vllm. + + Args + ---- + parser : ArgumentParser + Add extra arguments. + """ + group = parser.add_argument_group(title='vLLM extra arguments') + group.add_argument('--distributed-backend', default='nccl', + choices=['nccl', 'gloo'], + help='Which backend to use for distributed training.') + group.add_argument('--distributed-timeout-minutes', type=int, default=10, + help='Timeout minutes for torch.distributed.') + return parser + def init(self): + """ + :meta private: + """ + parallel_state.set_custom_all_reduce(False) + initialize_vllm(extra_args_provider=self.add_extra_args, + ignore_unknown_args=True, + args_dict=self.model_args) + def init_device(self): + return + self.worker.device = torch.device(f"cuda:{torch.cuda.current_device()}") + torch.cuda.set_device(self.device) + init_worker_distributed_environment(self.worker.parallel_config, self.worker.rank, + self.worker.distributed_init_method, + self.worker.local_rank) + # return self.worker.init_device() + def setup(self): """Set up model and load checkpoint""" super().setup() @@ -72,7 +99,7 @@ def setup(self): def setup_vllm(self, workers): # setup vllm engine in rank 0 - # os.environ['VLLM_HOST_IP'] = self.get_address() + os.environ['VLLM_HOST_IP'] = self.get_address() print(f"debug 1111 workers: {[id(ele) for ele in workers]} {workers}") set_vllm_actors(workers) @@ -168,10 +195,7 @@ def _convert_v1_inputs(self, prompts, prompt_token_ids): return inputs - async def generate_all(self, prompts, sampling_params): - pass - - async def generate_vllm(self, query, is_eval): + def generate_vllm(self, query, is_eval): print(f"debug aaaa tensor_parallel_rank: {id(self)} {self}") print(f"tensor_parallel_rank: {self.tensor_parallel_rank()}", flush=True) print(f"debug aaaa pipeline_parallel_rank: {id(self)} {self} {self.pipeline_parallel_rank()}", flush=True) @@ -182,7 +206,6 @@ async def generate_vllm(self, query, is_eval): prompts_token_ids = query[input_ids_key] seq_len = self.model_args.get("seq_length") final_outputs = [] - tasks = [] parsed_prompts = [] sampling_params = [] for i, prompt in enumerate(prompts): @@ -224,32 +247,6 @@ def num_layers(self): """ return self.llm.llm_engine.model_config.hf_config.num_hidden_layers - -# class VLLMWokerWrapper(TorchModule, RayWorkerWrapper): -# """VLLMWokerWrapper is the class for vLLM workers. - -# Args -# ---- -# name : str -# model name -# """ - - # def __init__(self, *args, **kwargs): - # # avoid overwrite methods - # methods_class1 = {method[0] for method in inspect.getmembers(TorchModule, predicate=inspect.isfunction)} - # methods_class2 = {method[0] for method in inspect.getmembers(RayWorkerWrapper, predicate=inspect.isfunction)} - # common_methods = methods_class1.intersection(methods_class2) - # # common method is '__init__' - # assert common_methods == {'__init__'}, \ - # f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}" - # TorchModule.__init__(self, *args) - # self.local_rank = 0 - # os.environ['LOCAL_RANK'] = '0' - # if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: - # RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called - # os.environ['VLLM_HOST_IP'] = self.get_address() - # self.llm_engine = None - def peak_memory(self): """ :meta private: @@ -274,10 +271,9 @@ def data_parallel_rank(self): @property def model(self): if self._model is None: - # breakpoint() - # if self.worker is not None: + assert self.worker is not None, \ + f"please set env variables `VLLM_USE_RAY_SPMD_WORKER` and `VLLM_USE_RAY_COMPILED_DAG` first." self._model = self.worker.model_runner.model - # print(f"debug 1111 self.llm: {self.llm}") return self._model def set_tp_pp_ranks(self, tp_rank, pp_rank): """ diff --git a/chatlearn/runtime/dist_actor.py b/chatlearn/runtime/dist_actor.py index 32c362a4..e355a9e2 100644 --- a/chatlearn/runtime/dist_actor.py +++ b/chatlearn/runtime/dist_actor.py @@ -233,7 +233,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.vllm_engine = None - def create_actor(self, num_gpus, placement_group, group_index, use_ray_vllm_worker=False): + def create_actor(self, num_gpus, placement_group, group_index): kwargs = { "worker_module_name": "vllm.worker.worker", "worker_class_name": "Worker", @@ -245,12 +245,8 @@ def create_actor(self, num_gpus, placement_group, group_index, use_ray_vllm_work # self.model.engine = self.vllm_engine # else: # self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs) - # if use_ray_vllm_worker: - # from vllm.executor.ray_utils import RayWorkerWrapper - # self.all_actors.append(RayWorkerWrapper(**kwargs)) - # else: + self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs) - # return self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs) def create_engine_actor(self, num_gpus, placement_group, group_index): self.vllm_engine = self._create_actor(self.model.__class__, num_gpus, placement_group, group_index) self.model.engine = self.vllm_engine diff --git a/chatlearn/schedule/model_manager.py b/chatlearn/schedule/model_manager.py index 4f930cef..2babbebb 100644 --- a/chatlearn/schedule/model_manager.py +++ b/chatlearn/schedule/model_manager.py @@ -354,10 +354,7 @@ def _get_model_replica_from_pack(gpu_index, model_pack): replica.create_engine_actor(num_gpus, placement_group, group) # we do not want to add engine actor to all_actors replica.all_actors.pop() - if isinstance(replica.model, VLLMModuleV2): - replica.create_actor(num_gpus, placement_group, group, use_ray_vllm_worker=not replica.all_actors) - else: - replica.create_actor(num_gpus, placement_group, group) + replica.create_actor(num_gpus, placement_group, group) models_to_revert = self._find_param_recv_models(gpu_models) for model in gpu_models: diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index d4a4203b..4bb2ae02 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -69,6 +69,7 @@ def __init__(self, src_model, dst_model, group_name, frequency, error_signal): if self.num_src_tensor_parallel % 2 == 1 and self.num_dst_tensor_parallel % 2 == 1: logger.warning("Only support PARAM_SYNC_COMM_TYPE.BROADCAST when TP SIZE is even number, use P2P instead") self._comm_type = PARAM_SYNC_COMM_TYPE.P2P + print(f"self._comm_type: {self._comm_type} while expect {get_args().runtime_args.param_sync_comm_type}") self.concurrent_comm = get_args().runtime_args.concurrent_comm self._enable_lora = self.src_model.module_args.lora.enable_lora @@ -117,8 +118,10 @@ def inner_func(*args, **kwargs): def is_same_gpu(self, src_actor, dst_actor): src_gpu = self.get_or_cache(src_actor, "get_visible_gpus") dst_gpu = self.get_or_cache(dst_actor, "get_visible_gpus") + print(f"src_actor {src_actor} src_gpu: {src_gpu}, dst_actor: {dst_actor} dst_gpu: {dst_gpu}") src_address = self.get_or_cache(src_actor, "get_address") dst_address = self.get_or_cache(dst_actor, "get_address") + print(f"src_address: {src_address} dst_address: {dst_address}") return src_gpu == dst_gpu and src_address == dst_address @property @@ -1088,10 +1091,14 @@ def _multi_thread_sync_for_tp_num_mapping_eq_1( if True: for send_actor in sorted_send_actors: recv_actors = actor_mappings[send_actor] + if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: # breakpoint() - print(f"send from {self.actor2rank[send_actor]} to {[self.actor2rank[ele] for ele in recv_actors]}") - actor_groups, finalized_group_name = self.create_broadcast_group(send_actor, recv_actors, param_group=param_group) - self.sync_broadcast(actor_groups, finalized_group_name, requires_grad, filter_fn=filter_fn, param_group=param_group) + print(f"send from {self.actor2rank[send_actor]} to {[self.actor2rank[ele] for ele in recv_actors]}") + actor_groups, finalized_group_name = self.create_broadcast_group(send_actor, recv_actors, param_group=param_group) + self.sync_broadcast(actor_groups, finalized_group_name, requires_grad, filter_fn=filter_fn, param_group=param_group) + else: + for recv_actor in recv_actors: + self.sync_send_recv(send_actor, recv_actor, requires_grad, filter_fn=filter_fn, param_group=param_group) def _single_thread_sync(self, actor_mappings_list:List, requires_grad=None, filter_fn=None, param_group="default"): assert len(actor_mappings_list) == 1 diff --git a/examples/megatron/configs/llama2/vllm_param_sync.yaml b/examples/megatron/configs/llama2/vllm_param_sync.yaml index 4be5bc65..5e99630e 100644 --- a/examples/megatron/configs/llama2/vllm_param_sync.yaml +++ b/examples/megatron/configs/llama2/vllm_param_sync.yaml @@ -50,3 +50,4 @@ runtime: exp_name: ${exp_name:chatlearn} debug: ${debug:False} validate_param_sync: ${validate_param_sync:False} + param_sync_comm_type: ${param_sync_comm_type:broadcast} diff --git a/examples/megatron/configs/llama2/vllm_rlhf.yaml b/examples/megatron/configs/llama2/vllm_rlhf.yaml index 3a616281..b4258ce4 100644 --- a/examples/megatron/configs/llama2/vllm_rlhf.yaml +++ b/examples/megatron/configs/llama2/vllm_rlhf.yaml @@ -83,3 +83,4 @@ runtime: exp_name: ${exp_name:chatlearn} debug: ${debug:False} validate_param_sync: ${validate_param_sync:False} + param_sync_comm_type: ${param_sync_comm_type:broadcast} diff --git a/examples/megatron/models/vllm_policy_inference.py b/examples/megatron/models/vllm_policy_inference.py index 0672215d..ba4a9e55 100644 --- a/examples/megatron/models/vllm_policy_inference.py +++ b/examples/megatron/models/vllm_policy_inference.py @@ -146,14 +146,14 @@ def decode_internal(self, batched_outputs): class VLLMPolicyInferenceAsync(VLLMPolicyInference): """VLLMPolicyInferenceAsync is the model for VLLMModuleV2, which uses async generate API""" - async def eval_forward(self, data, iteration=0): # pylint: disable=invalid-overridden-method - return await self._forward_step(data, iteration, True) + def eval_forward(self, data, iteration=0): # pylint: disable=invalid-overridden-method + return self._forward_step(data, iteration, True) - async def _forward_step(self, data, iteration, is_eval): # pylint: disable=unused-argument,invalid-overridden-method - outputs = await self.generate_vllm(data, is_eval) + def _forward_step(self, data, iteration, is_eval): # pylint: disable=unused-argument,invalid-overridden-method + outputs = self.generate_vllm(data, is_eval) if outputs is not None: rets = self.decode_internal(outputs) return rets - async def forward_step(self, data, iteration=0): # pylint: disable=invalid-overridden-method - return await self._forward_step(data, iteration, False) + def forward_step(self, data, iteration=0): # pylint: disable=invalid-overridden-method + return self._forward_step(data, iteration, False) diff --git a/examples/megatron/tests/test_unbalanced_param_sync.sh b/examples/megatron/tests/test_unbalanced_param_sync.sh index 01fce7f6..979d6894 100644 --- a/examples/megatron/tests/test_unbalanced_param_sync.sh +++ b/examples/megatron/tests/test_unbalanced_param_sync.sh @@ -38,7 +38,7 @@ if [[ "$model_size" == "llama2-7B" ]]; then fi export train_micro_batch_size=16 export max_num_batched_tokens=65536 - export gpu_memory_utilization=0.8 + export gpu_memory_utilization=0.5 export num_gpu_policy=4 export num_gpu_ppo_policy=4