Skip to content

Commit

Permalink
feat: add model selection when creating new Agent (#795)
Browse files Browse the repository at this point in the history
* Add new config step to create agent

* Fix chat not using agents model
  • Loading branch information
tianjing-li authored Oct 4, 2024
1 parent ced5e4b commit 3cb622f
Show file tree
Hide file tree
Showing 8 changed files with 1,374 additions and 389 deletions.
15 changes: 0 additions & 15 deletions src/backend/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,6 @@ async def chat_stream(
ctx.with_model(chat_request.model)
agent_id = chat_request.agent_id
ctx.with_agent_id(agent_id)
user_id = ctx.get_user_id()

if agent_id:
agent = validate_agent_exists(session, agent_id, user_id)
agent_schema = Agent.model_validate(agent)
ctx.with_agent(agent_schema)
agent_tool_metadata = (
agent_tool_metadata_crud.get_all_agent_tool_metadata_by_agent_id(
session, agent_id
)
)
agent_tool_metadata_schema = [
AgentToolMetadata.model_validate(x) for x in agent_tool_metadata
]
ctx.with_agent_tool_metadata(agent_tool_metadata_schema)

(
session,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/schemas/cohere_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class CohereChatRequest(BaseChatRequest):
""",
)
model: str | None = Field(
default="command-r",
default="command-r-plus",
title="The model to use for generating the response.",
)
temperature: float | None = Field(
Expand Down
16 changes: 14 additions & 2 deletions src/backend/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from backend.chat.collate import to_dict
from backend.chat.enums import StreamEvent
from backend.config.tools import AVAILABLE_TOOLS
from backend.crud import agent_tool_metadata as agent_tool_metadata_crud
from backend.crud import conversation as conversation_crud
from backend.crud import message as message_crud
from backend.crud import tool_call as tool_call_crud
Expand All @@ -25,7 +26,7 @@
)
from backend.database_models.tool_call import ToolCall as ToolCallModel
from backend.schemas import CohereChatRequest
from backend.schemas.agent import Agent
from backend.schemas.agent import Agent, AgentToolMetadata
from backend.schemas.chat import (
BaseChatRequest,
ChatMessage,
Expand Down Expand Up @@ -76,7 +77,7 @@ def process_chat(
ctx.with_deployment_config()
agent_id = ctx.get_agent_id()

if agent_id is not None:
if agent_id:
agent = validate_agent_exists(session, agent_id, user_id)
agent_schema = Agent.model_validate(agent)
ctx.with_agent(agent_schema)
Expand All @@ -86,11 +87,22 @@ def process_chat(
status_code=404, detail=f"Agent with ID {agent_id} not found."
)

agent_tool_metadata = (
agent_tool_metadata_crud.get_all_agent_tool_metadata_by_agent_id(
session, agent_id
)
)
agent_tool_metadata_schema = [
AgentToolMetadata.model_validate(x) for x in agent_tool_metadata
]
ctx.with_agent_tool_metadata(agent_tool_metadata_schema)

# if tools are not provided in the chat request, use the agent's tools
if not chat_request.tools:
chat_request.tools = [Tool(name=tool) for tool in agent.tools]

# Set the agent settings in the chat request
chat_request.model = agent.model
chat_request.preamble = agent.preamble

should_store = chat_request.chat_history is None and not is_custom_tool_call(
Expand Down
Loading

0 comments on commit 3cb622f

Please sign in to comment.