Skip to content

Commit

Permalink
Add Truncation To Sentencepiece (openvinotoolkit#225)
Browse files Browse the repository at this point in the history
* Update Tests

* Move checks to helpers functions
* Add tests for chat templates
* Filter skipped tests when pass-rate calculated

* Not use regex to add prefix space in sentencpiece

* Automatically Trigger Handling Special Tokens With Re

* Add diff to test results
* Fix chatglm + re2 splitting

* Del Unused Variable

* Add Basic Truncation To Sentencepiece

* Update Sentencepiece Special Tokens Handling

* Fix Gemma SP Model

* Add Env Var To Regulate Tests Output
  • Loading branch information
apaniukov authored Aug 23, 2024
1 parent 016aa8e commit c808365
Show file tree
Hide file tree
Showing 9 changed files with 1,859 additions and 500 deletions.
68 changes: 46 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,18 +428,18 @@ This report is autogenerated and includes tokenizers and detokenizers tests. The
<tbody>
<tr>
<td >BPE</td>
<td >95.56</td>
<td >5928</td>
<td >95.57</td>
<td >5932</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >83.98</td>
<td >4762</td>
<td >86.56</td>
<td >5698</td>
</tr>
<tr>
<td >Tiktoken</td>
<td >96.82</td>
<td >346</td>
<td >97.17</td>
<td >494</td>
</tr>
<tr>
<td >WordPiece</td>
Expand Down Expand Up @@ -495,7 +495,7 @@ This report is autogenerated and includes tokenizers and detokenizers tests. The
<td >BPE</td>
<td >NousResearch/Meta-Llama-3-8B-Instruct</td>
<td >100.00</td>
<td >239</td>
<td >241</td>
</tr>
<tr>
<td >BPE</td>
Expand Down Expand Up @@ -531,7 +531,7 @@ This report is autogenerated and includes tokenizers and detokenizers tests. The
<td >BPE</td>
<td >deepseek-ai/deepseek-coder-6.7b-instruct</td>
<td >100.00</td>
<td >255</td>
<td >257</td>
</tr>
<tr>
<td >BPE</td>
Expand Down Expand Up @@ -608,7 +608,7 @@ This report is autogenerated and includes tokenizers and detokenizers tests. The
<tr>
<td >SentencePiece</td>
<td >NousResearch/Llama-2-13b-hf</td>
<td >97.49</td>
<td >94.98</td>
<td >239</td>
</tr>
<tr>
Expand All @@ -632,25 +632,25 @@ This report is autogenerated and includes tokenizers and detokenizers tests. The
<tr>
<td >SentencePiece</td>
<td >THUDM/chatglm3-6b</td>
<td >51.63</td>
<td >153</td>
<td >50.97</td>
<td >155</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >THUDM/chatglm3-6b_slow</td>
<td >50.34</td>
<td >149</td>
<td >49.67</td>
<td >151</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >camembert-base</td>
<td >50.63</td>
<td >52.30</td>
<td >239</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >camembert-base_slow</td>
<td >78.03</td>
<td >78.92</td>
<td >223</td>
</tr>
<tr>
Expand All @@ -668,15 +668,27 @@ This report is autogenerated and includes tokenizers and detokenizers tests. The
<tr>
<td >SentencePiece</td>
<td >facebook/musicgen-small</td>
<td >82.43</td>
<td >84.52</td>
<td >239</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >facebook/musicgen-small_slow</td>
<td >78.03</td>
<td >78.92</td>
<td >223</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >microsoft/Phi-3-mini-128k-instruct</td>
<td >99.17</td>
<td >241</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >microsoft/Phi-3-mini-128k-instruct_slow</td>
<td >99.11</td>
<td >225</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >microsoft/deberta-v3-base</td>
Expand All @@ -689,6 +701,18 @@ This report is autogenerated and includes tokenizers and detokenizers tests. The
<td >100.00</td>
<td >223</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >mlx-community/quantized-gemma-7b-it</td>
<td >96.68</td>
<td >241</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >mlx-community/quantized-gemma-7b-it_slow</td>
<td >98.22</td>
<td >225</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >rinna/bilingual-gpt-neox-4b</td>
Expand Down Expand Up @@ -716,13 +740,13 @@ This report is autogenerated and includes tokenizers and detokenizers tests. The
<tr>
<td >SentencePiece</td>
<td >xlm-roberta-base</td>
<td >93.31</td>
<td >96.23</td>
<td >239</td>
</tr>
<tr>
<td >SentencePiece</td>
<td >xlm-roberta-base_slow</td>
<td >96.41</td>
<td >97.76</td>
<td >223</td>
</tr>
<tr>
Expand All @@ -741,13 +765,13 @@ This report is autogenerated and includes tokenizers and detokenizers tests. The
<td >Tiktoken</td>
<td >Qwen/Qwen-14B-Chat</td>
<td >100.00</td>
<td >181</td>
<td >255</td>
</tr>
<tr>
<td >Tiktoken</td>
<td >THUDM/glm-4-9b</td>
<td >93.33</td>
<td >165</td>
<td >94.14</td>
<td >239</td>
</tr>
<tr>
<td >WordPiece</td>
Expand Down
2 changes: 1 addition & 1 deletion python/openvino_tokenizers/convert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def convert_tokenizer(
detokenizer_input_type: Type = Type.i64,
streaming_detokenizer: bool = False,
use_max_padding: bool = False,
handle_special_tokens_with_re: bool = False,
handle_special_tokens_with_re: Optional[bool] = None,
) -> Union[Model, Tuple[Model, Model]]:
ov_tokenizers = None

Expand Down
86 changes: 58 additions & 28 deletions python/openvino_tokenizers/hf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import json
import sys
import tempfile
from copy import deepcopy
from functools import partial
Expand Down Expand Up @@ -383,11 +384,13 @@ def decoding(
def parse_special_tokens(hf_tokenizer: PreTrainedTokenizerBase, only_special_tokens: bool = True) -> Dict[int, str]:
# the order matters
result = {}
result.update({
idx: added_token.content
for idx, added_token in getattr(hf_tokenizer, "added_tokens_decoder", {}).items()
if not only_special_tokens or added_token.special
})
result.update(
{
idx: added_token.content
for idx, added_token in getattr(hf_tokenizer, "added_tokens_decoder", {}).items()
if not only_special_tokens or added_token.special
}
)
if hasattr(hf_tokenizer, "tokenizer") and hasattr(hf_tokenizer.tokenizer, "index_special_tokens"):
result.update(hf_tokenizer.tokenizer.index_special_tokens)
if hasattr(hf_tokenizer, "special_tokens"):
Expand Down Expand Up @@ -476,13 +479,25 @@ def is_sentencepiece_model(hf_tokenizer: PreTrainedTokenizerBase) -> bool:
return False


def is_sentencepiece_bpe_model(hf_tokenizer: PreTrainedTokenizerBase) -> bool:
with tempfile.TemporaryDirectory() as tmp:
hf_tokenizer.save_pretrained(tmp)
vocab_file = Path(tmp) / hf_tokenizer.vocab_files_names["vocab_file"]
model_pb = import_protobuf()
model = model_pb.ModelProto()
with open(vocab_file, "rb") as model_file:
model.ParseFromString(model_file.read())
return model.trainer_spec.model_type == 2 # UNIGRAM=1 BPE=2 WORD=3 CHAR=4


def align_model_file(
model: "ModelProto", # noqa
model: "ModelProto", # noqa
hf_tokenizer: PreTrainedTokenizerBase,
added_tokens: Optional[Dict[int, str]] = None,
) -> None:
if added_tokens is None:
added_tokens = hf_tokenizer.added_tokens_decoder

def is_byte(token: str) -> bool:
return len(token) == 6 and token.startswith("<0x") and token.endswith(">")

Expand All @@ -494,14 +509,18 @@ def is_byte(token: str) -> bool:
return

scores = np.array([piece.score for piece in model.pieces])
score_delta = np.mean(scores[np.where(scores < 0)])
score_delta = np.abs(np.mean(np.diff(scores[np.where(scores < 0)])))

for idx in range(hf_tokenizer.vocab_size):
token = added_tokens.get(idx, sorted_vocab.get(idx))

not_used = token is None
token = f"<new_token_{idx}>" if not_used else token

# gemma-7b has "\t" instead of byte representation
if token == "\t" and model.pieces[idx].piece == "<0x09>":
token = "<0x09>"

if token in existing:
new_pieces.append(existing[token])
continue
Expand Down Expand Up @@ -584,23 +603,19 @@ def modify_sentencepiece_model(
elif not skip_special_tokens and new_piece.type == 3:
new_piece.type = 4 # change control type to userdef type

if hf_tokenizer.is_fast:
assert True

if to_add:
while len(model.pieces) + 1 <= idx:
# to place special token in particular idx we have to extend vocab first
missing_piece = deepcopy(new_piece)
missing_piece.piece = hf_tokenizer.decode(len(model.pieces), skip_special_tokens=False) or f"<empty_{len(model.pieces)}>"
missing_piece.piece = (
hf_tokenizer.decode(len(model.pieces), skip_special_tokens=False) or f"<empty_{len(model.pieces)}>"
)
missing_piece.type = 4
model.pieces.insert(idx, missing_piece)
bos_eos = ("<bos>", "<eos>", "<s>", "</s>")
if (
idx < len(model.pieces)
and (
(model.pieces[idx].type not in (2, 3) or model.pieces[idx].piece == token)
or (token in bos_eos and model.pieces[idx].piece in bos_eos)
)
if idx < len(model.pieces) and (
(model.pieces[idx].type not in (2, 3) or model.pieces[idx].piece == token)
or (token in bos_eos and model.pieces[idx].piece in bos_eos)
):
model.pieces.pop(idx)
model.pieces.insert(idx, new_piece)
Expand All @@ -617,10 +632,11 @@ def modify_sentencepiece_model(
unk_token = next(piece for piece in model.pieces if piece.type == 2)
model.trainer_spec.unk_surface = unk_token.piece

has_bytes = any(piece.type == 6 for piece in model.pieces)
if byte_fallback is not None:
model.trainer_spec.byte_fallback = byte_fallback
model.trainer_spec.byte_fallback = byte_fallback and has_bytes

if byte_fallback is False:
if byte_fallback is False and has_bytes:
for piece in model.pieces:
if piece.type == 6:
piece.type = 5 # change BYTE type to UNUSED
Expand All @@ -637,11 +653,14 @@ def convert_sentencepiece_model_tokenizer(
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = False,
add_prefix_space: Optional[bool] = None,
handle_special_tokens_with_re: bool = False,
handle_special_tokens_with_re: Optional[bool] = None,
) -> Union[Model, Tuple[Model, Model]]:
if not is_sentencepiece_model(hf_tokenizer):
raise OVTypeError("Cannot convert tokenizer of this type without `.model` file.")

if handle_special_tokens_with_re is None:
handle_special_tokens_with_re = is_sentencepiece_bpe_model(hf_tokenizer)

is_chatglm = getattr(hf_tokenizer, "name", None) == "GLMTokenizer"
add_bos_token = add_eos_token = None
if is_chatglm:
Expand Down Expand Up @@ -710,7 +729,7 @@ def convert_sentencepiece_model_tokenizer(
elif prepend_scheme == "never":
add_prefix_space = False
elif prepend_scheme == "first":
add_prefix_space = False
add_prefix_space = True

# metaspace can be emulated with sequence of normalizers
if add_prefix_space is None:
Expand All @@ -719,7 +738,7 @@ def convert_sentencepiece_model_tokenizer(
prepend_scheme = "never"

elif add_prefix_space is None and isinstance(hf_tokenizer, PreTrainedTokenizerFast):
add_prefix_space = not add_bos_token
add_prefix_space = True

add_tokens = parse_special_tokens(hf_tokenizer, only_special_tokens=False)

Expand All @@ -728,7 +747,7 @@ def convert_sentencepiece_model_tokenizer(
add_tokens=add_tokens,
hf_tokenizer=hf_tokenizer,
skip_special_tokens=False,
add_prefix_space=add_prefix_space and not handle_special_tokens_with_re,
add_prefix_space=add_prefix_space,
byte_fallback=byte_fallback,
)
sp_model = np.frombuffer(sp_model_string, dtype=np.uint8)
Expand All @@ -749,11 +768,6 @@ def convert_sentencepiece_model_tokenizer(
input_node.set_friendly_name("string_input")
next_node = input_node.outputs()

if prepend_scheme == "first" or (add_prefix_space and handle_special_tokens_with_re):
next_node = _get_factory().create("StringTensorUnpack", next_node).outputs()
next_node = RegexNormalizationStep.add_prefix_whitespace_to_not_whitespace_regex().get_ov_subgraph(next_node)
next_node = _get_factory().create("StringTensorPack", next_node).outputs()

do_left_padding = hf_tokenizer.padding_side == "left"

if handle_special_tokens_with_re:
Expand Down Expand Up @@ -818,6 +832,22 @@ def convert_sentencepiece_model_tokenizer(
"Reverse", [scattered_input_ids, make_constant_node(np.array([-1]))], {"mode": "index"}
)

if 0 < (max_length := getattr(hf_tokenizer, "model_max_length", -1)) < 2**17:
scattered_input_ids = opset.slice(
scattered_input_ids,
start=[-max_length] if do_left_padding else [0],
stop=[sys.maxsize] if do_left_padding else [max_length],
step=[1],
axes=[-1],
)
attention_mask = opset.slice(
attention_mask,
start=[-max_length] if do_left_padding else [0],
stop=[sys.maxsize] if do_left_padding else [max_length],
step=[1],
axes=[-1],
)

scattered_input_ids.output(0).tensor.add_names({TOKEN_IDS_INPUT_NAME})
outputs = scattered_input_ids.outputs()

Expand Down
4 changes: 4 additions & 0 deletions python/openvino_tokenizers/tokenizer_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def add_prefix_whitespace_regex(cls) -> "RegexNormalizationStep":
def add_prefix_whitespace_to_not_whitespace_regex(cls) -> "RegexNormalizationStep":
return cls(regex_search_pattern=r"^([^ ])", replace_term=r" \1")

@classmethod
def replace_spaces_metaspace(cls) -> "RegexNormalizationStep":
return cls(regex_search_pattern=r" ", replace_term=r"▁")

@classmethod
def prepend_regex(cls, string: str) -> "RegexNormalizationStep":
return cls(regex_search_pattern=r"(^)(.+)", replace_term=rf"{string}\2")
Expand Down
Loading

0 comments on commit c808365

Please sign in to comment.