Skip to content

Commit

Permalink
Litellm bump (#2195)
Browse files Browse the repository at this point in the history
* ran bump-pydantic

* replace root_validator with model_validator

* mostly working. some alternate assistant error. changed root_validator and typing_extensions

* working generation chat. changed type

* replacing .dict with .model_dump

* argument needed to bring model_dump up to parity with dict()

* fix a fewremaining issues -- working with llama and gpt

* updating requirements file

* more requirement updates

* more requirement updates

* fix to make search work

* return type fix:

* half way tpyes change

* fixes for mypy and pydantic:

* endpoint fix

* fix pydantic protected namespaces

* it works!

* removed unecessary None initializations

* better logging

* changed default values to empty lists

* mypy fixes

* fixed array defaulting

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
  • Loading branch information
josvdw and hagen-danswer authored Aug 28, 2024
1 parent 657d205 commit 50c1743
Show file tree
Hide file tree
Showing 52 changed files with 230 additions and 223 deletions.
2 changes: 1 addition & 1 deletion backend/danswer/auth/noauth_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def set_no_auth_user_preferences(
store: DynamicConfigStore, preferences: UserPreferences
) -> None:
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())


def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
Expand Down
5 changes: 3 additions & 2 deletions backend/danswer/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ class QADocsResponse(RetrievalDocs):
applied_time_cutoff: datetime | None
recency_bias_multiplier: float

def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
initial_dict["applied_time_cutoff"] = (
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
)

return initial_dict


Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,4 +813,4 @@ def stream_chat_message(
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.dict())
yield get_json_line(obj.model_dump())
2 changes: 1 addition & 1 deletion backend/danswer/chat/tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypedDict
from typing_extensions import TypedDict # noreorder

from pydantic import BaseModel

Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/connectors/gmail/connector_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def update_gmail_credential_access_tokens(
) -> OAuthCredentials | None:
app_credentials = get_google_app_gmail_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.dict(),
app_credentials.model_dump(),
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/connectors/google_drive/connector_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def update_credential_access_tokens(
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.dict(),
app_credentials.model_dump(),
scopes=SCOPES,
redirect_uri=_build_frontend_google_drive_redirect(),
)
Expand Down
9 changes: 5 additions & 4 deletions backend/danswer/connectors/zulip/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

from pydantic import BaseModel
from pydantic import Field


class Message(BaseModel):
Expand All @@ -18,11 +19,11 @@ class Message(BaseModel):
sender_realm_str: str
subject: str
topic_links: Optional[List[Any]] = None
last_edit_timestamp: Optional[int] = None
edit_history: Any
last_edit_timestamp: Optional[int]
edit_history: Any = None
reactions: List[Any]
submessages: List[Any]
flags: List[str] = []
flags: List[str] = Field(default_factory=list)
display_recipient: Optional[str] = None
type: Optional[str] = None
stream_id: int
Expand All @@ -39,4 +40,4 @@ class GetMessagesResponse(BaseModel):
found_newest: Optional[bool] = None
history_limited: Optional[bool] = None
anchor: Optional[str] = None
messages: List[Message] = []
messages: List[Message] = Field(default_factory=list)
7 changes: 5 additions & 2 deletions backend/danswer/db/document_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
from danswer.db.models import UserRole
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from danswer.server.features.document_set.models import DocumentSetUpdateRequest
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation

logger = setup_logger()


def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
Expand Down Expand Up @@ -233,9 +236,9 @@ def insert_document_set(
)

db_session.commit()
except:
except Exception as e:
db_session.rollback()
raise
logger.error(f"Error creating document set: {e}")

return new_document_set_row, ds_cc_pairs

Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/db/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def upsert_cloud_embedding_provider(
.first()
)
if existing_provider:
for key, value in provider.dict().items():
for key, value in provider.model_dump().items():
setattr(existing_provider, key, value)
else:
new_provider = CloudEmbeddingProviderModel(**provider.dict())
new_provider = CloudEmbeddingProviderModel(**provider.model_dump())
db_session.add(new_provider)
existing_provider = new_provider
db_session.commit()
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Literal
from typing import NotRequired
from typing import Optional
from typing import TypedDict
from typing_extensions import TypedDict # noreorder
from uuid import UUID

from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/file_store/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
from enum import Enum
from typing import NotRequired
from typing import TypedDict
from typing_extensions import TypedDict # noreorder

from pydantic import BaseModel

Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/indexing/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def embed_chunks(
title_embed_dict[title] = title_embedding

new_embedded_chunk = IndexChunk(
**chunk.dict(),
**chunk.model_dump(),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],
Expand Down
5 changes: 2 additions & 3 deletions backend/danswer/indexing/indexing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Protocol

from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy.orm import Session

from danswer.access.access import get_access_for_documents
Expand Down Expand Up @@ -40,9 +41,7 @@
class DocumentBatchPrepareContext(BaseModel):
updatable_docs: list[Document]
id_to_db_doc_map: dict[str, DBDocument]

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)


class IndexingPipelineProtocol(Protocol):
Expand Down
16 changes: 11 additions & 5 deletions backend/danswer/indexing/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING

from pydantic import BaseModel
from pydantic import Field

from danswer.access.models import DocumentAccess
from danswer.connectors.models import Document
Expand All @@ -24,9 +25,8 @@ class BaseChunk(BaseModel):
chunk_id: int
blurb: str # The first sentence(s) of the first Section of the chunk
content: str
source_links: dict[
int, str
] | None # Holds the link and the offsets into the raw Chunk text
# Holds the link and the offsets into the raw Chunk text
source_links: dict[int, str] | None
section_continuation: bool # True if this Chunk's start is not at the start of a Section


Expand All @@ -47,7 +47,7 @@ class DocAwareChunk(BaseChunk):

mini_chunk_texts: list[str] | None

large_chunk_reference_ids: list[int] = []
large_chunk_reference_ids: list[int] = Field(default_factory=list)

def to_short_descriptor(self) -> str:
"""Used when logging the identity of a chunk"""
Expand Down Expand Up @@ -85,7 +85,7 @@ def from_index_chunk(
document_sets: set[str],
boost: int,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.dict()
index_chunk_data = index_chunk.model_dump()
return cls(
**index_chunk_data,
access=access,
Expand All @@ -102,6 +102,9 @@ class EmbeddingModelDetail(BaseModel):
provider_type: EmbeddingProvider | None = None
api_key: str | None = None

# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

@classmethod
def from_db_model(
cls,
Expand All @@ -123,6 +126,9 @@ class IndexingSetting(EmbeddingModelDetail):
index_name: str | None
multipass_indexing: bool

# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

@classmethod
def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting":
return cls(
Expand Down
21 changes: 8 additions & 13 deletions backend/danswer/llm/answering/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import TYPE_CHECKING

from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import root_validator
from pydantic import model_validator

from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.configs.constants import MessageType
Expand Down Expand Up @@ -117,22 +117,19 @@ class AnswerStyleConfig(BaseModel):
default_factory=DocumentPruningConfig
)

@root_validator
def check_quotes_and_citation(cls, values: dict[str, Any]) -> dict[str, Any]:
citation_config = values.get("citation_config")
quotes_config = values.get("quotes_config")

if citation_config is None and quotes_config is None:
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
if self.citation_config is None and self.quotes_config is None:
raise ValueError(
"One of `citation_config` or `quotes_config` must be provided"
)

if citation_config is not None and quotes_config is not None:
if self.citation_config is not None and self.quotes_config is not None:
raise ValueError(
"Only one of `citation_config` or `quotes_config` must be provided"
)

return values
return self


class PromptConfig(BaseModel):
Expand Down Expand Up @@ -160,6 +157,4 @@ def from_model(
include_citations=model.include_citations,
)

# needed so that this can be passed into lru_cache funcs
class Config:
frozen = True
model_config = ConfigDict(frozen=True)
19 changes: 12 additions & 7 deletions backend/danswer/llm/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def _convert_litellm_message_to_langchain_message(
litellm_message: litellm.Message,
) -> BaseMessage:
# Extracting the basic attributes from the litellm message
content = litellm_message.content
content = litellm_message.content or ""
role = litellm_message.role

# Handling function calls and tool calls if present
tool_calls = (
cast(
list[litellm.utils.ChatCompletionMessageToolCall],
list[litellm.ChatCompletionMessageToolCall],
litellm_message.tool_calls,
)
if hasattr(litellm_message, "tool_calls")
Expand All @@ -87,7 +87,7 @@ def _convert_litellm_message_to_langchain_message(
"args": json.loads(tool_call.function.arguments),
"id": tool_call.id,
}
for tool_call in tool_calls
for tool_call in (tool_calls if tool_calls else [])
],
)
elif role == "system":
Expand Down Expand Up @@ -296,9 +296,11 @@ def _invoke_implementation(
response = cast(
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
)
return _convert_litellm_message_to_langchain_message(
response.choices[0].message
)
choice = response.choices[0]
if hasattr(choice, "message"):
return _convert_litellm_message_to_langchain_message(choice.message)
else:
raise ValueError("Unexpected response choice type")

def _stream_implementation(
self,
Expand All @@ -314,7 +316,10 @@ def _stream_implementation(
return

output = None
response = self._completion(prompt, tools, tool_choice, True)
response = cast(
litellm.CustomStreamWrapper,
self._completion(prompt, tools, tool_choice, True),
)
try:
for part in response:
if len(part["choices"]) == 0:
Expand Down
9 changes: 6 additions & 3 deletions backend/danswer/llm/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ class LLMConfig(BaseModel):
model_provider: str
model_name: str
temperature: float
api_key: str | None
api_base: str | None
api_version: str | None
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None

# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}


def log_prompt(prompt: LanguageModelInput) -> None:
Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/llm/override_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class LLMOverride(BaseModel):
model_version: str | None = None
temperature: float | None = None

# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}


class PromptOverride(BaseModel):
system_prompt: str | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> EmbedResponse:
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
self.embed_server_endpoint, json=embed_request.model_dump()
)
try:
response.raise_for_status()
Expand Down Expand Up @@ -255,7 +255,7 @@ def predict(self, query: str, passages: list[str]) -> list[float]:
)

response = requests.post(
self.rerank_server_endpoint, json=rerank_request.dict()
self.rerank_server_endpoint, json=rerank_request.model_dump()
)
response.raise_for_status()

Expand Down Expand Up @@ -288,7 +288,7 @@ def predict(
)

response = requests.post(
self.intent_server_endpoint, json=intent_request.dict()
self.intent_server_endpoint, json=intent_request.model_dump()
)
response.raise_for_status()

Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def stream_search_answer(
db_session=session,
)
for obj in objects:
yield get_json_line(obj.dict())
yield get_json_line(obj.model_dump())


def get_search_answer(
Expand Down
Loading

0 comments on commit 50c1743

Please sign in to comment.