Skip to content

Commit

Permalink
Added save preprocessing option to gui, updated logging
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Jun 12, 2024
1 parent 691f1b0 commit dff0bd4
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 11 deletions.
1 change: 1 addition & 0 deletions kilosort/gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def set_parameters(self):
self.num_channels = settings["n_chan_bin"]

params = settings.copy()
params['save_preprocessed_copy'] = self.run_box.save_preproc_check.isChecked()

assert params

Expand Down
13 changes: 10 additions & 3 deletions kilosort/gui/run_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def __init__(self, parent):

self.run_all_button = QtWidgets.QPushButton("Run")
self.spike_sort_button = QtWidgets.QPushButton("Spikesort")

self.save_preproc_check = QtWidgets.QCheckBox("Save Preprocessed Copy")
self.save_preproc_check.setCheckState(QtCore.Qt.CheckState.Unchecked)

self.buttons = [
self.run_all_button,
self.run_all_button
]

self.data_path = None
Expand All @@ -43,7 +45,7 @@ def __init__(self, parent):
self.remote_widgets = None

self.progress_bar = QtWidgets.QProgressBar()
self.layout.addWidget(self.progress_bar, 2, 0, 2, 2)
self.layout.addWidget(self.progress_bar, 3, 0, 2, 2)

self.setup()

Expand All @@ -56,6 +58,7 @@ def setup(self):
)

self.layout.addWidget(self.run_all_button, 0, 0, 2, 2)
self.layout.addWidget(self.save_preproc_check, 2, 0, 1, 2)

self.setLayout(self.layout)

Expand All @@ -74,8 +77,12 @@ def reenable_buttons(self):
def disable_all_input(self, value):
if value:
self.disable_all_buttons()
# This is done separate from other buttons so that it can be checked
# on or off without needing to load data.
self.save_preproc_check.setEnabled(False)
else:
self.reenable_buttons()
self.save_preproc_check.setEnabled(True)

def set_data_path(self, data_path):
self.data_path = data_path
Expand Down
34 changes: 31 additions & 3 deletions kilosort/gui/sorter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import pprint
import logging
logger = logging.getLogger(__name__)

Expand All @@ -7,11 +8,14 @@
from qtpy import QtCore

#from kilosort.gui.logger import setup_logger
import kilosort
from kilosort.run_kilosort import (
setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction,
detect_spikes, cluster_spikes, save_sorting
)

from kilosort.io import save_preprocessing

#logger = setup_logger(__name__)


Expand Down Expand Up @@ -44,23 +48,40 @@ def run(self):
results_dir.mkdir(parents=True)

setup_logger(results_dir)
logger.info(f"Kilosort version {kilosort.__version__}")
logger.info(f"Sorting {self.data_path}")
logger.info('-'*40)

tic0 = time.time()

# TODO: make these options in GUI
do_CAR=True
invert_sign=False

if not do_CAR:
print("Skipping common average reference.")
logger.info("Skipping common average reference.")

if probe['chanMap'].max() >= settings['n_chan_bin']:
raise ValueError(
f'Largest value of chanMap exceeds channel count of data, '
'make sure chanMap is 0-indexed.'
)

if settings['nt0min'] is None:
settings['nt0min'] = int(20 * settings['nt']/61)
data_dtype = settings['data_dtype']
device = self.device
save_preprocessed_copy = settings['save_preprocessed_copy']

ops = initialize_ops(settings, probe, data_dtype, do_CAR,
invert_sign, device)
invert_sign, device, save_preprocessed_copy)
# Remove some stuff that doesn't need to be printed twice, then pretty-print
# format for log file.
ops_copy = ops.copy()
_ = ops_copy.pop('settings')
_ = ops_copy.pop('probe')
print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False)
logger.debug(f"Initial ops:\n{print_ops}\n")

# TODO: add support for file object through data conversion
# Set preprocessing and drift correction parameters
Expand All @@ -74,6 +95,13 @@ def run(self):
file_object=self.file_object
)

# Check scale of data for log file
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")

if save_preprocessed_copy:
save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)

# Will be None if nblocks = 0 (no drift correction)
if st0 is not None:
self.dshift = ops['dshift']
Expand Down
13 changes: 8 additions & 5 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
)

tic0 = time.time()
ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign, device)
ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign,
device, save_preprocessed_copy)
# Remove some stuff that doesn't need to be printed twice, then pretty-print
# format for log file.
ops_copy = ops.copy()
Expand All @@ -162,13 +163,13 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
file_object=file_object
)

if save_preprocessed_copy:
io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)

# Check scale of data for log file
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")

if save_preprocessed_copy:
io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)

# Sort spikes and save results
st,tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0,
progress_bar=progress_bar)
Expand Down Expand Up @@ -264,7 +265,8 @@ def setup_logger(results_dir):
numba_log.setLevel(logging.INFO)


def initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign, device) -> dict:
def initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign,
device, save_preprocesed_copy) -> dict:
"""Package settings and probe information into a single `ops` dictionary."""

if settings['nt0min'] is None:
Expand All @@ -280,6 +282,7 @@ def initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign, device) ->
ops['Nchan'] = len(probe['chanMap'])
ops['n_chan_bin'] = settings['n_chan_bin']
ops['torch_device'] = str(device)
ops['save_preprocessed_copy'] = save_preprocesed_copy

if not settings['templates_from_data'] and settings['nt'] != 61:
raise ValueError('If using pre-computed universal templates '
Expand Down

0 comments on commit dff0bd4

Please sign in to comment.