diff --git a/backend/alembic/versions/b156fa702355_chat_reworked.py b/backend/alembic/versions/b156fa702355_chat_reworked.py index c80ab6a0fb1..a6d75fb508b 100644 --- a/backend/alembic/versions/b156fa702355_chat_reworked.py +++ b/backend/alembic/versions/b156fa702355_chat_reworked.py @@ -288,6 +288,15 @@ def upgrade() -> None: def downgrade() -> None: + # NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints + # below + op.execute("DELETE FROM chat_feedback") + op.execute("DELETE FROM chat_message__search_doc") + op.execute("DELETE FROM document_retrieval_feedback") + op.execute("DELETE FROM document_retrieval_feedback") + op.execute("DELETE FROM chat_message") + op.execute("DELETE FROM chat_session") + op.drop_constraint( "chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey" ) diff --git a/backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py b/backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py index 58fcf482c85..222605189fe 100644 --- a/backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py +++ b/backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py @@ -23,6 +23,56 @@ def upgrade() -> None: def downgrade() -> None: + # Delete chat messages and feedback first since they reference chat sessions + # Get chat messages from sessions with null persona_id + chat_messages_query = """ + SELECT id + FROM chat_message + WHERE chat_session_id IN ( + SELECT id + FROM chat_session + WHERE persona_id IS NULL + ) + """ + + # Delete dependent records first + op.execute( + f""" + DELETE FROM document_retrieval_feedback + WHERE chat_message_id IN ( + {chat_messages_query} + ) + """ + ) + op.execute( + f""" + DELETE FROM chat_message__search_doc + WHERE chat_message_id IN ( + {chat_messages_query} + ) + """ + ) + + # Delete chat messages + op.execute( + """ + DELETE FROM chat_message + WHERE chat_session_id IN ( + SELECT id + FROM chat_session + WHERE persona_id IS NULL + ) + """ + ) + + # Now we can safely delete the chat sessions + op.execute( + """ + DELETE FROM chat_session + WHERE persona_id IS NULL + """ + ) + op.alter_column( "chat_session", "persona_id", diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 314e432b86a..45e447d2c6a 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -19,16 +19,10 @@ from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.chat.models import StreamStopInfo -from danswer.configs.app_configs import AZURE_DALLE_API_BASE -from danswer.configs.app_configs import AZURE_DALLE_API_KEY -from danswer.configs.app_configs import AZURE_DALLE_API_VERSION -from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME -from danswer.configs.chat_configs import BING_API_KEY from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.constants import MessageType -from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.db.chat import attach_files_to_chat_message from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message @@ -41,7 +35,6 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.engine import get_session_context_manager -from danswer.db.llm import fetch_existing_llm_providers from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import ToolCall from danswer.db.models import User @@ -61,14 +54,13 @@ from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_llms_for_persona from danswer.llm.factory import get_main_llm_from_tuple -from danswer.llm.interfaces import LLMConfig from danswer.llm.utils import litellm_exception_to_error_msg from danswer.natural_language_processing.utils import get_tokenizer -from danswer.search.enums import LLMEvaluationType from danswer.search.enums import OptionalSearchSetting from danswer.search.enums import QueryFlow from danswer.search.enums import SearchType from danswer.search.models import InferenceSection +from danswer.search.models import RetrievalDetails from danswer.search.retrieval.search_runner import inference_sections_from_ids from danswer.search.utils import chunks_or_sections_to_search_docs from danswer.search.utils import dedupe_documents @@ -77,14 +69,14 @@ from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.utils import get_json_line -from danswer.tools.built_in_tools import get_built_in_tool_by_id from danswer.tools.force import ForceUseTool -from danswer.tools.models import DynamicSchemaInfo from danswer.tools.models import ToolResponse from danswer.tools.tool import Tool -from danswer.tools.tool_implementations.custom.custom_tool import ( - build_custom_tools_from_openapi_schema_and_headers, -) +from danswer.tools.tool_constructor import construct_tools +from danswer.tools.tool_constructor import CustomToolConfig +from danswer.tools.tool_constructor import ImageGenerationToolConfig +from danswer.tools.tool_constructor import InternetSearchToolConfig +from danswer.tools.tool_constructor import SearchToolConfig from danswer.tools.tool_implementations.custom.custom_tool import ( CUSTOM_TOOL_RESPONSE_ID, ) @@ -95,9 +87,6 @@ from danswer.tools.tool_implementations.images.image_generation_tool import ( ImageGenerationResponse, ) -from danswer.tools.tool_implementations.images.image_generation_tool import ( - ImageGenerationTool, -) from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( INTERNET_SEARCH_RESPONSE_ID, ) @@ -122,9 +111,6 @@ SECTION_RELEVANCE_LIST_ID, ) from danswer.tools.tool_runner import ToolCallFinalResult -from danswer.tools.utils import compute_all_tool_tokens -from danswer.tools.utils import explicit_tool_calling_supported -from danswer.utils.headers import header_dict_to_header_list from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -295,7 +281,6 @@ def stream_chat_message_objects( max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, # if specified, uses the last user message and does not create a new user message based # on the `new_msg_req.message`. Currently, requires a state where the last message is a - use_existing_user_message: bool = False, litellm_additional_headers: dict[str, str] | None = None, custom_tool_additional_headers: dict[str, str] | None = None, is_connected: Callable[[], bool] | None = None, @@ -307,6 +292,9 @@ def stream_chat_message_objects( 3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails 4. [always] Details on the final AI response message that is created """ + use_existing_user_message = new_msg_req.use_existing_user_message + existing_assistant_message_id = new_msg_req.existing_assistant_message_id + # Currently surrounding context is not supported for chat # Chat is already token heavy and harder for the model to process plus it would roll history over much faster new_msg_req.chunks_above = 0 @@ -428,12 +416,20 @@ def stream_chat_message_objects( final_msg, history_msgs = create_chat_chain( chat_session_id=chat_session_id, db_session=db_session ) - if final_msg.message_type != MessageType.USER: - raise RuntimeError( - "The last message was not a user message. Cannot call " - "`stream_chat_message_objects` with `is_regenerate=True` " - "when the last message is not a user message." - ) + if existing_assistant_message_id is None: + if final_msg.message_type != MessageType.USER: + raise RuntimeError( + "The last message was not a user message. Cannot call " + "`stream_chat_message_objects` with `is_regenerate=True` " + "when the last message is not a user message." + ) + else: + if final_msg.id != existing_assistant_message_id: + raise RuntimeError( + "The last message was not the existing assistant message. " + f"Final message id: {final_msg.id}, " + f"existing assistant message id: {existing_assistant_message_id}" + ) # Disable Query Rephrasing for the first message # This leads to a better first response since the LLM rephrasing the question @@ -504,13 +500,19 @@ def stream_chat_message_objects( ), max_window_percentage=max_document_percentage, ) - reserved_message_id = reserve_message_id( - db_session=db_session, - chat_session_id=chat_session_id, - parent_message=user_message.id - if user_message is not None - else parent_message.id, - message_type=MessageType.ASSISTANT, + + # we don't need to reserve a message id if we're using an existing assistant message + reserved_message_id = ( + final_msg.id + if existing_assistant_message_id is not None + else reserve_message_id( + db_session=db_session, + chat_session_id=chat_session_id, + parent_message=user_message.id + if user_message is not None + else parent_message.id, + message_type=MessageType.ASSISTANT, + ) ) yield MessageResponseIDInfo( user_message_id=user_message.id if user_message else None, @@ -525,7 +527,13 @@ def stream_chat_message_objects( partial_response = partial( create_new_chat_message, chat_session_id=chat_session_id, - parent_message=final_msg, + # if we're using an existing assistant message, then this will just be an + # update operation, in which case the parent should be the parent of + # the latest. If we're creating a new assistant message, then the parent + # should be the latest message (latest user message) + parent_message=( + final_msg if existing_assistant_message_id is None else parent_message + ), prompt_id=prompt_id, overridden_model=overridden_model, # message=, @@ -537,6 +545,7 @@ def stream_chat_message_objects( # reference_docs=, db_session=db_session, commit=False, + reserved_message_id=reserved_message_id, ) if not final_msg.prompt: @@ -560,142 +569,39 @@ def stream_chat_message_objects( structured_response_format=new_msg_req.structured_response_format, ) - # find out what tools to use - search_tool: SearchTool | None = None - tool_dict: dict[int, list[Tool]] = {} # tool_id to tool - for db_tool_model in persona.tools: - # handle in-code tools specially - if db_tool_model.in_code_tool_id: - tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session) - if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files: - search_tool = SearchTool( - db_session=db_session, - user=user, - persona=persona, - retrieval_options=retrieval_options, - prompt_config=prompt_config, - llm=llm, - fast_llm=fast_llm, - pruning_config=document_pruning_config, - answer_style_config=answer_style_config, - selected_sections=selected_sections, - chunks_above=new_msg_req.chunks_above, - chunks_below=new_msg_req.chunks_below, - full_doc=new_msg_req.full_doc, - evaluation_type=( - LLMEvaluationType.BASIC - if persona.llm_relevance_filter - else LLMEvaluationType.SKIP - ), - ) - tool_dict[db_tool_model.id] = [search_tool] - elif tool_cls.__name__ == ImageGenerationTool.__name__: - img_generation_llm_config: LLMConfig | None = None - if ( - llm - and llm.config.api_key - and llm.config.model_provider == "openai" - ): - img_generation_llm_config = LLMConfig( - model_provider=llm.config.model_provider, - model_name="dall-e-3", - temperature=GEN_AI_TEMPERATURE, - api_key=llm.config.api_key, - api_base=llm.config.api_base, - api_version=llm.config.api_version, - ) - elif ( - llm.config.model_provider == "azure" - and AZURE_DALLE_API_KEY is not None - ): - img_generation_llm_config = LLMConfig( - model_provider="azure", - model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}", - temperature=GEN_AI_TEMPERATURE, - api_key=AZURE_DALLE_API_KEY, - api_base=AZURE_DALLE_API_BASE, - api_version=AZURE_DALLE_API_VERSION, - ) - else: - llm_providers = fetch_existing_llm_providers(db_session) - openai_provider = next( - iter( - [ - llm_provider - for llm_provider in llm_providers - if llm_provider.provider == "openai" - ] - ), - None, - ) - if not openai_provider or not openai_provider.api_key: - raise ValueError( - "Image generation tool requires an OpenAI API key" - ) - img_generation_llm_config = LLMConfig( - model_provider=openai_provider.provider, - model_name="dall-e-3", - temperature=GEN_AI_TEMPERATURE, - api_key=openai_provider.api_key, - api_base=openai_provider.api_base, - api_version=openai_provider.api_version, - ) - tool_dict[db_tool_model.id] = [ - ImageGenerationTool( - api_key=cast(str, img_generation_llm_config.api_key), - api_base=img_generation_llm_config.api_base, - api_version=img_generation_llm_config.api_version, - additional_headers=litellm_additional_headers, - model=img_generation_llm_config.model_name, - ) - ] - elif tool_cls.__name__ == InternetSearchTool.__name__: - bing_api_key = BING_API_KEY - if not bing_api_key: - raise ValueError( - "Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!" - ) - tool_dict[db_tool_model.id] = [ - InternetSearchTool( - api_key=bing_api_key, - answer_style_config=answer_style_config, - prompt_config=prompt_config, - ) - ] - - continue - - # handle all custom tools - if db_tool_model.openapi_schema: - tool_dict[db_tool_model.id] = cast( - list[Tool], - build_custom_tools_from_openapi_schema_and_headers( - db_tool_model.openapi_schema, - dynamic_schema_info=DynamicSchemaInfo( - chat_session_id=chat_session_id, - message_id=user_message.id if user_message else None, - ), - custom_headers=(db_tool_model.custom_headers or []) - + ( - header_dict_to_header_list( - custom_tool_additional_headers or {} - ) - ), - ), - ) - + tool_dict = construct_tools( + persona=persona, + prompt_config=prompt_config, + db_session=db_session, + user=user, + llm=llm, + fast_llm=fast_llm, + search_tool_config=SearchToolConfig( + answer_style_config=answer_style_config, + document_pruning_config=document_pruning_config, + retrieval_options=retrieval_options or RetrievalDetails(), + selected_sections=selected_sections, + chunks_above=new_msg_req.chunks_above, + chunks_below=new_msg_req.chunks_below, + full_doc=new_msg_req.full_doc, + latest_query_files=latest_query_files, + ), + internet_search_tool_config=InternetSearchToolConfig( + answer_style_config=answer_style_config, + ), + image_generation_tool_config=ImageGenerationToolConfig( + additional_headers=litellm_additional_headers, + ), + custom_tool_config=CustomToolConfig( + chat_session_id=chat_session_id, + message_id=user_message.id if user_message else None, + additional_headers=custom_tool_additional_headers, + ), + ) tools: list[Tool] = [] for tool_list in tool_dict.values(): tools.extend(tool_list) - # factor in tool definition size when pruning - document_pruning_config.tool_num_tokens = compute_all_tool_tokens( - tools, llm_tokenizer - ) - document_pruning_config.using_tool_message = explicit_tool_calling_supported( - llm_provider, llm_model_name - ) - # LLM prompt building, response capturing, etc. answer = Answer( is_connected=is_connected, @@ -871,7 +777,6 @@ def stream_chat_message_objects( tool_name_to_tool_id[tool.name] = tool_id gen_ai_response_message = partial_response( - reserved_message_id=reserved_message_id, message=answer.llm_answer, rephrased_query=( qa_docs_response.rephrased_query if qa_docs_response else None @@ -879,9 +784,11 @@ def stream_chat_message_objects( reference_docs=reference_db_search_docs, files=ai_message_files, token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), - citations=message_specific_citations.citation_map - if message_specific_citations - else None, + citations=( + message_specific_citations.citation_map + if message_specific_citations + else None + ), error=None, tool_call=( ToolCall( @@ -915,7 +822,6 @@ def stream_chat_message_objects( def stream_chat_message( new_msg_req: CreateChatMessageRequest, user: User | None, - use_existing_user_message: bool = False, litellm_additional_headers: dict[str, str] | None = None, custom_tool_additional_headers: dict[str, str] | None = None, is_connected: Callable[[], bool] | None = None, @@ -925,7 +831,6 @@ def stream_chat_message( new_msg_req=new_msg_req, user=user, db_session=db_session, - use_existing_user_message=use_existing_user_message, litellm_additional_headers=litellm_additional_headers, custom_tool_additional_headers=custom_tool_additional_headers, is_connected=is_connected, diff --git a/backend/danswer/db/tools.py b/backend/danswer/db/tools.py index 0fd126d0065..a89dafef385 100644 --- a/backend/danswer/db/tools.py +++ b/backend/danswer/db/tools.py @@ -24,6 +24,13 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool: return tool +def get_tool_by_name(tool_name: str, db_session: Session) -> Tool: + tool = db_session.scalar(select(Tool).where(Tool.name == tool_name)) + if not tool: + raise ValueError("Tool by specified name does not exist") + return tool + + def create_tool( name: str, description: str | None, @@ -37,7 +44,7 @@ def create_tool( description=description, in_code_tool_id=None, openapi_schema=openapi_schema, - custom_headers=[header.dict() for header in custom_headers] + custom_headers=[header.model_dump() for header in custom_headers] if custom_headers else [], user_id=user_id, diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 2ba3615dcca..0aff801c8f4 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -74,6 +74,9 @@ from danswer.server.manage.slack_bot import router as slack_bot_management_router from danswer.server.manage.users import router as user_router from danswer.server.middleware.latency_logging import add_latency_logging_middleware +from danswer.server.openai_assistants_api.full_openai_assistants_api import ( + get_full_openai_assistants_api_router, +) from danswer.server.query_and_chat.chat_backend import router as chat_router from danswer.server.query_and_chat.query_backend import ( admin_router as admin_query_router, @@ -270,6 +273,9 @@ def get_application() -> FastAPI: application, token_rate_limit_settings_router ) include_router_with_global_prefix_prepended(application, indexing_router) + include_router_with_global_prefix_prepended( + application, get_full_openai_assistants_api_router() + ) if AUTH_TYPE == AuthType.DISABLED: # Server logs this during auth setup verification step diff --git a/backend/danswer/server/openai_assistants_api/asssistants_api.py b/backend/danswer/server/openai_assistants_api/asssistants_api.py new file mode 100644 index 00000000000..66a04e8969c --- /dev/null +++ b/backend/danswer/server/openai_assistants_api/asssistants_api.py @@ -0,0 +1,273 @@ +from typing import Any +from typing import Optional +from uuid import uuid4 + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Query +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.db.engine import get_session +from danswer.db.models import Persona +from danswer.db.models import User +from danswer.db.persona import get_persona_by_id +from danswer.db.persona import get_personas +from danswer.db.persona import mark_persona_as_deleted +from danswer.db.persona import upsert_persona +from danswer.db.persona import upsert_prompt +from danswer.db.tools import get_tool_by_name +from danswer.search.enums import RecencyBiasSetting +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +router = APIRouter(prefix="/assistants") + + +# Base models +class AssistantObject(BaseModel): + id: int + object: str = "assistant" + created_at: int + name: Optional[str] = None + description: Optional[str] = None + model: str + instructions: Optional[str] = None + tools: list[dict[str, Any]] + file_ids: list[str] + metadata: Optional[dict[str, Any]] = None + + +class CreateAssistantRequest(BaseModel): + model: str + name: Optional[str] = None + description: Optional[str] = None + instructions: Optional[str] = None + tools: Optional[list[dict[str, Any]]] = None + file_ids: Optional[list[str]] = None + metadata: Optional[dict[str, Any]] = None + + +class ModifyAssistantRequest(BaseModel): + model: Optional[str] = None + name: Optional[str] = None + description: Optional[str] = None + instructions: Optional[str] = None + tools: Optional[list[dict[str, Any]]] = None + file_ids: Optional[list[str]] = None + metadata: Optional[dict[str, Any]] = None + + +class DeleteAssistantResponse(BaseModel): + id: int + object: str = "assistant.deleted" + deleted: bool + + +class ListAssistantsResponse(BaseModel): + object: str = "list" + data: list[AssistantObject] + first_id: Optional[int] = None + last_id: Optional[int] = None + has_more: bool + + +def persona_to_assistant(persona: Persona) -> AssistantObject: + return AssistantObject( + id=persona.id, + created_at=0, + name=persona.name, + description=persona.description, + model=persona.llm_model_version_override or "gpt-3.5-turbo", + instructions=persona.prompts[0].system_prompt if persona.prompts else None, + tools=[ + { + "type": tool.display_name, + "function": { + "name": tool.name, + "description": tool.description, + "schema": tool.openapi_schema, + }, + } + for tool in persona.tools + ], + file_ids=[], # Assuming no file support for now + metadata={}, # Assuming no metadata for now + ) + + +# API endpoints +@router.post("") +def create_assistant( + request: CreateAssistantRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> AssistantObject: + prompt = None + if request.instructions: + prompt = upsert_prompt( + user=user, + name=f"Prompt for {request.name or 'New Assistant'}", + description="Auto-generated prompt", + system_prompt=request.instructions, + task_prompt="", + include_citations=True, + datetime_aware=True, + personas=[], + db_session=db_session, + ) + + tool_ids = [] + for tool in request.tools or []: + tool_type = tool.get("type") + if not tool_type: + continue + + try: + tool_db = get_tool_by_name(tool_type, db_session) + tool_ids.append(tool_db.id) + except ValueError: + # Skip tools that don't exist in the database + logger.error(f"Tool {tool_type} not found in database") + raise HTTPException( + status_code=404, detail=f"Tool {tool_type} not found in database" + ) + + persona = upsert_persona( + user=user, + name=request.name or f"Assistant-{uuid4()}", + description=request.description or "", + num_chunks=25, + llm_relevance_filter=True, + llm_filter_extraction=True, + recency_bias=RecencyBiasSetting.AUTO, + llm_model_provider_override=None, + llm_model_version_override=request.model, + starter_messages=None, + is_public=False, + db_session=db_session, + prompt_ids=[prompt.id] if prompt else [0], + document_set_ids=[], + tool_ids=tool_ids, + icon_color=None, + icon_shape=None, + is_visible=True, + ) + + if prompt: + prompt.personas = [persona] + db_session.commit() + + return persona_to_assistant(persona) + + +"" + + +@router.get("/{assistant_id}") +def retrieve_assistant( + assistant_id: int, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> AssistantObject: + try: + persona = get_persona_by_id( + persona_id=assistant_id, + user=user, + db_session=db_session, + is_for_edit=False, + ) + except ValueError: + persona = None + + if not persona: + raise HTTPException(status_code=404, detail="Assistant not found") + return persona_to_assistant(persona) + + +@router.post("/{assistant_id}") +def modify_assistant( + assistant_id: int, + request: ModifyAssistantRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> AssistantObject: + persona = get_persona_by_id( + persona_id=assistant_id, + user=user, + db_session=db_session, + is_for_edit=True, + ) + if not persona: + raise HTTPException(status_code=404, detail="Assistant not found") + + update_data = request.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(persona, key, value) + + if "instructions" in update_data and persona.prompts: + persona.prompts[0].system_prompt = update_data["instructions"] + + db_session.commit() + return persona_to_assistant(persona) + + +@router.delete("/{assistant_id}") +def delete_assistant( + assistant_id: int, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> DeleteAssistantResponse: + try: + mark_persona_as_deleted( + persona_id=int(assistant_id), + user=user, + db_session=db_session, + ) + return DeleteAssistantResponse(id=assistant_id, deleted=True) + except ValueError: + raise HTTPException(status_code=404, detail="Assistant not found") + + +@router.get("") +def list_assistants( + limit: int = Query(20, le=100), + order: str = Query("desc", regex="^(asc|desc)$"), + after: Optional[int] = None, + before: Optional[int] = None, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> ListAssistantsResponse: + personas = list( + get_personas( + user=user, + db_session=db_session, + get_editable=False, + joinedload_all=True, + ) + ) + + # Apply filtering based on after and before + if after: + personas = [p for p in personas if p.id > int(after)] + if before: + personas = [p for p in personas if p.id < int(before)] + + # Apply ordering + personas.sort(key=lambda p: p.id, reverse=(order == "desc")) + + # Apply limit + personas = personas[:limit] + + assistants = [persona_to_assistant(p) for p in personas] + + return ListAssistantsResponse( + data=assistants, + first_id=assistants[0].id if assistants else None, + last_id=assistants[-1].id if assistants else None, + has_more=len(personas) == limit, + ) diff --git a/backend/danswer/server/openai_assistants_api/full_openai_assistants_api.py b/backend/danswer/server/openai_assistants_api/full_openai_assistants_api.py new file mode 100644 index 00000000000..2b2fe93e96e --- /dev/null +++ b/backend/danswer/server/openai_assistants_api/full_openai_assistants_api.py @@ -0,0 +1,19 @@ +from fastapi import APIRouter + +from danswer.server.openai_assistants_api.asssistants_api import ( + router as assistants_router, +) +from danswer.server.openai_assistants_api.messages_api import router as messages_router +from danswer.server.openai_assistants_api.runs_api import router as runs_router +from danswer.server.openai_assistants_api.threads_api import router as threads_router + + +def get_full_openai_assistants_api_router() -> APIRouter: + router = APIRouter(prefix="/openai-assistants") + + router.include_router(assistants_router) + router.include_router(runs_router) + router.include_router(threads_router) + router.include_router(messages_router) + + return router diff --git a/backend/danswer/server/openai_assistants_api/messages_api.py b/backend/danswer/server/openai_assistants_api/messages_api.py new file mode 100644 index 00000000000..c28c349f277 --- /dev/null +++ b/backend/danswer/server/openai_assistants_api/messages_api.py @@ -0,0 +1,235 @@ +import uuid +from datetime import datetime +from typing import Any +from typing import Literal +from typing import Optional + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from pydantic import BaseModel +from pydantic import Field +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.configs.constants import MessageType +from danswer.db.chat import create_new_chat_message +from danswer.db.chat import get_chat_message +from danswer.db.chat import get_chat_messages_by_session +from danswer.db.chat import get_chat_session_by_id +from danswer.db.chat import get_or_create_root_message +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.llm.utils import check_number_of_tokens + +router = APIRouter(prefix="") + + +Role = Literal["user", "assistant"] + + +class MessageContent(BaseModel): + type: Literal["text"] + text: str + + +class Message(BaseModel): + id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}") + object: Literal["thread.message"] = "thread.message" + created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) + thread_id: str + role: Role + content: list[MessageContent] + file_ids: list[str] = [] + assistant_id: Optional[str] = None + run_id: Optional[str] = None + metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any] + + +class CreateMessageRequest(BaseModel): + role: Role + content: str + file_ids: list[str] = [] + metadata: Optional[dict] = None + + +class ListMessagesResponse(BaseModel): + object: Literal["list"] = "list" + data: list[Message] + first_id: str + last_id: str + has_more: bool + + +@router.post("/threads/{thread_id}/messages") +def create_message( + thread_id: str, + message: CreateMessageRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> Message: + user_id = user.id if user else None + + try: + chat_session = get_chat_session_by_id( + chat_session_id=uuid.UUID(thread_id), + user_id=user_id, + db_session=db_session, + ) + except ValueError: + raise HTTPException(status_code=404, detail="Chat session not found") + + chat_messages = get_chat_messages_by_session( + chat_session_id=chat_session.id, + user_id=user.id if user else None, + db_session=db_session, + ) + latest_message = ( + chat_messages[-1] + if chat_messages + else get_or_create_root_message(chat_session.id, db_session) + ) + + new_message = create_new_chat_message( + chat_session_id=chat_session.id, + parent_message=latest_message, + message=message.content, + prompt_id=chat_session.persona.prompts[0].id, + token_count=check_number_of_tokens(message.content), + message_type=( + MessageType.USER if message.role == "user" else MessageType.ASSISTANT + ), + db_session=db_session, + ) + + return Message( + id=str(new_message.id), + thread_id=thread_id, + role="user", + content=[MessageContent(type="text", text=message.content)], + file_ids=message.file_ids, + metadata=message.metadata, + ) + + +@router.get("/threads/{thread_id}/messages") +def list_messages( + thread_id: str, + limit: int = 20, + order: Literal["asc", "desc"] = "desc", + after: Optional[str] = None, + before: Optional[str] = None, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> ListMessagesResponse: + user_id = user.id if user else None + + try: + chat_session = get_chat_session_by_id( + chat_session_id=uuid.UUID(thread_id), + user_id=user_id, + db_session=db_session, + ) + except ValueError: + raise HTTPException(status_code=404, detail="Chat session not found") + + messages = get_chat_messages_by_session( + chat_session_id=chat_session.id, + user_id=user_id, + db_session=db_session, + ) + + # Apply filtering based on after and before + if after: + messages = [m for m in messages if str(m.id) >= after] + if before: + messages = [m for m in messages if str(m.id) <= before] + + # Apply ordering + messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc")) + + # Apply limit + messages = messages[:limit] + + data = [ + Message( + id=str(m.id), + thread_id=thread_id, + role="user" if m.message_type == "user" else "assistant", + content=[MessageContent(type="text", text=m.message)], + created_at=int(m.time_sent.timestamp()), + ) + for m in messages + ] + + return ListMessagesResponse( + data=data, + first_id=str(data[0].id) if data else "", + last_id=str(data[-1].id) if data else "", + has_more=len(messages) == limit, + ) + + +@router.get("/threads/{thread_id}/messages/{message_id}") +def retrieve_message( + thread_id: str, + message_id: int, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> Message: + user_id = user.id if user else None + + try: + chat_message = get_chat_message( + chat_message_id=message_id, + user_id=user_id, + db_session=db_session, + ) + except ValueError: + raise HTTPException(status_code=404, detail="Message not found") + + return Message( + id=str(chat_message.id), + thread_id=thread_id, + role="user" if chat_message.message_type == "user" else "assistant", + content=[MessageContent(type="text", text=chat_message.message)], + created_at=int(chat_message.time_sent.timestamp()), + ) + + +class ModifyMessageRequest(BaseModel): + metadata: dict + + +@router.post("/threads/{thread_id}/messages/{message_id}") +def modify_message( + thread_id: str, + message_id: int, + request: ModifyMessageRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> Message: + user_id = user.id if user else None + + try: + chat_message = get_chat_message( + chat_message_id=message_id, + user_id=user_id, + db_session=db_session, + ) + except ValueError: + raise HTTPException(status_code=404, detail="Message not found") + + # Update metadata + # TODO: Uncomment this once we have metadata in the chat message + # chat_message.metadata = request.metadata + # db_session.commit() + + return Message( + id=str(chat_message.id), + thread_id=thread_id, + role="user" if chat_message.message_type == "user" else "assistant", + content=[MessageContent(type="text", text=chat_message.message)], + created_at=int(chat_message.time_sent.timestamp()), + metadata=request.metadata, + ) diff --git a/backend/danswer/server/openai_assistants_api/runs_api.py b/backend/danswer/server/openai_assistants_api/runs_api.py new file mode 100644 index 00000000000..616afcad168 --- /dev/null +++ b/backend/danswer/server/openai_assistants_api/runs_api.py @@ -0,0 +1,344 @@ +from typing import Literal +from typing import Optional +from uuid import UUID + +from fastapi import APIRouter +from fastapi import BackgroundTasks +from fastapi import Depends +from fastapi import HTTPException +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.chat.process_message import stream_chat_message_objects +from danswer.configs.constants import MessageType +from danswer.db.chat import create_new_chat_message +from danswer.db.chat import get_chat_message +from danswer.db.chat import get_chat_messages_by_session +from danswer.db.chat import get_chat_session_by_id +from danswer.db.chat import get_or_create_root_message +from danswer.db.engine import get_session +from danswer.db.models import ChatMessage +from danswer.db.models import User +from danswer.search.models import RetrievalDetails +from danswer.server.query_and_chat.models import ChatMessageDetail +from danswer.server.query_and_chat.models import CreateChatMessageRequest +from danswer.tools.tool_implementations.search.search_tool import SearchTool +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +router = APIRouter() + + +class RunRequest(BaseModel): + assistant_id: int + model: Optional[str] = None + instructions: Optional[str] = None + additional_instructions: Optional[str] = None + tools: Optional[list[dict]] = None + metadata: Optional[dict] = None + + +RunStatus = Literal[ + "queued", + "in_progress", + "requires_action", + "cancelling", + "cancelled", + "failed", + "completed", + "expired", +] + + +class RunResponse(BaseModel): + id: str + object: Literal["thread.run"] + created_at: int + assistant_id: int + thread_id: UUID + status: RunStatus + started_at: Optional[int] = None + expires_at: Optional[int] = None + cancelled_at: Optional[int] = None + failed_at: Optional[int] = None + completed_at: Optional[int] = None + last_error: Optional[dict] = None + model: str + instructions: str + tools: list[dict] + file_ids: list[str] + metadata: Optional[dict] = None + + +def process_run_in_background( + message_id: int, + parent_message_id: int, + chat_session_id: UUID, + assistant_id: int, + instructions: str, + tools: list[dict], + user: User | None, + db_session: Session, +) -> None: + # Get the latest message in the chat session + chat_session = get_chat_session_by_id( + chat_session_id=chat_session_id, + user_id=user.id if user else None, + db_session=db_session, + ) + + search_tool_retrieval_details = RetrievalDetails() + for tool in tools: + if tool["type"] == SearchTool.__name__ and ( + retrieval_details := tool.get("retrieval_details") + ): + search_tool_retrieval_details = RetrievalDetails.model_validate( + retrieval_details + ) + break + + new_msg_req = CreateChatMessageRequest( + chat_session_id=chat_session_id, + parent_message_id=int(parent_message_id) if parent_message_id else None, + message=instructions, + file_descriptors=[], + prompt_id=chat_session.persona.prompts[0].id, + search_doc_ids=None, + retrieval_options=search_tool_retrieval_details, # Adjust as needed + query_override=None, + regenerate=None, + llm_override=None, + prompt_override=None, + alternate_assistant_id=assistant_id, + use_existing_user_message=True, + existing_assistant_message_id=message_id, + ) + + run_message = get_chat_message(message_id, user.id if user else None, db_session) + try: + for packet in stream_chat_message_objects( + new_msg_req=new_msg_req, + user=user, + db_session=db_session, + ): + if isinstance(packet, ChatMessageDetail): + # Update the run status and message content + run_message = get_chat_message( + message_id, user.id if user else None, db_session + ) + if run_message: + # this handles cancelling + if run_message.error: + return + + run_message.message = packet.message + run_message.message_type = MessageType.ASSISTANT + db_session.commit() + except Exception as e: + logger.exception("Error processing run in background") + run_message.error = str(e) + db_session.commit() + return + + db_session.refresh(run_message) + if run_message.token_count == 0: + run_message.error = "No tokens generated" + db_session.commit() + + +@router.post("/threads/{thread_id}/runs") +def create_run( + thread_id: UUID, + run_request: RunRequest, + background_tasks: BackgroundTasks, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> RunResponse: + try: + chat_session = get_chat_session_by_id( + chat_session_id=thread_id, + user_id=user.id if user else None, + db_session=db_session, + ) + except ValueError: + raise HTTPException(status_code=404, detail="Thread not found") + + chat_messages = get_chat_messages_by_session( + chat_session_id=chat_session.id, + user_id=user.id if user else None, + db_session=db_session, + ) + latest_message = ( + chat_messages[-1] + if chat_messages + else get_or_create_root_message(chat_session.id, db_session) + ) + + # Create a new "run" (chat message) in the session + new_message = create_new_chat_message( + chat_session_id=chat_session.id, + parent_message=latest_message, + message="", + prompt_id=chat_session.persona.prompts[0].id, + token_count=0, + message_type=MessageType.ASSISTANT, + db_session=db_session, + commit=False, + ) + db_session.flush() + latest_message.latest_child_message = new_message.id + db_session.commit() + + # Schedule the background task + background_tasks.add_task( + process_run_in_background, + new_message.id, + latest_message.id, + chat_session.id, + run_request.assistant_id, + run_request.instructions or "", + run_request.tools or [], + user, + db_session, + ) + + return RunResponse( + id=str(new_message.id), + object="thread.run", + created_at=int(new_message.time_sent.timestamp()), + assistant_id=run_request.assistant_id, + thread_id=chat_session.id, + status="queued", + model=run_request.model or "default_model", + instructions=run_request.instructions or "", + tools=run_request.tools or [], + file_ids=[], + metadata=run_request.metadata, + ) + + +@router.get("/threads/{thread_id}/runs/{run_id}") +def retrieve_run( + thread_id: UUID, + run_id: str, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> RunResponse: + # Retrieve the chat message (which represents a "run" in DAnswer) + chat_message = get_chat_message( + chat_message_id=int(run_id), # Convert string run_id to int + user_id=user.id if user else None, + db_session=db_session, + ) + if not chat_message: + raise HTTPException(status_code=404, detail="Run not found") + + chat_session = chat_message.chat_session + + # Map DAnswer status to OpenAI status + run_status: RunStatus = "queued" + if chat_message.message: + run_status = "in_progress" + if chat_message.token_count != 0: + run_status = "completed" + if chat_message.error: + run_status = "cancelled" + + return RunResponse( + id=run_id, + object="thread.run", + created_at=int(chat_message.time_sent.timestamp()), + assistant_id=chat_session.persona_id or 0, + thread_id=chat_session.id, + status=run_status, + started_at=int(chat_message.time_sent.timestamp()), + completed_at=( + int(chat_message.time_sent.timestamp()) if chat_message.message else None + ), + model=chat_session.current_alternate_model or "default_model", + instructions="", # DAnswer doesn't store per-message instructions + tools=[], # DAnswer doesn't have a direct equivalent for tools + file_ids=( + [file["id"] for file in chat_message.files] if chat_message.files else [] + ), + metadata=None, # DAnswer doesn't store metadata for individual messages + ) + + +@router.post("/threads/{thread_id}/runs/{run_id}/cancel") +def cancel_run( + thread_id: UUID, + run_id: str, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> RunResponse: + # In DAnswer, we don't have a direct equivalent to cancelling a run + # We'll simulate it by marking the message as "cancelled" + chat_message = ( + db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first() + ) + if not chat_message: + raise HTTPException(status_code=404, detail="Run not found") + + chat_message.error = "Cancelled" + db_session.commit() + + return retrieve_run(thread_id, run_id, user, db_session) + + +@router.get("/threads/{thread_id}/runs") +def list_runs( + thread_id: UUID, + limit: int = 20, + order: Literal["asc", "desc"] = "desc", + after: Optional[str] = None, + before: Optional[str] = None, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> list[RunResponse]: + # In DAnswer, we'll treat each message in a chat session as a "run" + chat_messages = get_chat_messages_by_session( + chat_session_id=thread_id, + user_id=user.id if user else None, + db_session=db_session, + ) + + # Apply pagination + if after: + chat_messages = [msg for msg in chat_messages if str(msg.id) > after] + if before: + chat_messages = [msg for msg in chat_messages if str(msg.id) < before] + + # Apply ordering + chat_messages = sorted( + chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc") + ) + + # Apply limit + chat_messages = chat_messages[:limit] + + return [ + retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages + ] + + +@router.get("/threads/{thread_id}/runs/{run_id}/steps") +def list_run_steps( + run_id: str, + limit: int = 20, + order: Literal["asc", "desc"] = "desc", + after: Optional[str] = None, + before: Optional[str] = None, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> list[dict]: # You may want to create a specific model for run steps + # DAnswer doesn't have an equivalent to run steps + # We'll return an empty list to maintain API compatibility + return [] + + +# Additional helper functions can be added here if needed diff --git a/backend/danswer/server/openai_assistants_api/threads_api.py b/backend/danswer/server/openai_assistants_api/threads_api.py new file mode 100644 index 00000000000..ffc3a3016dc --- /dev/null +++ b/backend/danswer/server/openai_assistants_api/threads_api.py @@ -0,0 +1,156 @@ +from typing import Optional +from uuid import UUID + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.db.chat import create_chat_session +from danswer.db.chat import delete_chat_session +from danswer.db.chat import get_chat_session_by_id +from danswer.db.chat import get_chat_sessions_by_user +from danswer.db.chat import update_chat_session +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.server.query_and_chat.models import ChatSessionDetails +from danswer.server.query_and_chat.models import ChatSessionsResponse + +router = APIRouter(prefix="/threads") + + +# Models +class Thread(BaseModel): + id: UUID + object: str = "thread" + created_at: int + metadata: Optional[dict[str, str]] = None + + +class CreateThreadRequest(BaseModel): + messages: Optional[list[dict]] = None + metadata: Optional[dict[str, str]] = None + + +class ModifyThreadRequest(BaseModel): + metadata: Optional[dict[str, str]] = None + + +# API Endpoints +@router.post("") +def create_thread( + request: CreateThreadRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> Thread: + user_id = user.id if user else None + new_chat_session = create_chat_session( + db_session=db_session, + description="", # Leave the naming till later to prevent delay + user_id=user_id, + persona_id=0, + ) + + return Thread( + id=new_chat_session.id, + created_at=int(new_chat_session.time_created.timestamp()), + metadata=request.metadata, + ) + + +@router.get("/{thread_id}") +def retrieve_thread( + thread_id: UUID, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> Thread: + user_id = user.id if user else None + try: + chat_session = get_chat_session_by_id( + chat_session_id=thread_id, + user_id=user_id, + db_session=db_session, + ) + except ValueError: + raise HTTPException(status_code=404, detail="Thread not found") + + return Thread( + id=chat_session.id, + created_at=int(chat_session.time_created.timestamp()), + metadata=None, # Assuming we don't store metadata in our current implementation + ) + + +@router.post("/{thread_id}") +def modify_thread( + thread_id: UUID, + request: ModifyThreadRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> Thread: + user_id = user.id if user else None + try: + chat_session = update_chat_session( + db_session=db_session, + user_id=user_id, + chat_session_id=thread_id, + description=None, # Not updating description + sharing_status=None, # Not updating sharing status + ) + except ValueError: + raise HTTPException(status_code=404, detail="Thread not found") + + return Thread( + id=chat_session.id, + created_at=int(chat_session.time_created.timestamp()), + metadata=request.metadata, + ) + + +@router.delete("/{thread_id}") +def delete_thread( + thread_id: UUID, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> dict: + user_id = user.id if user else None + try: + delete_chat_session( + user_id=user_id, + chat_session_id=thread_id, + db_session=db_session, + ) + except ValueError: + raise HTTPException(status_code=404, detail="Thread not found") + + return {"id": str(thread_id), "object": "thread.deleted", "deleted": True} + + +@router.get("") +def list_threads( + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> ChatSessionsResponse: + user_id = user.id if user else None + chat_sessions = get_chat_sessions_by_user( + user_id=user_id, + deleted=False, + db_session=db_session, + ) + + return ChatSessionsResponse( + sessions=[ + ChatSessionDetails( + id=chat.id, + name=chat.description, + persona_id=chat.persona_id, + time_created=chat.time_created.isoformat(), + shared_status=chat.shared_status, + folder_id=chat.folder_id, + current_alternate_model=chat.current_alternate_model, + ) + for chat in chat_sessions + ] + ) diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 41176a0453f..64c667cfd16 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -347,7 +347,6 @@ def stream_generator() -> Generator[str, None, None]: for packet in stream_chat_message( new_msg_req=chat_message_req, user=user, - use_existing_user_message=chat_message_req.use_existing_user_message, litellm_additional_headers=extract_headers( request.headers, LITELLM_PASS_THROUGH_HEADERS ), diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 13b3b1ec0a8..6e905a20708 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -108,6 +108,9 @@ class CreateChatMessageRequest(ChunkContext): # used for seeded chats to kick off the generation of an AI answer use_existing_user_message: bool = False + # used for "OpenAI Assistants API" + existing_assistant_message_id: int | None = None + # forces the LLM to return a structured response, see # https://platform.openai.com/docs/guides/structured-outputs/introduction structured_response_format: dict | None = None diff --git a/backend/danswer/tools/tool_constructor.py b/backend/danswer/tools/tool_constructor.py new file mode 100644 index 00000000000..dacbe5ad112 --- /dev/null +++ b/backend/danswer/tools/tool_constructor.py @@ -0,0 +1,255 @@ +from typing import cast +from uuid import UUID + +from pydantic import BaseModel +from pydantic import Field +from sqlalchemy.orm import Session + +from danswer.configs.app_configs import AZURE_DALLE_API_BASE +from danswer.configs.app_configs import AZURE_DALLE_API_KEY +from danswer.configs.app_configs import AZURE_DALLE_API_VERSION +from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME +from danswer.configs.chat_configs import BING_API_KEY +from danswer.configs.model_configs import GEN_AI_TEMPERATURE +from danswer.db.llm import fetch_existing_llm_providers +from danswer.db.models import Persona +from danswer.db.models import User +from danswer.file_store.models import InMemoryChatFile +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import PromptConfig +from danswer.llm.interfaces import LLM +from danswer.llm.interfaces import LLMConfig +from danswer.natural_language_processing.utils import get_tokenizer +from danswer.search.enums import LLMEvaluationType +from danswer.search.models import InferenceSection +from danswer.search.models import RetrievalDetails +from danswer.tools.built_in_tools import get_built_in_tool_by_id +from danswer.tools.models import DynamicSchemaInfo +from danswer.tools.tool import Tool +from danswer.tools.tool_implementations.custom.custom_tool import ( + build_custom_tools_from_openapi_schema_and_headers, +) +from danswer.tools.tool_implementations.images.image_generation_tool import ( + ImageGenerationTool, +) +from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( + InternetSearchTool, +) +from danswer.tools.tool_implementations.search.search_tool import SearchTool +from danswer.tools.utils import compute_all_tool_tokens +from danswer.tools.utils import explicit_tool_calling_supported +from danswer.utils.headers import header_dict_to_header_list +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig: + """Helper function to get image generation LLM config based on available providers""" + if llm and llm.config.api_key and llm.config.model_provider == "openai": + return LLMConfig( + model_provider=llm.config.model_provider, + model_name="dall-e-3", + temperature=GEN_AI_TEMPERATURE, + api_key=llm.config.api_key, + api_base=llm.config.api_base, + api_version=llm.config.api_version, + ) + + if llm.config.model_provider == "azure" and AZURE_DALLE_API_KEY is not None: + return LLMConfig( + model_provider="azure", + model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}", + temperature=GEN_AI_TEMPERATURE, + api_key=AZURE_DALLE_API_KEY, + api_base=AZURE_DALLE_API_BASE, + api_version=AZURE_DALLE_API_VERSION, + ) + + # Fallback to checking for OpenAI provider in database + llm_providers = fetch_existing_llm_providers(db_session) + openai_provider = next( + iter( + [ + llm_provider + for llm_provider in llm_providers + if llm_provider.provider == "openai" + ] + ), + None, + ) + + if not openai_provider or not openai_provider.api_key: + raise ValueError("Image generation tool requires an OpenAI API key") + + return LLMConfig( + model_provider=openai_provider.provider, + model_name="dall-e-3", + temperature=GEN_AI_TEMPERATURE, + api_key=openai_provider.api_key, + api_base=openai_provider.api_base, + api_version=openai_provider.api_version, + ) + + +class SearchToolConfig(BaseModel): + answer_style_config: AnswerStyleConfig = Field( + default_factory=lambda: AnswerStyleConfig(citation_config=CitationConfig()) + ) + document_pruning_config: DocumentPruningConfig = Field( + default_factory=DocumentPruningConfig + ) + retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails) + selected_sections: list[InferenceSection] | None = None + chunks_above: int = 0 + chunks_below: int = 0 + full_doc: bool = False + latest_query_files: list[InMemoryChatFile] | None = None + + +class InternetSearchToolConfig(BaseModel): + answer_style_config: AnswerStyleConfig = Field( + default_factory=lambda: AnswerStyleConfig( + citation_config=CitationConfig(all_docs_useful=True) + ) + ) + + +class ImageGenerationToolConfig(BaseModel): + additional_headers: dict[str, str] | None = None + + +class CustomToolConfig(BaseModel): + chat_session_id: UUID | None = None + message_id: int | None = None + additional_headers: dict[str, str] | None = None + + +def construct_tools( + persona: Persona, + prompt_config: PromptConfig, + db_session: Session, + user: User | None, + llm: LLM, + fast_llm: LLM, + search_tool_config: SearchToolConfig | None = None, + internet_search_tool_config: InternetSearchToolConfig | None = None, + image_generation_tool_config: ImageGenerationToolConfig | None = None, + custom_tool_config: CustomToolConfig | None = None, +) -> dict[int, list[Tool]]: + """Constructs tools based on persona configuration and available APIs""" + tool_dict: dict[int, list[Tool]] = {} + + for db_tool_model in persona.tools: + if db_tool_model.in_code_tool_id: + tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session) + + # Handle Search Tool + if tool_cls.__name__ == SearchTool.__name__: + if not search_tool_config: + search_tool_config = SearchToolConfig() + + search_tool = SearchTool( + db_session=db_session, + user=user, + persona=persona, + retrieval_options=search_tool_config.retrieval_options, + prompt_config=prompt_config, + llm=llm, + fast_llm=fast_llm, + pruning_config=search_tool_config.document_pruning_config, + answer_style_config=search_tool_config.answer_style_config, + selected_sections=search_tool_config.selected_sections, + chunks_above=search_tool_config.chunks_above, + chunks_below=search_tool_config.chunks_below, + full_doc=search_tool_config.full_doc, + evaluation_type=( + LLMEvaluationType.BASIC + if persona.llm_relevance_filter + else LLMEvaluationType.SKIP + ), + ) + tool_dict[db_tool_model.id] = [search_tool] + + # Handle Image Generation Tool + elif tool_cls.__name__ == ImageGenerationTool.__name__: + if not image_generation_tool_config: + image_generation_tool_config = ImageGenerationToolConfig() + + img_generation_llm_config = _get_image_generation_config( + llm, db_session + ) + + tool_dict[db_tool_model.id] = [ + ImageGenerationTool( + api_key=cast(str, img_generation_llm_config.api_key), + api_base=img_generation_llm_config.api_base, + api_version=img_generation_llm_config.api_version, + additional_headers=image_generation_tool_config.additional_headers, + model=img_generation_llm_config.model_name, + ) + ] + + # Handle Internet Search Tool + elif tool_cls.__name__ == InternetSearchTool.__name__: + if not internet_search_tool_config: + internet_search_tool_config = InternetSearchToolConfig() + + if not BING_API_KEY: + raise ValueError( + "Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!" + ) + tool_dict[db_tool_model.id] = [ + InternetSearchTool( + api_key=BING_API_KEY, + answer_style_config=internet_search_tool_config.answer_style_config, + prompt_config=prompt_config, + ) + ] + + # Handle custom tools + elif db_tool_model.openapi_schema: + if not custom_tool_config: + custom_tool_config = CustomToolConfig() + + tool_dict[db_tool_model.id] = cast( + list[Tool], + build_custom_tools_from_openapi_schema_and_headers( + db_tool_model.openapi_schema, + dynamic_schema_info=DynamicSchemaInfo( + chat_session_id=custom_tool_config.chat_session_id, + message_id=custom_tool_config.message_id, + ), + custom_headers=(db_tool_model.custom_headers or []) + + ( + header_dict_to_header_list( + custom_tool_config.additional_headers or {} + ) + ), + ), + ) + + tools: list[Tool] = [] + for tool_list in tool_dict.values(): + tools.extend(tool_list) + + # factor in tool definition size when pruning + if search_tool_config: + search_tool_config.document_pruning_config.tool_num_tokens = ( + compute_all_tool_tokens( + tools, + get_tokenizer( + model_name=llm.config.model_name, + provider_type=llm.config.model_provider, + ), + ) + ) + search_tool_config.document_pruning_config.using_tool_message = ( + explicit_tool_calling_supported( + llm.config.model_provider, llm.config.model_name + ) + ) + + return tool_dict diff --git a/backend/tests/integration/common_utils/managers/user.py b/backend/tests/integration/common_utils/managers/user.py index 2b9aa6e189d..43286c6a716 100644 --- a/backend/tests/integration/common_utils/managers/user.py +++ b/backend/tests/integration/common_utils/managers/user.py @@ -13,6 +13,14 @@ from tests.integration.common_utils.test_models import DATestUser +DOMAIN = "test.com" +DEFAULT_PASSWORD = "test" + + +def build_email(name: str) -> str: + return f"{name}@test.com" + + class UserManager: @staticmethod def create( @@ -23,9 +31,9 @@ def create( name = f"test{str(uuid4())}" if email is None: - email = f"{name}@test.com" + email = build_email(name) - password = "test" + password = DEFAULT_PASSWORD body = { "email": email, diff --git a/backend/tests/integration/openai_assistants_api/conftest.py b/backend/tests/integration/openai_assistants_api/conftest.py new file mode 100644 index 00000000000..172247dc391 --- /dev/null +++ b/backend/tests/integration/openai_assistants_api/conftest.py @@ -0,0 +1,55 @@ +from typing import Optional +from uuid import UUID + +import pytest +import requests + +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.managers.llm_provider import LLMProviderManager +from tests.integration.common_utils.managers.user import build_email +from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import DATestLLMProvider +from tests.integration.common_utils.test_models import DATestUser + +BASE_URL = f"{API_SERVER_URL}/openai-assistants" + + +@pytest.fixture +def admin_user() -> DATestUser | None: + try: + return UserManager.create("admin_user") + except Exception: + pass + + try: + return UserManager.login_as_user( + DATestUser( + id="", + email=build_email("admin_user"), + password=DEFAULT_PASSWORD, + headers=GENERAL_HEADERS, + ) + ) + except Exception: + pass + + return None + + +@pytest.fixture +def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider: + return LLMProviderManager.create(user_performing_action=admin_user) + + +@pytest.fixture +def thread_id(admin_user: Optional[DATestUser]) -> UUID: + # Create a thread to use in the tests + response = requests.post( + f"{BASE_URL}/threads", # Updated endpoint path + json={}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + return UUID(response.json()["id"]) diff --git a/backend/tests/integration/openai_assistants_api/test_assistants.py b/backend/tests/integration/openai_assistants_api/test_assistants.py new file mode 100644 index 00000000000..14f270f1a0e --- /dev/null +++ b/backend/tests/integration/openai_assistants_api/test_assistants.py @@ -0,0 +1,151 @@ +import requests + +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import DATestUser + +ASSISTANTS_URL = f"{API_SERVER_URL}/openai-assistants/assistants" + + +def test_create_assistant(admin_user: DATestUser | None) -> None: + response = requests.post( + ASSISTANTS_URL, + json={ + "model": "gpt-3.5-turbo", + "name": "Test Assistant", + "description": "A test assistant", + "instructions": "You are a helpful assistant.", + }, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Test Assistant" + assert data["description"] == "A test assistant" + assert data["model"] == "gpt-3.5-turbo" + assert data["instructions"] == "You are a helpful assistant." + + +def test_retrieve_assistant(admin_user: DATestUser | None) -> None: + # First, create an assistant + create_response = requests.post( + ASSISTANTS_URL, + json={"model": "gpt-3.5-turbo", "name": "Retrieve Test"}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert create_response.status_code == 200 + assistant_id = create_response.json()["id"] + + # Now, retrieve the assistant + response = requests.get( + f"{ASSISTANTS_URL}/{assistant_id}", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] == assistant_id + assert data["name"] == "Retrieve Test" + + +def test_modify_assistant(admin_user: DATestUser | None) -> None: + # First, create an assistant + create_response = requests.post( + ASSISTANTS_URL, + json={"model": "gpt-3.5-turbo", "name": "Modify Test"}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert create_response.status_code == 200 + assistant_id = create_response.json()["id"] + + # Now, modify the assistant + response = requests.post( + f"{ASSISTANTS_URL}/{assistant_id}", + json={"name": "Modified Assistant", "instructions": "New instructions"}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] == assistant_id + assert data["name"] == "Modified Assistant" + assert data["instructions"] == "New instructions" + + +def test_delete_assistant(admin_user: DATestUser | None) -> None: + # First, create an assistant + create_response = requests.post( + ASSISTANTS_URL, + json={"model": "gpt-3.5-turbo", "name": "Delete Test"}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert create_response.status_code == 200 + assistant_id = create_response.json()["id"] + + # Now, delete the assistant + response = requests.delete( + f"{ASSISTANTS_URL}/{assistant_id}", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] == assistant_id + assert data["deleted"] is True + + +def test_list_assistants(admin_user: DATestUser | None) -> None: + # Create multiple assistants + for i in range(3): + requests.post( + ASSISTANTS_URL, + json={"model": "gpt-3.5-turbo", "name": f"List Test {i}"}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + + # Now, list the assistants + response = requests.get( + ASSISTANTS_URL, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) >= 3 # At least the 3 we just created + assert all(assistant["object"] == "assistant" for assistant in data["data"]) + + +def test_list_assistants_pagination(admin_user: DATestUser | None) -> None: + # Create 5 assistants + for i in range(5): + requests.post( + ASSISTANTS_URL, + json={"model": "gpt-3.5-turbo", "name": f"Pagination Test {i}"}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + + # List assistants with limit + response = requests.get( + f"{ASSISTANTS_URL}?limit=2", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + data = response.json() + assert len(data["data"]) == 2 + assert data["has_more"] is True + + # Get next page + before = data["last_id"] + response = requests.get( + f"{ASSISTANTS_URL}?limit=2&before={before}", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + data = response.json() + assert len(data["data"]) == 2 + + +def test_assistant_not_found(admin_user: DATestUser | None) -> None: + non_existent_id = -99 + response = requests.get( + f"{ASSISTANTS_URL}/{non_existent_id}", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 404 diff --git a/backend/tests/integration/openai_assistants_api/test_messages.py b/backend/tests/integration/openai_assistants_api/test_messages.py new file mode 100644 index 00000000000..cbcf6869435 --- /dev/null +++ b/backend/tests/integration/openai_assistants_api/test_messages.py @@ -0,0 +1,133 @@ +import uuid +from typing import Optional + +import pytest +import requests + +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import DATestUser + +BASE_URL = f"{API_SERVER_URL}/openai-assistants/threads" + + +@pytest.fixture +def thread_id(admin_user: Optional[DATestUser]) -> str: + response = requests.post( + BASE_URL, + json={}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + return response.json()["id"] + + +def test_create_message(admin_user: Optional[DATestUser], thread_id: str) -> None: + response = requests.post( + f"{BASE_URL}/{thread_id}/messages", # URL structure matches API + json={ + "role": "user", + "content": "Hello, world!", + "file_ids": [], + "metadata": {"key": "value"}, + }, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + + response_json = response.json() + assert "id" in response_json + assert response_json["thread_id"] == thread_id + assert response_json["role"] == "user" + assert response_json["content"] == [{"type": "text", "text": "Hello, world!"}] + assert response_json["metadata"] == {"key": "value"} + + +def test_list_messages(admin_user: Optional[DATestUser], thread_id: str) -> None: + # Create a message first + requests.post( + f"{BASE_URL}/{thread_id}/messages", + json={"role": "user", "content": "Test message"}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + + # Now, list the messages + response = requests.get( + f"{BASE_URL}/{thread_id}/messages", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + + response_json = response.json() + assert response_json["object"] == "list" + assert isinstance(response_json["data"], list) + assert len(response_json["data"]) > 0 + assert "first_id" in response_json + assert "last_id" in response_json + assert "has_more" in response_json + + +def test_retrieve_message(admin_user: Optional[DATestUser], thread_id: str) -> None: + # Create a message first + create_response = requests.post( + f"{BASE_URL}/{thread_id}/messages", + json={"role": "user", "content": "Test message"}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + message_id = create_response.json()["id"] + + # Now, retrieve the message + response = requests.get( + f"{BASE_URL}/{thread_id}/messages/{message_id}", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + + response_json = response.json() + assert response_json["id"] == message_id + assert response_json["thread_id"] == thread_id + assert response_json["role"] == "user" + assert response_json["content"] == [{"type": "text", "text": "Test message"}] + + +def test_modify_message(admin_user: Optional[DATestUser], thread_id: str) -> None: + # Create a message first + create_response = requests.post( + f"{BASE_URL}/{thread_id}/messages", + json={"role": "user", "content": "Test message"}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + message_id = create_response.json()["id"] + + # Now, modify the message + response = requests.post( + f"{BASE_URL}/{thread_id}/messages/{message_id}", + json={"metadata": {"new_key": "new_value"}}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + + response_json = response.json() + assert response_json["id"] == message_id + assert response_json["thread_id"] == thread_id + assert response_json["metadata"] == {"new_key": "new_value"} + + +def test_error_handling(admin_user: Optional[DATestUser]) -> None: + non_existent_thread_id = str(uuid.uuid4()) + non_existent_message_id = -99 + + # Test with non-existent thread + response = requests.post( + f"{BASE_URL}/{non_existent_thread_id}/messages", + json={"role": "user", "content": "Test message"}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 404 + + # Test with non-existent message + response = requests.get( + f"{BASE_URL}/{non_existent_thread_id}/messages/{non_existent_message_id}", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 404 diff --git a/backend/tests/integration/openai_assistants_api/test_runs.py b/backend/tests/integration/openai_assistants_api/test_runs.py new file mode 100644 index 00000000000..2ee0dbd4ba9 --- /dev/null +++ b/backend/tests/integration/openai_assistants_api/test_runs.py @@ -0,0 +1,137 @@ +from uuid import UUID + +import pytest +import requests + +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import DATestLLMProvider +from tests.integration.common_utils.test_models import DATestUser + +BASE_URL = f"{API_SERVER_URL}/openai-assistants" + + +@pytest.fixture +def run_id(admin_user: DATestUser | None, thread_id: UUID) -> str: + """Create a run and return its ID.""" + response = requests.post( + f"{BASE_URL}/threads/{thread_id}/runs", + json={ + "assistant_id": 0, + }, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + return response.json()["id"] + + +def test_create_run( + admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider +) -> None: + response = requests.post( + f"{BASE_URL}/threads/{thread_id}/runs", + json={ + "assistant_id": 0, + "model": "gpt-3.5-turbo", + "instructions": "Test instructions", + }, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + + response_json = response.json() + assert "id" in response_json + assert response_json["object"] == "thread.run" + assert "created_at" in response_json + assert response_json["assistant_id"] == 0 + assert UUID(response_json["thread_id"]) == thread_id + assert response_json["status"] == "queued" + assert response_json["model"] == "gpt-3.5-turbo" + assert response_json["instructions"] == "Test instructions" + + +def test_retrieve_run( + admin_user: DATestUser | None, + thread_id: UUID, + run_id: str, + llm_provider: DATestLLMProvider, +) -> None: + retrieve_response = requests.get( + f"{BASE_URL}/threads/{thread_id}/runs/{run_id}", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert retrieve_response.status_code == 200 + + response_json = retrieve_response.json() + assert response_json["id"] == run_id + assert response_json["object"] == "thread.run" + assert "created_at" in response_json + assert UUID(response_json["thread_id"]) == thread_id + + +def test_cancel_run( + admin_user: DATestUser | None, + thread_id: UUID, + run_id: str, + llm_provider: DATestLLMProvider, +) -> None: + cancel_response = requests.post( + f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/cancel", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert cancel_response.status_code == 200 + + response_json = cancel_response.json() + assert response_json["id"] == run_id + assert response_json["status"] == "cancelled" + + +def test_list_runs( + admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider +) -> None: + # Create a few runs + for _ in range(3): + requests.post( + f"{BASE_URL}/threads/{thread_id}/runs", + json={ + "assistant_id": 0, + }, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + + # Now, list the runs + list_response = requests.get( + f"{BASE_URL}/threads/{thread_id}/runs", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert list_response.status_code == 200 + + response_json = list_response.json() + assert isinstance(response_json, list) + assert len(response_json) >= 3 + + for run in response_json: + assert "id" in run + assert run["object"] == "thread.run" + assert "created_at" in run + assert UUID(run["thread_id"]) == thread_id + assert "status" in run + assert "model" in run + + +def test_list_run_steps( + admin_user: DATestUser | None, + thread_id: UUID, + run_id: str, + llm_provider: DATestLLMProvider, +) -> None: + steps_response = requests.get( + f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/steps", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert steps_response.status_code == 200 + + response_json = steps_response.json() + assert isinstance(response_json, list) + # Since DAnswer doesn't have an equivalent to run steps, we expect an empty list + assert len(response_json) == 0 diff --git a/backend/tests/integration/openai_assistants_api/test_threads.py b/backend/tests/integration/openai_assistants_api/test_threads.py new file mode 100644 index 00000000000..4ae128b2612 --- /dev/null +++ b/backend/tests/integration/openai_assistants_api/test_threads.py @@ -0,0 +1,132 @@ +from uuid import UUID + +import requests + +from danswer.db.models import ChatSessionSharedStatus +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import DATestUser + +THREADS_URL = f"{API_SERVER_URL}/openai-assistants/threads" + + +def test_create_thread(admin_user: DATestUser | None) -> None: + response = requests.post( + THREADS_URL, + json={"messages": None, "metadata": {"key": "value"}}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + + response_json = response.json() + assert "id" in response_json + assert response_json["object"] == "thread" + assert "created_at" in response_json + assert response_json["metadata"] == {"key": "value"} + + +def test_retrieve_thread(admin_user: DATestUser | None) -> None: + # First, create a thread + create_response = requests.post( + THREADS_URL, + json={}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert create_response.status_code == 200 + thread_id = create_response.json()["id"] + + # Now, retrieve the thread + retrieve_response = requests.get( + f"{THREADS_URL}/{thread_id}", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert retrieve_response.status_code == 200 + + response_json = retrieve_response.json() + assert response_json["id"] == thread_id + assert response_json["object"] == "thread" + assert "created_at" in response_json + + +def test_modify_thread(admin_user: DATestUser | None) -> None: + # First, create a thread + create_response = requests.post( + THREADS_URL, + json={}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert create_response.status_code == 200 + thread_id = create_response.json()["id"] + + # Now, modify the thread + modify_response = requests.post( + f"{THREADS_URL}/{thread_id}", + json={"metadata": {"new_key": "new_value"}}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert modify_response.status_code == 200 + + response_json = modify_response.json() + assert response_json["id"] == thread_id + assert response_json["metadata"] == {"new_key": "new_value"} + + +def test_delete_thread(admin_user: DATestUser | None) -> None: + # First, create a thread + create_response = requests.post( + THREADS_URL, + json={}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert create_response.status_code == 200 + thread_id = create_response.json()["id"] + + # Now, delete the thread + delete_response = requests.delete( + f"{THREADS_URL}/{thread_id}", + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert delete_response.status_code == 200 + + response_json = delete_response.json() + assert response_json["id"] == thread_id + assert response_json["object"] == "thread.deleted" + assert response_json["deleted"] is True + + +def test_list_threads(admin_user: DATestUser | None) -> None: + # Create a few threads + for _ in range(3): + requests.post( + THREADS_URL, + json={}, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + + # Now, list the threads + list_response = requests.get( + THREADS_URL, + headers=admin_user.headers if admin_user else GENERAL_HEADERS, + ) + assert list_response.status_code == 200 + + response_json = list_response.json() + assert "sessions" in response_json + assert len(response_json["sessions"]) >= 3 + + for session in response_json["sessions"]: + assert "id" in session + assert "name" in session + assert "persona_id" in session + assert "time_created" in session + assert "shared_status" in session + assert "folder_id" in session + assert "current_alternate_model" in session + + # Validate UUID + UUID(session["id"]) + + # Validate shared_status + assert session["shared_status"] in [ + status.value for status in ChatSessionSharedStatus + ] diff --git a/examples/assistants-api/topics_analyzer.py b/examples/assistants-api/topics_analyzer.py new file mode 100644 index 00000000000..a8ef6c27a23 --- /dev/null +++ b/examples/assistants-api/topics_analyzer.py @@ -0,0 +1,125 @@ +import argparse +import os +import time +from datetime import datetime +from datetime import timedelta +from datetime import timezone + +from openai import OpenAI + + +ASSISTANT_NAME = "Topic Analyzer" +SYSTEM_PROMPT = """ +You are a helpful assistant that analyzes topics by searching through available \ +documents and providing insights. These available documents come from common \ +workplace tools like Slack, emails, Confluence, Google Drive, etc. + +When analyzing a topic: +1. Search for relevant information using the search tool +2. Synthesize the findings into clear insights +3. Highlight key trends, patterns, or notable developments +4. Maintain objectivity and cite sources where relevant +""" +USER_PROMPT = """ +Please analyze and provide insights about this topic: {topic}. + +IMPORTANT: do not mention things that are not relevant to the specified topic. \ +If there is no relevant information, just say "No relevant information found." +""" + + +def wait_on_run(client: OpenAI, run, thread): + while run.status == "queued" or run.status == "in_progress": + run = client.beta.threads.runs.retrieve( + thread_id=thread.id, + run_id=run.id, + ) + time.sleep(0.5) + return run + + +def show_response(messages) -> None: + # Get only the assistant's response text + for message in messages.data[::-1]: + if message.role == "assistant": + for content in message.content: + if content.type == "text": + print(content.text) + break + + +def analyze_topics(topics: list[str]) -> None: + openai_api_key = os.environ.get( + "OPENAI_API_KEY", "" + ) + danswer_api_key = os.environ.get( + "DANSWER_API_KEY", "" + ) + client = OpenAI( + api_key=openai_api_key, + base_url="http://localhost:8080/openai-assistants", + default_headers={ + "Authorization": f"Bearer {danswer_api_key}", + }, + ) + + # Create an assistant if it doesn't exist + try: + assistants = client.beta.assistants.list(limit=100) + # Find the Topic Analyzer assistant if it exists + assistant = next((a for a in assistants.data if a.name == ASSISTANT_NAME)) + client.beta.assistants.delete(assistant.id) + except Exception: + pass + + assistant = client.beta.assistants.create( + name=ASSISTANT_NAME, + instructions=SYSTEM_PROMPT, + tools=[{"type": "SearchTool"}], # type: ignore + model="gpt-4o", + ) + + # Process each topic individually + for topic in topics: + thread = client.beta.threads.create() + message = client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content=USER_PROMPT.format(topic=topic), + ) + + run = client.beta.threads.runs.create( + thread_id=thread.id, + assistant_id=assistant.id, + tools=[ + { # type: ignore + "type": "SearchTool", + "retrieval_details": { + "run_search": "always", + "filters": { + "time_cutoff": str( + datetime.now(timezone.utc) - timedelta(days=7) + ) + }, + }, + } + ], + ) + + run = wait_on_run(client, run, thread) + messages = client.beta.threads.messages.list( + thread_id=thread.id, order="asc", after=message.id + ) + print(f"\nAnalysis for topic: {topic}") + print("-" * 40) + show_response(messages) + print() + + +# Example usage +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Analyze specific topics") + parser.add_argument("topics", nargs="+", help="Topics to analyze (one or more)") + + args = parser.parse_args() + analyze_topics(args.topics)