From 40a5a023b8a18e9fdb7950cb6c084c6000dbf675 Mon Sep 17 00:00:00 2001 From: Mustafa Kerem Kurban Date: Fri, 27 Sep 2024 19:15:44 +0200 Subject: [PATCH] implement kwargs in structured manner --- src/neuroagent/tools/bluenaas_tool.py | 60 ++++++++++++++++++++------- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/src/neuroagent/tools/bluenaas_tool.py b/src/neuroagent/tools/bluenaas_tool.py index b055e1b..d3cc050 100644 --- a/src/neuroagent/tools/bluenaas_tool.py +++ b/src/neuroagent/tools/bluenaas_tool.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__name__) +#TODO : since bluenaaas has multiple endpoints for synapse generation ,nexus query, simulation, how should we handle those? class SynapseSimulationConfig(BaseModel): id: str delay: int @@ -24,20 +25,20 @@ class SimulationStimulusConfig(BaseModel): amplitudes: Annotated[list[float], Len(min_length=1, max_length=15)] class CurrentInjectionConfig(BaseModel): - injectTo: str + injectTo: str #TODO: could be soma, dendrite/basal, apical, AIS or how its named in the platform stimulus: SimulationStimulusConfig class RecordingLocation(BaseModel): - section: str - offset: float + section: str #TODO: how to constrain this, we need to query available section ids from the obtained me model's morphology + offset: float # TODO: should be between 0-1 class SimulationConditionsConfig(BaseModel): - celsius: float - vinit: float - hypamp: float - max_time: float - time_step: float - seed: int + celsius: float # TODO: ideally this should be controlled as models dont perform well outside calibrated temperature range + vinit: float # TODO: usually default value is set to -70 mV but can change depending on the simulation + hypamp: float # this is usually set to experimentally defined value ! + max_time: float # user defined + time_step: float # usually 0.025 ms + seed: int # can be any nubmer class SimulationWithSynapseBody(BaseModel): directCurrentConfig: CurrentInjectionConfig @@ -47,7 +48,7 @@ class SimulationWithSynapseBody(BaseModel): class InputBlueNaaS(BaseModel): """Inputs for the BlueNaaS single-neuron simulation.""" - model_id: str = Field( + me_model_id: str = Field( description=( "ID of the neuron model to be used in the simulation. The model ID can be" " fetched using the 'get-me-model-tool'." @@ -78,7 +79,6 @@ class InputBlueNaaS(BaseModel): " (max_time, in ms), time step (time_step, in ms), and random seed (seed)." ) ) - #TODO: implement synaptome simulation simulationType: Literal["single-neuron-simulation","synaptome-simulation"] = Field( description=( "Type of the simulation. it can be single neuron simulation for simulation" @@ -107,14 +107,44 @@ class BlueNaaSTool(BasicTool): metadata: dict[str, Any] args_schema: Type[BaseModel] = InputBlueNaaS - async def _arun(self, **kwargs) -> BlueNaaSOutput: - """Run a single-neuron simulation using the BlueNaaS service.""" - logger.info(f"Running BlueNaaS tool with inputs: {kwargs}") + async def _arun(self, + me_model_id: str, + synapses: List[SynapseSimulationConfig], + currentInjection: CurrentInjectionConfig, + recordFrom: List[RecordingLocation], + conditions: SimulationConditionsConfig, + simulationType: Literal["single-neuron-simulation","synaptome-simulation"], + simulationDuration: int + ) -> BaseToolOutput: + """ + Run the BlueNaaS tool. + + Args: + me_model_id: ID of the neuron model to be used in the simulation. + synapses: List of synapse configurations. + currentInjection: Configuration for current injection. + recordFrom: List of sections to record from during the simulation. + conditions: Simulation conditions. + simulationType: Type of the simulation. + simulationDuration: Duration of the simulation in milliseconds. + + Returns: + BaseToolOutput: Output of the BlueNaaS tool. + """ + logger.info(f"Running BlueNaaS tool with inputs: {locals()}") try: response = await self.metadata["httpx_client"].post( url=self.metadata["url"], headers={"Authorization": f"Bearer {self.metadata['token']}"}, - json=kwargs, + json={ + "model_id": me_model_id, + "synapses": [synapse.dict() for synapse in synapses], + "currentInjection": currentInjection.dict(), + "recordFrom": [record.dict() for record in recordFrom], + "conditions": conditions.dict(), + "type": simulationType, + "simulationDuration": simulationDuration + }, ) response_data = response.json() return BlueNaaSOutput(status="success", result=response_data)