From c9edc3061b3e9bf0c0360d6e85a712dbd3a2ada2 Mon Sep 17 00:00:00 2001 From: "Benjamin R. Laney" Date: Fri, 8 Mar 2024 16:52:53 -0700 Subject: [PATCH] Update support for encodings hashes when hosting cache files. --- tiktoken/load.py | 31 ++++++++++++------ tiktoken_ext/openai_public.py | 62 ++++++++++++++++++++++------------- 2 files changed, 61 insertions(+), 32 deletions(-) diff --git a/tiktoken/load.py b/tiktoken/load.py index cc0a6a6d..3c6d295b 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -32,7 +32,11 @@ def check_hash(data: bytes, expected_hash: str) -> bool: return actual_hash == expected_hash -def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> bytes: +def read_file_cached( + blobpath: str, + expected_hash: Optional[str] = None, + is_self_hosting: Optional[bool] = False +) -> bytes: user_specified_cache = True if "TIKTOKEN_CACHE_DIR" in os.environ: cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] @@ -52,9 +56,20 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> byte if os.path.exists(cache_path): with open(cache_path, "rb") as f: data = f.read() - if expected_hash is None or check_hash(data, expected_hash): + if expected_hash is None: return data + if check_hash(data, expected_hash): + return data + + if is_self_hosting: + raise ValueError( + f"Hash mismatch for data from {blobpath} (expected {expected_hash}). " + f"This may indicate change in the `tiktoken` encodings for this version. " + f"Please update the hosted encodings or remove/unset the `ENCODINGS_HOST` " + "to attempt to refresh the cache from the central host (`openaipublic`)." + ) + # the cached file does not match the hash, remove it and re-fetch try: os.remove(cache_path) @@ -83,10 +98,8 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> byte def data_gym_to_mergeable_bpe_ranks( - vocab_bpe_file: str, - encoder_json_file: str, - vocab_bpe_hash: Optional[str] = None, - encoder_json_hash: Optional[str] = None, + vocab_bpe_contents: str, + encoder_json_contents: str, ) -> dict[bytes, int]: # NB: do not add caching to this function rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] @@ -101,7 +114,6 @@ def data_gym_to_mergeable_bpe_ranks( assert len(rank_to_intbyte) == 2**8 # vocab_bpe contains the merges along with associated ranks - vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode() bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]] def decode_data_gym(value: str) -> bytes: @@ -118,7 +130,7 @@ def decode_data_gym(value: str) -> bytes: # check that the encoder file matches the merges file # this sanity check is important since tiktoken assumes that ranks are ordered the same # as merge priority - encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash)) + encoder_json = json.loads(encoder_json_contents) encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()} # drop these two special tokens if present, since they're not mergeable bpe tokens encoder_json_loaded.pop(b"<|endoftext|>", None) @@ -141,10 +153,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No def load_tiktoken_bpe( - tiktoken_bpe_file: str, expected_hash: Optional[str] = None + contents:bytes ) -> dict[bytes, int]: # NB: do not add caching to this function - contents = read_file_cached(tiktoken_bpe_file, expected_hash) return { base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py index a7ad22ab..9f557e0d 100644 --- a/tiktoken_ext/openai_public.py +++ b/tiktoken_ext/openai_public.py @@ -1,5 +1,5 @@ import os -from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe +from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe, read_file_cached ENDOFTEXT = "<|endoftext|>" FIM_PREFIX = "<|fim_prefix|>" @@ -7,14 +7,40 @@ FIM_SUFFIX = "<|fim_suffix|>" ENDOFPROMPT = "<|endofprompt|>" -ENCODINGS_HOST = os.getenv("ENCODINGS_HOST", "https://openaipublic.blob.core.windows.net") +ENCODINGS_HOST = os.getenv("ENCODINGS_HOST", None) + +if "ENCODINGS_HOST" in os.environ: + ENCODINGS_HOST = os.environ["ENCODINGS_HOST"] + IS_HOSTING_ENCODINGS = True +else: + ENCODINGS_HOST = "https://openaipublic.blob.core.windows.net" + IS_HOSTING_ENCODINGS = False + +VOCAB_BPE_FILE = f"{ENCODINGS_HOST}/gpt-2/encodings/main/vocab.bpe" +VOCAB_BPE_HASH = "1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5" +ENCODER_JSON_FILE = f"{ENCODINGS_HOST}/gpt-2/encodings/main/encoder.json" +ENCODER_JSON_HASH = "196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783" +R50K_BASE_FILE = f"{ENCODINGS_HOST}/encodings/r50k_base.tiktoken" +R50K_BASE_HASH = "306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930" +P50K_BASE_FILE = f"{ENCODINGS_HOST}/encodings/p50k_base.tiktoken" +P50K_BASE_HASH = "94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069" +CL100K_BASE_FILE = f"{ENCODINGS_HOST}/encodings/cl100k_base.tiktoken" +CL100K_BASE_HASH = "223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7" def gpt2(): + vocab_bpe_contents = read_file_cached( + VOCAB_BPE_FILE, + VOCAB_BPE_HASH, + IS_HOSTING_ENCODINGS + ).decode() + encoder_json_contents = read_file_cached( + ENCODER_JSON_FILE, + ENCODER_JSON_HASH, + IS_HOSTING_ENCODINGS + ) mergeable_ranks = data_gym_to_mergeable_bpe_ranks( - vocab_bpe_file=f"{ENCODINGS_HOST}/gpt-2/encodings/main/vocab.bpe", - encoder_json_file=f"{ENCODINGS_HOST}/gpt-2/encodings/main/encoder.json", - vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5", - encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783", + vocab_bpe_contents= vocab_bpe_contents, + encoder_json_contents=encoder_json_contents ) return { "name": "gpt2", @@ -29,10 +55,8 @@ def gpt2(): def r50k_base(): - mergeable_ranks = load_tiktoken_bpe( - f"{ENCODINGS_HOST}/encodings/r50k_base.tiktoken", - expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930", - ) + contents = read_file_cached(R50K_BASE_FILE, R50K_BASE_HASH, IS_HOSTING_ENCODINGS) + mergeable_ranks = load_tiktoken_bpe(contents) return { "name": "r50k_base", "explicit_n_vocab": 50257, @@ -43,10 +67,8 @@ def r50k_base(): def p50k_base(): - mergeable_ranks = load_tiktoken_bpe( - f"{ENCODINGS_HOST}/encodings/p50k_base.tiktoken", - expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", - ) + contents = read_file_cached(P50K_BASE_FILE, P50K_BASE_HASH, IS_HOSTING_ENCODINGS) + mergeable_ranks = load_tiktoken_bpe(contents) return { "name": "p50k_base", "explicit_n_vocab": 50281, @@ -57,10 +79,8 @@ def p50k_base(): def p50k_edit(): - mergeable_ranks = load_tiktoken_bpe( - f"{ENCODINGS_HOST}/encodings/p50k_base.tiktoken", - expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", - ) + contents = read_file_cached(P50K_BASE_FILE, P50K_BASE_HASH, IS_HOSTING_ENCODINGS) + mergeable_ranks = load_tiktoken_bpe(contents) special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} return { "name": "p50k_edit", @@ -71,10 +91,8 @@ def p50k_edit(): def cl100k_base(): - mergeable_ranks = load_tiktoken_bpe( - f"{ENCODINGS_HOST}/encodings/cl100k_base.tiktoken", - expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7", - ) + contents = read_file_cached(CL100K_BASE_FILE, CL100K_BASE_HASH, IS_HOSTING_ENCODINGS) + mergeable_ranks = load_tiktoken_bpe(contents) special_tokens = { ENDOFTEXT: 100257, FIM_PREFIX: 100258,