Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
update dependencies and configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon committed Nov 19, 2023
1 parent 26e011e commit 708b0dc
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 103 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,14 @@ switching from Llama-Index? We've got you covered.

>>> vector_store = LanceDBVectorStore(uri="./.lancedb")
>>> storage_context = StorageContext.from_defaults(vector_store=vector_store)
>>> index = VectorStoreIndex.from_documents(documents=documents)
>>> service_context = ServiceContext.from_defaults()
>>> index = VectorStoreIndex.from_documents(
documents=documents,
storage_context=storage_contex,
service_context=service_context,
)

>>> query_engine = AutoQueryEngine.from_instances(index, service_context)
>>> query_engine = AutoQueryEngine.from_instances(index)
```

</details>
Expand Down
2 changes: 1 addition & 1 deletion autollm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
and vector databases, along with various utility functions.
"""

__version__ = '0.1.0.dev1'
__version__ = '0.1.0.dev2'
__author__ = 'safevideo'
__license__ = 'AGPL-3.0'

Expand Down
127 changes: 86 additions & 41 deletions autollm/auto/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from llama_index.embeddings.utils import EmbedType
from llama_index.indices.query.base import BaseQueryEngine
from llama_index.schema import BaseNode
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.prompts.base import PromptTemplate
from llama_index.prompts.prompt_type import PromptType

from autollm.auto.llm import AutoLiteLLM
from autollm.auto.service_context import AutoServiceContext
Expand All @@ -25,15 +28,22 @@ def create_query_engine(
enable_cost_calculator: bool = True,
embed_model: Union[str, EmbedType] = "default", # ["default", "local"]
chunk_size: Optional[int] = 512,
chunk_overlap: Optional[int] = None,
chunk_overlap: Optional[int] = 200,
context_window: Optional[int] = None,
enable_title_extractor: bool = False,
enable_summary_extractor: bool = False,
enable_qa_extractor: bool = False,
enable_keyword_extractor: bool = False,
enable_entity_extractor: bool = False,
# query_engine_params
similarity_top_k: int = 6,
response_mode: str = "compact",
refine_prompt: str = None,
structured_answer_filtering: bool = False,
# vector_store_params
vector_store_type: str = "LanceDBVectorStore",
lancedb_uri: str = "./.lancedb",
lancedb_table_name: str = "vectors",
enable_metadata_extraction: bool = False,
# Deprecated parameters
llm_params: dict = None,
vector_store_params: dict = None,
Expand All @@ -45,15 +55,30 @@ def create_query_engine(
Parameters:
documents (Sequence[Document]): Sequence of llama_index.Document instances.
nodes (Sequence[BaseNode]): Sequence of llama_index.BaseNode instances.
llm_model (str): The LLM model to use for the query engine.
llm_max_tokens (int): The maximum number of tokens to be generated as LLM output.
llm_temperature (float): The temperature to use for the LLM.
llm_api_base (str): The API base to use for the LLM.
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (str): The query wrapper prompt to use for the query engine.
enable_cost_calculator (bool): Flag to enable cost calculator logging.
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI,
"local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large)
llm_params (dict): Parameters for the LLM.
vector_store_params (dict): Parameters for the vector store.
service_context_params (dict): Parameters for the service context.
query_engine_params (dict): Parameters for the query engine.
chunk_size (int): The token chunk size for each chunk.
chunk_overlap (int): The token overlap between each chunk.
context_window (int): The maximum context size that will get sent to the LLM.
enable_title_extractor (bool): Flag to enable title extractor.
enable_summary_extractor (bool): Flag to enable summary extractor.
enable_qa_extractor (bool): Flag to enable question answering extractor.
enable_keyword_extractor (bool): Flag to enable keyword extractor.
enable_entity_extractor (bool): Flag to enable entity extractor.
similarity_top_k (int): The number of similar documents to return.
response_mode (str): The response mode to use for the query engine.
refine_prompt (str): The refine prompt to use for the query engine.
vector_store_type (str): The vector store type to use for the query engine.
lancedb_uri (str): The URI to use for the LanceDB vector store.
lancedb_table_name (str): The table name to use for the LanceDB vector store.
Returns:
A llama_index.BaseQueryEngine instance.
Expand Down Expand Up @@ -89,18 +114,35 @@ def create_query_engine(
enable_cost_calculator=enable_cost_calculator,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
context_window=context_window)
context_window=context_window,
enable_title_extractor=enable_title_extractor,
enable_summary_extractor=enable_summary_extractor,
enable_qa_extractor=enable_qa_extractor,
enable_keyword_extractor=enable_keyword_extractor,
enable_entity_extractor=enable_entity_extractor,
)
vector_store_index = AutoVectorStoreIndex.from_defaults(
vector_store_type=vector_store_type,
lancedb_uri=lancedb_uri,
lancedb_table_name=lancedb_table_name,
enable_metadata_extraction=enable_metadata_extraction,
documents=documents,
nodes=nodes,
service_context=service_context,
**vector_store_kwargs)
if refine_prompt is not None:
refine_prompt_template = PromptTemplate(
refine_prompt, prompt_type=PromptType.REFINE
)
else:
refine_prompt_template = None
response_synthesizer = get_response_synthesizer(
service_context=service_context,
response_mode=response_mode,
refine_template=refine_prompt_template,
structured_answer_filtering=structured_answer_filtering
)

return vector_store_index.as_query_engine(similarity_top_k=similarity_top_k)
return vector_store_index.as_query_engine(similarity_top_k=similarity_top_k, response_synthesizer=response_synthesizer)


class AutoQueryEngine:
Expand Down Expand Up @@ -177,10 +219,13 @@ def from_defaults(
enable_cost_calculator: bool = True,
embed_model: Union[str, EmbedType] = "default", # ["default", "local"]
chunk_size: Optional[int] = 512,
chunk_overlap: Optional[int] = None,
chunk_overlap: Optional[int] = 200,
context_window: Optional[int] = None,
# query_engine_params
similarity_top_k: int = 6,
response_mode: str = "compact",
refine_prompt: str = None,
structured_answer_filtering: bool = False,
# vector_store_params
vector_store_type: str = "LanceDBVectorStore",
lancedb_uri: str = "./.lancedb",
Expand All @@ -195,17 +240,32 @@ def from_defaults(
"""
Create an AutoQueryEngine from default parameters.
Parameters:
documents (Sequence[Document]): Sequence of llama_index.Document instances.
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (str): The query wrapper prompt to use for the query engine.
enable_cost_calculator (bool): Flag to enable cost calculator logging.
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI,
"local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large)
llm_params (dict): Parameters for the LLM.
vector_store_params (dict): Parameters for the vector store.
service_context_params (dict): Parameters for the service context.
query_engine_params (dict): Parameters for the query engine.
Parameters:
documents (Sequence[Document]): Sequence of llama_index.Document instances.
nodes (Sequence[BaseNode]): Sequence of llama_index.BaseNode instances.
llm_model (str): The LLM model to use for the query engine.
llm_max_tokens (int): The maximum number of tokens to be generated as LLM output.
llm_temperature (float): The temperature to use for the LLM.
llm_api_base (str): The API base to use for the LLM.
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (str): The query wrapper prompt to use for the query engine.
enable_cost_calculator (bool): Flag to enable cost calculator logging.
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI,
"local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large)
chunk_size (int): The token chunk size for each chunk.
chunk_overlap (int): The token overlap between each chunk.
context_window (int): The maximum context size that will get sent to the LLM.
enable_title_extractor (bool): Flag to enable title extractor.
enable_summary_extractor (bool): Flag to enable summary extractor.
enable_qa_extractor (bool): Flag to enable question answering extractor.
enable_keyword_extractor (bool): Flag to enable keyword extractor.
enable_entity_extractor (bool): Flag to enable entity extractor.
similarity_top_k (int): The number of similar documents to return.
response_mode (str): The response mode to use for the query engine.
refine_prompt (str): The refine prompt to use for the query engine.
vector_store_type (str): The vector store type to use for the query engine.
lancedb_uri (str): The URI to use for the LanceDB vector store.
lancedb_table_name (str): The table name to use for the LanceDB vector store.
Returns:
A llama_index.BaseQueryEngine instance.
Expand All @@ -229,6 +289,9 @@ def from_defaults(
context_window=context_window,
# query_engine_params
similarity_top_k=similarity_top_k,
response_mode=response_mode,
refine_prompt=refine_prompt,
structured_answer_filtering=structured_answer_filtering,
# vector_store_params
vector_store_type=vector_store_type,
lancedb_uri=lancedb_uri,
Expand Down Expand Up @@ -302,23 +365,5 @@ def from_config(
return create_query_engine(
documents=documents,
nodes=nodes,
llm_model=config.get('llm_model'),
llm_api_base=config.get('llm_api_base'),
llm_max_tokens=config.get('llm_max_tokens'),
llm_temperature=config.get('llm_temperature'),
system_prompt=config.get('system_prompt'),
query_wrapper_prompt=config.get('query_wrapper_prompt'),
enable_cost_calculator=config.get('enable_cost_calculator'),
embed_model=config.get('embed_model'),
chunk_size=config.get('chunk_size'),
chunk_overlap=config.get('chunk_overlap'),
context_window=config.get('context_window'),
similarity_top_k=config.get('similarity_top_k'),
vector_store_type=config.get('vector_store_type'),
lancedb_uri=config.get('lancedb_uri'),
lancedb_table_name=config.get('lancedb_table_name'),
enable_metadata_extraction=config.get('enable_metadata_extraction'),
llm_params=config.get('llm_params'),
vector_store_params=config.get('vector_store_params'),
service_context_params=config.get('service_context_params'),
**config.get('vector_store_kwargs', {}))
**config,
)
32 changes: 29 additions & 3 deletions autollm/auto/service_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from llama_index.llms.utils import LLMType
from llama_index.prompts import PromptTemplate
from llama_index.prompts.base import BasePromptTemplate
from llama_index.text_splitter import SentenceSplitter
from llama_index.extractors import TitleExtractor, SummaryExtractor, QuestionsAnsweredExtractor, KeywordExtractor, EntityExtractor

from autollm.callbacks.cost_calculating import CostCalculatingHandler
from autollm.utils.llm_utils import set_default_prompt_template
Expand All @@ -24,8 +26,13 @@ def from_defaults(
query_wrapper_prompt: Union[str, BasePromptTemplate] = None,
enable_cost_calculator: bool = False,
chunk_size: Optional[int] = 512,
chunk_overlap: Optional[int] = None,
chunk_overlap: Optional[int] = 200,
context_window: Optional[int] = None,
enable_title_extractor: bool = False,
enable_summary_extractor: bool = False,
enable_qa_extractor: bool = False,
enable_keyword_extractor: bool = False,
enable_entity_extractor: bool = False,
**kwargs) -> ServiceContext:
"""
Create a ServiceContext with default parameters with extended enable_token_counting functionality. If
Expand All @@ -40,6 +47,11 @@ def from_defaults(
chunk_size (int): The token chunk size for each chunk.
chunk_overlap (int): The token overlap between each chunk.
context_window (int): The maximum context size that will get sent to the LLM.
enable_title_extractor (bool): Flag to enable title extractor.
enable_summary_extractor (bool): Flag to enable summary extractor.
enable_qa_extractor (bool): Flag to enable question answering extractor.
enable_keyword_extractor (bool): Flag to enable keyword extractor.
enable_entity_extractor (bool): Flag to enable entity extractor.
**kwargs: Arbitrary keyword arguments.
Returns:
Expand All @@ -53,12 +65,26 @@ def from_defaults(

callback_manager: CallbackManager = kwargs.get('callback_manager', CallbackManager())
if enable_cost_calculator:
model = llm.metadata.model_name if not "default" else "gpt-3.5-turbo"
callback_manager.add_handler(CostCalculatingHandler(model=model, verbose=True))
llm_model_name = llm.metadata.model_name if not "default" else "gpt-3.5-turbo"
callback_manager.add_handler(CostCalculatingHandler(model_name=llm_model_name, verbose=True))

sentence_splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
transformations = [sentence_splitter]
if enable_entity_extractor:
transformations.append(EntityExtractor())
if enable_keyword_extractor:
transformations.append(KeywordExtractor(llm=llm, keywords=5))
if enable_summary_extractor:
transformations.append(SummaryExtractor(llm=llm, summaries=["prev", "self"]))
if enable_title_extractor:
transformations.append(TitleExtractor(llm=llm, nodes=5))
if enable_qa_extractor:
transformations.append(QuestionsAnsweredExtractor(llm=llm, questions=5))

service_context = ServiceContext.from_defaults(
llm=llm,
embed_model=embed_model,
transformations=transformations,
system_prompt=system_prompt,
query_wrapper_prompt=query_wrapper_prompt,
chunk_size=chunk_size,
Expand Down
44 changes: 5 additions & 39 deletions autollm/auto/vector_store_index.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
from typing import Optional, Sequence

from llama_index import Document, ServiceContext, StorageContext, VectorStoreIndex
from llama_index.node_parser import SimpleNodeParser
from llama_index.node_parser.extractors import (
EntityExtractor,
KeywordExtractor,
MetadataExtractor,
QuestionsAnsweredExtractor,
SummaryExtractor,
TitleExtractor,
)
from llama_index.schema import BaseNode


Expand All @@ -36,7 +27,6 @@ def from_defaults(
vector_store_type: str = "LanceDBVectorStore",
lancedb_uri: str = "./.lancedb",
lancedb_table_name: str = "vectors",
enable_metadata_extraction: bool = False,
documents: Optional[Sequence[Document]] = None,
nodes: Optional[Sequence[BaseNode]] = None,
service_context: Optional[ServiceContext] = None,
Expand All @@ -46,7 +36,6 @@ def from_defaults(
Parameters:
vector_store_type (str): The class name of the vector store (e.g., 'LanceDBVectorStore', 'SimpleVectorStore'..)
enable_metadata_extraction (bool): Whether to enable automated metadata extraction as questions, keywords, entities, or summaries.
documents (Optional[Sequence[Document]]): Documents to initialize the vector store index from.
nodes (Optional[Sequence[BaseNode]]): Nodes to initialize the vector store index from.
service_context (Optional[ServiceContext]): Service context to initialize the vector store index from.
Expand Down Expand Up @@ -79,36 +68,13 @@ def from_defaults(
# Initialize vector store index from documents or nodes
storage_context = StorageContext.from_defaults(vector_store=vector_store)

# Get llm from service context for metadata extraction
llm = service_context.llm if service_context is not None else None

if documents is not None:
# TODO: create_index_from_documents() function
if enable_metadata_extraction:
metadata_extractor = MetadataExtractor(
extractors=[
TitleExtractor(llm=llm, nodes=5),
QuestionsAnsweredExtractor(llm=llm, questions=3),
SummaryExtractor(llm=llm, summaries=["prev", "self"]),
KeywordExtractor(llm=llm, keywords=10),
EntityExtractor(prediction_threshold=0.5)
], )
node_parser = SimpleNodeParser.from_defaults(metadata_extractor=metadata_extractor)
nodes = node_parser.get_nodes_from_documents(documents)
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
service_context=service_context,
show_progress=True)
# Initialize index without metadata extraction
else:
index = VectorStoreIndex.from_documents(
documents=documents,
storage_context=storage_context,
service_context=service_context,
show_progress=True)
index = VectorStoreIndex.from_documents(
documents=documents,
storage_context=storage_context,
service_context=service_context,
show_progress=True)
else:
# TODO: create_index_from_nodes() function
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
Expand Down
Loading

0 comments on commit 708b0dc

Please sign in to comment.