Skip to content

Commit

Permalink
Merge pull request SpikeInterface#1777 from alejoe91/check-probe-in-w…
Browse files Browse the repository at this point in the history
…aveform-extractor

Check probe in waveform extractor and add `has_probe` method
  • Loading branch information
samuelgarcia authored Jul 6, 2023
2 parents d6d9769 + e493178 commit 20763a5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
35 changes: 29 additions & 6 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def has_scaled(self):
else:
return True

def has_probe(self):
return "contact_vector" in self.get_property_keys()

def is_filtered(self):
# the is_filtered is handle with annotation
return self._annotations.get("is_filtered", False)
Expand Down Expand Up @@ -220,10 +223,7 @@ def get_probegroup(self):
raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.")
else:
warn("There is no Probe attached to this recording. Creating a dummy one with contact positions")
ndim = positions.shape[1]
probe = Probe(ndim=ndim)
probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5})
probe.set_device_channel_indices(np.arange(self.get_num_channels(), dtype="int64"))
probe = self.create_dummy_probe_from_locations(positions)
# probe.create_auto_shape()
probegroup = ProbeGroup()
probegroup.add_probe(probe)
Expand All @@ -248,9 +248,9 @@ def _extra_metadata_to_folder(self, folder):
probegroup = self.get_probegroup()
write_probeinterface(folder / "probe.json", probegroup)

def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params={"radius": 1}, axes="xy"):
def create_dummy_probe_from_locations(self, locations, shape="circle", shape_params={"radius": 1}, axes="xy"):
"""
Sets a 'dummy' probe based on locations.
Creates a 'dummy' probe based on locations.
Parameters
----------
Expand All @@ -262,6 +262,11 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params
Shape parameters, by default {"radius": 1}
axes : str, optional
If ndim is 3, indicates the axes that define the plane of the electrodes, by default "xy"
Returns
-------
probe : Probe
The created probe
"""
ndim = locations.shape[1]
probe = Probe(ndim=2)
Expand All @@ -275,6 +280,24 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params
if ndim == 3:
probe = probe.to_3d(axes=axes)

return probe

def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params={"radius": 1}, axes="xy"):
"""
Sets a 'dummy' probe based on locations.
Parameters
----------
locations : np.array
Array with channel locations (num_channels, ndim) [ndim can be 2 or 3]
shape : str, optional
Electrode shapes, by default "circle"
shape_params : dict, optional
Shape parameters, by default {"radius": 1}
axes : str, optional
If ndim is 3, indicates the axes that define the plane of the electrodes, by default "xy"
"""
probe = self.create_dummy_probe_from_locations(locations, shape=shape, shape_params=shape_params, axes=axes)
self.set_probe(probe, in_place=True)

def set_channel_locations(self, locations, channel_ids=None):
Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,10 @@ def extract_waveforms(

estimate_kwargs, job_kwargs = split_job_kwargs(kwargs)

assert (
recording.has_probe()
), "Recording must have a probe to extract waveforms. Use the `set_probe()` or `set_dummy_probe_from_locations()` methods."

if mode == "folder":
assert folder is not None
folder = Path(folder)
Expand Down

0 comments on commit 20763a5

Please sign in to comment.