Skip to content

Commit

Permalink
Reenable OpenAI Tokenizer (#3062)
Browse files Browse the repository at this point in the history
* k

* clean up test embeddings

* nit

* minor update to ensure consistency

* minor organizational update

* minor updates

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
  • Loading branch information
yuhongsun96 and pablonyx authored Nov 8, 2024
1 parent 2bbc5d5 commit 4fb65dc
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 42 deletions.
76 changes: 51 additions & 25 deletions backend/danswer/natural_language_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,31 @@ def decode(self, tokens: list[int]) -> str:
class TiktokenTokenizer(BaseTokenizer):
_instances: dict[str, "TiktokenTokenizer"] = {}

def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
if encoding_name not in cls._instances:
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[encoding_name]
def __new__(cls, model_name: str) -> "TiktokenTokenizer":
if model_name not in cls._instances:
cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[model_name]

def __init__(self, encoding_name: str = "cl100k_base"):
def __init__(self, model_name: str):
if not hasattr(self, "encoder"):
import tiktoken

self.encoder = tiktoken.get_encoding(encoding_name)
self.encoder = tiktoken.encoding_for_model(model_name)

def encode(self, string: str) -> list[int]:
# this returns no special tokens
# this ignores special tokens that the model is trained on, see encode_ordinary for details
return self.encoder.encode_ordinary(string)

def tokenize(self, string: str) -> list[str]:
return [self.encoder.decode([token]) for token in self.encode(string)]
encoded = self.encode(string)
decoded = [self.encoder.decode([token]) for token in encoded]

if len(decoded) != len(encoded):
logger.warning(
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
)

return decoded

def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
Expand All @@ -74,22 +82,35 @@ def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)


_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}


def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
global _TOKENIZER_CACHE

if tokenizer_name not in _TOKENIZER_CACHE:
if tokenizer_name == "openai":
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
return _TOKENIZER_CACHE[tokenizer_name]
id_tuple = (model_provider, model_name)

if id_tuple not in _TOKENIZER_CACHE:
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
if model_name is None:
raise ValueError(
"model_name is required for OPENAI and AZURE embeddings"
)

_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
return _TOKENIZER_CACHE[id_tuple]

try:
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
if model_name is None:
model_name = DOCUMENT_ENCODER_MODEL

logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
except Exception as primary_error:
logger.error(
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
)
logger.warning(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
Expand All @@ -98,18 +119,18 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
try:
# Cache this tokenizer name to the default so we don't have to try to load it again
# and fail again
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
DOCUMENT_ENCODER_MODEL
)
except Exception as fallback_error:
logger.error(
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
)
raise ValueError(
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
f"Failed to initialize tokenizer for {model_name} and fallback model"
) from fallback_error

return _TOKENIZER_CACHE[tokenizer_name]
return _TOKENIZER_CACHE[id_tuple]


_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
Expand All @@ -118,11 +139,16 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
# Currently all of the viable models use the same sentencepiece tokenizer
# OpenAI uses a different one but currently it's not supported due to quality issues
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
# LLM tokenizers are specified by strings
global _DEFAULT_TOKENIZER
if provider_type is not None:
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
return _DEFAULT_TOKENIZER


Expand Down
18 changes: 9 additions & 9 deletions web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
LLM_PROVIDERS_ADMIN_URL,
} from "../../configuration/llm/constants";
import { mutate } from "swr";
import { testEmbedding } from "../pages/utils";

export function ChangeCredentialsModal({
provider,
Expand Down Expand Up @@ -112,16 +113,15 @@ export function ChangeCredentialsModal({
const normalizedProviderType = provider.provider_type
.toLowerCase()
.split(" ")[0];

try {
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: normalizedProviderType,
api_key: apiKey,
api_url: apiUrl,
model_name: modelName,
}),
const testResponse = await testEmbedding({
provider_type: normalizedProviderType,
modelName,
apiKey,
apiUrl,
apiVersion: null,
deploymentName: null,
});

if (!testResponse.ok) {
Expand Down
23 changes: 15 additions & 8 deletions web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,27 @@ export function ProviderCreationModal({
setErrorMsg("");
try {
const customConfig = Object.fromEntries(values.custom_config);
const providerType = values.provider_type.toLowerCase().split(" ")[0];
const isOpenAI = providerType === "openai";

const testModelName =
isOpenAI || isAzure ? "text-embedding-3-small" : values.model_name;

const testEmbeddingPayload = {
provider_type: providerType,
api_key: values.api_key,
api_url: values.api_url,
model_name: testModelName,
api_version: values.api_version,
deployment_name: values.deployment_name,
};

const initialResponse = await fetch(
"/api/admin/embedding/test-embedding",
{
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: values.provider_type.toLowerCase().split(" ")[0],
api_key: values.api_key,
api_url: values.api_url,
model_name: values.model_name,
api_version: values.api_version,
deployment_name: values.deployment_name,
}),
body: JSON.stringify(testEmbeddingPayload),
}
);

Expand Down
34 changes: 34 additions & 0 deletions web/src/app/admin/embeddings/pages/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,37 @@ export const deleteSearchSettings = async (search_settings_id: number) => {
});
return response;
};

export const testEmbedding = async ({
provider_type,
modelName,
apiKey,
apiUrl,
apiVersion,
deploymentName,
}: {
provider_type: string;
modelName: string;
apiKey: string | null;
apiUrl: string | null;
apiVersion: string | null;
deploymentName: string | null;
}) => {
const testModelName =
provider_type === "openai" ? "text-embedding-3-small" : modelName;

const testResponse = await fetch("/api/admin/embedding/test-embedding", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: provider_type,
api_key: apiKey,
api_url: apiUrl,
model_name: testModelName,
api_version: apiVersion,
deployment_name: deploymentName,
}),
});

return testResponse;
};

0 comments on commit 4fb65dc

Please sign in to comment.