Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

feature: implement auto embedding #181

Merged
merged 5 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions autollm/auto/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import asyncio
from typing import Any, List

from litellm import embedding as lite_embedding
from llama_index.bridge.pydantic import Field
from llama_index.embeddings.base import BaseEmbedding, Embedding


class AutoEmbedding(BaseEmbedding):
"""
Custom embedding class for flexible and efficient text embedding.

This class interfaces with the LiteLLM library to use its embedding functionality, making it compatible
with a wide range of LLM models.
"""

# Define the model attribute using Pydantic's Field
model: str = Field(default="text-embedding-ada-002", description="The name of the embedding model.")

def __init__(self, model: str, **kwargs: Any) -> None:
"""
Initialize the AutoEmbedding with a specific model.

Args:
model (str): ID of the embedding model to use.
**kwargs (Any): Additional keyword arguments.
"""
super().__init__(**kwargs)
self.model = model # Set the model ID for embedding

def _get_query_embedding(self, query: str) -> Embedding:
"""
Synchronously get the embedding for a query string.

Args:
query (str): The query text to embed.

Returns:
Embedding: The embedding vector.
"""
response = lite_embedding(model=self.model, input=[query])
return self._parse_embedding_response(response)

async def _aget_query_embedding(self, query: str) -> Embedding:
"""
Asynchronously get the embedding for a query string.

Args:
query (str): The query text to embed.

Returns:
Embedding: The embedding vector.
"""
response = await asyncio.to_thread(lite_embedding, model=self.model, input=[query])
return self._parse_embedding_response(response)

def _get_text_embedding(self, text: str) -> Embedding:
"""
Synchronously get the embedding for a text string.

Args:
text (str): The text to embed.

Returns:
Embedding: The embedding vector.
"""
return self._get_query_embedding(text)

async def _aget_text_embedding(self, text: str) -> Embedding:
"""
Asynchronously get the embedding for a text string.

Args:
text (str): The text to embed.

Returns:
Embedding: The embedding vector.
"""
return await self._aget_query_embedding(text)

def _parse_embedding_response(self, response):
"""
Parse the embedding response from LiteLLM and extract the embedding data.

Args:
response: The response object from LiteLLM's embedding function.

Returns:
List[float]: The extracted embedding list.
"""
try:
if 'data' in response and len(response['data']) > 0 and 'embedding' in response['data'][0]:
return response['data'][0]['embedding']
else:
raise ValueError("Invalid response structure from embedding function.")
except (TypeError, KeyError, IndexError) as e:
# Handle any parsing errors
raise ValueError(f"Error parsing embedding response: {e}")
4 changes: 2 additions & 2 deletions autollm/auto/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from llama_index.llms import LiteLLM
from llama_index.llms.base import LLM
from llama_index.llms.base import BaseLLM


class AutoLiteLLM:
Expand All @@ -14,7 +14,7 @@ def from_defaults(
model: str = "gpt-3.5-turbo",
max_tokens: Optional[int] = 256,
temperature: float = 0.1,
api_base: Optional[str] = None) -> LLM:
api_base: Optional[str] = None) -> BaseLLM:
"""
Create any LLM by model name. Check https://docs.litellm.ai/docs/providers for a list of
supported models.
Expand Down
19 changes: 11 additions & 8 deletions autollm/auto/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.schema import BaseNode

from autollm.auto.embedding import AutoEmbedding
from autollm.auto.llm import AutoLiteLLM
from autollm.auto.service_context import AutoServiceContext
from autollm.auto.vector_store_index import AutoVectorStoreIndex
Expand All @@ -26,7 +27,7 @@ def create_query_engine(
system_prompt: str = None,
query_wrapper_prompt: Union[str, BasePromptTemplate] = None,
enable_cost_calculator: bool = True,
embed_model: Union[str, EmbedType] = "default", # ["default", "local"]
embed_model: Optional[str] = "text-embedding-ada-002",
chunk_size: Optional[int] = 512,
chunk_overlap: Optional[int] = 100,
context_window: Optional[int] = None,
Expand Down Expand Up @@ -106,9 +107,12 @@ def create_query_engine(

llm = AutoLiteLLM.from_defaults(
model=llm_model, api_base=llm_api_base, max_tokens=llm_max_tokens, temperature=llm_temperature)

embedding = AutoEmbedding(model=embed_model)

service_context = AutoServiceContext.from_defaults(
llm=llm,
embed_model=embed_model,
embed_model=embedding,
system_prompt=system_prompt,
query_wrapper_prompt=query_wrapper_prompt,
enable_cost_calculator=enable_cost_calculator,
Expand Down Expand Up @@ -173,7 +177,7 @@ class AutoQueryEngine:
system_prompt=None,
query_wrapper_prompt=None,
enable_cost_calculator=True,
embed_model="default", # ["default", "local"]
embed_model="text-embedding-ada-002",
chunk_size=512,
chunk_overlap=None,
context_window=None,
Expand Down Expand Up @@ -212,15 +216,15 @@ def from_defaults(
documents: Optional[Sequence[Document]] = None,
nodes: Optional[Sequence[BaseNode]] = None,
# llm_params
llm_model: str = "gpt-3.5-turbo",
llm_model: Optional[str] = "gpt-3.5-turbo",
llm_api_base: Optional[str] = None,
llm_max_tokens: Optional[int] = None,
llm_temperature: float = 0.1,
llm_temperature: Optional[float] = 0.1,
# service_context_params
system_prompt: str = None,
query_wrapper_prompt: Union[str, BasePromptTemplate] = None,
enable_cost_calculator: bool = True,
embed_model: Union[str, EmbedType] = "default", # ["default", "local"]
embed_model: Optional[str] = "text-embedding-ada-002",
chunk_size: Optional[int] = 512,
chunk_overlap: Optional[int] = 200,
context_window: Optional[int] = None,
Expand Down Expand Up @@ -253,8 +257,7 @@ def from_defaults(
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (Union[str, BasePromptTemplate]): The query wrapper prompt to use for the query engine.
enable_cost_calculator (bool): Flag to enable cost calculator logging.
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI,
"local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large)
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings.
chunk_size (int): The token chunk size for each chunk.
chunk_overlap (int): The token overlap between each chunk.
context_window (int): The maximum context size that will get sent to the LLM.
Expand Down
2 changes: 1 addition & 1 deletion autollm/auto/service_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def from_defaults(
query_wrapper_prompt: Union[str, BasePromptTemplate] = None,
enable_cost_calculator: bool = False,
chunk_size: Optional[int] = 512,
chunk_overlap: Optional[int] = 200,
chunk_overlap: Optional[int] = 100,
context_window: Optional[int] = None,
enable_title_extractor: bool = False,
enable_summary_extractor: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
llama-index==0.9.10
litellm==1.1.1
llama-index==0.9.21
litellm==1.15.6
uvicorn
fastapi
python-dotenv
Expand Down
2 changes: 1 addition & 1 deletion tests/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ tasks:
llm_temperature: 0.1
system_prompt: "You are a friendly chatbot that can summarize documents.:" # System prompt for this task
enable_cost_calculator: true
embed_model: "default"
embed_model: "text-embedding-ada-002"
chunk_size: 512
chunk_overlap: 64
context_window: 2048
Expand Down
4 changes: 2 additions & 2 deletions tests/test_auto_lite_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from llama_index import Document, ServiceContext, VectorStoreIndex
from llama_index.llms.base import LLM
from llama_index.llms.base import BaseLLM
from llama_index.query_engine import BaseQueryEngine

from autollm.auto.llm import AutoLiteLLM
Expand All @@ -11,7 +11,7 @@ def test_auto_lite_llm():
llm = AutoLiteLLM.from_defaults(model="gpt-3.5-turbo")

# Check if the llm is an instance of LLM
assert isinstance(llm, LLM)
assert isinstance(llm, BaseLLM)

service_context = ServiceContext.from_defaults(llm=llm)

Expand Down