Skip to content

Commit

Permalink
Code cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaksymczuk committed Jan 23, 2025
1 parent 23b1624 commit 253da91
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 32 deletions.
4 changes: 0 additions & 4 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,6 @@ def main(args: argparse.Namespace):
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.")

parser.add_argument("--pipeline_parallel_size",
type=int,
default=1,
help="Piepeline parallel size.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
Expand Down
4 changes: 0 additions & 4 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2253,8 +2253,6 @@ def bind_kv_cache(
if ctx[layer_name].attn_type in (AttentionType.DECODER,
AttentionType.ENCODER_DECODER)
]
for layer_need_kv_cache_n in layer_need_kv_cache:
print(f'layer_need_kv_cache_n = {layer_need_kv_cache_n}')
layer_index_sorted = sorted(
set(
extract_layer_index(layer_name)
Expand All @@ -2263,8 +2261,6 @@ def bind_kv_cache(
kv_cache_idx = layer_index_sorted.index(
extract_layer_index(layer_name))
forward_ctx = ctx[layer_name]
print(f'frw.kv_cache len = {len(forward_ctx.kv_cache)}, kv_cache len = {len(kv_cache)}')
#print(f'frw.kv_cache = {forward_ctx.kv_cache}, kv_cache = {kv_cache}')
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
Expand Down
31 changes: 11 additions & 20 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def find_rope_layer(parent, path):
for child_name, child_module in parent.named_children():
# If the current child is of type RotaryEmbedding,
# return the full path
if child_module.__class__.__name__.endswith("RotaryEmbedding"):
if child_module.__class__.__name__.endswith("RotaryEmbedding"):
return path + [child_name]
# Otherwise, recurse into this child to check its children
result = find_rope_layer(child_module, path + [child_name])
Expand Down Expand Up @@ -367,10 +367,8 @@ def _prepare_cos_sin(self, positions):
current_module = self.model # Start from the top level of the model

for layer in self.layer_names:
# If the current layer is a string, it's a name (for named_children)
if layer.isdigit(
): # Check if the layer is an index (numeric as string)
layer = int(layer) # Convert to integer if it's an index
if layer.isdigit(): # Check if the layer is an index
layer = int(layer)

# Check if the current layer is a name in a module
if isinstance(
Expand All @@ -383,18 +381,14 @@ def _prepare_cos_sin(self, positions):

# At the end, we should be at the RotaryEmbedding layer.
if hasattr(current_module, 'prepare_cos_sin'):
current_module.prepare_cos_sin(positions)
current_module.prepare_cos_sin(
positions, recompute_cos_sin=self.recompute_cos_sin)
else:
raise AttributeError(
"The module at the end of the path does not have \
a 'prepare_cos_sin' method.")

def forward(self, *args, **kwargs):
'''
if not get_pp_group().is_first_rank:
for key, tensor in kwargs['intermediate_tensors'].tensors.items():
kwargs['intermediate_tensors'][key] = torch.empty_like(tensor).copy_(tensor)
'''
kwargs = kwargs.copy()
selected_token_indices = kwargs.pop('selected_token_indices')
if 'warmup_mode' in kwargs:
Expand All @@ -415,8 +409,10 @@ def forward(self, *args, **kwargs):
if not get_pp_group().is_last_rank:
pass
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(0, selected_token_indices)
hidden_states = hidden_states.view(
-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(
0, selected_token_indices)

return hidden_states

Expand Down Expand Up @@ -1496,16 +1492,13 @@ def create_dummy_seq_group_metadata(self,

def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
#print(f'num_layers = {num_layers} \n self.parallel_config = {self.parallel_config}')
kv_caches = [None] * num_layers
#print(f'kv_caches = {[kv_caches]} \n kv_caches len = {len([kv_caches])}')
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
[[None] * num_layers for _ in range(self.parallel_config.pipeline_parallel_size)])

Check failure on line 1498 in vllm/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/worker/hpu_model_runner.py:1498:81: E501 Line too long (94 > 80)
_, max_seq_len = self.bucketing_ctx.get_max_prompt_shape()
max_batch_size = min(self.max_num_seqs,
self.max_num_batched_tokens // max_seq_len)

self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
return
Expand Down Expand Up @@ -2150,8 +2143,7 @@ def execute_model(
assert model_input.lora_ids is not None
lora_mask, lora_logits_mask = self.create_lora_mask(
input_tokens, model_input.lora_ids,
attn_metadata.is_prompt)

attn_metadata.is_prompt)
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
Expand Down Expand Up @@ -2211,7 +2203,6 @@ def try_revert_dummy_output_tokens():
self.trim_attn_metadata(
broadcast_data["attn_metadata"])
})

with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
Expand All @@ -2235,7 +2226,7 @@ def try_revert_dummy_output_tokens():
if num_steps == 1:
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states,
sampling_metadata)
sampling_metadata)
htorch.core.mark_step()
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _warm_up_model(self) -> None:
set_random_seed(self.model_config.seed)

@property
def do_metadata_broadcast(self) -> bool:
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1

@property
Expand Down
4 changes: 1 addition & 3 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,16 +376,14 @@ def execute_model(
start_time = time.perf_counter()
inputs = self.prepare_input(execute_model_req)
if inputs is None:
return None

return None
model_input, worker_input, kwargs = inputs
num_steps = worker_input.num_steps

self.execute_worker(worker_input)

if worker_input.num_seq_groups == 0:
return []

intermediate_tensors = None
orig_model_execute_time = 0.0
if not get_pp_group().is_first_rank:
Expand Down

0 comments on commit 253da91

Please sign in to comment.