Skip to content

Commit

Permalink
Added convenience methods for building sparse/hybrid vector dbs (#250)
Browse files Browse the repository at this point in the history
* Added convenience methods for building sparse/hybrid vector dbs

* Switched to factory model for init
  • Loading branch information
whitead authored Mar 8, 2024
1 parent 7b28a34 commit a512c2f
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 89 deletions.
6 changes: 6 additions & 0 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
OpenAILLMModel,
SentenceTransformerEmbeddingModel,
SparseEmbeddingModel,
embedding_model_factory,
llm_model_factory,
vector_store_factory,
)
from .version import __version__

Expand All @@ -40,4 +43,7 @@
"LangchainVectorStore",
"print_callback",
"LLMResult",
"vector_store_factory",
"llm_model_factory",
"embedding_model_factory",
]
110 changes: 28 additions & 82 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
from pydantic import BaseModel, ConfigDict, Field, model_validator

from .llms import (
LangchainEmbeddingModel,
LangchainLLMModel,
HybridEmbeddingModel,
LLMModel,
NumpyVectorStore,
OpenAIEmbeddingModel,
OpenAILLMModel,
SentenceTransformerEmbeddingModel,
VectorStore,
get_score,
is_openai_model,
llm_model_factory,
vector_store_factory,
)
from .paths import PAPERQA_DIR
from .readers import read_doc
Expand Down Expand Up @@ -97,94 +96,31 @@ class Docs(BaseModel):

def __init__(self, **data):
# We do it here because we need to move things to private attributes
embedding_client: Any | None = None
client: Any | None = None
if "embedding_client" in data:
embedding_client = data.pop("embedding_client")
# convenience to pull embedding_client from client if reasonable
elif (
"client" in data
and data["client"] is not None
and type(data["client"]) == AsyncOpenAI
):
# convenience
embedding_client = data["client"]
elif "embedding" in data and data["embedding"] != "default":
embedding_client = None
else:
embedding_client = AsyncOpenAI()
if "client" in data:
client = data.pop("client")
elif "llm_model" in data and data["llm_model"] is not None:
# except if it is an OpenAILLMModel
client = (
AsyncOpenAI() if type(data["llm_model"]) == OpenAILLMModel else None
)
else:
client = AsyncOpenAI()
# backwards compatibility
if "doc_index" in data:
data["docs_index"] = data.pop("doc_index")
super().__init__(**data)
self._client = client
self._embedding_client = embedding_client
# more convenience
if (
type(self.texts_index.embedding_model) == OpenAIEmbeddingModel
and embedding_client is None
):
self._embedding_client = self._client

# run this here (instead of automatically) so it has access to privates
# If I ever figure out a better way of validating privates
# I can move this back to the decorator
Docs.make_llm_names_consistent(self)
self.set_client(client, embedding_client)

@model_validator(mode="before")
@classmethod
def setup_alias_models(cls, data: Any) -> Any: # noqa: C901, PLR0912
def setup_alias_models(cls, data: Any) -> Any:
if isinstance(data, dict):
if "llm" in data and data["llm"] != "default":
if is_openai_model(data["llm"]):
data["llm_model"] = OpenAILLMModel(config={"model": data["llm"]})
elif data["llm"] == "langchain":
data["llm_model"] = LangchainLLMModel()
else:
raise ValueError(f"Could not guess model type for {data['llm']}. ")
data["llm_model"] = llm_model_factory(data["llm"])
if "summary_llm" in data and data["summary_llm"] is not None:
if is_openai_model(data["summary_llm"]):
data["summary_llm_model"] = OpenAILLMModel(
config={"model": data["summary_llm"]}
)
else:
raise ValueError(f"Could not guess model type for {data['llm']}. ")
data["summary_llm_model"] = llm_model_factory(data["summary_llm"])
if "embedding" in data and data["embedding"] != "default":
if data["embedding"] == "langchain":
if "texts_index" not in data:
data["texts_index"] = NumpyVectorStore(
embedding_model=LangchainEmbeddingModel()
)
if "docs_index" not in data:
data["docs_index"] = NumpyVectorStore(
embedding_model=LangchainEmbeddingModel()
)
elif data["embedding"] == "sentence-transformers":
if "texts_index" not in data:
data["texts_index"] = NumpyVectorStore(
embedding_model=SentenceTransformerEmbeddingModel()
)
if "docs_index" not in data:
data["docs_index"] = NumpyVectorStore(
embedding_model=SentenceTransformerEmbeddingModel()
)
else:
# must be an openai model
if "texts_index" not in data:
data["texts_index"] = NumpyVectorStore(
embedding_model=OpenAIEmbeddingModel(name=data["embedding"])
)
if "docs_index" not in data:
data["docs_index"] = NumpyVectorStore(
embedding_model=OpenAIEmbeddingModel(name=data["embedding"])
)
if "texts_index" not in data:
data["texts_index"] = vector_store_factory(data["embedding"])
if "docs_index" not in data:
data["docs_index"] = vector_store_factory(data["embedding"])
return data

@model_validator(mode="after")
Expand Down Expand Up @@ -255,14 +191,24 @@ def __setstate__(self, state):

def set_client(
self,
client: AsyncOpenAI | None = None,
embedding_client: AsyncOpenAI | None = None,
client: Any | None = None,
embedding_client: Any | None = None,
):
if client is None:
if client is None and isinstance(self.llm_model, OpenAILLMModel):
client = AsyncOpenAI()
self._client = client
if embedding_client is None:
embedding_client = client if type(client) == AsyncOpenAI else AsyncOpenAI()
if embedding_client is None: # noqa: SIM102
# check if we have an openai embedding model in use
if isinstance(self.texts_index.embedding_model, OpenAIEmbeddingModel) or (
isinstance(self.texts_index.embedding_model, HybridEmbeddingModel)
and any(
isinstance(m, OpenAIEmbeddingModel)
for m in self.texts_index.embedding_model.models
)
):
embedding_client = (
client if isinstance(client, AsyncOpenAI) else AsyncOpenAI()
)
self._embedding_client = embedding_client
Docs.make_llm_names_consistent(self)

Expand Down
41 changes: 38 additions & 3 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def embed_documents(self, client: Any, texts: list[str]) -> list[list[floa
class SparseEmbeddingModel(EmbeddingModel):
"""This is a very simple keyword search model - probably best to be mixed with others."""

name: str = "sparse-embed"
name: str = "sparse"
ndim: int = 256
enc: Any = Field(default_factory=lambda: tiktoken.get_encoding("cl100k_base"))

Expand Down Expand Up @@ -429,7 +429,7 @@ async def achat(self, client: Any, messages: list[dict[str, str]]) -> str:
messages=[m for m in messages if m["role"] != "system"],
**process_llm_config(self.config, "max_tokens"),
)
return completion.content or ""
return str(completion.content) or ""

async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any:
aclient = self._check_client(client)
Expand Down Expand Up @@ -525,7 +525,7 @@ class VectorStore(BaseModel, ABC):

embedding_model: EmbeddingModel = Field(default=OpenAIEmbeddingModel())
# can be tuned for different tasks
mmr_lambda: float = Field(default=0.5)
mmr_lambda: float = Field(default=0.9)
model_config = ConfigDict(extra="forbid")

@abstractmethod
Expand Down Expand Up @@ -815,3 +815,38 @@ def get_score(text: str) -> int:
if len(text) < 100: # noqa: PLR2004
return 1
return 5


def llm_model_factory(llm: str) -> LLMModel:
if llm != "default":
if is_openai_model(llm):
return OpenAILLMModel(config={"model": llm})
elif llm == "langchain": # noqa: RET505
return LangchainLLMModel()
elif "claude" in llm:
return AnthropicLLMModel(config={"model": llm})
else:
raise ValueError(f"Could not guess model type for {llm}. ")
return OpenAILLMModel()


def embedding_model_factory(embedding: str) -> EmbeddingModel:
if embedding == "langchain":
return LangchainEmbeddingModel()
elif embedding == "sentence-transformers": # noqa: RET505
return SentenceTransformerEmbeddingModel()
elif embedding.startswith("hybrid"):
embedding_model_name = "-".join(embedding.split("-")[1:])
return HybridEmbeddingModel(
models=[
OpenAIEmbeddingModel(name=embedding_model_name),
SparseEmbeddingModel(),
]
)
elif embedding == "sparse":
return SparseEmbeddingModel()
return OpenAIEmbeddingModel(name=embedding)


def vector_store_factory(embedding: str) -> NumpyVectorStore:
return NumpyVectorStore(embedding_model=embedding_model_factory(embedding))
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "4.1.1"
__version__ = "4.2.0"
53 changes: 50 additions & 3 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,10 +486,23 @@ def accum(x):
assert completion.prompt_count > 0
assert completion.completion_count > 0
assert str(completion) == "".join(outputs)
assert type(completion.text) is str # noqa: E721

completion = await call({"animal": "duck"}) # type: ignore[call-arg]
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0
assert type(completion.text) is str # noqa: E721

docs = Docs(llm="claude-3-sonnet-20240229", client=client)
await docs.aadd_url(
"https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day",
citation="WikiMedia Foundation, 2023, Accessed now",
dockey="test",
)
answer = await docs.aget_evidence(
Answer(question="What is the national flag of Canada?")
)
await docs.aquery("What is the national flag of Canada?", answer=answer)


def test_docs():
Expand Down Expand Up @@ -730,8 +743,19 @@ def test_sparse_embedding():
)
assert any(docs.docs["test"].embedding) # type: ignore[arg-type]

# test alias
docs = Docs(embedding="sparse")
assert docs._embedding_client is None
assert docs.embedding.startswith("sparse") # type: ignore[union-attr]
docs.add_url(
"https://en.wikipedia.org/wiki/Frederick_Bates_(politician)",
citation="WikiMedia Foundation, 2023, Accessed now",
dockey="test",
)
assert any(docs.docs["test"].embedding) # type: ignore[arg-type]


def test_hyrbrid_embedding():
def test_hybrid_embedding():
model = HybridEmbeddingModel(
models=[
OpenAIEmbeddingModel(),
Expand All @@ -751,6 +775,19 @@ def test_hyrbrid_embedding():
)
assert any(docs.docs["test"].embedding) # type: ignore[arg-type]

# now try via alias
docs = Docs(
embedding="hybrid-text-embedding-3-small",
)
assert type(docs._embedding_client) is AsyncOpenAI
assert docs.embedding.startswith("hybrid") # type: ignore[union-attr]
docs.add_url(
"https://en.wikipedia.org/wiki/Frederick_Bates_(politician)",
citation="WikiMedia Foundation, 2023, Accessed now",
dockey="test",
)
assert any(docs.docs["test"].embedding) # type: ignore[arg-type]


def test_sentence_transformer_embedding():
from paperqa import SentenceTransformerEmbeddingModel
Expand Down Expand Up @@ -1038,7 +1075,10 @@ def test_docs_pickle() -> None:
docs = Docs(
llm_model=OpenAILLMModel(
config={"temperature": 0.0, "model": "gpt-3.5-turbo"}
)
),
summary_llm_model=OpenAILLMModel(
config={"temperature": 0.0, "model": "gpt-3.5-turbo"}
),
)
assert docs._client is not None
old_config = docs.llm_model.config
Expand All @@ -1054,6 +1094,7 @@ def test_docs_pickle() -> None:
assert docs2._client is not None
assert docs2.llm_model.config == old_config
assert docs2.summary_llm_model.config == old_sconfig
print(old_config, old_sconfig)
assert len(docs.docs) == len(docs2.docs)
for _ in range(4): # Retry a few times, because this is flaky
docs_context = docs.get_evidence(
Expand All @@ -1072,7 +1113,8 @@ def test_docs_pickle() -> None:
k=3,
max_sources=1,
).context
if strings_similarity(s1=docs_context, s2=docs2_context) > 0.75:
# It is shocking how unrepeatable this is
if strings_similarity(s1=docs_context, s2=docs2_context) > 0.50:
break
else:
raise AssertionError("Failed to attain similar contexts, even with retrying.")
Expand Down Expand Up @@ -1514,6 +1556,11 @@ def test_embedding_name_consistency():
)
assert docs.embedding == "test"

docs = Docs(embedding="hybrid-text-embedding-ada-002")
assert type(docs.docs_index.embedding_model) is HybridEmbeddingModel
assert docs.docs_index.embedding_model.models[0].name == "text-embedding-ada-002"
assert docs.docs_index.embedding_model.models[1].name == "sparse"


def test_external_texts_index():
docs = Docs(jit_texts_index=True)
Expand Down

0 comments on commit a512c2f

Please sign in to comment.