Skip to content
This repository has been archived by the owner on Jan 5, 2025. It is now read-only.

Adding fastembed with claude instant #407

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llm-server/enums/embedding_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ class EmbeddingProvider(Enum):
BARD = "bard"
azure = "azure"
llama2 = "llama2"
openchat = "openchat"
fastembed = "fastembed"
8 changes: 5 additions & 3 deletions llm-server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ hpack==4.0.0
httpcore==0.18.0
httptools==0.6.1
httpx==0.25.0
huggingface-hub==0.17.3
huggingface-hub==0.19.4
humanfriendly==10.0
Hypercorn==0.15.0
hyperframe==6.0.1
Expand Down Expand Up @@ -143,7 +143,7 @@ sympy==1.12
taskgroup==0.0.0a4
tenacity==8.2.3
tiktoken==0.5.1
tokenizers==0.14.1
tokenizers==0.15.0
tomli==2.0.1
tqdm==4.66.1
trio==0.23.1
Expand All @@ -166,4 +166,6 @@ wrapt==1.16.0
wsproto==1.2.0
yarl==1.9.2
zipp==3.17.0
lxml==4.9.3
lxml==4.9.3
fastembed==0.1.3
anthropic==0.7.8
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ class EmbeddingProvider(Enum):
azure = "azure"
llama2 = "llama2"
openchat = "openchat"

fastembed = "fastembed"
25 changes: 18 additions & 7 deletions llm-server/shared/utils/opencopilot_utils/get_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@
from .embedding_type import EmbeddingProvider
from langchain.embeddings.base import Embeddings
from utils.get_logger import CustomLogger
from langchain.embeddings.fastembed import FastEmbedEmbeddings


logger = CustomLogger(module_name=__name__)

LOCAL_IP = os.getenv("LOCAL_IP", "host.docker.internal")


@lru_cache(maxsize=1)
def get_embeddings():
embedding_provider = os.environ.get("EMBEDDING_PROVIDER", EmbeddingProvider.OPENAI.value)
embedding_provider = os.environ.get(
"EMBEDDING_PROVIDER", EmbeddingProvider.OPENAI.value
)

if embedding_provider == EmbeddingProvider.azure.value:
deployment = os.environ.get("AZURE_OPENAI_EMBEDDING_MODEL_NAME")
client = os.environ.get("AZURE_OPENAI_API_TYPE")

# These keys should be set
# os.environ["OPENAI_API_TYPE"] = "azure"
# os.environ["OPENAI_API_BASE"] = "https://<your-endpoint.openai.azure.com/"
Expand All @@ -32,19 +36,26 @@ def get_embeddings():
client=client,
chunk_size=8,
)

elif embedding_provider == EmbeddingProvider.openchat.value:
logger.info("Got ollama embedding provider", provider=embedding_provider)
return OllamaEmbeddings(base_url=f"{LOCAL_IP}:11434", model="openchat")

elif embedding_provider == EmbeddingProvider.OPENAI.value or embedding_provider is None:

elif (
embedding_provider == EmbeddingProvider.OPENAI.value
or embedding_provider is None
):
if embedding_provider is None:
warnings.warn("No embedding provider specified. Defaulting to OpenAI.")
return OpenAIEmbeddings()
elif embedding_provider == EmbeddingProvider.fastembed.value:
return FastEmbedEmbeddings()

else:
available_providers = ", ".join([service.value for service in EmbeddingProvider])
available_providers = ", ".join(
[service.value for service in EmbeddingProvider]
)
raise ValueError(
f"Embedding service '{embedding_provider}' is not currently available. "
f"Available services: {available_providers}"
)
)
Loading