Skip to content

Commit

Permalink
Revert mark_step in mixtral model from PR #1260 (#1273)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeonsily authored and ssarkar2 committed Aug 23, 2024
1 parent 1548fe3 commit 144569c
Showing 1 changed file with 0 additions and 13 deletions.
13 changes: 0 additions & 13 deletions optimum/habana/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,6 @@ def forward(
reuse_cache: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -481,10 +480,7 @@ def forward(
- add new args reuse_cache
- add new args flash_attention_recompute
- add new args cache_idx
- add new args lazy_mode
"""
if lazy_mode:
htcore.mark_step()
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand All @@ -504,16 +500,12 @@ def forward(
cache_idx=cache_idx,
)
hidden_states = residual + hidden_states
if lazy_mode:
htcore.mark_step()

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
if lazy_mode:
htcore.mark_step()

outputs = (hidden_states,)

Expand Down Expand Up @@ -554,7 +546,6 @@ def forward(
reuse_cache: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, MoeModelOutputWithPast]:
"""
Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069
Expand Down Expand Up @@ -684,7 +675,6 @@ def forward(
reuse_cache=reuse_cache,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -759,7 +749,6 @@ def forward(
reuse_cache: Optional[bool] = None,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
Expand Down Expand Up @@ -788,7 +777,6 @@ def forward(
reuse_cache=reuse_cache,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -893,7 +881,6 @@ def prepare_inputs_for_generation(
"reuse_cache": reuse_cache,
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
}
)
return model_inputs

0 comments on commit 144569c

Please sign in to comment.