Skip to content

Commit

Permalink
New assistants api (#3097)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves authored Nov 11, 2024
1 parent 9d57f34 commit ba805f7
Show file tree
Hide file tree
Showing 20 changed files with 2,179 additions and 177 deletions.
9 changes: 9 additions & 0 deletions backend/alembic/versions/b156fa702355_chat_reworked.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ def upgrade() -> None:


def downgrade() -> None:
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
# below
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM chat_message__search_doc")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM chat_message")
op.execute("DELETE FROM chat_session")

op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,56 @@ def upgrade() -> None:


def downgrade() -> None:
# Delete chat messages and feedback first since they reference chat sessions
# Get chat messages from sessions with null persona_id
chat_messages_query = """
SELECT id
FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""

# Delete dependent records first
op.execute(
f"""
DELETE FROM document_retrieval_feedback
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
op.execute(
f"""
DELETE FROM chat_message__search_doc
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)

# Delete chat messages
op.execute(
"""
DELETE FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
)

# Now we can safely delete the chat sessions
op.execute(
"""
DELETE FROM chat_session
WHERE persona_id IS NULL
"""
)

op.alter_column(
"chat_session",
"persona_id",
Expand Down
251 changes: 78 additions & 173 deletions backend/danswer/chat/process_message.py

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion backend/danswer/db/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
return tool


def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
if not tool:
raise ValueError("Tool by specified name does not exist")
return tool


def create_tool(
name: str,
description: str | None,
Expand All @@ -37,7 +44,7 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.dict() for header in custom_headers]
custom_headers=[header.model_dump() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,
Expand Down
6 changes: 6 additions & 0 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.users import router as user_router
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
)
from danswer.server.query_and_chat.chat_backend import router as chat_router
from danswer.server.query_and_chat.query_backend import (
admin_router as admin_query_router,
Expand Down Expand Up @@ -270,6 +273,9 @@ def get_application() -> FastAPI:
application, token_rate_limit_settings_router
)
include_router_with_global_prefix_prepended(application, indexing_router)
include_router_with_global_prefix_prepended(
application, get_full_openai_assistants_api_router()
)

if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step
Expand Down
273 changes: 273 additions & 0 deletions backend/danswer/server/openai_assistants_api/asssistants_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
from typing import Any
from typing import Optional
from uuid import uuid4

from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy.orm import Session

from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.db.persona import get_personas
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.db.tools import get_tool_by_name
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.logger import setup_logger

logger = setup_logger()


router = APIRouter(prefix="/assistants")


# Base models
class AssistantObject(BaseModel):
id: int
object: str = "assistant"
created_at: int
name: Optional[str] = None
description: Optional[str] = None
model: str
instructions: Optional[str] = None
tools: list[dict[str, Any]]
file_ids: list[str]
metadata: Optional[dict[str, Any]] = None


class CreateAssistantRequest(BaseModel):
model: str
name: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
tools: Optional[list[dict[str, Any]]] = None
file_ids: Optional[list[str]] = None
metadata: Optional[dict[str, Any]] = None


class ModifyAssistantRequest(BaseModel):
model: Optional[str] = None
name: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
tools: Optional[list[dict[str, Any]]] = None
file_ids: Optional[list[str]] = None
metadata: Optional[dict[str, Any]] = None


class DeleteAssistantResponse(BaseModel):
id: int
object: str = "assistant.deleted"
deleted: bool


class ListAssistantsResponse(BaseModel):
object: str = "list"
data: list[AssistantObject]
first_id: Optional[int] = None
last_id: Optional[int] = None
has_more: bool


def persona_to_assistant(persona: Persona) -> AssistantObject:
return AssistantObject(
id=persona.id,
created_at=0,
name=persona.name,
description=persona.description,
model=persona.llm_model_version_override or "gpt-3.5-turbo",
instructions=persona.prompts[0].system_prompt if persona.prompts else None,
tools=[
{
"type": tool.display_name,
"function": {
"name": tool.name,
"description": tool.description,
"schema": tool.openapi_schema,
},
}
for tool in persona.tools
],
file_ids=[], # Assuming no file support for now
metadata={}, # Assuming no metadata for now
)


# API endpoints
@router.post("")
def create_assistant(
request: CreateAssistantRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
prompt = None
if request.instructions:
prompt = upsert_prompt(
user=user,
name=f"Prompt for {request.name or 'New Assistant'}",
description="Auto-generated prompt",
system_prompt=request.instructions,
task_prompt="",
include_citations=True,
datetime_aware=True,
personas=[],
db_session=db_session,
)

tool_ids = []
for tool in request.tools or []:
tool_type = tool.get("type")
if not tool_type:
continue

try:
tool_db = get_tool_by_name(tool_type, db_session)
tool_ids.append(tool_db.id)
except ValueError:
# Skip tools that don't exist in the database
logger.error(f"Tool {tool_type} not found in database")
raise HTTPException(
status_code=404, detail=f"Tool {tool_type} not found in database"
)

persona = upsert_persona(
user=user,
name=request.name or f"Assistant-{uuid4()}",
description=request.description or "",
num_chunks=25,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
llm_model_provider_override=None,
llm_model_version_override=request.model,
starter_messages=None,
is_public=False,
db_session=db_session,
prompt_ids=[prompt.id] if prompt else [0],
document_set_ids=[],
tool_ids=tool_ids,
icon_color=None,
icon_shape=None,
is_visible=True,
)

if prompt:
prompt.personas = [persona]
db_session.commit()

return persona_to_assistant(persona)


""


@router.get("/{assistant_id}")
def retrieve_assistant(
assistant_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
try:
persona = get_persona_by_id(
persona_id=assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
except ValueError:
persona = None

if not persona:
raise HTTPException(status_code=404, detail="Assistant not found")
return persona_to_assistant(persona)


@router.post("/{assistant_id}")
def modify_assistant(
assistant_id: int,
request: ModifyAssistantRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
persona = get_persona_by_id(
persona_id=assistant_id,
user=user,
db_session=db_session,
is_for_edit=True,
)
if not persona:
raise HTTPException(status_code=404, detail="Assistant not found")

update_data = request.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(persona, key, value)

if "instructions" in update_data and persona.prompts:
persona.prompts[0].system_prompt = update_data["instructions"]

db_session.commit()
return persona_to_assistant(persona)


@router.delete("/{assistant_id}")
def delete_assistant(
assistant_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> DeleteAssistantResponse:
try:
mark_persona_as_deleted(
persona_id=int(assistant_id),
user=user,
db_session=db_session,
)
return DeleteAssistantResponse(id=assistant_id, deleted=True)
except ValueError:
raise HTTPException(status_code=404, detail="Assistant not found")


@router.get("")
def list_assistants(
limit: int = Query(20, le=100),
order: str = Query("desc", regex="^(asc|desc)$"),
after: Optional[int] = None,
before: Optional[int] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ListAssistantsResponse:
personas = list(
get_personas(
user=user,
db_session=db_session,
get_editable=False,
joinedload_all=True,
)
)

# Apply filtering based on after and before
if after:
personas = [p for p in personas if p.id > int(after)]
if before:
personas = [p for p in personas if p.id < int(before)]

# Apply ordering
personas.sort(key=lambda p: p.id, reverse=(order == "desc"))

# Apply limit
personas = personas[:limit]

assistants = [persona_to_assistant(p) for p in personas]

return ListAssistantsResponse(
data=assistants,
first_id=assistants[0].id if assistants else None,
last_id=assistants[-1].id if assistants else None,
has_more=len(personas) == limit,
)
Loading

0 comments on commit ba805f7

Please sign in to comment.