From 4fb65dcf73d8f263115d0659e8bd3437dfbaad6a Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 8 Nov 2024 14:54:15 -0800 Subject: [PATCH] Reenable OpenAI Tokenizer (#3062) * k * clean up test embeddings * nit * minor update to ensure consistency * minor organizational update * minor updates --------- Co-authored-by: pablodanswer --- .../natural_language_processing/utils.py | 76 +++++++++++++------ .../modals/ChangeCredentialsModal.tsx | 18 ++--- .../modals/ProviderCreationModal.tsx | 23 ++++-- web/src/app/admin/embeddings/pages/utils.ts | 34 +++++++++ 4 files changed, 109 insertions(+), 42 deletions(-) diff --git a/backend/danswer/natural_language_processing/utils.py b/backend/danswer/natural_language_processing/utils.py index d2b9a7d7f1e..a8250570e84 100644 --- a/backend/danswer/natural_language_processing/utils.py +++ b/backend/danswer/natural_language_processing/utils.py @@ -35,23 +35,31 @@ def decode(self, tokens: list[int]) -> str: class TiktokenTokenizer(BaseTokenizer): _instances: dict[str, "TiktokenTokenizer"] = {} - def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer": - if encoding_name not in cls._instances: - cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls) - return cls._instances[encoding_name] + def __new__(cls, model_name: str) -> "TiktokenTokenizer": + if model_name not in cls._instances: + cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls) + return cls._instances[model_name] - def __init__(self, encoding_name: str = "cl100k_base"): + def __init__(self, model_name: str): if not hasattr(self, "encoder"): import tiktoken - self.encoder = tiktoken.get_encoding(encoding_name) + self.encoder = tiktoken.encoding_for_model(model_name) def encode(self, string: str) -> list[int]: - # this returns no special tokens + # this ignores special tokens that the model is trained on, see encode_ordinary for details return self.encoder.encode_ordinary(string) def tokenize(self, string: str) -> list[str]: - return [self.encoder.decode([token]) for token in self.encode(string)] + encoded = self.encode(string) + decoded = [self.encoder.decode([token]) for token in encoded] + + if len(decoded) != len(encoded): + logger.warning( + f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}" + ) + + return decoded def decode(self, tokens: list[int]) -> str: return self.encoder.decode(tokens) @@ -74,22 +82,35 @@ def decode(self, tokens: list[int]) -> str: return self.encoder.decode(tokens) -_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {} +_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {} -def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer: +def _check_tokenizer_cache( + model_provider: EmbeddingProvider | None, model_name: str | None +) -> BaseTokenizer: global _TOKENIZER_CACHE - if tokenizer_name not in _TOKENIZER_CACHE: - if tokenizer_name == "openai": - _TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base") - return _TOKENIZER_CACHE[tokenizer_name] + id_tuple = (model_provider, model_name) + + if id_tuple not in _TOKENIZER_CACHE: + if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]: + if model_name is None: + raise ValueError( + "model_name is required for OPENAI and AZURE embeddings" + ) + + _TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name) + return _TOKENIZER_CACHE[id_tuple] + try: - logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}") - _TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name) + if model_name is None: + model_name = DOCUMENT_ENCODER_MODEL + + logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}") + _TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name) except Exception as primary_error: logger.error( - f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}" + f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}" ) logger.warning( f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}" @@ -98,7 +119,7 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer: try: # Cache this tokenizer name to the default so we don't have to try to load it again # and fail again - _TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer( + _TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer( DOCUMENT_ENCODER_MODEL ) except Exception as fallback_error: @@ -106,10 +127,10 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer: f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}" ) raise ValueError( - f"Failed to initialize tokenizer for {tokenizer_name} and fallback model" + f"Failed to initialize tokenizer for {model_name} and fallback model" ) from fallback_error - return _TOKENIZER_CACHE[tokenizer_name] + return _TOKENIZER_CACHE[id_tuple] _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) @@ -118,11 +139,16 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer: def get_tokenizer( model_name: str | None, provider_type: EmbeddingProvider | str | None ) -> BaseTokenizer: - # Currently all of the viable models use the same sentencepiece tokenizer - # OpenAI uses a different one but currently it's not supported due to quality issues - # the inconsistent chunking makes using the sentencepiece tokenizer default better for now - # LLM tokenizers are specified by strings - global _DEFAULT_TOKENIZER + if provider_type is not None: + if isinstance(provider_type, str): + try: + provider_type = EmbeddingProvider(provider_type) + except ValueError: + logger.debug( + f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer." + ) + return _DEFAULT_TOKENIZER + return _check_tokenizer_cache(provider_type, model_name) return _DEFAULT_TOKENIZER diff --git a/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx b/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx index 044155fce3a..e8507a4fdef 100644 --- a/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx +++ b/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx @@ -11,6 +11,7 @@ import { LLM_PROVIDERS_ADMIN_URL, } from "../../configuration/llm/constants"; import { mutate } from "swr"; +import { testEmbedding } from "../pages/utils"; export function ChangeCredentialsModal({ provider, @@ -112,16 +113,15 @@ export function ChangeCredentialsModal({ const normalizedProviderType = provider.provider_type .toLowerCase() .split(" ")[0]; + try { - const testResponse = await fetch("/api/admin/embedding/test-embedding", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ - provider_type: normalizedProviderType, - api_key: apiKey, - api_url: apiUrl, - model_name: modelName, - }), + const testResponse = await testEmbedding({ + provider_type: normalizedProviderType, + modelName, + apiKey, + apiUrl, + apiVersion: null, + deploymentName: null, }); if (!testResponse.ok) { diff --git a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx index 4ca22e2501e..a270d0498f8 100644 --- a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx +++ b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx @@ -110,20 +110,27 @@ export function ProviderCreationModal({ setErrorMsg(""); try { const customConfig = Object.fromEntries(values.custom_config); + const providerType = values.provider_type.toLowerCase().split(" ")[0]; + const isOpenAI = providerType === "openai"; + + const testModelName = + isOpenAI || isAzure ? "text-embedding-3-small" : values.model_name; + + const testEmbeddingPayload = { + provider_type: providerType, + api_key: values.api_key, + api_url: values.api_url, + model_name: testModelName, + api_version: values.api_version, + deployment_name: values.deployment_name, + }; const initialResponse = await fetch( "/api/admin/embedding/test-embedding", { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ - provider_type: values.provider_type.toLowerCase().split(" ")[0], - api_key: values.api_key, - api_url: values.api_url, - model_name: values.model_name, - api_version: values.api_version, - deployment_name: values.deployment_name, - }), + body: JSON.stringify(testEmbeddingPayload), } ); diff --git a/web/src/app/admin/embeddings/pages/utils.ts b/web/src/app/admin/embeddings/pages/utils.ts index 86af49522b4..3d3065b54eb 100644 --- a/web/src/app/admin/embeddings/pages/utils.ts +++ b/web/src/app/admin/embeddings/pages/utils.ts @@ -8,3 +8,37 @@ export const deleteSearchSettings = async (search_settings_id: number) => { }); return response; }; + +export const testEmbedding = async ({ + provider_type, + modelName, + apiKey, + apiUrl, + apiVersion, + deploymentName, +}: { + provider_type: string; + modelName: string; + apiKey: string | null; + apiUrl: string | null; + apiVersion: string | null; + deploymentName: string | null; +}) => { + const testModelName = + provider_type === "openai" ? "text-embedding-3-small" : modelName; + + const testResponse = await fetch("/api/admin/embedding/test-embedding", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + provider_type: provider_type, + api_key: apiKey, + api_url: apiUrl, + model_name: testModelName, + api_version: apiVersion, + deployment_name: deploymentName, + }), + }); + + return testResponse; +};