Skip to content

Commit

Permalink
gptj: fix missing token_idx (huggingface#1234)
Browse files Browse the repository at this point in the history
  • Loading branch information
envsp authored and regisss committed Aug 13, 2024
1 parent ff4a6b8 commit 95c8104
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions optimum/habana/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ def forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
sin: Optional[torch.Tensor] = None,
cos: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
Expand Down Expand Up @@ -221,19 +221,22 @@ def forward(
key = torch.cat([past_key, key], dim=-2)
value = torch.cat([past_value, value], dim=-2)

if use_cache is True and token_idx is not None:
if use_cache is True:
if reuse_cache:
key = self.k_cache(key, 2, token_idx)
value = self.v_cache(value, 2, token_idx)
present = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
if layer_past is None:
past_key = torch.zeros(key.shape, dtype=self.k_proj.weight.dtype, device=key.device)
past_value = torch.zeros(key.shape, dtype=self.k_proj.weight.dtype, device=key.device)
layer_past = (past_key, past_value)
key = self.k_cache.update(layer_past[0], key, 2, token_idx, self.inp_seq_len)
value = self.v_cache.update(layer_past[1], value, 2, token_idx, self.inp_seq_len)
present = layer_past
if token_idx is not None:
if layer_past is None:
past_key = torch.zeros(key.shape, dtype=self.k_proj.weight.dtype, device=key.device)
past_value = torch.zeros(key.shape, dtype=self.k_proj.weight.dtype, device=key.device)
layer_past = (past_key, past_value)
key = self.k_cache.update(layer_past[0], key, 2, token_idx, self.inp_seq_len)
value = self.v_cache.update(layer_past[1], value, 2, token_idx, self.inp_seq_len)
present = layer_past
else:
present = (key.to(hidden_states.dtype), value)

if cache_idx is not None and q_len == 1:
key = key[:, :, :cache_idx, :]
Expand Down Expand Up @@ -288,10 +291,10 @@ def forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
sin: Optional[torch.Tensor] = None,
cos: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
"""
Copied from GPTJBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py
Expand Down

0 comments on commit 95c8104

Please sign in to comment.