Skip to content

Commit

Permalink
fix all weird bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
BoBer78 committed Sep 16, 2024
1 parent 410fbd4 commit 596110b
Show file tree
Hide file tree
Showing 16 changed files with 69 additions and 76 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Removed
- Github action to create the docs.


### Changed
- Migration to pydantic V2.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"langgraph-checkpoint-postgres",
"langgraph-checkpoint-sqlite",
"neurom",
"psycopg-binary",
"psycopg2-binary",
"pydantic-settings",
"python-dotenv",
Expand Down
11 changes: 4 additions & 7 deletions src/neuroagent/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, AsyncIterator

from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain_core.messages import (
AIMessage,
ChatMessage,
Expand All @@ -21,9 +20,7 @@
SystemMessagePromptTemplate,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel as BaseModelV2
from pydantic import ConfigDict
from pydantic.v1 import BaseModel
from pydantic import BaseModel, ConfigDict

BASE_PROMPT = ChatPromptTemplate(
input_variables=["agent_scratchpad", "input"],
Expand Down Expand Up @@ -63,14 +60,14 @@
)


class AgentStep(BaseModelV2):
class AgentStep(BaseModel):
"""Class for agent decision steps."""

tool_name: str
arguments: dict[str, Any] | str


class AgentOutput(BaseModelV2):
class AgentOutput(BaseModel):
"""Class for agent response."""

response: str
Expand All @@ -81,7 +78,7 @@ class AgentOutput(BaseModelV2):
class BaseAgent(BaseModel, ABC):
"""Base class for services."""

llm: BaseLLM | BaseChatModel
llm: BaseChatModel
tools: list[BaseTool]
agent: Any

Expand Down
18 changes: 8 additions & 10 deletions src/neuroagent/agents/simple_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from langchain_core.messages import AIMessage
from langgraph.prebuilt import create_react_agent
from pydantic import ConfigDict
from pydantic.v1 import root_validator
from pydantic import model_validator

from neuroagent.agents import AgentOutput, AgentStep, BaseAgent

Expand All @@ -16,20 +15,19 @@
class SimpleAgent(BaseAgent):
"""Simple Agent class."""

model_config = ConfigDict(arbitrary_types_allowed=True)

@root_validator(pre=True)
def create_agent(cls, values: dict[str, Any]) -> dict[str, Any]:
@model_validator(mode="before")
@classmethod
def create_agent(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Instantiate the clients upon class creation."""
# Initialise the agent with the tools
values["agent"] = create_react_agent(
model=values["llm"],
tools=values["tools"],
data["agent"] = create_react_agent(
model=data["llm"],
tools=data["tools"],
state_modifier="""You are a helpful assistant helping scientists with neuro-scientific questions.
You must always specify in your answers from which brain regions the information is extracted.
Do no blindly repeat the brain region requested by the user, use the output of the tools instead.""",
)
return values
return data

def run(self, query: str) -> Any:
"""Run the agent against a query.
Expand Down
22 changes: 9 additions & 13 deletions src/neuroagent/agents/simple_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.prebuilt import create_react_agent
from pydantic.v1 import root_validator
from pydantic import model_validator

from neuroagent.agents import AgentOutput, AgentStep, BaseAgent

Expand All @@ -18,23 +18,19 @@ class SimpleChatAgent(BaseAgent):

memory: BaseCheckpointSaver

class Config:
"""Config."""

arbitrary_types_allowed = True

@root_validator(pre=True)
def create_agent(cls, values: dict[str, Any]) -> dict[str, Any]:
@model_validator(mode="before")
@classmethod
def create_agent(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Instantiate the clients upon class creation."""
values["agent"] = create_react_agent(
model=values["llm"],
tools=values["tools"],
checkpointer=values["memory"],
data["agent"] = create_react_agent(
model=data["llm"],
tools=data["tools"],
checkpointer=data["memory"],
state_modifier="""You are a helpful assistant helping scientists with neuro-scientific questions.
You must always specify in your answers from which brain regions the information is extracted.
Do no blindly repeat the brain region requested by the user, use the output of the tools instead.""",
)
return values
return data

def run(self, session_id: str, query: str) -> Any:
"""Run the agent against a query."""
Expand Down
6 changes: 2 additions & 4 deletions src/neuroagent/multi_agents/base_multi_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from typing import Any, AsyncIterator

from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from pydantic import ConfigDict
from pydantic.v1 import BaseModel
from pydantic import BaseModel, ConfigDict

from neuroagent.agents import AgentOutput
from neuroagent.tools.base_tool import BasicTool
Expand All @@ -15,7 +13,7 @@
class BaseMultiAgent(BaseModel, ABC):
"""Base class for multi agents."""

llm: BaseLLM | BaseChatModel
llm: BaseChatModel
main_agent: Any
agents: list[tuple[str, list[BasicTool]]]

Expand Down
20 changes: 10 additions & 10 deletions src/neuroagent/multi_agents/supervisor_multi_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from langgraph.graph import END, START, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import create_react_agent
from pydantic import ConfigDict
from pydantic.v1 import root_validator
from pydantic import ConfigDict, model_validator

from neuroagent.agents import AgentOutput, AgentStep
from neuroagent.multi_agents.base_multi_agent import BaseMultiAgent
Expand All @@ -39,8 +38,9 @@ class SupervisorMultiAgent(BaseMultiAgent):

model_config = ConfigDict(arbitrary_types_allowed=True)

@root_validator(pre=True)
def create_main_agent(cls, values: dict[str, Any]) -> dict[str, Any]:
@model_validator(mode="before")
@classmethod
def create_main_agent(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Instantiate the clients upon class creation."""
logger.info("Creating main agent, supervisor and all the agents with tools.")
system_prompt = (
Expand All @@ -50,7 +50,7 @@ def create_main_agent(cls, values: dict[str, Any]) -> dict[str, Any]:
" task and respond with their results and status. When finished,"
" respond with FINISH."
)
agents_list = [elem[0] for elem in values["agents"]]
agents_list = [elem[0] for elem in data["agents"]]
logger.info(f"List of agents name: {agents_list}")

options = ["FINISH"] + agents_list
Expand Down Expand Up @@ -84,14 +84,14 @@ def create_main_agent(cls, values: dict[str, Any]) -> dict[str, Any]:
),
]
).partial(options=str(options), members=", ".join(agents_list))
values["main_agent"] = (
data["main_agent"] = (
prompt
| values["llm"].bind_functions(
| data["llm"].bind_functions(
functions=[function_def], function_call="route"
)
| JsonOutputFunctionsParser()
)
values["summarizer"] = (
data["summarizer"] = (
PromptTemplate.from_template(
"""You are an helpful assistant. Here is the question of the user: {question}.
And here are the results of the different tools used to answer: {responses}.
Expand All @@ -101,10 +101,10 @@ def create_main_agent(cls, values: dict[str, Any]) -> dict[str, Any]:
Please formulate a complete response to give to the user ONLY based on the results.
"""
)
| values["llm"]
| data["llm"]
)

return values
return data

@staticmethod
async def agent_node(
Expand Down
26 changes: 12 additions & 14 deletions src/neuroagent/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,30 @@
import logging
from typing import Any

from langchain_core.pydantic_v1 import ValidationError
from langchain_core.tools import BaseTool, ToolException
from pydantic.v1 import BaseModel, root_validator
from pydantic import BaseModel, ValidationError, model_validator

logger = logging.getLogger(__name__)


def process_validation_error(error: ValidationError) -> str:
"""Handle validation errors when tool inputs are wrong."""
error_list = []

# not happy with this solution but it is to extract the name of the input class
name = str(error.model).split(".")[-1].strip(">")
name = error.title
# We have to iterate, in case there are multiple errors.
try:
for err in error.errors():
if "ctx" in err:
if err["type"] == "literal_error":
error_list.append(
{
"Validation error": (
f'Wrong value: {err["ctx"]["given"]} for input'
f'Wrong value: provided {err["input"]} for input'
f' {err["loc"][0]}. Try again and change this problematic'
" input."
)
}
)
elif "loc" in err and err["msg"] == "field required":
elif err["type"] == "missing":
error_list.append(
{
"Validation error": (
Expand Down Expand Up @@ -70,17 +67,18 @@ class BasicTool(BaseTool):
name: str = "base"
description: str = "Base tool from which regular tools should inherit."

@root_validator(pre=True)
def handle_errors(cls, values: dict[str, Any]) -> dict[str, Any]:
@model_validator(mode="before")
@classmethod
def handle_errors(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Instantiate the clients upon class creation."""
values["handle_validation_error"] = process_validation_error
values["handle_tool_error"] = process_tool_error
return values
data["handle_validation_error"] = process_validation_error
data["handle_tool_error"] = process_tool_error
return data


class BaseToolOutput(BaseModel):
"""Base class for tool outputs."""

def __repr__(self) -> str:
"""Representation method."""
return self.json()
return self.model_dump_json()
2 changes: 1 addition & 1 deletion src/neuroagent/tools/electrophys_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from bluepyefe.extract import extract_efeatures
from efel.units import get_unit
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import ToolException
from pydantic import BaseModel, Field

from neuroagent.tools.base_tool import BaseToolOutput, BasicTool
from neuroagent.utils import get_kg_data
Expand Down
2 changes: 1 addition & 1 deletion src/neuroagent/tools/get_morpho_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
from typing import Any, Optional, Type

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import ToolException
from pydantic import BaseModel, Field

from neuroagent.cell_types import get_celltypes_descendants
from neuroagent.tools.base_tool import BaseToolOutput, BasicTool
Expand Down
13 changes: 7 additions & 6 deletions src/neuroagent/tools/kg_morpho_features_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
from typing import Any, Literal, Type

from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.tools import ToolException
from pydantic import BaseModel, Field, model_validator

from neuroagent.tools.base_tool import BaseToolOutput, BasicTool
from neuroagent.utils import get_descendants_id
Expand Down Expand Up @@ -121,13 +121,14 @@ class FeatureInput(BaseModel):
)
feat_range: FeatRangeInput | None = None

@root_validator(pre=True)
def check_if_list(cls, values: Any) -> dict[str, str | list[float | int] | None]:
@model_validator(mode="before")
@classmethod
def check_if_list(cls, data: Any) -> dict[str, str | list[float | int] | None]:
"""Validate that the values passed to the constructor are a dictionary."""
if isinstance(values, list) and len(values) == 1:
data_dict = values[0]
if isinstance(data, list) and len(data) == 1:
data_dict = data[0]
else:
data_dict = values
data_dict = data
return data_dict


Expand Down
2 changes: 1 addition & 1 deletion src/neuroagent/tools/literature_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
from typing import Any, Type

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import ToolException
from pydantic import BaseModel, Field

from neuroagent.tools.base_tool import BaseToolOutput, BasicTool

Expand Down
2 changes: 1 addition & 1 deletion src/neuroagent/tools/morphology_features_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import neurom
import numpy as np
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import ToolException
from neurom.io.utils import load_morphology
from pydantic import BaseModel, Field

from neuroagent.tools.base_tool import BaseToolOutput, BasicTool
from neuroagent.utils import get_kg_data
Expand Down
2 changes: 1 addition & 1 deletion src/neuroagent/tools/resolve_brain_region_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
from typing import Any, Optional, Type

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import ToolException
from pydantic import BaseModel, Field

from neuroagent.resolving import resolve_query
from neuroagent.tools.base_tool import BaseToolOutput, BasicTool
Expand Down
2 changes: 1 addition & 1 deletion src/neuroagent/tools/traces_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
from typing import Any, Literal, Optional, Type

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import ToolException
from pydantic import BaseModel, Field

from neuroagent.tools.base_tool import BaseToolOutput, BasicTool
from neuroagent.utils import get_descendants_id
Expand Down
Loading

0 comments on commit 596110b

Please sign in to comment.