Skip to content

Commit

Permalink
merge main into me_model
Browse files Browse the repository at this point in the history
  • Loading branch information
Mustafa Kerem Kurban committed Sep 27, 2024
2 parents cb37988 + 4f69726 commit f662684
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ jobs:
run: |
pip install --upgrade pip
pip install mypy==1.8.0
pip install -e ".[dev]"
pip install ".[dev]"
- name: Running mypy and tests
run: |
mypy src/
Expand Down
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Add get morphoelectric (me) model tool
## [0.1.1] - 26.09.2024

### Fixed
- Fixed a bug that prevented AsyncSqlite checkpoint to access the DB in streamed endpoints.

## [0.1.0] - 19.09.2024

Expand Down
2 changes: 1 addition & 1 deletion src/neuroagent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Neuroagent package."""

__version__ = "0.1.0"
__version__ = "0.1.1"
24 changes: 24 additions & 0 deletions src/neuroagent/agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Base agent."""

from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator

from langchain.chat_models.base import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from pydantic import BaseModel, ConfigDict


Expand Down Expand Up @@ -47,3 +49,25 @@ def astream(self, *args: Any, **kwargs: Any) -> AsyncIterator[str]:
@abstractmethod
def _process_output(*args: Any, **kwargs: Any) -> AgentOutput:
"""Format the output."""


class AsyncSqliteSaverWithPrefix(AsyncSqliteSaver):
"""Wrapper around the AsyncSqliteSaver that accepts a connection string with prefix."""

@classmethod
@asynccontextmanager
async def from_conn_string(
cls, conn_string: str
) -> AsyncIterator["AsyncSqliteSaver"]:
"""Create a new AsyncSqliteSaver instance from a connection string.
Args:
conn_string (str): The SQLite connection string. It can have the 'sqlite:///' prefix.
Yields
------
AsyncSqliteSaverWithPrefix: A new AsyncSqliteSaverWithPrefix instance.
"""
conn_string = conn_string.split("///")[-1]
async with super().from_conn_string(conn_string) as memory:
yield AsyncSqliteSaverWithPrefix(memory.conn)
8 changes: 4 additions & 4 deletions src/neuroagent/agents/simple_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,9 @@ async def astream(
streamed_response = self.agent.astream_events(
{"messages": query}, version="v2", config=config
)
is_streaming = False
async for event in streamed_response:
kind = event["event"]

# newline everytime model starts streaming.
if kind == "on_chat_model_start":
yield "\n\n"
# check for the model stream.
if kind == "on_chat_model_stream":
# check if we are calling the tools.
Expand All @@ -95,6 +92,9 @@ async def astream(

content = data_chunk.content
if content:
if not is_streaming:
yield "\n<begin_llm_response>\n"
is_streaming = True
yield content
yield "\n"

Expand Down
6 changes: 3 additions & 3 deletions src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
Expand All @@ -23,6 +22,7 @@
SimpleAgent,
SimpleChatAgent,
)
from neuroagent.agents.base_agent import AsyncSqliteSaverWithPrefix
from neuroagent.app.config import Settings
from neuroagent.cell_types import CellTypesMeta
from neuroagent.multi_agents import BaseMultiAgent, SupervisorMultiAgent
Expand Down Expand Up @@ -345,8 +345,8 @@ async def get_agent_memory(
"""Get the agent checkpointer."""
if connection_string:
if connection_string.startswith("sqlite"):
async with AsyncSqliteSaver.from_conn_string(
connection_string.split("///")[-1]
async with AsyncSqliteSaverWithPrefix.from_conn_string(
connection_string
) as memory:
await memory.setup()
yield memory
Expand Down
1 change: 1 addition & 0 deletions src/neuroagent/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Neuroagent scripts."""
4 changes: 2 additions & 2 deletions tests/agents/test_simple_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ async def test_astream(fake_llm_with_tools, httpx_mock):

msg_list = "".join([el async for el in response])
assert (
msg_list == "\n\n\nCalling tool : get-morpho-tool with arguments :"
' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n\nGreat'
msg_list == "\nCalling tool : get-morpho-tool with arguments :"
' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n<begin_llm_response>\nGreat'
" answer\n"
)

Expand Down
10 changes: 5 additions & 5 deletions tests/test_resolving.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def test_sparql_exact_resolve(httpx_mock, get_resolve_query_output):
}
]

httpx_mock.reset(assert_all_responses_were_requested=False)
httpx_mock.reset()

mtype = "Interneuron"
mocked_response = get_resolve_query_output[1]
Expand Down Expand Up @@ -84,7 +84,7 @@ async def test_sparql_fuzzy_resolve(httpx_mock, get_resolve_query_output):
"id": "http://api.brain-map.org/api/v2/data/Structure/463",
},
]
httpx_mock.reset(assert_all_responses_were_requested=False)
httpx_mock.reset()

mtype = "Interneu"
mocked_response = get_resolve_query_output[3]
Expand Down Expand Up @@ -143,7 +143,7 @@ async def test_es_resolve(httpx_mock, get_resolve_query_output):
"id": "http://api.brain-map.org/api/v2/data/Structure/184",
},
]
httpx_mock.reset(assert_all_responses_were_requested=True)
httpx_mock.reset()

mtype = "Ventral neuron"
mocked_response = get_resolve_query_output[5]
Expand Down Expand Up @@ -222,7 +222,7 @@ async def test_resolve_query(httpx_mock, get_resolve_query_output):
"id": "http://api.brain-map.org/api/v2/data/Structure/463",
},
]
httpx_mock.reset(assert_all_responses_were_requested=True)
httpx_mock.reset()

httpx_mock.add_response(url=url, json=get_resolve_query_output[0])

Expand Down Expand Up @@ -253,7 +253,7 @@ async def test_resolve_query(httpx_mock, get_resolve_query_output):
"id": "http://api.brain-map.org/api/v2/data/Structure/549",
}
]
httpx_mock.reset(assert_all_responses_were_requested=True)
httpx_mock.reset()
httpx_mock.add_response(
url=url,
json={
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ async def test_get_file_from_KG_errors(httpx_mock):
)
assert not_found.value.args[0] == "No file url was found."

httpx_mock.reset(assert_all_responses_were_requested=True)
httpx_mock.reset()
# no file found corresponding to file_url
test_file_url = "http://test_url.com"
json_response = {
Expand Down

0 comments on commit f662684

Please sign in to comment.