Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 committed Dec 26, 2024
1 parent 68e2218 commit 9b2e36d
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 48 deletions.
2 changes: 1 addition & 1 deletion chatlearn/models/vllm/hooks/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Dict, Optional

# pylint: disable=unused-import,wildcard-import,unused-argument
# pylint: disable=unused-import,wildcard-import,unused-argument,not-callable
from vllm.config import EngineConfig
from vllm.engine import async_llm_engine
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down
4 changes: 1 addition & 3 deletions chatlearn/models/vllm/hooks/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Dict, List, Optional

from vllm import envs
from vllm.distributed import parallel_state
from vllm.executor.ray_gpu_executor import RayGPUExecutor
from vllm.executor.ray_utils import RayWorkerWrapper, ray

Expand Down Expand Up @@ -182,9 +181,8 @@ def sort_by_driver_then_worker_ip(worker):
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

self._run_workers("init_device")

self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
Expand Down
10 changes: 0 additions & 10 deletions chatlearn/models/vllm/hooks/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,10 @@
# ==============================================================================
"""Hooks of vllm-0.6.3 worker_base to update execute_method."""


import inspect

# pylint: disable=unused-import,wildcard-import
from vllm.distributed import parallel_state
from vllm.worker import worker
from vllm.worker import worker_base
from vllm.worker.worker_base import logger

def get_tp_and_pp_rank(self):
return parallel_state.get_tensor_model_parallel_rank(), \
parallel_state.get_pp_group().rank_in_group

worker.Worker.get_tp_and_pp_rank = get_tp_and_pp_rank

def execute_method(self, method, *args, **kwargs):
try:
Expand Down
11 changes: 3 additions & 8 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@
from chatlearn.utils.vllm_import_helper import parallel_state
from chatlearn.utils.vllm_import_helper import get_pipeline_model_parallel_rank
from chatlearn.utils.vllm_import_helper import TextTokensPrompt
from .torch_module import TorchModule
from vllm.worker.worker import Worker
from chatlearn.utils.vllm_utils import initialize_vllm
from .torch_module import TorchModule


class VLLMModuleV2(TorchModule, RayWorkerWrapper):
Expand All @@ -51,8 +50,6 @@ def __init__(self, *args, **kwargs):
os.environ['VLLM_HOST_IP'] = self.get_address()

self.tokenizer = None
self._tp_rank = None
self._pp_rank = None
self._model = None

def add_extra_args(self, parser):
Expand Down Expand Up @@ -191,7 +188,6 @@ def generate_vllm(self, query, is_eval):
prompts = query[prompt_key]
prompts_token_ids = query[input_ids_key]
seq_len = self.model_args.get("seq_length")
final_outputs = []
parsed_prompts = []
sampling_params = []
for i, prompt in enumerate(prompts):
Expand Down Expand Up @@ -219,8 +215,7 @@ def generate_vllm(self, query, is_eval):
sampling_params,
use_tqdm=True,
)
final_outputs = sorted(outputs, key=lambda x: int(x.request_id))
return final_outputs
return outputs

def is_last_rank(self):
return True
Expand Down Expand Up @@ -256,7 +251,7 @@ def data_parallel_rank(self):
def model(self):
if self._model is None:
assert self.worker is not None, \
f"please set env variables `VLLM_USE_RAY_SPMD_WORKER` and `VLLM_USE_RAY_COMPILED_DAG` first."
"please set env variables `VLLM_USE_RAY_SPMD_WORKER=1` and `VLLM_USE_RAY_COMPILED_DAG=1` first."
self._model = self.worker.model_runner.model
return self._model

Expand Down
41 changes: 15 additions & 26 deletions chatlearn/synchronizer/parameter_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,6 @@ def sync_broadcast_two_stage(self, actors, group_name, requires_grad=None, stage
def sync_broadcast(self, actors, group_name, requires_grad=None, filter_fn=None, param_group="default"):
send_actor = actors[0]
for recv_actor in actors[1:]:
is_the_same_gpu = self.is_same_gpu(send_actor, recv_actor)
self.set_sync_param_names(send_actor, recv_actor, requires_grad, filter_fn, param_group)
pipe_stage = self.get_actor_pipe_rank(send_actor)
refs = []
Expand Down Expand Up @@ -703,8 +702,7 @@ def inner_func():

def get_actor_tp_rank(self, actor):
def inner_func():
rank = future.get(actor.tensor_parallel_rank.remote())
return rank
return future.get(actor.tensor_parallel_rank.remote())
return utils.get_or_cache(self._actor2tp, actor, inner_func)

def get_actor_ep_rank(self, actor):
Expand Down Expand Up @@ -1053,35 +1051,26 @@ def _multi_thread_sync_for_tp_num_mapping_eq_1(
sorted_send_actors = self.sort_send_actors(actor_mappings, send_actors)
max_workers = self._calculate_max_workers(sorted_send_actors, actor_mappings)

# with ThreadPoolExecutor(max_workers=max_workers) as executor:
# futures = []
# for send_actor in sorted_send_actors:
# recv_actors = actor_mappings[send_actor]
# if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST:
# actor_groups, finalized_group_name = self.create_broadcast_group(send_actor, recv_actors, param_group=param_group)
# futures.append(executor.submit(
# self.sync_broadcast, actor_groups, finalized_group_name, requires_grad, filter_fn=filter_fn, param_group=param_group
# ))
# else:
# for recv_actor in recv_actors:
# futures.append(executor.submit(
# self.sync_send_recv, send_actor, recv_actor, requires_grad, filter_fn=filter_fn, param_group=param_group
# ))
# for _future in concurrent.futures.as_completed(futures):
# try:
# _future.result()œ
# except Exception as e:
# raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from
# concurrent.futures.wait(futures)
if True:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for send_actor in sorted_send_actors:
recv_actors = actor_mappings[send_actor]
if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST:
actor_groups, finalized_group_name = self.create_broadcast_group(send_actor, recv_actors, param_group=param_group)
self.sync_broadcast(actor_groups, finalized_group_name, requires_grad, filter_fn=filter_fn, param_group=param_group)
futures.append(executor.submit(
self.sync_broadcast, actor_groups, finalized_group_name, requires_grad, filter_fn=filter_fn, param_group=param_group
))
else:
for recv_actor in recv_actors:
self.sync_send_recv(send_actor, recv_actor, requires_grad, filter_fn=filter_fn, param_group=param_group)
futures.append(executor.submit(
self.sync_send_recv, send_actor, recv_actor, requires_grad, filter_fn=filter_fn, param_group=param_group
))
for _future in concurrent.futures.as_completed(futures):
try:
_future.result()
except Exception as e:
raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from
concurrent.futures.wait(futures)

def _single_thread_sync(self, actor_mappings_list:List, requires_grad=None, filter_fn=None, param_group="default"):
assert len(actor_mappings_list) == 1
Expand Down

0 comments on commit 9b2e36d

Please sign in to comment.