Skip to content

Commit

Permalink
refactor: refactored embedder and chroma client
Browse files Browse the repository at this point in the history
  • Loading branch information
umbertogriffo committed Sep 13, 2024
1 parent de6eda4 commit 0594b3f
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 280 deletions.
75 changes: 22 additions & 53 deletions chatbot/bot/memory/embedder.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,49 @@
from abc import ABC, abstractmethod
from typing import Any

import sentence_transformers

class Embedder(ABC):
@abstractmethod
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""

@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""


class HuggingFaceEmbedder(Embedder):
"""HuggingFace sentence_transformers embedding models.
To use, you should have the ``sentence_transformers`` python package installed.
"""

client: Any #: :meta private:
model_name: str = "all-MiniLM-L6-v2"
"""Model name to use."""
cache_folder: str | None = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: dict[str, Any] = {}
"""Keyword arguments to pass to the model."""
encode_kwargs: dict[str, Any] = {}
"""Keyword arguments to pass when calling the `encode` method of the model."""
multi_process: bool = False
"""Run encode() on multiple GPUs."""

def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
import sentence_transformers

except ImportError as exc:
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
) from exc
class Embedder:
def __init__(self, model_name: str = "all-MiniLM-L6-v2", cache_folder: str | None = None, **kwargs: Any):
"""
Initialize the Embedder class with the specified parameters.
self.client = sentence_transformers.SentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
Args:
**kwargs (Any): Additional keyword arguments to pass to the SentenceTransformer model.
"""
self.client = sentence_transformers.SentenceTransformer(model_name, cache_folder=cache_folder, **kwargs)

def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
def embed_documents(self, texts: list[str], multi_process: bool = False, **encode_kwargs: Any) -> list[list[float]]:
"""
Compute document embeddings using a transformer model.
Args:
texts: The list of texts to embed.
texts (list[str]): The list of texts to embed.
multi_process (bool): If True, use multiple processes to compute embeddings.
**encode_kwargs (Any): Additional keyword arguments to pass when calling the `encode` method of the model.
Returns:
List of embeddings, one for each text.
list[list[float]]: A list of embeddings, one for each text.
"""
import sentence_transformers

texts = list(map(lambda x: x.replace("\n", " "), texts))
if self.multi_process:
if multi_process:
pool = self.client.start_multi_process_pool()
embeddings = self.client.encode_multi_process(texts, pool)
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
else:
embeddings = self.client.encode(texts, **self.encode_kwargs)
embeddings = self.client.encode(texts, show_progress_bar=True, **encode_kwargs)

return embeddings.tolist()

def embed_query(self, text: str) -> list[float]:
"""Compute query embeddings using a HuggingFace transformer model.
"""
Compute query embeddings using a transformer model.
Args:
text: The text to embed.
text (str): The text to embed.
Returns:
Embeddings for the text.
list[float]: Embeddings for the text.
"""
return self.embed_documents([text])[0]
4 changes: 2 additions & 2 deletions chatbot/cli/rag_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from bot.client.lama_cpp_client import LamaCppClient
from bot.conversation.conversation_retrieval import ConversationRetrieval
from bot.conversation.ctx_strategy import get_ctx_synthesis_strategies, get_ctx_synthesis_strategy
from bot.memory.embedder import HuggingFaceEmbedder
from bot.memory.embedder import Embedder
from bot.memory.vector_memory import VectorMemory
from bot.model.model_settings import get_model_setting, get_models
from helpers.log import get_logger
Expand Down Expand Up @@ -135,7 +135,7 @@ def main(parameters):

conversation = ConversationRetrieval(llm)

embedding = HuggingFaceEmbedder()
embedding = Embedder()
index = VectorMemory(vector_store_path=str(vector_store_path), embedding=embedding)

loop(conversation, synthesis_strategy, index, parameters)
Expand Down
7 changes: 4 additions & 3 deletions chatbot/memory_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from pathlib import Path
from typing import List

from bot.memory.embedder import HuggingFaceEmbedder
from bot.memory.embedder import Embedder
from bot.memory.vector_memory import VectorMemory
from document_loader.format import Format
from document_loader.loader import DirectoryLoader
from document_loader.text_splitter import Format, create_recursive_text_splitter
from document_loader.text_splitter import create_recursive_text_splitter
from entities.document import Document
from helpers.log import get_logger

Expand Down Expand Up @@ -62,7 +63,7 @@ def build_memory_index(docs_path: Path, vector_store_path: str, chunk_size: int,
logger.info(f"Number of generated chunks: {len(chunks)}")

logger.info("Creating memory index...")
embedding = HuggingFaceEmbedder()
embedding = Embedder()
VectorMemory.create_memory_index(embedding, chunks, vector_store_path)
logger.info("Memory Index has been created successfully!")

Expand Down
4 changes: 2 additions & 2 deletions chatbot/rag_chatbot_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_ctx_synthesis_strategies,
get_ctx_synthesis_strategy,
)
from bot.memory.embedder import HuggingFaceEmbedder
from bot.memory.embedder import Embedder
from bot.memory.vector_memory import VectorMemory
from bot.model.model_settings import get_model_setting, get_models
from helpers.log import get_logger
Expand Down Expand Up @@ -51,7 +51,7 @@ def load_index(vector_store_path: Path) -> VectorMemory:
Returns:
VectorMemory: An instance of the VectorMemory class with the loaded index.
"""
embedding = HuggingFaceEmbedder()
embedding = Embedder()
index = VectorMemory(vector_store_path=str(vector_store_path), embedding=embedding)

return index
Expand Down
Loading

0 comments on commit 0594b3f

Please sign in to comment.