Skip to content

Commit

Permalink
Enable async llm engine for qwen. (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 authored Dec 17, 2024
1 parent ab65fcb commit fd55040
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
6 changes: 5 additions & 1 deletion chatlearn/models/vllm/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
import os
from .. import is_vllm_v2


if is_vllm_v2():
if importlib.util.find_spec("vllm"):
from . import ray_gpu_executor
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
from chatlearn.models.vllm.hooks import input_preprocess
else:
if importlib.util.find_spec("vllm"):
import vllm
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion # pylint: disable=ungrouped-imports
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
from chatlearn.models.vllm.hooks import sampler
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
Expand Down
52 changes: 49 additions & 3 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.utils import FlexibleArgumentParser

from chatlearn.utils.global_vars import set_vllm_actors
from chatlearn.utils.vllm_import_helper import TextTokensPrompt
from .torch_module import TorchModule


Expand All @@ -54,6 +55,21 @@ def setup(self):
tokenizer.tokenizer = tokenizer
self.tokenizer = tokenizer

def _init_args(self, args):
# scheduler config
args.max_num_seqs = self.module_args.generation_batch_size
args.max_num_batched_tokens = self.model_args.get("max_num_batched_tokens")
args.num_scheduler_steps = self.model_args.get("num_scheduler_steps", 1)

# model config
args.max_seq_len = self.model_args.get("seq_length")

# logger config
args.disable_log_requests = True

# engine config
args.enforce_eager = self.model_args.get("enforce_eager", False)

def setup_vllm(self, workers):
# setup vllm engine in rank 0
os.environ['VLLM_HOST_IP'] = self.get_address()
Expand All @@ -73,6 +89,7 @@ def setup_vllm(self, workers):
"--disable_custom_all_reduce"]
sys.argv = vllm_sys_argv
args = parser.parse_args()
self._init_args(args)
engine_args = AsyncEngineArgs.from_cli_args(args)
self.engine = self.from_engine_args(engine_args)

Expand Down Expand Up @@ -149,25 +166,54 @@ def _get_sampling_params(self, is_eval):
sampling_params.use_beam_search = self.model_args.get("use_beam_search")
return sampling_params

def convert_v1_inputs(self, prompts, prompt_token_ids):
num_requests = len(prompts)
assert num_requests == len(prompt_token_ids), \
("The lengths of prompts and prompt_token_ids must be the same.")

inputs = []
for i in range(num_requests):
if prompts[i] is None:
assert isinstance(prompt_token_ids[i], List[int]), \
f"Expect prompt_token_ids[{i}] is List[int] when prompt is None, while {prompt_token_ids[i]}."
if prompt_token_ids[i] is None:
assert isinstance(prompts[i], str), \
f"Expect prompts[{i}] is a string when prompt_token_ids is None, while {prompts[i]}."
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
inputs.append(item)

return inputs

async def generate_vllm(self, query, is_eval):
prompts = query['prompt']
prompt_key = self.model_args.get("vllm_prompt_key", "prompt")
input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids")

prompts = query[prompt_key]
prompts_token_ids = query[input_ids_key]
seq_len = self.model_args.get("seq_length")
final_outputs = []
tasks = []
for i, prompt in enumerate(prompts):
request_id = i
prompt_token_ids = prompts_token_ids[i]
if 'sampling_param' in query:
sampling_param = query['sampling_param'][i]
else:
sampling_param = self._get_sampling_params(is_eval)
if not self.model_args.get("new_token_limit", False):
prompt_token_ids = query['input_ids'][i]
max_tokens = seq_len - len(prompt_token_ids)
else:
max_tokens = self.model_args.get("max_new_tokens")
assert max_tokens < seq_len, "max_new_tokens must less than seq length."
sampling_param.max_tokens = max_tokens
task = asyncio.create_task(self.generate_one_sample(prompt, sampling_param, request_id))
inputs = self.convert_v1_inputs(
prompts=[prompt],
prompt_token_ids=[prompt_token_ids],
)[0]

task = asyncio.create_task(self.generate_one_sample(inputs, sampling_param, request_id))
tasks.append(task)
outputs = await asyncio.gather(*tasks)
final_outputs = sorted(outputs, key=lambda x: int(x.request_id))
Expand Down

0 comments on commit fd55040

Please sign in to comment.