Skip to content

Commit

Permalink
[Fix] Fix when missing both pad and eos token (#287)
Browse files Browse the repository at this point in the history
* fix when missing both pad and eos token

* update pad_token_id impl
  • Loading branch information
Leymore authored Aug 31, 2023
1 parent 166022f commit e810974
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 9 deletions.
37 changes: 31 additions & 6 deletions opencompass/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class HuggingFace(BaseModel):
prediction tokens before decoding. Defaults to False.
batch_padding (bool): If False, inference with be performed in for-loop
without batch padding.
pad_token_id (int): The id of the padding token. Defaults to None. Use
(#vocab + pad_token_id) if get negative value.
Note:
About ``extract_pred_after_decode``: Commonly, we should extract the
Expand All @@ -59,7 +61,8 @@ def __init__(self,
model_kwargs: dict = dict(device_map='auto'),
meta_template: Optional[Dict] = None,
extract_pred_after_decode: bool = False,
batch_padding: bool = False):
batch_padding: bool = False,
pad_token_id: Optional[int] = None):
super().__init__(path=path,
max_seq_len=max_seq_len,
tokenizer_only=tokenizer_only,
Expand All @@ -69,6 +72,7 @@ def __init__(self,
hf_cache_dir = os.getenv('HF_MODEL_HUB', None)
patch_hf_auto_model(hf_cache_dir)
self.logger = get_logger()
self.pad_token_id = pad_token_id
self._load_tokenizer(path=path,
tokenizer_path=tokenizer_path,
tokenizer_kwargs=tokenizer_kwargs)
Expand All @@ -84,10 +88,31 @@ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path if tokenizer_path else path, **tokenizer_kwargs)
if self.tokenizer.pad_token_id is None:
self.logger.warning('pad_token_id is not set for the tokenizer. '
'Using eos_token_id as pad_token_id.')
self.tokenizer.pad_token = self.tokenizer.eos_token

# A patch for some models without pad_token_id
if self.pad_token_id is not None:
if self.pad_token_id < 0:
self.pad_token_id += self.tokenizer.vocab_size
if self.tokenizer.pad_token_id is None:
self.logger.warning(
f'Using {self.pad_token_id} as pad_token_id')
elif self.tokenizer.pad_token_id != self.pad_token_id:
self.logger.warning(
f'pad_token_id is not consistent with the tokenizer. Using {self.pad_token_id} as pad_token_id' # noqa
)
self.tokenizer.pad_token_id = self.pad_token_id
elif self.tokenizer.pad_token_id is None:
self.logger.warning('pad_token_id is not set for the tokenizer.')
if self.tokenizer.eos_token is not None:
self.logger.warning('Using eos_token_id as pad_token_id.')
self.logger.warning(
f'{self.tokenizer.eos_token} la {self.tokenizer.eos_token is None}' # noqa
)
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
raise ValueError(
'pad_token_id is not set for this tokenizer. Try to set pad_token_id via passing `pad_token_id={PAD_TOKEN_ID}` in model_cfg. You may find pad_token_id in `generation.json`' # noqa
)

# A patch for llama when batch_padding = True
if 'decapoda-research/llama' in path or \
Expand Down Expand Up @@ -298,7 +323,7 @@ def _get_ppl(self,
"""

outputs, inputs = self.get_logits(inputs)
shift_logits = outputs[..., :-1, :].contiguous()
shift_logits = outputs[..., :-1, :].contiguous().float()

shift_labels = inputs['tokens']['input_ids'][..., 1:].contiguous()

Expand Down
2 changes: 1 addition & 1 deletion opencompass/models/intern_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_ppl(self,
"""
outputs, inputs = self.generator.get_logits(input_texts)

shift_logits = outputs[..., :-1, :].contiguous()
shift_logits = outputs[..., :-1, :].contiguous().float()
shift_labels = inputs['tokens'][..., 1:].contiguous()

loss_fct = torch.nn.CrossEntropyLoss(
Expand Down
2 changes: 1 addition & 1 deletion opencompass/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_ppl(self,
# forward
outputs = self.model.forward(tokens, 0)
# compute ppl
shift_logits = outputs[..., :-1, :].contiguous()
shift_logits = outputs[..., :-1, :].contiguous().float()
shift_labels = tokens[..., 1:].contiguous()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
Expand Down
2 changes: 1 addition & 1 deletion opencompass/openicl/icl_inferencer/icl_clp_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __get_cond_prob(self,
else:
outputs, _ = self.model.get_logits(input_texts)

shift_logits = outputs[..., :-1, :].contiguous()
shift_logits = outputs[..., :-1, :].contiguous().float()

shift_logits = F.log_softmax(shift_logits, dim=-1)
log_probs = []
Expand Down
1 change: 1 addition & 0 deletions opencompass/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def get_config_from_arg(args) -> Config:
max_out_len=args.max_out_len,
batch_padding=not args.no_batch_padding,
batch_size=args.batch_size,
pad_token_id=args.pad_token_id,
run_cfg=dict(num_gpus=args.num_gpus))
models.append(model)
return Config(dict(models=models, datasets=datasets),
Expand Down
1 change: 1 addition & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def parse_hf_args(hf_parser):
default=False)
hf_parser.add_argument('--batch-size', type=int)
hf_parser.add_argument('--num-gpus', type=int)
hf_parser.add_argument('--pad-token-id', type=int)


def main():
Expand Down

0 comments on commit e810974

Please sign in to comment.