Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added model dumps #49

Merged
merged 17 commits into from
Dec 18, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed
- Return model dumps of DB schema objects.

### Added
- LLM evaluation logic
- Integrated Alembic for managing chat history migrations
Expand Down
2 changes: 1 addition & 1 deletion src/neuroagent/cell_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CellTypesMeta:
"""

def __init__(self) -> None:
self.name_: dict[str, str] = {}
self.name_: dict[Any, Any | None] = {}
self.descendants_ids: dict[str, set[str]] = {}

def descendants(self, ids: str | set[str]) -> set[str]:
Expand Down
2 changes: 1 addition & 1 deletion swarm_copy/cell_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CellTypesMeta:
"""

def __init__(self) -> None:
self.name_: dict[str, str] = {}
self.name_: dict[Any, Any | None] = {}
self.descendants_ids: dict[str, set[str]] = {}

def descendants(self, ids: str | set[str]) -> set[str]:
Expand Down
7 changes: 3 additions & 4 deletions swarm_copy/tools/bluenaas_memodel_getall.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""BlueNaaS single cell stimulation, simulation and synapse placement tool."""

import logging
from typing import ClassVar, Literal
from typing import Any, ClassVar, Literal

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -46,7 +46,7 @@ class MEModelGetAllTool(BaseTool):
metadata: MEModelGetAllMetadata
input_schema: InputMEModelGetAll

async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelResponse:
async def arun(self) -> dict[str, Any]:
"""Run the MEModelGetAll tool."""
logger.info(
f"Running MEModelGetAll tool with inputs {self.input_schema.model_dump()}"
Expand All @@ -61,7 +61,6 @@ async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelRespo
},
headers={"Authorization": f"Bearer {self.metadata.token}"},
)
breakpoint()
return PaginatedResponseUnionMEModelResponseSynaptomeModelResponse(
**response.json()
)
).model_dump()
6 changes: 3 additions & 3 deletions swarm_copy/tools/bluenaas_memodel_getone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""BlueNaaS single cell stimulation, simulation and synapse placement tool."""

import logging
from typing import ClassVar
from typing import Any, ClassVar
from urllib.parse import quote_plus

from pydantic import BaseModel, Field
Expand Down Expand Up @@ -38,7 +38,7 @@ class MEModelGetOneTool(BaseTool):
metadata: MEModelGetOneMetadata
input_schema: InputMEModelGetOne

async def arun(self) -> MEModelResponse:
async def arun(self) -> dict[str, Any]:
"""Run the MEModelGetOne tool."""
logger.info(
f"Running MEModelGetOne tool with inputs {self.input_schema.model_dump()}"
Expand All @@ -49,4 +49,4 @@ async def arun(self) -> MEModelResponse:
headers={"Authorization": f"Bearer {self.metadata.token}"},
)

return MEModelResponse(**response.json())
return MEModelResponse(**response.json()).model_dump()
8 changes: 5 additions & 3 deletions swarm_copy/tools/bluenaas_scs_getall.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""BlueNaaS single cell stimulation, simulation and synapse placement tool."""

import logging
from typing import ClassVar, Literal
from typing import Any, ClassVar, Literal

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -47,7 +47,7 @@ class SCSGetAllTool(BaseTool):
metadata: SCSGetAllMetadata
input_schema: InputSCSGetAll

async def arun(self) -> PaginatedResponseSimulationDetailsResponse:
async def arun(self) -> dict[str, Any]:
"""Run the SCSGetAll tool."""
logger.info(
f"Running SCSGetAll tool with inputs {self.input_schema.model_dump()}"
Expand All @@ -63,4 +63,6 @@ async def arun(self) -> PaginatedResponseSimulationDetailsResponse:
headers={"Authorization": f"Bearer {self.metadata.token}"},
)

return PaginatedResponseSimulationDetailsResponse(**response.json())
return PaginatedResponseSimulationDetailsResponse(
**response.json()
).model_dump()
6 changes: 3 additions & 3 deletions swarm_copy/tools/bluenaas_scs_getone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""BlueNaaS single cell stimulation, simulation and synapse placement tool."""

import logging
from typing import ClassVar
from typing import Any, ClassVar

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -39,7 +39,7 @@ class SCSGetOneTool(BaseTool):
metadata: SCSGetOneMetadata
input_schema: InputSCSGetOne

async def arun(self) -> SimulationDetailsResponse:
async def arun(self) -> dict[str, Any]:
"""Run the SCSGetOne tool."""
logger.info(
f"Running SCSGetOne tool with inputs {self.input_schema.model_dump()}"
Expand All @@ -50,4 +50,4 @@ async def arun(self) -> SimulationDetailsResponse:
headers={"Authorization": f"Bearer {self.metadata.token}"},
)

return SimulationDetailsResponse(**response.json())
return SimulationDetailsResponse(**response.json()).model_dump()
4 changes: 2 additions & 2 deletions swarm_copy/tools/bluenaas_scs_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class SCSPostTool(BaseTool):
metadata: SCSPostMetadata
input_schema: InputSCSPost

async def arun(self) -> SCSPostOutput:
async def arun(self) -> dict[str, Any]:
"""Run the SCSPost tool."""
logger.info(
f"Running SCSPost tool with inputs {self.input_schema.model_dump()}"
Expand Down Expand Up @@ -126,7 +126,7 @@ async def arun(self) -> SCSPostOutput:
status=json_response["status"],
name=json_response["name"],
error=json_response["error"],
)
).model_dump()

@staticmethod
def create_json_api(
Expand Down
4 changes: 2 additions & 2 deletions swarm_copy/tools/electrophys_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class ElectrophysFeatureTool(BaseTool):
input_schema: ElectrophysInput
metadata: ElectrophysMetadata

async def arun(self) -> FeatureOutput:
async def arun(self) -> dict[str, Any]:
"""Give features about trace."""
logger.info(
f"Entering electrophys tool. Inputs: {self.input_schema.trace_id=}, {self.input_schema.calculated_feature=},"
Expand Down Expand Up @@ -329,4 +329,4 @@ async def arun(self) -> FeatureOutput:
)
return FeatureOutput(
brain_region=metadata.brain_region, feature_dict=output_features
)
).model_dump()
6 changes: 3 additions & 3 deletions swarm_copy/tools/get_morpho_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class GetMorphoTool(BaseTool):
input_schema: GetMorphoInput
metadata: GetMorphoMetadata

async def arun(self) -> list[KnowledgeGraphOutput]:
async def arun(self) -> list[dict[str, Any]]:
"""From a brain region ID, extract morphologies.

Returns
Expand Down Expand Up @@ -175,7 +175,7 @@ def create_query(
return entire_query

@staticmethod
def _process_output(output: Any) -> list[KnowledgeGraphOutput]:
def _process_output(output: Any) -> list[dict[str, Any]]:
"""Process output to fit the KnowledgeGraphOutput pydantic class defined above.

Parameters
Expand Down Expand Up @@ -211,7 +211,7 @@ def _process_output(output: Any) -> list[KnowledgeGraphOutput]:
if "subjectAge" in res["_source"]
else None
),
)
).model_dump()
for res in output["hits"]["hits"]
]
return formatted_output
6 changes: 3 additions & 3 deletions swarm_copy/tools/kg_morpho_features_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class KGMorphoFeatureTool(BaseTool):
input_schema: KGMorphoFeatureInput
metadata: KGMorphoFeatureMetadata

async def arun(self) -> list[KGMorphoFeatureOutput]:
async def arun(self) -> list[dict[str, Any]]:
"""Run the tool async.

Returns
Expand Down Expand Up @@ -319,7 +319,7 @@ def create_query(
return entire_query

@staticmethod
def _process_output(output: Any) -> list[KGMorphoFeatureOutput]:
def _process_output(output: Any) -> list[dict[str, Any]]:
"""Process output.

Parameters
Expand Down Expand Up @@ -347,7 +347,7 @@ def _process_output(output: Any) -> list[KGMorphoFeatureOutput]:
morphology_id=morpho_source["neuronMorphology"]["@id"],
morphology_name=morpho_source["neuronMorphology"].get("name"),
features=feature_output,
)
).model_dump()
)

return formatted_output
6 changes: 3 additions & 3 deletions swarm_copy/tools/literature_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class LiteratureSearchTool(BaseTool):
input_schema: LiteratureSearchInput
metadata: LiteratureSearchMetadata

async def arun(self) -> list[ParagraphMetadata]:
async def arun(self) -> list[dict[str, Any]]:
"""Async search the scientific literature and returns citations.

Returns
Expand Down Expand Up @@ -91,7 +91,7 @@ async def arun(self) -> list[ParagraphMetadata]:
return self._process_output(response.json())

@staticmethod
def _process_output(output: list[dict[str, Any]]) -> list[ParagraphMetadata]:
def _process_output(output: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Process output."""
paragraphs_metadata = [
ParagraphMetadata(
Expand All @@ -101,7 +101,7 @@ def _process_output(output: list[dict[str, Any]]) -> list[ParagraphMetadata]:
section=paragraph["section"],
article_doi=paragraph["article_doi"],
journal_issn=paragraph["journal_issn"],
)
).model_dump()
for paragraph in output
]
return paragraphs_metadata
4 changes: 2 additions & 2 deletions swarm_copy/tools/morphology_features_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class MorphologyFeatureTool(BaseTool):
input_schema: MorphologyFeatureInput
metadata: MorphologyFeatureMetadata

async def arun(self) -> list[MorphologyFeatureOutput]:
async def arun(self) -> list[dict[str, Any]]:
"""Give features about morphology."""
logger.info(
f"Entering morphology feature tool. Inputs: {self.input_schema.morphology_id=}"
Expand All @@ -71,7 +71,7 @@ async def arun(self) -> list[MorphologyFeatureOutput]:
return [
MorphologyFeatureOutput(
brain_region=metadata.brain_region, feature_dict=features
)
).model_dump()
]

def get_features(self, morphology_content: bytes, reader: str) -> dict[str, Any]:
Expand Down
16 changes: 10 additions & 6 deletions swarm_copy/tools/resolve_entities_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tool to resolve the brain region from natural english to a KG ID."""

import logging
from typing import ClassVar
from typing import Any, ClassVar

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -86,14 +86,14 @@ class ResolveEntitiesTool(BaseTool):

async def arun(
self,
) -> list[BRResolveOutput | MTypeResolveOutput | EtypeResolveOutput]:
) -> list[dict[str, Any]]:
"""Given a brain region in natural language, resolve its ID."""
logger.info(
f"Entering Brain Region resolver tool. Inputs: {self.input_schema.brain_region=}, "
f"{self.input_schema.mtype=}, {self.input_schema.etype=}"
)
# Prepare the output list.
output: list[BRResolveOutput | MTypeResolveOutput | EtypeResolveOutput] = []
output: list[dict[str, Any]] = []

# First resolve the brain regions.
brain_regions = await resolve_query(
Expand All @@ -108,7 +108,9 @@ async def arun(
# Extend the resolved BRs.
output.extend(
[
BRResolveOutput(brain_region_name=br["label"], brain_region_id=br["id"])
BRResolveOutput(
brain_region_name=br["label"], brain_region_id=br["id"]
).model_dump()
for br in brain_regions
]
)
Expand All @@ -127,7 +129,9 @@ async def arun(
# Extend the resolved mtypes.
output.extend(
[
MTypeResolveOutput(mtype_name=mtype["label"], mtype_id=mtype["id"])
MTypeResolveOutput(
mtype_name=mtype["label"], mtype_id=mtype["id"]
).model_dump()
for mtype in mtypes
]
)
Expand All @@ -138,7 +142,7 @@ async def arun(
EtypeResolveOutput(
etype_name=self.input_schema.etype,
etype_id=ETYPE_IDS[self.input_schema.etype],
)
).model_dump()
)

return output
6 changes: 3 additions & 3 deletions swarm_copy/tools/traces_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GetTracesTool(BaseTool):
input_schema: GetTracesInput
metadata: GetTracesMetadata

async def arun(self) -> list[TracesOutput]:
async def arun(self) -> list[dict[str, Any]]:
"""From a brain region ID, extract traces."""
logger.info(
f"Entering get trace tool. Inputs: {self.input_schema.brain_region_id=}, {self.input_schema.etype_id=}"
Expand Down Expand Up @@ -153,7 +153,7 @@ def create_query(
return entire_query

@staticmethod
def _process_output(output: Any) -> list[TracesOutput]:
def _process_output(output: Any) -> list[dict[str, Any]]:
"""Process output to fit the TracesOutput pydantic class defined above.

Parameters
Expand Down Expand Up @@ -190,7 +190,7 @@ def _process_output(output: Any) -> list[TracesOutput]:
if "subjectAge" in res["_source"]
else None
),
)
).model_dump()
for res in output["hits"]["hits"]
]
return results
Loading