Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor iteration 2 #6

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ NEUROAGENT_GENERATIVE__OPENAI__TOKEN=

# Important but not required
NEUROAGENT_AGENT__MODEL=
NEUROAGENT_AGENT__CHAT=

NEUROAGENT_KNOWLEDGE_GRAPH__USE_TOKEN=
NEUROAGENT_KNOWLEDGE_GRAPH__TOKEN=
NEUROAGENT_KNOWLEDGE_GRAPH__DOWNLOAD_HIERARCHY=
Expand All @@ -27,12 +27,9 @@ NEUROAGENT_TOOLS__TRACE__SEARCH_SIZE=

NEUROAGENT_TOOLS__KG_MORPHO__SEARCH_SIZE=

NEUROAGENT_GENERATIVE__LLM_TYPE= # can only be openai for now
NEUROAGENT_GENERATIVE__OPENAI__MODEL=
NEUROAGENT_GENERATIVE__OPENAI__TEMPERATURE=
NEUROAGENT_GENERATIVE__OPENAI__MAX_TOKENS=

NEUROAGENT_COHERE__TOKEN=
NEUROAGENT_OPENAI__MODEL=
NEUROAGENT_OPENAI__TEMPERATURE=
NEUROAGENT_OPENAI__MAX_TOKENS=

NEUROAGENT_LOGGING__LEVEL=
NEUROAGENT_LOGGING__EXTERNAL_PACKAGES=
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Migration to pydantic V2.
- Deleted some legacy code.
53 changes: 0 additions & 53 deletions src/neuroagent/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,9 @@
from typing import Any, AsyncIterator

from langchain.chat_models.base import BaseChatModel
from langchain_core.messages import (
AIMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel, ConfigDict

BASE_PROMPT = ChatPromptTemplate(
input_variables=["agent_scratchpad", "input"],
input_types={
"chat_history": list[
AIMessage
| HumanMessage
| ChatMessage
| SystemMessage
| FunctionMessage
| ToolMessage
],
"agent_scratchpad": list[
AIMessage
| HumanMessage
| ChatMessage
| SystemMessage
| FunctionMessage
| ToolMessage
],
},
messages=[
SystemMessagePromptTemplate(
prompt=PromptTemplate(
input_variables=[],
template="""You are a helpful assistant helping scientists with neuro-scientific questions.
You must always specify in your answers from which brain regions the information is extracted.
Do no blindly repeat the brain region requested by the user, use the output of the tools instead.""",
)
),
MessagesPlaceholder(variable_name="chat_history", optional=True),
HumanMessagePromptTemplate(
prompt=PromptTemplate(input_variables=["input"], template="{input}")
),
MessagesPlaceholder(variable_name="agent_scratchpad"),
],
)


class AgentStep(BaseModel):
"""Class for agent decision steps."""
Expand All @@ -72,7 +20,6 @@ class AgentOutput(BaseModel):

response: str
steps: list[AgentStep]
plan: str | None = None


class BaseAgent(BaseModel, ABC):
Expand Down
31 changes: 4 additions & 27 deletions src/neuroagent/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
class SettingsAgent(BaseModel):
"""Agent setting."""

model: str = "simple"
chat: str = "simple"
model: Literal["simple", "multi"] = "simple"

model_config = ConfigDict(frozen=True)

Expand Down Expand Up @@ -84,9 +83,9 @@ class SettingsLiterature(BaseModel):
"""Literature search API settings."""

url: str
retriever_k: int = 700
retriever_k: int = 500
use_reranker: bool = True
reranker_k: int = 5
reranker_k: int = 8

model_config = ConfigDict(frozen=True)

Expand Down Expand Up @@ -173,23 +172,6 @@ class SettingsOpenAI(BaseModel):
model_config = ConfigDict(frozen=True)


class SettingsGenerative(BaseModel):
"""Generative QA settings."""

llm_type: Literal["fake", "openai"] = "openai"
openai: SettingsOpenAI = SettingsOpenAI()

model_config = ConfigDict(frozen=True)


class SettingsCohere(BaseModel):
"""Settings cohere reranker."""

token: Optional[SecretStr] = None

model_config = ConfigDict(frozen=True)


class SettingsLogging(BaseModel):
"""Metadata settings."""

Expand Down Expand Up @@ -219,8 +201,7 @@ class Settings(BaseSettings):
knowledge_graph: SettingsKnowledgeGraph
agent: SettingsAgent = SettingsAgent() # has no required
db: SettingsDB = SettingsDB() # has no required
generative: SettingsGenerative = SettingsGenerative() # has no required
cohere: SettingsCohere = SettingsCohere() # has no required
openai: SettingsOpenAI = SettingsOpenAI() # has no required
logging: SettingsLogging = SettingsLogging() # has no required
keycloak: SettingsKeycloak = SettingsKeycloak() # has no required
misc: SettingsMisc = SettingsMisc() # has no required
Expand All @@ -240,10 +221,6 @@ def check_consistency(self) -> "Settings":
model validator is run during instantiation.

"""
# generative
if self.generative.llm_type == "openai":
if self.generative.openai.token is None:
raise ValueError("OpenAI token not provided")
if not self.keycloak.password and not self.keycloak.validate_token:
if not self.knowledge_graph.use_token:
raise ValueError("if no password is provided, please use token auth.")
Expand Down
84 changes: 42 additions & 42 deletions src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,12 @@ def get_language_model(
settings: Annotated[Settings, Depends(get_settings)],
) -> ChatOpenAI:
"""Get the language model."""
logger.info(f"OpenAI selected. Loading model {settings.generative.openai.model}.")
logger.info(f"OpenAI selected. Loading model {settings.openai.model}.")
return ChatOpenAI(
model_name=settings.generative.openai.model,
temperature=settings.generative.openai.temperature,
openai_api_key=settings.generative.openai.token.get_secret_value(), # type: ignore
max_tokens=settings.generative.openai.max_tokens,
model_name=settings.openai.model,
temperature=settings.openai.temperature,
openai_api_key=settings.openai.token.get_secret_value(), # type: ignore
max_tokens=settings.openai.max_tokens,
seed=78,
streaming=True,
)
Expand Down Expand Up @@ -369,43 +369,10 @@ def get_agent(
ElectrophysFeatureTool, Depends(get_electrophys_feature_tool)
],
traces_tool: Annotated[GetTracesTool, Depends(get_traces_tool)],
) -> BaseAgent | BaseMultiAgent:
"""Get the generative question answering service."""
tools = [
literature_tool,
br_resolver_tool,
morpho_tool,
morphology_feature_tool,
kg_morpho_feature_tool,
electrophys_feature_tool,
traces_tool,
]
logger.info("Load simple agent")
return SimpleAgent(llm=llm, tools=tools) # type: ignore


def get_chat_agent(
llm: Annotated[ChatOpenAI, Depends(get_language_model)],
memory: Annotated[BaseCheckpointSaver, Depends(get_agent_memory)],
literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)],
br_resolver_tool: Annotated[
ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool)
],
morpho_tool: Annotated[GetMorphoTool, Depends(get_morpho_tool)],
morphology_feature_tool: Annotated[
MorphologyFeatureTool, Depends(get_morphology_feature_tool)
],
kg_morpho_feature_tool: Annotated[
KGMorphoFeatureTool, Depends(get_kg_morpho_feature_tool)
],
electrophys_feature_tool: Annotated[
ElectrophysFeatureTool, Depends(get_electrophys_feature_tool)
],
traces_tool: Annotated[GetTracesTool, Depends(get_traces_tool)],
settings: Annotated[Settings, Depends(get_settings)],
) -> BaseAgent:
) -> BaseAgent | BaseMultiAgent:
"""Get the generative question answering service."""
if settings.agent.chat == "multi":
if settings.agent.model == "multi":
logger.info("Load multi-agent chat")
tools_list = [
("literature", [literature_tool]),
Expand All @@ -422,7 +389,6 @@ def get_chat_agent(
]
return SupervisorMultiAgent(llm=llm, agents=tools_list) # type: ignore
else:
logger.info("Load simple chat")
tools = [
literature_tool,
br_resolver_tool,
Expand All @@ -432,7 +398,41 @@ def get_chat_agent(
electrophys_feature_tool,
traces_tool,
]
return SimpleChatAgent(llm=llm, tools=tools, memory=memory) # type: ignore
logger.info("Load simple agent")
return SimpleAgent(llm=llm, tools=tools) # type: ignore


def get_chat_agent(
llm: Annotated[ChatOpenAI, Depends(get_language_model)],
memory: Annotated[BaseCheckpointSaver, Depends(get_agent_memory)],
literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)],
br_resolver_tool: Annotated[
ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool)
],
morpho_tool: Annotated[GetMorphoTool, Depends(get_morpho_tool)],
morphology_feature_tool: Annotated[
MorphologyFeatureTool, Depends(get_morphology_feature_tool)
],
kg_morpho_feature_tool: Annotated[
KGMorphoFeatureTool, Depends(get_kg_morpho_feature_tool)
],
electrophys_feature_tool: Annotated[
ElectrophysFeatureTool, Depends(get_electrophys_feature_tool)
],
traces_tool: Annotated[GetTracesTool, Depends(get_traces_tool)],
) -> BaseAgent:
"""Get the generative question answering service."""
logger.info("Load simple chat")
tools = [
literature_tool,
br_resolver_tool,
morpho_tool,
morphology_feature_tool,
kg_morpho_feature_tool,
electrophys_feature_tool,
traces_tool,
]
return SimpleChatAgent(llm=llm, tools=tools, memory=memory) # type: ignore


async def get_update_kg_hierarchy(
Expand Down
12 changes: 6 additions & 6 deletions src/neuroagent/app/routers/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ async def run_agent(
) -> AgentOutput:
"""Run agent."""
logger.info("Running agent query.")
logger.info(f"User's query: {request.inputs}")
return await agent.arun(request.inputs)
logger.info(f"User's query: {request.query}")
return await agent.arun(request.query)


@router.post("/chat/{thread_id}", response_model=AgentOutput)
Expand All @@ -47,8 +47,8 @@ async def run_chat_agent(
) -> AgentOutput:
"""Run chat agent."""
logger.info("Running agent query.")
logger.info(f"User's query: {request.inputs}")
return await agent.arun(query=request.inputs, thread_id=thread_id)
logger.info(f"User's query: {request.query}")
return await agent.arun(query=request.query, thread_id=thread_id)


@router.post("/chat_streamed/{thread_id}")
Expand All @@ -60,5 +60,5 @@ async def run_streamed_chat_agent(
) -> StreamingResponse:
"""Run agent in streaming mode."""
logger.info("Running agent query.")
logger.info(f"User's query: {request.inputs}")
return StreamingResponse(agent.astream(query=request.inputs, thread_id=thread_id)) # type: ignore
logger.info(f"User's query: {request.query}")
return StreamingResponse(agent.astream(query=request.query, thread_id=thread_id)) # type: ignore
5 changes: 1 addition & 4 deletions src/neuroagent/app/schemas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
"""Schemas."""

from typing import Any

from pydantic import BaseModel


class AgentRequest(BaseModel):
"""Class for agent request."""

inputs: str
parameters: dict[str, Any]
query: str
4 changes: 2 additions & 2 deletions tests/app/database/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def test_get_thread(
# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"inputs": "This is my query", "parameters": {}},
json={"query": "This is my query"},
)

create_output = app_client.post("/threads/").json()
Expand Down Expand Up @@ -131,7 +131,7 @@ async def test_delete_thread(
# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"inputs": "This is my query", "parameters": {}},
json={"query": "This is my query"},
params={"thread_id": thread_id},
)
# Get the messages of the thread
Expand Down
4 changes: 2 additions & 2 deletions tests/app/database/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_get_tool_calls(
# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"inputs": "This is my query", "parameters": {}},
json={"query": "This is my query"},
params={"thread_id": thread_id},
)

Expand Down Expand Up @@ -121,7 +121,7 @@ async def test_get_tool_output(
# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"inputs": "This is my query", "parameters": {}},
json={"query": "This is my query"},
params={"thread_id": thread_id},
)

Expand Down
2 changes: 1 addition & 1 deletion tests/app/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_required(monkeypatch, patch_required_env):

assert settings.tools.literature.url == "https://fake_url"
assert settings.knowledge_graph.base_url == "https://fake_url/api/nexus/v1"
assert settings.generative.openai.token.get_secret_value() == "dummy"
assert settings.openai.token.get_secret_value() == "dummy"
assert settings.knowledge_graph.use_token
assert settings.knowledge_graph.token.get_secret_value() == "token"

Expand Down
Loading
Loading