Skip to content

Commit

Permalink
init process group for param sync.
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 committed Dec 26, 2024
1 parent 5ba5dc3 commit b4ad394
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 69 deletions.
54 changes: 49 additions & 5 deletions chatlearn/models/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@
from chatlearn.launcher import dlc_utils


from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetComputeRunningProcesses, nvmlShutdown

def get_gpu_usage():
nvmlInit()
device_count = nvmlDeviceGetCount()
gpu_usage = {}
for i in range(device_count):
handle = nvmlDeviceGetHandleByIndex(i)
processes = nvmlDeviceGetComputeRunningProcesses(handle)
for process in processes:
if process.pid == os.getpid():
gpu_usage[i] = process.pid
nvmlShutdown()
return gpu_usage


class BaseModule:
"""BaseModule is the base class for Base models.
Expand Down Expand Up @@ -130,6 +146,7 @@ def __init__(self, name, args=None, replica_id=0):
self._sync_buffer = defaultdict(list)
self._expert_sync_buffer = {}
self._synchronizer = None
self.col_groups = {}

def get_sync_buffer(self):
return self._sync_buffer
Expand Down Expand Up @@ -485,7 +502,8 @@ def setup_collective_group(self, rank, world_size, backend, group_name):
"""
self._group_names.append(group_name)
self._world_size = world_size
col.init_collective_group(
# breakpoint()
self.col_groups[group_name] = col.init_collective_group(
world_size, rank, backend=backend, group_name=group_name)

def _destroy_collective_group(self, group_name):
Expand Down Expand Up @@ -617,6 +635,7 @@ def _set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_t
if self._synchronizer is not None:
params_to_sync_list = self._synchronizer.transform_parameters(params_to_sync_list)
parameters_to_sync[pipe_stage] = params_to_sync_list
print(f"{self} params_to_sync_list: {len(params_to_sync_list)} {[name for name,_ in params_to_sync_list]}")
return parameters_to_sync

def set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to_sync=None):
Expand Down Expand Up @@ -675,7 +694,8 @@ def get_parameter_shape(self, pipe_stage=0, parameters_to_sync=None):
parameters_to_sync = self._parameters_to_sync
parameters_shape = []
for name, param in parameters_to_sync[pipe_stage]:
if self._expert_sync_buffer and name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed:
if self._expert_sync_buffer and name in self._expert_sync_buffer and \
self._synchronizer and self._synchronizer.is_parameter_changed:
parameters_shape.append((name, self._expert_sync_buffer[name].shape))
else:
parameters_shape.append((name, param.shape))
Expand All @@ -693,12 +713,13 @@ def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False, regroup=False):
assert pipe_stage in self._parameters_to_sync and len(self._parameters_to_sync[pipe_stage]) > 0
for name0, param in self._parameters_to_sync[pipe_stage]:
if name0 == name:
if name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed:
if name in self._expert_sync_buffer and self._synchronizer and \
self._synchronizer.is_parameter_changed:
param = self._expert_sync_buffer[name]
regroup_routed_experts = True
else:
regroup_routed_experts = False
if regroup:
if regroup and self._synchronizer:
param = self._synchronizer.regroup_params_to_sync(
name,
param.data,
Expand Down Expand Up @@ -740,6 +761,7 @@ def send_recv_parameter(self, rank, group_name, func, pipe_stage=0):
func(param, rank, group_name)

def alltoall_routed_expert_parameter(self, pipe_stage=0):
assert self._synchronizer is not None
for name, param in self._parameters_to_sync[pipe_stage]:
param, state = self._synchronizer.alltoall_routed_experts(
name,
Expand All @@ -751,6 +773,7 @@ def alltoall_routed_expert_parameter(self, pipe_stage=0):
self._expert_sync_buffer[name] = param

def allgather_routed_expert_parameter(self, group_name, pipe_stage=0):
assert self._synchronizer is not None
for name, param in self._parameters_to_sync[pipe_stage]:
param, state = self._synchronizer.allgather_routed_experts(
name,
Expand All @@ -767,18 +790,39 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
:meta private:
"""
tensors = []
# breakpoint()
print(f"debug from {src_rank} to {rank} self._parameters_to_sync[{pipe_stage}]: {len(self._parameters_to_sync[pipe_stage])} {self._parameters_to_sync[pipe_stage][0][0]}", flush=True)
print(f"{self} get_visible_gpus: {self.get_visible_gpus()}")
for name, param in self._parameters_to_sync[pipe_stage]:
if self._expert_sync_buffer and name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed:
if self._expert_sync_buffer and name in self._expert_sync_buffer and \
(self._synchronizer and self._synchronizer.is_parameter_changed):
tensors.append(self._expert_sync_buffer[name])
else:
tensors.append(param.data)
# if src_rank != rank:
# print(f"self.worker.device: {self.worker.device} current_device: {torch.cuda.current_device()} tensors[0].device: {tensors[0].device}")

# breakpoint()
print(f"debug from {src_rank} to {rank} tensors: {len(tensors)} {[tensors[0].shape]}", flush=True)
print(f"debug get_gpu_usage {self}: {get_gpu_usage()}", flush=True)
tensor_device = tensors[0].device
print(f"debug tensor is on device {self}: {tensor_device}")


assert len(tensors) > 0
dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb)
debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger)
tensor_changed = rank != src_rank

print(f"debug from {src_rank} to {rank} dense_buckets: {len(dense_buckets)} sparse_bucket: {len(sparse_bucket)}", flush=True)
# breakpoint()
print(f"self.col_groups: {self.col_groups}")
for bucket in dense_buckets:
# for ele in bucket:
# breakpoint()
# col.broadcast(ele, src_rank, group_name)
# coalesced_comm_dense([ele], col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed)

coalesced_comm_dense(bucket, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed)

for param in sparse_bucket:
Expand Down
5 changes: 5 additions & 0 deletions chatlearn/models/vllm/hooks/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ def sort_by_driver_then_worker_ip(worker):
print(f"debug 1111 self.workers 4: {[id(ele) for ele in self.workers]} {self.workers}")

self._run_workers("init_device")
from chatlearn.utils import future
refs = [self.workers[rank].init_device.remote() for rank in range(len(self.workers))]
future.wait(refs)


self._run_workers("load_model",
# use_dummy_driver=True,
max_concurrent_workers=self.parallel_config.
Expand Down
84 changes: 40 additions & 44 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from chatlearn.utils.vllm_import_helper import get_pipeline_model_parallel_rank
from chatlearn.utils.vllm_import_helper import TextTokensPrompt
from .torch_module import TorchModule
from vllm.worker.worker import Worker
from chatlearn.utils.vllm_utils import initialize_vllm


class VLLMModuleV2(TorchModule, RayWorkerWrapper):
Expand All @@ -45,24 +47,49 @@ def __init__(self, *args, **kwargs):
f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}"
# TorchModule.__init__(self, *args)
self.local_rank = 0
os.environ['LOCAL_RANK'] = '0'
# os.environ['LOCAL_RANK'] = '0'
if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs:
RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called
os.environ['VLLM_HOST_IP'] = self.get_address()

# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.local_rank = 0
# os.environ['LOCAL_RANK'] = '0'
# if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs:
# RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called
# os.environ['VLLM_HOST_IP'] = self.get_address()
# self.llm_engine = None
self.tokenizer = None
self._tp_rank = None
self._pp_rank = None
self._model = None

def add_extra_args(self, parser):
"""
Add extra arguments for vllm.
Args
----
parser : ArgumentParser
Add extra arguments.
"""
group = parser.add_argument_group(title='vLLM extra arguments')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
return parser
def init(self):
"""
:meta private:
"""
parallel_state.set_custom_all_reduce(False)
initialize_vllm(extra_args_provider=self.add_extra_args,
ignore_unknown_args=True,
args_dict=self.model_args)
def init_device(self):
return
self.worker.device = torch.device(f"cuda:{torch.cuda.current_device()}")
torch.cuda.set_device(self.device)
init_worker_distributed_environment(self.worker.parallel_config, self.worker.rank,
self.worker.distributed_init_method,
self.worker.local_rank)
# return self.worker.init_device()

def setup(self):
"""Set up model and load checkpoint"""
super().setup()
Expand All @@ -72,7 +99,7 @@ def setup(self):

def setup_vllm(self, workers):
# setup vllm engine in rank 0
# os.environ['VLLM_HOST_IP'] = self.get_address()
os.environ['VLLM_HOST_IP'] = self.get_address()
print(f"debug 1111 workers: {[id(ele) for ele in workers]} {workers}")
set_vllm_actors(workers)

Expand Down Expand Up @@ -168,10 +195,7 @@ def _convert_v1_inputs(self, prompts, prompt_token_ids):

return inputs

async def generate_all(self, prompts, sampling_params):
pass

async def generate_vllm(self, query, is_eval):
def generate_vllm(self, query, is_eval):
print(f"debug aaaa tensor_parallel_rank: {id(self)} {self}")
print(f"tensor_parallel_rank: {self.tensor_parallel_rank()}", flush=True)
print(f"debug aaaa pipeline_parallel_rank: {id(self)} {self} {self.pipeline_parallel_rank()}", flush=True)
Expand All @@ -182,7 +206,6 @@ async def generate_vllm(self, query, is_eval):
prompts_token_ids = query[input_ids_key]
seq_len = self.model_args.get("seq_length")
final_outputs = []
tasks = []
parsed_prompts = []
sampling_params = []
for i, prompt in enumerate(prompts):
Expand Down Expand Up @@ -224,32 +247,6 @@ def num_layers(self):
"""
return self.llm.llm_engine.model_config.hf_config.num_hidden_layers


# class VLLMWokerWrapper(TorchModule, RayWorkerWrapper):
# """VLLMWokerWrapper is the class for vLLM workers.

# Args
# ----
# name : str
# model name
# """

# def __init__(self, *args, **kwargs):
# # avoid overwrite methods
# methods_class1 = {method[0] for method in inspect.getmembers(TorchModule, predicate=inspect.isfunction)}
# methods_class2 = {method[0] for method in inspect.getmembers(RayWorkerWrapper, predicate=inspect.isfunction)}
# common_methods = methods_class1.intersection(methods_class2)
# # common method is '__init__'
# assert common_methods == {'__init__'}, \
# f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}"
# TorchModule.__init__(self, *args)
# self.local_rank = 0
# os.environ['LOCAL_RANK'] = '0'
# if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs:
# RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called
# os.environ['VLLM_HOST_IP'] = self.get_address()
# self.llm_engine = None

def peak_memory(self):
"""
:meta private:
Expand All @@ -274,10 +271,9 @@ def data_parallel_rank(self):
@property
def model(self):
if self._model is None:
# breakpoint()
# if self.worker is not None:
assert self.worker is not None, \
f"please set env variables `VLLM_USE_RAY_SPMD_WORKER` and `VLLM_USE_RAY_COMPILED_DAG` first."
self._model = self.worker.model_runner.model
# print(f"debug 1111 self.llm: {self.llm}")
return self._model
def set_tp_pp_ranks(self, tp_rank, pp_rank):
"""
Expand Down
8 changes: 2 additions & 6 deletions chatlearn/runtime/dist_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vllm_engine = None

def create_actor(self, num_gpus, placement_group, group_index, use_ray_vllm_worker=False):
def create_actor(self, num_gpus, placement_group, group_index):
kwargs = {
"worker_module_name": "vllm.worker.worker",
"worker_class_name": "Worker",
Expand All @@ -245,12 +245,8 @@ def create_actor(self, num_gpus, placement_group, group_index, use_ray_vllm_work
# self.model.engine = self.vllm_engine
# else:
# self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs)
# if use_ray_vllm_worker:
# from vllm.executor.ray_utils import RayWorkerWrapper
# self.all_actors.append(RayWorkerWrapper(**kwargs))
# else:

self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs)
# return self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs)
def create_engine_actor(self, num_gpus, placement_group, group_index):
self.vllm_engine = self._create_actor(self.model.__class__, num_gpus, placement_group, group_index)
self.model.engine = self.vllm_engine
Expand Down
5 changes: 1 addition & 4 deletions chatlearn/schedule/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,7 @@ def _get_model_replica_from_pack(gpu_index, model_pack):
replica.create_engine_actor(num_gpus, placement_group, group)
# we do not want to add engine actor to all_actors
replica.all_actors.pop()
if isinstance(replica.model, VLLMModuleV2):
replica.create_actor(num_gpus, placement_group, group, use_ray_vllm_worker=not replica.all_actors)
else:
replica.create_actor(num_gpus, placement_group, group)
replica.create_actor(num_gpus, placement_group, group)

models_to_revert = self._find_param_recv_models(gpu_models)
for model in gpu_models:
Expand Down
13 changes: 10 additions & 3 deletions chatlearn/synchronizer/parameter_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, src_model, dst_model, group_name, frequency, error_signal):
if self.num_src_tensor_parallel % 2 == 1 and self.num_dst_tensor_parallel % 2 == 1:
logger.warning("Only support PARAM_SYNC_COMM_TYPE.BROADCAST when TP SIZE is even number, use P2P instead")
self._comm_type = PARAM_SYNC_COMM_TYPE.P2P
print(f"self._comm_type: {self._comm_type} while expect {get_args().runtime_args.param_sync_comm_type}")

self.concurrent_comm = get_args().runtime_args.concurrent_comm
self._enable_lora = self.src_model.module_args.lora.enable_lora
Expand Down Expand Up @@ -117,8 +118,10 @@ def inner_func(*args, **kwargs):
def is_same_gpu(self, src_actor, dst_actor):
src_gpu = self.get_or_cache(src_actor, "get_visible_gpus")
dst_gpu = self.get_or_cache(dst_actor, "get_visible_gpus")
print(f"src_actor {src_actor} src_gpu: {src_gpu}, dst_actor: {dst_actor} dst_gpu: {dst_gpu}")
src_address = self.get_or_cache(src_actor, "get_address")
dst_address = self.get_or_cache(dst_actor, "get_address")
print(f"src_address: {src_address} dst_address: {dst_address}")
return src_gpu == dst_gpu and src_address == dst_address

@property
Expand Down Expand Up @@ -1088,10 +1091,14 @@ def _multi_thread_sync_for_tp_num_mapping_eq_1(
if True:
for send_actor in sorted_send_actors:
recv_actors = actor_mappings[send_actor]
if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST:
# breakpoint()
print(f"send from {self.actor2rank[send_actor]} to {[self.actor2rank[ele] for ele in recv_actors]}")
actor_groups, finalized_group_name = self.create_broadcast_group(send_actor, recv_actors, param_group=param_group)
self.sync_broadcast(actor_groups, finalized_group_name, requires_grad, filter_fn=filter_fn, param_group=param_group)
print(f"send from {self.actor2rank[send_actor]} to {[self.actor2rank[ele] for ele in recv_actors]}")
actor_groups, finalized_group_name = self.create_broadcast_group(send_actor, recv_actors, param_group=param_group)
self.sync_broadcast(actor_groups, finalized_group_name, requires_grad, filter_fn=filter_fn, param_group=param_group)
else:
for recv_actor in recv_actors:
self.sync_send_recv(send_actor, recv_actor, requires_grad, filter_fn=filter_fn, param_group=param_group)

def _single_thread_sync(self, actor_mappings_list:List, requires_grad=None, filter_fn=None, param_group="default"):
assert len(actor_mappings_list) == 1
Expand Down
1 change: 1 addition & 0 deletions examples/megatron/configs/llama2/vllm_param_sync.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ runtime:
exp_name: ${exp_name:chatlearn}
debug: ${debug:False}
validate_param_sync: ${validate_param_sync:False}
param_sync_comm_type: ${param_sync_comm_type:broadcast}
1 change: 1 addition & 0 deletions examples/megatron/configs/llama2/vllm_rlhf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,4 @@ runtime:
exp_name: ${exp_name:chatlearn}
debug: ${debug:False}
validate_param_sync: ${validate_param_sync:False}
param_sync_comm_type: ${param_sync_comm_type:broadcast}
12 changes: 6 additions & 6 deletions examples/megatron/models/vllm_policy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,14 @@ def decode_internal(self, batched_outputs):
class VLLMPolicyInferenceAsync(VLLMPolicyInference):
"""VLLMPolicyInferenceAsync is the model for VLLMModuleV2, which uses async generate API"""

async def eval_forward(self, data, iteration=0): # pylint: disable=invalid-overridden-method
return await self._forward_step(data, iteration, True)
def eval_forward(self, data, iteration=0): # pylint: disable=invalid-overridden-method
return self._forward_step(data, iteration, True)

async def _forward_step(self, data, iteration, is_eval): # pylint: disable=unused-argument,invalid-overridden-method
outputs = await self.generate_vllm(data, is_eval)
def _forward_step(self, data, iteration, is_eval): # pylint: disable=unused-argument,invalid-overridden-method
outputs = self.generate_vllm(data, is_eval)
if outputs is not None:
rets = self.decode_internal(outputs)
return rets

async def forward_step(self, data, iteration=0): # pylint: disable=invalid-overridden-method
return await self._forward_step(data, iteration, False)
def forward_step(self, data, iteration=0): # pylint: disable=invalid-overridden-method
return self._forward_step(data, iteration, False)
2 changes: 1 addition & 1 deletion examples/megatron/tests/test_unbalanced_param_sync.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ if [[ "$model_size" == "llama2-7B" ]]; then
fi
export train_micro_batch_size=16
export max_num_batched_tokens=65536
export gpu_memory_utilization=0.8
export gpu_memory_utilization=0.5

export num_gpu_policy=4
export num_gpu_ppo_policy=4
Expand Down

0 comments on commit b4ad394

Please sign in to comment.