From 8abdf1befeb3353a76cd0bfbe0e7229977d39257 Mon Sep 17 00:00:00 2001 From: cszsolnai Date: Thu, 14 Nov 2024 10:49:14 +0100 Subject: [PATCH 01/14] Added model dumps --- CHANGELOG.md | 3 +++ .../app/routers/database/threads.py | 22 +++++++++---------- src/neuroagent/app/routers/database/tools.py | 6 ++--- swarm_copy/app/routers/threads.py | 18 +++++++-------- swarm_copy/app/routers/tools.py | 4 ++-- tests/app/database/test_threads.py | 4 ++-- tests/app/database/test_tools.py | 15 +++++-------- 7 files changed, 36 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 687691c..e2491be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed +- Return model dumps of DB schema objects. + ### Added - LLM evaluation logic - Integrated Alembic for managing chat history migrations diff --git a/src/neuroagent/app/routers/database/threads.py b/src/neuroagent/app/routers/database/threads.py index b836972..da7182e 100644 --- a/src/neuroagent/app/routers/database/threads.py +++ b/src/neuroagent/app/routers/database/threads.py @@ -1,7 +1,7 @@ """Conversation related CRUD operations.""" import logging -from typing import Annotated +from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException from httpx import AsyncClient @@ -45,7 +45,7 @@ async def create_thread( virtual_lab_id: str, project_id: str, title: str = "title", -) -> ThreadsRead: +) -> dict[str, Any]: """Create thread. \f @@ -78,14 +78,14 @@ async def create_thread( session.add(new_thread) await session.commit() await session.refresh(new_thread) - return ThreadsRead(**new_thread.__dict__) + return ThreadsRead(**new_thread.__dict__).model_dump() @router.get("/") async def get_threads( session: Annotated[AsyncSession, Depends(get_session)], user_id: Annotated[str, Depends(get_user_id)], -) -> list[ThreadsRead]: +) -> list[dict[str, Any]]: """Get threads for a user. \f @@ -104,7 +104,7 @@ async def get_threads( query = select(Threads).where(Threads.user_sub == user_id) results = await session.execute(query) threads = results.all() - return [ThreadsRead(**thread[0].__dict__) for thread in threads] + return [ThreadsRead(**thread[0].__dict__).model_dump() for thread in threads] @router.get("/{thread_id}") @@ -112,7 +112,7 @@ async def get_thread( _: Annotated[Threads, Depends(get_object)], memory: Annotated[AsyncSqliteSaver | None, Depends(get_agent_memory)], thread_id: str, -) -> list[GetThreadsOutput]: +) -> list[dict[str, Any]]: """Get thread. \f @@ -147,7 +147,7 @@ async def get_thread( if not messages: return [] - output: list[GetThreadsOutput] = [] + output: list[dict[str, Any]] = [] # Reconstruct the conversation. Also output message_id for other endpoints for message in messages["channel_values"]["messages"]: if isinstance(message, HumanMessage): @@ -156,7 +156,7 @@ async def get_thread( message_id=message.id, entity="Human", message=message.content, - ) + ).model_dump() ) if isinstance(message, AIMessage) and message.content: output.append( @@ -164,7 +164,7 @@ async def get_thread( message_id=message.id, entity="AI", message=message.content, - ) + ).model_dump() ) return output @@ -174,7 +174,7 @@ async def update_thread_title( thread: Annotated[Threads, Depends(get_object)], session: Annotated[AsyncSession, Depends(get_session)], thread_update: ThreadsUpdate, -) -> ThreadsRead: +) -> dict[str, Any]: """Update thread. \f @@ -199,7 +199,7 @@ async def update_thread_title( await session.commit() await session.refresh(thread) thread_return = ThreadsRead(**thread.__dict__) # For mypy. - return thread_return + return thread_return.model_dump() @router.delete("/{thread_id}") diff --git a/src/neuroagent/app/routers/database/tools.py b/src/neuroagent/app/routers/database/tools.py index 9248c66..d96815b 100644 --- a/src/neuroagent/app/routers/database/tools.py +++ b/src/neuroagent/app/routers/database/tools.py @@ -24,7 +24,7 @@ async def get_tool_calls( memory: Annotated[AsyncSqliteSaver | None, Depends(get_agent_memory)], thread_id: str, message_id: str, -) -> list[ToolCallSchema]: +) -> list[dict[str, Any]]: """Get tool calls of a specific message. \f @@ -93,14 +93,14 @@ async def get_tool_calls( ) # From sub list, extract tool calls - tool_calls: list[ToolCallSchema] = [] + tool_calls: list[dict[str, Any]] = [] for message in message_list[previous_content_message + 1 : relevant_message]: if isinstance(message, AIMessage): tool_calls.extend( [ ToolCallSchema( call_id=tool["id"], name=tool["name"], arguments=tool["args"] - ) + ).model_dump() for tool in message.tool_calls ] ) diff --git a/swarm_copy/app/routers/threads.py b/swarm_copy/app/routers/threads.py index 58e4b2d..73325e7 100644 --- a/swarm_copy/app/routers/threads.py +++ b/swarm_copy/app/routers/threads.py @@ -2,7 +2,7 @@ import json import logging -from typing import Annotated +from typing import Annotated, Any from fastapi import APIRouter, Depends from httpx import AsyncClient @@ -37,7 +37,7 @@ async def create_thread( session: Annotated[AsyncSession, Depends(get_session)], user_id: Annotated[str, Depends(get_user_id)], title: str = "New chat", -) -> ThreadsRead: +) -> dict[str, Any]: """Create thread.""" # We first need to check if the combination thread/vlab/project is valid await validate_project( @@ -57,20 +57,20 @@ async def create_thread( await session.commit() await session.refresh(new_thread) - return ThreadsRead(**new_thread.__dict__) + return ThreadsRead(**new_thread.__dict__).model_dump() @router.get("/") async def get_threads( session: Annotated[AsyncSession, Depends(get_session)], user_id: Annotated[str, Depends(get_user_id)], -) -> list[ThreadsRead]: +) -> list[dict[str, Any]]: """Get threads for a user.""" thread_result = await session.execute( select(Threads).where(Threads.user_id == user_id) ) threads = thread_result.scalars().all() - return [ThreadsRead(**thread.__dict__) for thread in threads] + return [ThreadsRead(**thread.__dict__).model_dump() for thread in threads] @router.get("/{thread_id}") @@ -78,7 +78,7 @@ async def get_messages( session: Annotated[AsyncSession, Depends(get_session)], _: Annotated[Threads, Depends(get_thread)], # to check if thread exist thread_id: str, -) -> list[MessagesRead]: +) -> list[dict[str, Any]]: """Get all messages of the thread.""" messages_result = await session.execute( select(Messages) @@ -96,7 +96,7 @@ async def get_messages( MessagesRead( msg_content=json.loads(msg.content)["content"], **msg.__dict__, - ) + ).model_dump() ) return messages @@ -107,14 +107,14 @@ async def update_thread_title( session: Annotated[AsyncSession, Depends(get_session)], update_thread: ThreadUpdate, thread: Annotated[Threads, Depends(get_thread)], -) -> ThreadsRead: +) -> dict[str, Any]: """Update thread.""" thread_data = update_thread.model_dump(exclude_unset=True) for key, value in thread_data.items(): setattr(thread, key, value) await session.commit() await session.refresh(thread) - return ThreadsRead(**thread.__dict__) + return ThreadsRead(**thread.__dict__).model_dump() @router.delete("/{thread_id}") diff --git a/swarm_copy/app/routers/tools.py b/swarm_copy/app/routers/tools.py index bc4e06f..d8a0cc1 100644 --- a/swarm_copy/app/routers/tools.py +++ b/swarm_copy/app/routers/tools.py @@ -24,7 +24,7 @@ async def get_tool_calls( session: Annotated[AsyncSession, Depends(get_session)], thread_id: str, message_id: str, -) -> list[ToolCallSchema]: +) -> list[dict[str, Any]]: """Get tool calls of a specific message.""" # Find relevant messages relevant_message = await session.get(Messages, message_id) @@ -78,7 +78,7 @@ async def get_tool_calls( tool_call_id=tool["id"], name=tool["function"]["name"], arguments=json.loads(tool["function"]["arguments"]), - ) + ).model_dump() ) return tool_calls_response diff --git a/tests/app/database/test_threads.py b/tests/app/database/test_threads.py index 246509b..4b1d6b6 100644 --- a/tests/app/database/test_threads.py +++ b/tests/app/database/test_threads.py @@ -110,12 +110,12 @@ async def test_get_thread( message_id=messages[0]["message_id"], entity="Human", message="This is my query", - ).model_dump(), + ), GetThreadsOutput( message_id="run-42768b30-044a-4263-8c5c-da61429aa9da-0", entity="AI", message="Great answer", - ).model_dump(), + ), ] diff --git a/tests/app/database/test_tools.py b/tests/app/database/test_tools.py index 3ce9750..fe32781 100644 --- a/tests/app/database/test_tools.py +++ b/tests/app/database/test_tools.py @@ -53,15 +53,12 @@ async def test_get_tool_calls( message_id = messages[-1]["message_id"] tool_calls = app_client.get(f"/tools/{thread_id}/{message_id}").json() - assert ( - tool_calls[0] - == ToolCallSchema( - call_id="call_zHhwfNLSvGGHXMoILdIYtDVI", - name="get-morpho-tool", - arguments={ - "brain_region_id": "http://api.brain-map.org/api/v2/data/Structure/549" - }, - ).model_dump() + assert tool_calls[0] == ToolCallSchema( + call_id="call_zHhwfNLSvGGHXMoILdIYtDVI", + name="get-morpho-tool", + arguments={ + "brain_region_id": "http://api.brain-map.org/api/v2/data/Structure/549" + }, ) From e8321de526a9937caddcdc32ebba7b094824a031 Mon Sep 17 00:00:00 2001 From: cszsolnai Date: Thu, 14 Nov 2024 14:43:35 +0100 Subject: [PATCH 02/14] Adjusted unit tests --- tests/app/database/test_threads.py | 20 ++++++++++---------- tests/app/database/test_tools.py | 10 +++++----- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/app/database/test_threads.py b/tests/app/database/test_threads.py index 4b1d6b6..edafac9 100644 --- a/tests/app/database/test_threads.py +++ b/tests/app/database/test_threads.py @@ -106,16 +106,16 @@ async def test_get_thread( messages = app_client.get(f"/threads/{thread_id}").json() assert messages == [ - GetThreadsOutput( - message_id=messages[0]["message_id"], - entity="Human", - message="This is my query", - ), - GetThreadsOutput( - message_id="run-42768b30-044a-4263-8c5c-da61429aa9da-0", - entity="AI", - message="Great answer", - ), + { + "message_id": messages[0]["message_id"], + "entity": "Human", + "message": "This is my query", + }, + { + "message_id": "run-42768b30-044a-4263-8c5c-da61429aa9da-0", + "entity": "AI", + "message": "Great answer", + } ] diff --git a/tests/app/database/test_tools.py b/tests/app/database/test_tools.py index fe32781..b3d056e 100644 --- a/tests/app/database/test_tools.py +++ b/tests/app/database/test_tools.py @@ -53,13 +53,13 @@ async def test_get_tool_calls( message_id = messages[-1]["message_id"] tool_calls = app_client.get(f"/tools/{thread_id}/{message_id}").json() - assert tool_calls[0] == ToolCallSchema( - call_id="call_zHhwfNLSvGGHXMoILdIYtDVI", - name="get-morpho-tool", - arguments={ + assert tool_calls[0] == { + "call_id": "call_zHhwfNLSvGGHXMoILdIYtDVI", + "name": "get-morpho-tool", + "arguments": { "brain_region_id": "http://api.brain-map.org/api/v2/data/Structure/549" }, - ) + } @pytest.mark.httpx_mock(can_send_already_matched_responses=True) From ad8cb0dda5b824b0d7b59230ee53c248f38b54c8 Mon Sep 17 00:00:00 2001 From: cszsol Date: Thu, 14 Nov 2024 14:45:17 +0100 Subject: [PATCH 03/14] Update test_threads.py --- tests/app/database/test_threads.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/app/database/test_threads.py b/tests/app/database/test_threads.py index edafac9..246509b 100644 --- a/tests/app/database/test_threads.py +++ b/tests/app/database/test_threads.py @@ -106,16 +106,16 @@ async def test_get_thread( messages = app_client.get(f"/threads/{thread_id}").json() assert messages == [ - { - "message_id": messages[0]["message_id"], - "entity": "Human", - "message": "This is my query", - }, - { - "message_id": "run-42768b30-044a-4263-8c5c-da61429aa9da-0", - "entity": "AI", - "message": "Great answer", - } + GetThreadsOutput( + message_id=messages[0]["message_id"], + entity="Human", + message="This is my query", + ).model_dump(), + GetThreadsOutput( + message_id="run-42768b30-044a-4263-8c5c-da61429aa9da-0", + entity="AI", + message="Great answer", + ).model_dump(), ] From dd5a926a40dbbafaa58d05569fdb5304fe5ebbe6 Mon Sep 17 00:00:00 2001 From: cszsol Date: Thu, 14 Nov 2024 14:46:13 +0100 Subject: [PATCH 04/14] Update test_tools.py --- tests/app/database/test_tools.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/app/database/test_tools.py b/tests/app/database/test_tools.py index b3d056e..3ce9750 100644 --- a/tests/app/database/test_tools.py +++ b/tests/app/database/test_tools.py @@ -53,13 +53,16 @@ async def test_get_tool_calls( message_id = messages[-1]["message_id"] tool_calls = app_client.get(f"/tools/{thread_id}/{message_id}").json() - assert tool_calls[0] == { - "call_id": "call_zHhwfNLSvGGHXMoILdIYtDVI", - "name": "get-morpho-tool", - "arguments": { - "brain_region_id": "http://api.brain-map.org/api/v2/data/Structure/549" - }, - } + assert ( + tool_calls[0] + == ToolCallSchema( + call_id="call_zHhwfNLSvGGHXMoILdIYtDVI", + name="get-morpho-tool", + arguments={ + "brain_region_id": "http://api.brain-map.org/api/v2/data/Structure/549" + }, + ).model_dump() + ) @pytest.mark.httpx_mock(can_send_already_matched_responses=True) From 6f11c26ed4f9470885d58a28f2daecd553f6ebba Mon Sep 17 00:00:00 2001 From: kanesoban Date: Wed, 11 Dec 2024 13:28:34 +0100 Subject: [PATCH 05/14] Update threads.py From 8aad2ebbcb35ea30c9dfa1be52b861ca17f3fcb0 Mon Sep 17 00:00:00 2001 From: kanesoban Date: Wed, 11 Dec 2024 13:30:30 +0100 Subject: [PATCH 06/14] Update threads.py --- .../app/routers/database/threads.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/neuroagent/app/routers/database/threads.py b/src/neuroagent/app/routers/database/threads.py index da7182e..b836972 100644 --- a/src/neuroagent/app/routers/database/threads.py +++ b/src/neuroagent/app/routers/database/threads.py @@ -1,7 +1,7 @@ """Conversation related CRUD operations.""" import logging -from typing import Annotated, Any +from typing import Annotated from fastapi import APIRouter, Depends, HTTPException from httpx import AsyncClient @@ -45,7 +45,7 @@ async def create_thread( virtual_lab_id: str, project_id: str, title: str = "title", -) -> dict[str, Any]: +) -> ThreadsRead: """Create thread. \f @@ -78,14 +78,14 @@ async def create_thread( session.add(new_thread) await session.commit() await session.refresh(new_thread) - return ThreadsRead(**new_thread.__dict__).model_dump() + return ThreadsRead(**new_thread.__dict__) @router.get("/") async def get_threads( session: Annotated[AsyncSession, Depends(get_session)], user_id: Annotated[str, Depends(get_user_id)], -) -> list[dict[str, Any]]: +) -> list[ThreadsRead]: """Get threads for a user. \f @@ -104,7 +104,7 @@ async def get_threads( query = select(Threads).where(Threads.user_sub == user_id) results = await session.execute(query) threads = results.all() - return [ThreadsRead(**thread[0].__dict__).model_dump() for thread in threads] + return [ThreadsRead(**thread[0].__dict__) for thread in threads] @router.get("/{thread_id}") @@ -112,7 +112,7 @@ async def get_thread( _: Annotated[Threads, Depends(get_object)], memory: Annotated[AsyncSqliteSaver | None, Depends(get_agent_memory)], thread_id: str, -) -> list[dict[str, Any]]: +) -> list[GetThreadsOutput]: """Get thread. \f @@ -147,7 +147,7 @@ async def get_thread( if not messages: return [] - output: list[dict[str, Any]] = [] + output: list[GetThreadsOutput] = [] # Reconstruct the conversation. Also output message_id for other endpoints for message in messages["channel_values"]["messages"]: if isinstance(message, HumanMessage): @@ -156,7 +156,7 @@ async def get_thread( message_id=message.id, entity="Human", message=message.content, - ).model_dump() + ) ) if isinstance(message, AIMessage) and message.content: output.append( @@ -164,7 +164,7 @@ async def get_thread( message_id=message.id, entity="AI", message=message.content, - ).model_dump() + ) ) return output @@ -174,7 +174,7 @@ async def update_thread_title( thread: Annotated[Threads, Depends(get_object)], session: Annotated[AsyncSession, Depends(get_session)], thread_update: ThreadsUpdate, -) -> dict[str, Any]: +) -> ThreadsRead: """Update thread. \f @@ -199,7 +199,7 @@ async def update_thread_title( await session.commit() await session.refresh(thread) thread_return = ThreadsRead(**thread.__dict__) # For mypy. - return thread_return.model_dump() + return thread_return @router.delete("/{thread_id}") From ee4b97b69c30b5fdf48866e9521cfa9e4b07f083 Mon Sep 17 00:00:00 2001 From: kanesoban Date: Wed, 11 Dec 2024 13:31:42 +0100 Subject: [PATCH 07/14] Update tools.py --- src/neuroagent/app/routers/database/tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/neuroagent/app/routers/database/tools.py b/src/neuroagent/app/routers/database/tools.py index d96815b..9248c66 100644 --- a/src/neuroagent/app/routers/database/tools.py +++ b/src/neuroagent/app/routers/database/tools.py @@ -24,7 +24,7 @@ async def get_tool_calls( memory: Annotated[AsyncSqliteSaver | None, Depends(get_agent_memory)], thread_id: str, message_id: str, -) -> list[dict[str, Any]]: +) -> list[ToolCallSchema]: """Get tool calls of a specific message. \f @@ -93,14 +93,14 @@ async def get_tool_calls( ) # From sub list, extract tool calls - tool_calls: list[dict[str, Any]] = [] + tool_calls: list[ToolCallSchema] = [] for message in message_list[previous_content_message + 1 : relevant_message]: if isinstance(message, AIMessage): tool_calls.extend( [ ToolCallSchema( call_id=tool["id"], name=tool["name"], arguments=tool["args"] - ).model_dump() + ) for tool in message.tool_calls ] ) From f297cb6cfc11edfeadf2d1d05e03c1d1e6f2d6d2 Mon Sep 17 00:00:00 2001 From: kanesoban Date: Wed, 11 Dec 2024 13:33:18 +0100 Subject: [PATCH 08/14] Update threads.py --- swarm_copy/app/routers/threads.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/swarm_copy/app/routers/threads.py b/swarm_copy/app/routers/threads.py index 73325e7..58e4b2d 100644 --- a/swarm_copy/app/routers/threads.py +++ b/swarm_copy/app/routers/threads.py @@ -2,7 +2,7 @@ import json import logging -from typing import Annotated, Any +from typing import Annotated from fastapi import APIRouter, Depends from httpx import AsyncClient @@ -37,7 +37,7 @@ async def create_thread( session: Annotated[AsyncSession, Depends(get_session)], user_id: Annotated[str, Depends(get_user_id)], title: str = "New chat", -) -> dict[str, Any]: +) -> ThreadsRead: """Create thread.""" # We first need to check if the combination thread/vlab/project is valid await validate_project( @@ -57,20 +57,20 @@ async def create_thread( await session.commit() await session.refresh(new_thread) - return ThreadsRead(**new_thread.__dict__).model_dump() + return ThreadsRead(**new_thread.__dict__) @router.get("/") async def get_threads( session: Annotated[AsyncSession, Depends(get_session)], user_id: Annotated[str, Depends(get_user_id)], -) -> list[dict[str, Any]]: +) -> list[ThreadsRead]: """Get threads for a user.""" thread_result = await session.execute( select(Threads).where(Threads.user_id == user_id) ) threads = thread_result.scalars().all() - return [ThreadsRead(**thread.__dict__).model_dump() for thread in threads] + return [ThreadsRead(**thread.__dict__) for thread in threads] @router.get("/{thread_id}") @@ -78,7 +78,7 @@ async def get_messages( session: Annotated[AsyncSession, Depends(get_session)], _: Annotated[Threads, Depends(get_thread)], # to check if thread exist thread_id: str, -) -> list[dict[str, Any]]: +) -> list[MessagesRead]: """Get all messages of the thread.""" messages_result = await session.execute( select(Messages) @@ -96,7 +96,7 @@ async def get_messages( MessagesRead( msg_content=json.loads(msg.content)["content"], **msg.__dict__, - ).model_dump() + ) ) return messages @@ -107,14 +107,14 @@ async def update_thread_title( session: Annotated[AsyncSession, Depends(get_session)], update_thread: ThreadUpdate, thread: Annotated[Threads, Depends(get_thread)], -) -> dict[str, Any]: +) -> ThreadsRead: """Update thread.""" thread_data = update_thread.model_dump(exclude_unset=True) for key, value in thread_data.items(): setattr(thread, key, value) await session.commit() await session.refresh(thread) - return ThreadsRead(**thread.__dict__).model_dump() + return ThreadsRead(**thread.__dict__) @router.delete("/{thread_id}") From c85647980123debf772d5ba4321eeed3f7c26d73 Mon Sep 17 00:00:00 2001 From: kanesoban Date: Wed, 11 Dec 2024 13:33:39 +0100 Subject: [PATCH 09/14] Update tools.py --- swarm_copy/app/routers/tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swarm_copy/app/routers/tools.py b/swarm_copy/app/routers/tools.py index d8a0cc1..bc4e06f 100644 --- a/swarm_copy/app/routers/tools.py +++ b/swarm_copy/app/routers/tools.py @@ -24,7 +24,7 @@ async def get_tool_calls( session: Annotated[AsyncSession, Depends(get_session)], thread_id: str, message_id: str, -) -> list[dict[str, Any]]: +) -> list[ToolCallSchema]: """Get tool calls of a specific message.""" # Find relevant messages relevant_message = await session.get(Messages, message_id) @@ -78,7 +78,7 @@ async def get_tool_calls( tool_call_id=tool["id"], name=tool["function"]["name"], arguments=json.loads(tool["function"]["arguments"]), - ).model_dump() + ) ) return tool_calls_response From bd125163bad66902bd06799e6c08d17955d8cd19 Mon Sep 17 00:00:00 2001 From: kanesoban Date: Wed, 11 Dec 2024 13:45:10 +0100 Subject: [PATCH 10/14] Changed tool returns to model dump --- swarm_copy/tools/bluenaas_memodel_getall.py | 10 +++++----- swarm_copy/tools/bluenaas_memodel_getone.py | 10 +++++----- swarm_copy/tools/bluenaas_scs_getall.py | 8 +++++--- swarm_copy/tools/bluenaas_scs_getone.py | 6 +++--- swarm_copy/tools/bluenaas_scs_post.py | 4 ++-- swarm_copy/tools/electrophys_tool.py | 4 ++-- swarm_copy/tools/get_morpho_tool.py | 6 +++--- swarm_copy/tools/kg_morpho_features_tool.py | 6 +++--- swarm_copy/tools/literature_search_tool.py | 6 +++--- swarm_copy/tools/morphology_features_tool.py | 4 ++-- swarm_copy/tools/resolve_entities_tool.py | 8 ++++---- swarm_copy/tools/traces_tool.py | 6 +++--- 12 files changed, 40 insertions(+), 38 deletions(-) diff --git a/swarm_copy/tools/bluenaas_memodel_getall.py b/swarm_copy/tools/bluenaas_memodel_getall.py index 8bda00e..59688d1 100644 --- a/swarm_copy/tools/bluenaas_memodel_getall.py +++ b/swarm_copy/tools/bluenaas_memodel_getall.py @@ -1,7 +1,7 @@ """BlueNaaS single cell stimulation, simulation and synapse placement tool.""" import logging -from typing import ClassVar, Literal +from typing import Any, ClassVar, Literal from pydantic import BaseModel, Field @@ -29,7 +29,7 @@ class InputMEModelGetAll(BaseModel): page_size: int = Field( default=20, description="Number of results returned by the API." ) - model_type: Literal["single-neuron-simulation", "synaptome-simulation"] = Field( + memodel_type: Literal["single-neuron-simulation", "synaptome-simulation"] = Field( default="single-neuron-simulation", description="Type of simulation to retrieve.", ) @@ -46,7 +46,7 @@ class MEModelGetAllTool(BaseTool): metadata: MEModelGetAllMetadata input_schema: InputMEModelGetAll - async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelResponse: + async def arun(self) -> dict[str, Any]: """Run the MEModelGetAll tool.""" logger.info( f"Running MEModelGetAll tool with inputs {self.input_schema.model_dump()}" @@ -55,7 +55,7 @@ async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelRespo response = await self.metadata.httpx_client.get( url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/me-models", params={ - "simulation_type": self.input_schema.model_type, + "simulation_type": self.input_schema.memodel_type, "offset": self.input_schema.offset, "page_size": self.input_schema.page_size, }, @@ -64,4 +64,4 @@ async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelRespo breakpoint() return PaginatedResponseUnionMEModelResponseSynaptomeModelResponse( **response.json() - ) + ).model_dump() diff --git a/swarm_copy/tools/bluenaas_memodel_getone.py b/swarm_copy/tools/bluenaas_memodel_getone.py index 4f4a3b3..2f36f64 100644 --- a/swarm_copy/tools/bluenaas_memodel_getone.py +++ b/swarm_copy/tools/bluenaas_memodel_getone.py @@ -1,7 +1,7 @@ """BlueNaaS single cell stimulation, simulation and synapse placement tool.""" import logging -from typing import ClassVar +from typing import Any, ClassVar from urllib.parse import quote_plus from pydantic import BaseModel, Field @@ -24,7 +24,7 @@ class MEModelGetOneMetadata(BaseMetadata): class InputMEModelGetOne(BaseModel): """Inputs for the BlueNaaS single-neuron simulation.""" - model_id: str = Field( + memodel_id: str = Field( description="ID of the model to retrieve. Should be an https link." ) @@ -38,15 +38,15 @@ class MEModelGetOneTool(BaseTool): metadata: MEModelGetOneMetadata input_schema: InputMEModelGetOne - async def arun(self) -> MEModelResponse: + async def arun(self) -> dict[str, Any]: """Run the MEModelGetOne tool.""" logger.info( f"Running MEModelGetOne tool with inputs {self.input_schema.model_dump()}" ) response = await self.metadata.httpx_client.get( - url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/{quote_plus(self.input_schema.model_id)}", + url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/{quote_plus(self.input_schema.memodel_id)}", headers={"Authorization": f"Bearer {self.metadata.token}"}, ) - return MEModelResponse(**response.json()) + return MEModelResponse(**response.json()).model_dump() diff --git a/swarm_copy/tools/bluenaas_scs_getall.py b/swarm_copy/tools/bluenaas_scs_getall.py index 95897dc..533ed7f 100644 --- a/swarm_copy/tools/bluenaas_scs_getall.py +++ b/swarm_copy/tools/bluenaas_scs_getall.py @@ -1,7 +1,7 @@ """BlueNaaS single cell stimulation, simulation and synapse placement tool.""" import logging -from typing import ClassVar, Literal +from typing import Any, ClassVar, Literal from pydantic import BaseModel, Field @@ -47,7 +47,7 @@ class SCSGetAllTool(BaseTool): metadata: SCSGetAllMetadata input_schema: InputSCSGetAll - async def arun(self) -> PaginatedResponseSimulationDetailsResponse: + async def arun(self) -> dict[str, Any]: """Run the SCSGetAll tool.""" logger.info( f"Running SCSGetAll tool with inputs {self.input_schema.model_dump()}" @@ -63,4 +63,6 @@ async def arun(self) -> PaginatedResponseSimulationDetailsResponse: headers={"Authorization": f"Bearer {self.metadata.token}"}, ) - return PaginatedResponseSimulationDetailsResponse(**response.json()) + return PaginatedResponseSimulationDetailsResponse( + **response.json() + ).model_dump() diff --git a/swarm_copy/tools/bluenaas_scs_getone.py b/swarm_copy/tools/bluenaas_scs_getone.py index 4957be9..2575682 100644 --- a/swarm_copy/tools/bluenaas_scs_getone.py +++ b/swarm_copy/tools/bluenaas_scs_getone.py @@ -1,7 +1,7 @@ """BlueNaaS single cell stimulation, simulation and synapse placement tool.""" import logging -from typing import ClassVar +from typing import Any, ClassVar from pydantic import BaseModel, Field @@ -39,7 +39,7 @@ class SCSGetOneTool(BaseTool): metadata: SCSGetOneMetadata input_schema: InputSCSGetOne - async def arun(self) -> SimulationDetailsResponse: + async def arun(self) -> dict[str, Any]: """Run the SCSGetOne tool.""" logger.info( f"Running SCSGetOne tool with inputs {self.input_schema.model_dump()}" @@ -50,4 +50,4 @@ async def arun(self) -> SimulationDetailsResponse: headers={"Authorization": f"Bearer {self.metadata.token}"}, ) - return SimulationDetailsResponse(**response.json()) + return SimulationDetailsResponse(**response.json()).model_dump() diff --git a/swarm_copy/tools/bluenaas_scs_post.py b/swarm_copy/tools/bluenaas_scs_post.py index 6c8e154..7e3144c 100644 --- a/swarm_copy/tools/bluenaas_scs_post.py +++ b/swarm_copy/tools/bluenaas_scs_post.py @@ -94,7 +94,7 @@ class SCSPostTool(BaseTool): metadata: SCSPostMetadata input_schema: InputSCSPost - async def arun(self) -> SCSPostOutput: + async def arun(self) -> dict[str, Any]: """Run the SCSPost tool.""" logger.info( f"Running SCSPost tool with inputs {self.input_schema.model_dump()}" @@ -126,7 +126,7 @@ async def arun(self) -> SCSPostOutput: status=json_response["status"], name=json_response["name"], error=json_response["error"], - ) + ).model_dump() @staticmethod def create_json_api( diff --git a/swarm_copy/tools/electrophys_tool.py b/swarm_copy/tools/electrophys_tool.py index 00673c1..2984836 100644 --- a/swarm_copy/tools/electrophys_tool.py +++ b/swarm_copy/tools/electrophys_tool.py @@ -194,7 +194,7 @@ class ElectrophysFeatureTool(BaseTool): input_schema: ElectrophysInput metadata: ElectrophysMetadata - async def arun(self) -> FeatureOutput: + async def arun(self) -> dict[str, Any]: """Give features about trace.""" logger.info( f"Entering electrophys tool. Inputs: {self.input_schema.trace_id=}, {self.input_schema.calculated_feature=}," @@ -329,4 +329,4 @@ async def arun(self) -> FeatureOutput: ) return FeatureOutput( brain_region=metadata.brain_region, feature_dict=output_features - ) + ).model_dump() diff --git a/swarm_copy/tools/get_morpho_tool.py b/swarm_copy/tools/get_morpho_tool.py index dc8d4a6..45c72bb 100644 --- a/swarm_copy/tools/get_morpho_tool.py +++ b/swarm_copy/tools/get_morpho_tool.py @@ -70,7 +70,7 @@ class GetMorphoTool(BaseTool): input_schema: GetMorphoInput metadata: GetMorphoMetadata - async def arun(self) -> list[KnowledgeGraphOutput]: + async def arun(self) -> list[dict[str, Any]]: """From a brain region ID, extract morphologies. Returns @@ -175,7 +175,7 @@ def create_query( return entire_query @staticmethod - def _process_output(output: Any) -> list[KnowledgeGraphOutput]: + def _process_output(output: Any) -> list[dict[str, Any]]: """Process output to fit the KnowledgeGraphOutput pydantic class defined above. Parameters @@ -211,7 +211,7 @@ def _process_output(output: Any) -> list[KnowledgeGraphOutput]: if "subjectAge" in res["_source"] else None ), - ) + ).model_dump() for res in output["hits"]["hits"] ] return formatted_output diff --git a/swarm_copy/tools/kg_morpho_features_tool.py b/swarm_copy/tools/kg_morpho_features_tool.py index 24eeac8..7298636 100644 --- a/swarm_copy/tools/kg_morpho_features_tool.py +++ b/swarm_copy/tools/kg_morpho_features_tool.py @@ -186,7 +186,7 @@ class KGMorphoFeatureTool(BaseTool): input_schema: KGMorphoFeatureInput metadata: KGMorphoFeatureMetadata - async def arun(self) -> list[KGMorphoFeatureOutput]: + async def arun(self) -> list[dict[str, Any]]: """Run the tool async. Returns @@ -319,7 +319,7 @@ def create_query( return entire_query @staticmethod - def _process_output(output: Any) -> list[KGMorphoFeatureOutput]: + def _process_output(output: Any) -> list[dict[str, Any]]: """Process output. Parameters @@ -347,7 +347,7 @@ def _process_output(output: Any) -> list[KGMorphoFeatureOutput]: morphology_id=morpho_source["neuronMorphology"]["@id"], morphology_name=morpho_source["neuronMorphology"].get("name"), features=feature_output, - ) + ).model_dump() ) return formatted_output diff --git a/swarm_copy/tools/literature_search_tool.py b/swarm_copy/tools/literature_search_tool.py index 99880b9..92ebf7d 100644 --- a/swarm_copy/tools/literature_search_tool.py +++ b/swarm_copy/tools/literature_search_tool.py @@ -61,7 +61,7 @@ class LiteratureSearchTool(BaseTool): input_schema: LiteratureSearchInput metadata: LiteratureSearchMetadata - async def arun(self) -> list[ParagraphMetadata]: + async def arun(self) -> list[dict[str, Any]]: """Async search the scientific literature and returns citations. Returns @@ -91,7 +91,7 @@ async def arun(self) -> list[ParagraphMetadata]: return self._process_output(response.json()) @staticmethod - def _process_output(output: list[dict[str, Any]]) -> list[ParagraphMetadata]: + def _process_output(output: list[dict[str, Any]]) -> list[dict[str, Any]]: """Process output.""" paragraphs_metadata = [ ParagraphMetadata( @@ -101,7 +101,7 @@ def _process_output(output: list[dict[str, Any]]) -> list[ParagraphMetadata]: section=paragraph["section"], article_doi=paragraph["article_doi"], journal_issn=paragraph["journal_issn"], - ) + ).model_dump() for paragraph in output ] return paragraphs_metadata diff --git a/swarm_copy/tools/morphology_features_tool.py b/swarm_copy/tools/morphology_features_tool.py index 00bd349..31b2548 100644 --- a/swarm_copy/tools/morphology_features_tool.py +++ b/swarm_copy/tools/morphology_features_tool.py @@ -52,7 +52,7 @@ class MorphologyFeatureTool(BaseTool): input_schema: MorphologyFeatureInput metadata: MorphologyFeatureMetadata - async def arun(self) -> list[MorphologyFeatureOutput]: + async def arun(self) -> list[dict[str, Any]]: """Give features about morphology.""" logger.info( f"Entering morphology feature tool. Inputs: {self.input_schema.morphology_id=}" @@ -71,7 +71,7 @@ async def arun(self) -> list[MorphologyFeatureOutput]: return [ MorphologyFeatureOutput( brain_region=metadata.brain_region, feature_dict=features - ) + ).model_dump() ] def get_features(self, morphology_content: bytes, reader: str) -> dict[str, Any]: diff --git a/swarm_copy/tools/resolve_entities_tool.py b/swarm_copy/tools/resolve_entities_tool.py index 1264ac1..00503ad 100644 --- a/swarm_copy/tools/resolve_entities_tool.py +++ b/swarm_copy/tools/resolve_entities_tool.py @@ -1,7 +1,7 @@ """Tool to resolve the brain region from natural english to a KG ID.""" import logging -from typing import ClassVar +from typing import Any, ClassVar from pydantic import BaseModel, Field @@ -86,14 +86,14 @@ class ResolveEntitiesTool(BaseTool): async def arun( self, - ) -> list[BRResolveOutput | MTypeResolveOutput | EtypeResolveOutput]: + ) -> list[dict[str, Any]]: """Given a brain region in natural language, resolve its ID.""" logger.info( f"Entering Brain Region resolver tool. Inputs: {self.input_schema.brain_region=}, " f"{self.input_schema.mtype=}, {self.input_schema.etype=}" ) # Prepare the output list. - output: list[BRResolveOutput | MTypeResolveOutput | EtypeResolveOutput] = [] + output: list[dict[str, Any]] = [] # First resolve the brain regions. brain_regions = await resolve_query( @@ -138,7 +138,7 @@ async def arun( EtypeResolveOutput( etype_name=self.input_schema.etype, etype_id=ETYPE_IDS[self.input_schema.etype], - ) + ).model_dump() ) return output diff --git a/swarm_copy/tools/traces_tool.py b/swarm_copy/tools/traces_tool.py index 41028b2..616da8e 100644 --- a/swarm_copy/tools/traces_tool.py +++ b/swarm_copy/tools/traces_tool.py @@ -70,7 +70,7 @@ class GetTracesTool(BaseTool): input_schema: GetTracesInput metadata: GetTracesMetadata - async def arun(self) -> list[TracesOutput]: + async def arun(self) -> list[dict[str, Any]]: """From a brain region ID, extract traces.""" logger.info( f"Entering get trace tool. Inputs: {self.input_schema.brain_region_id=}, {self.input_schema.etype_id=}" @@ -152,7 +152,7 @@ def create_query( return entire_query @staticmethod - def _process_output(output: Any) -> list[TracesOutput]: + def _process_output(output: Any) -> list[dict[str, Any]]: """Process output to fit the TracesOutput pydantic class defined above. Parameters @@ -189,7 +189,7 @@ def _process_output(output: Any) -> list[TracesOutput]: if "subjectAge" in res["_source"] else None ), - ) + ).model_dump() for res in output["hits"]["hits"] ] return results From 0f9380691c18dceabcec2f69125286f7e8eea08d Mon Sep 17 00:00:00 2001 From: kanesoban Date: Wed, 11 Dec 2024 13:55:00 +0100 Subject: [PATCH 11/14] mypy --- src/neuroagent/cell_types.py | 2 +- swarm_copy/cell_types.py | 2 +- swarm_copy/tools/resolve_entities_tool.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/neuroagent/cell_types.py b/src/neuroagent/cell_types.py index df84d38..9c91e32 100644 --- a/src/neuroagent/cell_types.py +++ b/src/neuroagent/cell_types.py @@ -17,7 +17,7 @@ class CellTypesMeta: """ def __init__(self) -> None: - self.name_: dict[str, str] = {} + self.name_: dict[Any, Any | None] = {} self.descendants_ids: dict[str, set[str]] = {} def descendants(self, ids: str | set[str]) -> set[str]: diff --git a/swarm_copy/cell_types.py b/swarm_copy/cell_types.py index df84d38..9c91e32 100644 --- a/swarm_copy/cell_types.py +++ b/swarm_copy/cell_types.py @@ -17,7 +17,7 @@ class CellTypesMeta: """ def __init__(self) -> None: - self.name_: dict[str, str] = {} + self.name_: dict[Any, Any | None] = {} self.descendants_ids: dict[str, set[str]] = {} def descendants(self, ids: str | set[str]) -> set[str]: diff --git a/swarm_copy/tools/resolve_entities_tool.py b/swarm_copy/tools/resolve_entities_tool.py index 00503ad..935a5f9 100644 --- a/swarm_copy/tools/resolve_entities_tool.py +++ b/swarm_copy/tools/resolve_entities_tool.py @@ -108,7 +108,7 @@ async def arun( # Extend the resolved BRs. output.extend( [ - BRResolveOutput(brain_region_name=br["label"], brain_region_id=br["id"]) + BRResolveOutput(brain_region_name=br["label"], brain_region_id=br["id"]).model_dump() for br in brain_regions ] ) @@ -127,7 +127,7 @@ async def arun( # Extend the resolved mtypes. output.extend( [ - MTypeResolveOutput(mtype_name=mtype["label"], mtype_id=mtype["id"]) + MTypeResolveOutput(mtype_name=mtype["label"], mtype_id=mtype["id"]).model_dump() for mtype in mtypes ] ) From 850bca32cc675467f3c7d66fd3b26290b57b681e Mon Sep 17 00:00:00 2001 From: kanesoban Date: Wed, 11 Dec 2024 13:58:33 +0100 Subject: [PATCH 12/14] lint --- swarm_copy/tools/resolve_entities_tool.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/swarm_copy/tools/resolve_entities_tool.py b/swarm_copy/tools/resolve_entities_tool.py index 935a5f9..e6ab88d 100644 --- a/swarm_copy/tools/resolve_entities_tool.py +++ b/swarm_copy/tools/resolve_entities_tool.py @@ -108,7 +108,9 @@ async def arun( # Extend the resolved BRs. output.extend( [ - BRResolveOutput(brain_region_name=br["label"], brain_region_id=br["id"]).model_dump() + BRResolveOutput( + brain_region_name=br["label"], brain_region_id=br["id"] + ).model_dump() for br in brain_regions ] ) @@ -127,7 +129,9 @@ async def arun( # Extend the resolved mtypes. output.extend( [ - MTypeResolveOutput(mtype_name=mtype["label"], mtype_id=mtype["id"]).model_dump() + MTypeResolveOutput( + mtype_name=mtype["label"], mtype_id=mtype["id"] + ).model_dump() for mtype in mtypes ] ) From 2549d16365e241691fb1932f04bd2d0063f1786e Mon Sep 17 00:00:00 2001 From: kanesoban Date: Tue, 17 Dec 2024 14:07:14 +0100 Subject: [PATCH 13/14] Fixed fixture --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 88ee7ea..65c4ac7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -107,7 +107,7 @@ def brain_region_json_path(): return br_path -@pytest.fixture +@pytest_asyncio.fixture async def fake_llm_with_tools(brain_region_json_path): class FakeFuntionChatModel(GenericFakeChatModel): def bind_tools(self, functions: list): From d7175aec7af57244a68c8cca9d3b92f2bca72d83 Mon Sep 17 00:00:00 2001 From: kanesoban Date: Wed, 18 Dec 2024 10:46:51 +0100 Subject: [PATCH 14/14] Removed breakpoint --- swarm_copy/tools/bluenaas_memodel_getall.py | 1 - 1 file changed, 1 deletion(-) diff --git a/swarm_copy/tools/bluenaas_memodel_getall.py b/swarm_copy/tools/bluenaas_memodel_getall.py index 59688d1..6c95c0e 100644 --- a/swarm_copy/tools/bluenaas_memodel_getall.py +++ b/swarm_copy/tools/bluenaas_memodel_getall.py @@ -61,7 +61,6 @@ async def arun(self) -> dict[str, Any]: }, headers={"Authorization": f"Bearer {self.metadata.token}"}, ) - breakpoint() return PaginatedResponseUnionMEModelResponseSynaptomeModelResponse( **response.json() ).model_dump()