Skip to content

Commit

Permalink
Merge pull request #187 from NeuralEnsemble/feat/get-unbranched-segme…
Browse files Browse the repository at this point in the history
…nt-groups

Extend `get_segment_groups_from_substring` to include an `unbranched` filter
  • Loading branch information
sanjayankur31 authored Mar 5, 2024
2 parents 0f8ef51 + 7098d90 commit f268dc7
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 117 deletions.
1 change: 1 addition & 0 deletions neuroml/nml/gds_imports-template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import networkx as nx
import numpy
import natsort
import typing

import neuroml
import neuroml.neuro_lex_ids
95 changes: 46 additions & 49 deletions neuroml/nml/helper_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def __repr__(self):
return str(self)
def distance_to(self, other_3d_point):
def distance_to(self, other_3d_point) -> float:
"""Find the distance between this point and another.
:param other_3d_point: other 3D point to calculate distance to
Expand Down Expand Up @@ -286,13 +286,13 @@ def distance_to(self, other_3d_point):
name="connection_cell_ids",
source='''\
def _get_cell_id(self, id_string):
def _get_cell_id(self, id_string: str) -> int:
if '[' in id_string:
return int(id_string.split('[')[1].split(']')[0])
else:
return int(id_string.split('/')[2])
def get_pre_cell_id(self):
def get_pre_cell_id(self) -> str:
"""Get the ID of the pre-synaptic cell
:returns: ID of pre-synaptic cell
Expand All @@ -301,7 +301,7 @@ def get_pre_cell_id(self):
return self._get_cell_id(self.pre_cell_id)
def get_post_cell_id(self):
def get_post_cell_id(self) -> str:
"""Get the ID of the post-synaptic cell
:returns: ID of post-synaptic cell
Expand All @@ -310,7 +310,7 @@ def get_post_cell_id(self):
return self._get_cell_id(self.post_cell_id)
def get_pre_segment_id(self):
def get_pre_segment_id(self) -> str:
"""Get the ID of the pre-synpatic segment
:returns: ID of pre-synaptic segment.
Expand All @@ -319,7 +319,7 @@ def get_pre_segment_id(self):
return int(self.pre_segment_id)
def get_post_segment_id(self):
def get_post_segment_id(self) -> str:
"""Get the ID of the post-synpatic segment
:returns: ID of post-synaptic segment.
Expand Down Expand Up @@ -368,7 +368,7 @@ def __str__(self):
return "Connection "+str(self.id)+": "+str(self.get_pre_info())+" -> "+str(self.get_post_info())+ \
", weight: "+'PERCENTAGEf' PERCENTAGE (float(self.weight))+", delay: "+'PERCENTAGE.5f' PERCENTAGE (self.get_delay_in_ms())+" ms"
def get_delay_in_ms(self):
def get_delay_in_ms(self) -> float:
"""Get connection delay in milli seconds
:returns: connection delay in milli seconds
Expand Down Expand Up @@ -411,7 +411,7 @@ def __str__(self):
name="elec_connection_instance_w",
source='''\
def get_weight(self):
def get_weight(self) -> float:
"""Get the weight of the connection
If a weight is not set (or is set to None), returns the default value
Expand All @@ -438,10 +438,10 @@ def __str__(self):
name="elec_connection_cell_ids",
source='''\
def _get_cell_id(self, id_string):
def _get_cell_id(self, id_string: str) -> int:
return int(float(id_string))
def get_pre_cell_id(self):
def get_pre_cell_id(self) -> float:
"""Get the ID of the pre-synaptic cell
:returns: ID of pre-synaptic cell
Expand All @@ -450,7 +450,7 @@ def get_pre_cell_id(self):
return self._get_cell_id(self.pre_cell)
def get_post_cell_id(self):
def get_post_cell_id(self) -> str:
"""Get the ID of the post-synaptic cell
:returns: ID of post-synaptic cell
Expand All @@ -459,7 +459,7 @@ def get_post_cell_id(self):
return self._get_cell_id(self.post_cell)
def get_pre_segment_id(self):
def get_pre_segment_id(self) -> str:
"""Get the ID of the pre-synpatic segment
:returns: ID of pre-synaptic segment.
Expand All @@ -468,7 +468,7 @@ def get_pre_segment_id(self):
return int(self.pre_segment)
def get_post_segment_id(self):
def get_post_segment_id(self) -> str:
"""Get the ID of the post-synpatic segment
:returns: ID of post-synaptic segment.
Expand Down Expand Up @@ -566,7 +566,7 @@ def _get_cell_id(self, id_string):
return int(float(id_string))
def get_pre_cell_id(self):
def get_pre_cell_id(self) -> str:
"""Get the ID of the pre-synaptic cell
:returns: ID of pre-synaptic cell
Expand All @@ -575,7 +575,7 @@ def get_pre_cell_id(self):
return self._get_cell_id(self.pre_cell)
def get_post_cell_id(self):
def get_post_cell_id(self) -> str:
"""Get the ID of the post-synaptic cell
:returns: ID of post-synaptic cell
Expand All @@ -584,7 +584,7 @@ def get_post_cell_id(self):
return self._get_cell_id(self.post_cell)
def get_pre_segment_id(self):
def get_pre_segment_id(self) -> str:
"""Get the ID of the pre-synpatic segment
:returns: ID of pre-synaptic segment.
Expand All @@ -593,7 +593,7 @@ def get_pre_segment_id(self):
return int(self.pre_segment)
def get_post_segment_id(self):
def get_post_segment_id(self) -> str:
"""Get the ID of the post-synpatic segment
:returns: ID of post-synaptic segment.
Expand Down Expand Up @@ -879,7 +879,7 @@ def summary(self, show_includes=True, show_non_network=True):
warn_count = 0
def get_by_id(self,id):
def get_by_id(self, id: str) -> typing.Optional[typing.Any]:
"""Get a component by specifying its ID.
:param id: id of Component to get
Expand Down Expand Up @@ -929,7 +929,7 @@ def append(self, element):
source='''\
warn_count = 0
def get_by_id(self,id):
def get_by_id(self, id: str) -> typing.Optional[typing.Any]:
"""Get a component by its ID
:param id: ID of component to find
Expand Down Expand Up @@ -972,8 +972,7 @@ def __str__(self):
source='''\
# Get segment object by its id
def get_segment(self, segment_id):
# type: (int) -> Segment
def get_segment(self, segment_id: int) -> Segment:
"""Get segment object by its id
:param segment_id: ID of segment
Expand All @@ -988,8 +987,7 @@ def get_segment(self, segment_id):
raise ValueError("Segment with id "+str(segment_id)+" not found in cell "+str(self.id))
def get_segments_by_substring(self, substring):
# type: (str) -> dict
def get_segments_by_substring(self, substring: str) -> typing.Dict[str, Segment]:
"""Get a dictionary of segment IDs and the segment matching the specified substring
:param substring: substring to match
Expand All @@ -1010,8 +1008,7 @@ def get_segments_by_substring(self, substring):
# Get the proximal point of a segment, even the proximal field is None and
# so the proximal point is on the parent (at a point set by fraction_along)
def get_actual_proximal(self, segment_id):
# type: (str) -> Point3DWithDiam
def get_actual_proximal(self, segment_id: str):
"""Get the proximal point of a segment.
If the proximal for the segment is set to None, calculate the proximal
Expand Down Expand Up @@ -1040,8 +1037,7 @@ def get_actual_proximal(self, segment_id):
return p
def get_segment_length(self, segment_id):
# type: (str) -> float
def get_segment_length(self, segment_id: str) -> float:
"""Get the length of the segment.
:param segment_id: ID of segment
Expand All @@ -1058,8 +1054,7 @@ def get_segment_length(self, segment_id):
return length
def get_segment_surface_area(self, segment_id):
# type: (str) -> float
def get_segment_surface_area(self, segment_id: str) -> float:
"""Get the surface area of the segment.
:param segment_id: ID of the segment
Expand All @@ -1076,8 +1071,7 @@ def get_segment_surface_area(self, segment_id):
return temp_seg.surface_area
def get_segment_volume(self, segment_id):
# type: (str) -> float
def get_segment_volume(self, segment_id: str) -> float:
"""Get volume of segment
:param segment_id: ID of the segment
Expand All @@ -1093,8 +1087,7 @@ def get_segment_volume(self, segment_id):
return temp_seg.volume
def get_segment_ids_vs_segments(self):
# type: () -> Dict
def get_segment_ids_vs_segments(self) -> typing.Dict[str, Segment]:
"""Get a dictionary of segment IDs and the segments in the cell.
:return: dictionary with segment ID as key, and segment as value
Expand All @@ -1107,9 +1100,8 @@ def get_segment_ids_vs_segments(self):
return segments
def get_all_segments_in_group(self,
segment_group,
assume_all_means_all=True):
# type: (SegmentGroup, bool) -> List[int]
segment_group: SegmentGroup,
assume_all_means_all: bool = True) -> typing.List[int]:
"""Get all the segments in a segment group of the cell.
:param segment_group: segment group to get all segments of
Expand Down Expand Up @@ -1149,12 +1141,12 @@ def get_all_segments_in_group(self,
def get_ordered_segments_in_groups(self,
group_list,
check_parentage=False,
include_cumulative_lengths=False,
include_path_lengths=False,
path_length_metric="Path Length from root"): # Only option supported
# type: (List, bool, bool, bool, str) -> Any
group_list: typing.List[str],
check_parentage: bool = False,
include_cumulative_lengths: bool = False,
include_path_lengths: bool = False,
path_length_metric: str = "Path Length from root" # Only option supported
) -> typing.Any:
"""
Get ordered list of segments in specified groups, with additional
information.
Expand Down Expand Up @@ -1304,8 +1296,7 @@ def get_ordered_segments_in_groups(self,
return ord_segs
def get_segment_group(self, sg_id):
# type: (str) -> SegmentGroup
def get_segment_group(self, sg_id: str) -> SegmentGroup:
"""Return the SegmentGroup object for the specified segment group id.
:param sg_id: id of segment group to find
Expand All @@ -1320,19 +1311,25 @@ def get_segment_group(self, sg_id):
raise ValueError("Segment group with id "+str(sg_id)+" not found in cell "+str(self.id))
def get_segment_groups_by_substring(self, substring):
# type: (str) -> dict
def get_segment_groups_by_substring(self, substring: str, unbranched: bool = False) -> typing.Dict[str, SegmentGroup]:
"""Get a dictionary of segment group IDs and the segment groups matching the specified substring
:param substring: substring to match
:param substring: substring to match, an empty string "" matches all
groups
:type substring: str
:param unbranced: toggle selecting unbranched segment groups
:type unbranched: bool
:return: dictionary with segment group ID as key, and segment group as value
:raises ValueError: if no matching segment groups are found in cell
"""
sgs = {}
for sg in self.morphology.segment_groups:
if substring in sg.id:
sgs[sg.id] = sg
if substring == "" or substring in sg.id:
if unbranched is True:
if sg.neuro_lex_id == neuroml.neuro_lex_ids.neuro_lex_ids["section"]:
sgs[sg.id] = sg
else:
sgs[sg.id] = sg
if len(sgs) == 0:
raise ValueError("Segment group with id matching "+str(substring)+" not found in cell "+str(self.id))
return sgs
Expand Down
Loading

0 comments on commit f268dc7

Please sign in to comment.