Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaksymczuk committed Jan 23, 2025
1 parent 253da91 commit f4a12ef
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 32 deletions.
9 changes: 5 additions & 4 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import direct_register_custom_op, supports_custom_op
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, supports_custom_op

if TYPE_CHECKING:
from vllm.config import VllmConfig

if current_platform.is_hpu():
import habana_frameworks.torch as htorch


@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
Expand Down Expand Up @@ -524,7 +525,7 @@ def send_object(self, obj: Any, dst: int) -> None:
size_tensor = torch.tensor([object_tensor.numel()],
dtype=torch.long,
device="cpu")

# Send object size
htorch.core.mark_step()
torch.hpu.synchronize()
Expand Down Expand Up @@ -708,7 +709,7 @@ def send_tensor_dict(
if (all_gather_group is not None
and tensor.numel() % all_gather_size == 0):
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

if tensor.is_cpu:
# use metadata_group for CPU tensors
htorch.core.mark_step()
Expand Down Expand Up @@ -788,7 +789,7 @@ def recv_tensor_dict(
src=self.ranks[src],
group=group)
htorch.core.mark_step()

if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,10 +1132,11 @@ def __eq__(self, other: object):

def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"

def __hash__(self) -> int:
return hash(tuple(self.tensors.values()))


class PoolerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
Expand Down
47 changes: 26 additions & 21 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed import get_pp_group
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
Expand Down Expand Up @@ -179,6 +178,7 @@ def forward_hook(module, args, output):
else:
modify_model_layers(child_module, suffix_list, n, counter)


def get_path_to_rope(model: torch.nn.Module):
"""Dynamically get the path to the RotaryEmbedding layer in the model.
This function will recursively search through the module hierarchy to find
Expand All @@ -197,7 +197,7 @@ def find_rope_layer(parent, path):
for child_name, child_module in parent.named_children():
# If the current child is of type RotaryEmbedding,
# return the full path
if child_module.__class__.__name__.endswith("RotaryEmbedding"):
if child_module.__class__.__name__.endswith("RotaryEmbedding"):
return path + [child_name]
# Otherwise, recurse into this child to check its children
result = find_rope_layer(child_module, path + [child_name])
Expand All @@ -211,6 +211,7 @@ def find_rope_layer(parent, path):
# Return the result if found, otherwise None
return path_to_rope


class HpuModelAdapter:

def __init__(self, model, vllm_config, layer_names):
Expand Down Expand Up @@ -409,11 +410,9 @@ def forward(self, *args, **kwargs):
if not get_pp_group().is_last_rank:
pass
else:
hidden_states = hidden_states.view(
-1, hidden_states.shape[-1])
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(
0, selected_token_indices)

0, selected_token_indices)
return hidden_states

def compute_logits(self, *args, **kwargs):
Expand All @@ -424,7 +423,7 @@ def sample(self, *args, **kwargs):

def make_empty_intermediate_tensors(self, *args, **kwargs):
return self.model.make_empty_intermediate_tensors(*args, **kwargs)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_tokens(input_ids)

Expand Down Expand Up @@ -1493,9 +1492,13 @@ def create_dummy_seq_group_metadata(self,
def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
bind_kv_caches = [
[None] * num_layers
for _ in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
[[None] * num_layers for _ in range(self.parallel_config.pipeline_parallel_size)])
bind_kv_caches)
_, max_seq_len = self.bucketing_ctx.get_max_prompt_shape()
max_batch_size = min(self.max_num_seqs,
self.max_num_batched_tokens // max_seq_len)
Expand Down Expand Up @@ -1576,14 +1579,17 @@ def warmup_scenario(self,
self.vllm_config.scheduler_config.num_scheduler_steps == 1
if is_single_step:
intermediate_tensors = None

if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
context_size=seq_len if is_prompt else 1,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(inputs, kv_caches, intermediate_tensors=intermediate_tensors, warmup_mode=True)
intermediate_tensors = \
self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
context_size=seq_len if is_prompt else 1,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(inputs,
kv_caches,
intermediate_tensors=intermediate_tensors,
warmup_mode=True)
else: # decode with multi-step
inputs = dataclasses.replace(inputs,
is_first_multi_step=True,
Expand Down Expand Up @@ -2143,7 +2149,7 @@ def execute_model(
assert model_input.lora_ids is not None
lora_mask, lora_logits_mask = self.create_lora_mask(
input_tokens, model_input.lora_ids,
attn_metadata.is_prompt)
attn_metadata.is_prompt)
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
Expand Down Expand Up @@ -2212,17 +2218,16 @@ def try_revert_dummy_output_tokens():
LoraMask.setLoraMask(
lora_logits_mask.index_select(
0, sampling_metadata.selected_token_indices))

if not get_pp_group().is_last_rank:
return hidden_states

# Compute the logits.
with self.profiler.record_event(
'internal',
('compute_logits_'
f'{"prompt" if is_prompt else "decode"}_bs'
f'{batch_size}_'
f'seq{seq_len}')):
f'{"prompt" if is_prompt else "decode"}_bs'
f'{batch_size}_'
f'seq{seq_len}')):
if num_steps == 1:
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states,
Expand Down
10 changes: 5 additions & 5 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, get_pp_group)
from vllm.distributed import (ensure_model_parallel_initialized, get_pp_group,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
Expand Down Expand Up @@ -513,9 +513,9 @@ def init_worker_distributed_environment(
parallel_config.pipeline_parallel_size)

if parallel_config.pipeline_parallel_size > 1:
# torch-ccl xpu need a collective API warm up
# before calling send/recv API
get_pp_group().all_reduce(torch.zeros(1).to('hpu'))
# torch-ccl xpu need a collective API warm up
# before calling send/recv API
get_pp_group().all_reduce(torch.zeros(1).to('hpu'))
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def execute_model(
start_time = time.perf_counter()
inputs = self.prepare_input(execute_model_req)
if inputs is None:
return None
return None
model_input, worker_input, kwargs = inputs
num_steps = worker_input.num_steps

Expand Down

0 comments on commit f4a12ef

Please sign in to comment.