diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9adb97f..a20df73 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -84,7 +84,7 @@ jobs: run: | pip install --upgrade pip pip install mypy==1.8.0 - pip install -e ".[dev]" + pip install ".[dev]" - name: Running mypy and tests run: | mypy src/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 6725f78..d71d1fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -### Added -- Add get morphoelectric (me) model tool +## [0.1.1] - 26.09.2024 + +### Fixed +- Fixed a bug that prevented AsyncSqlite checkpoint to access the DB in streamed endpoints. ## [0.1.0] - 19.09.2024 diff --git a/src/neuroagent/__init__.py b/src/neuroagent/__init__.py index b4d4ea2..508d96b 100644 --- a/src/neuroagent/__init__.py +++ b/src/neuroagent/__init__.py @@ -1,3 +1,3 @@ """Neuroagent package.""" -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/src/neuroagent/agents/base_agent.py b/src/neuroagent/agents/base_agent.py index 9ecf545..347e3bd 100644 --- a/src/neuroagent/agents/base_agent.py +++ b/src/neuroagent/agents/base_agent.py @@ -1,10 +1,12 @@ """Base agent.""" from abc import ABC, abstractmethod +from contextlib import asynccontextmanager from typing import Any, AsyncIterator from langchain.chat_models.base import BaseChatModel from langchain_core.tools import BaseTool +from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from pydantic import BaseModel, ConfigDict @@ -47,3 +49,25 @@ def astream(self, *args: Any, **kwargs: Any) -> AsyncIterator[str]: @abstractmethod def _process_output(*args: Any, **kwargs: Any) -> AgentOutput: """Format the output.""" + + +class AsyncSqliteSaverWithPrefix(AsyncSqliteSaver): + """Wrapper around the AsyncSqliteSaver that accepts a connection string with prefix.""" + + @classmethod + @asynccontextmanager + async def from_conn_string( + cls, conn_string: str + ) -> AsyncIterator["AsyncSqliteSaver"]: + """Create a new AsyncSqliteSaver instance from a connection string. + + Args: + conn_string (str): The SQLite connection string. It can have the 'sqlite:///' prefix. + + Yields + ------ + AsyncSqliteSaverWithPrefix: A new AsyncSqliteSaverWithPrefix instance. + """ + conn_string = conn_string.split("///")[-1] + async with super().from_conn_string(conn_string) as memory: + yield AsyncSqliteSaverWithPrefix(memory.conn) diff --git a/src/neuroagent/agents/simple_chat_agent.py b/src/neuroagent/agents/simple_chat_agent.py index 274d18b..882b8d7 100644 --- a/src/neuroagent/agents/simple_chat_agent.py +++ b/src/neuroagent/agents/simple_chat_agent.py @@ -73,12 +73,9 @@ async def astream( streamed_response = self.agent.astream_events( {"messages": query}, version="v2", config=config ) + is_streaming = False async for event in streamed_response: kind = event["event"] - - # newline everytime model starts streaming. - if kind == "on_chat_model_start": - yield "\n\n" # check for the model stream. if kind == "on_chat_model_stream": # check if we are calling the tools. @@ -95,6 +92,9 @@ async def astream( content = data_chunk.content if content: + if not is_streaming: + yield "\n\n" + is_streaming = True yield content yield "\n" diff --git a/src/neuroagent/app/dependencies.py b/src/neuroagent/app/dependencies.py index dfcd9f2..430491d 100644 --- a/src/neuroagent/app/dependencies.py +++ b/src/neuroagent/app/dependencies.py @@ -11,7 +11,6 @@ from langchain_openai import ChatOpenAI from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver -from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError @@ -23,6 +22,7 @@ SimpleAgent, SimpleChatAgent, ) +from neuroagent.agents.base_agent import AsyncSqliteSaverWithPrefix from neuroagent.app.config import Settings from neuroagent.cell_types import CellTypesMeta from neuroagent.multi_agents import BaseMultiAgent, SupervisorMultiAgent @@ -345,8 +345,8 @@ async def get_agent_memory( """Get the agent checkpointer.""" if connection_string: if connection_string.startswith("sqlite"): - async with AsyncSqliteSaver.from_conn_string( - connection_string.split("///")[-1] + async with AsyncSqliteSaverWithPrefix.from_conn_string( + connection_string ) as memory: await memory.setup() yield memory diff --git a/src/neuroagent/scripts/__init__.py b/src/neuroagent/scripts/__init__.py new file mode 100644 index 0000000..cc662d0 --- /dev/null +++ b/src/neuroagent/scripts/__init__.py @@ -0,0 +1 @@ +"""Neuroagent scripts.""" diff --git a/tests/agents/test_simple_chat_agent.py b/tests/agents/test_simple_chat_agent.py index e580574..72f2498 100644 --- a/tests/agents/test_simple_chat_agent.py +++ b/tests/agents/test_simple_chat_agent.py @@ -85,8 +85,8 @@ async def test_astream(fake_llm_with_tools, httpx_mock): msg_list = "".join([el async for el in response]) assert ( - msg_list == "\n\n\nCalling tool : get-morpho-tool with arguments :" - ' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n\nGreat' + msg_list == "\nCalling tool : get-morpho-tool with arguments :" + ' {"brain_region_id":"http://api.brain-map.org/api/v2/data/Structure/549"}\n\nGreat' " answer\n" ) diff --git a/tests/test_resolving.py b/tests/test_resolving.py index 2e794ce..6d4f25a 100644 --- a/tests/test_resolving.py +++ b/tests/test_resolving.py @@ -33,7 +33,7 @@ async def test_sparql_exact_resolve(httpx_mock, get_resolve_query_output): } ] - httpx_mock.reset(assert_all_responses_were_requested=False) + httpx_mock.reset() mtype = "Interneuron" mocked_response = get_resolve_query_output[1] @@ -84,7 +84,7 @@ async def test_sparql_fuzzy_resolve(httpx_mock, get_resolve_query_output): "id": "http://api.brain-map.org/api/v2/data/Structure/463", }, ] - httpx_mock.reset(assert_all_responses_were_requested=False) + httpx_mock.reset() mtype = "Interneu" mocked_response = get_resolve_query_output[3] @@ -143,7 +143,7 @@ async def test_es_resolve(httpx_mock, get_resolve_query_output): "id": "http://api.brain-map.org/api/v2/data/Structure/184", }, ] - httpx_mock.reset(assert_all_responses_were_requested=True) + httpx_mock.reset() mtype = "Ventral neuron" mocked_response = get_resolve_query_output[5] @@ -222,7 +222,7 @@ async def test_resolve_query(httpx_mock, get_resolve_query_output): "id": "http://api.brain-map.org/api/v2/data/Structure/463", }, ] - httpx_mock.reset(assert_all_responses_were_requested=True) + httpx_mock.reset() httpx_mock.add_response(url=url, json=get_resolve_query_output[0]) @@ -253,7 +253,7 @@ async def test_resolve_query(httpx_mock, get_resolve_query_output): "id": "http://api.brain-map.org/api/v2/data/Structure/549", } ] - httpx_mock.reset(assert_all_responses_were_requested=True) + httpx_mock.reset() httpx_mock.add_response( url=url, json={ diff --git a/tests/test_utils.py b/tests/test_utils.py index 7f30bc3..c43e33b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -185,7 +185,7 @@ async def test_get_file_from_KG_errors(httpx_mock): ) assert not_found.value.args[0] == "No file url was found." - httpx_mock.reset(assert_all_responses_were_requested=True) + httpx_mock.reset() # no file found corresponding to file_url test_file_url = "http://test_url.com" json_response = {