Skip to content

Commit

Permalink
Fixed printing errors to log file
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Jun 27, 2024
1 parent 3e010bb commit c6cd856
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 137 deletions.
159 changes: 86 additions & 73 deletions kilosort/gui/sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,79 +48,92 @@ def run(self):
results_dir.mkdir(parents=True)

setup_logger(results_dir)
logger.info(f"Kilosort version {kilosort.__version__}")
logger.info(f"Sorting {self.data_path}")
logger.info('-'*40)

tic0 = time.time()

# TODO: make these options in GUI
do_CAR=True
invert_sign=False

if not do_CAR:
logger.info("Skipping common average reference.")

if probe['chanMap'].max() >= settings['n_chan_bin']:
raise ValueError(
f'Largest value of chanMap exceeds channel count of data, '
'make sure chanMap is 0-indexed.'
)

if settings['nt0min'] is None:
settings['nt0min'] = int(20 * settings['nt']/61)
data_dtype = settings['data_dtype']
device = self.device
save_preprocessed_copy = settings['save_preprocessed_copy']

ops = initialize_ops(settings, probe, data_dtype, do_CAR,
invert_sign, device, save_preprocessed_copy)
# Remove some stuff that doesn't need to be printed twice, then pretty-print
# format for log file.
ops_copy = ops.copy()
_ = ops_copy.pop('settings')
_ = ops_copy.pop('probe')
print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False)
logger.debug(f"Initial ops:\n{print_ops}\n")

# TODO: add support for file object through data conversion
# Set preprocessing and drift correction parameters
ops = compute_preprocessing(ops, self.device, tic0=tic0,
file_object=self.file_object)
np.random.seed(1)
torch.cuda.manual_seed_all(1)
torch.random.manual_seed(1)
ops, bfile, st0 = compute_drift_correction(
ops, self.device, tic0=tic0, progress_bar=self.progress_bar,
file_object=self.file_object
)

# Check scale of data for log file
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")

if save_preprocessed_copy:
save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)

# Will be None if nblocks = 0 (no drift correction)
if st0 is not None:
self.dshift = ops['dshift']
self.st0 = st0
self.plotDataReady.emit('drift')

# Sort spikes and save results
st, tF, Wall0, clu0 = detect_spikes(ops, self.device, bfile, tic0=tic0,
progress_bar=self.progress_bar)

self.Wall0 = Wall0
self.wPCA = torch.clone(ops['wPCA'].cpu()).numpy()
self.clu0 = clu0
self.plotDataReady.emit('diagnostics')

clu, Wall = cluster_spikes(st, tF, ops, self.device, bfile, tic0=tic0,
progress_bar=self.progress_bar)
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)

try:
logger.info(f"Kilosort version {kilosort.__version__}")
logger.info(f"Sorting {self.data_path}")
logger.info('-'*40)

tic0 = time.time()

# TODO: make these options in GUI
do_CAR=True
invert_sign=False

if not do_CAR:
logger.info("Skipping common average reference.")

if probe['chanMap'].max() >= settings['n_chan_bin']:
raise ValueError(
f'Largest value of chanMap exceeds channel count of data, '
'make sure chanMap is 0-indexed.'
)

if settings['nt0min'] is None:
settings['nt0min'] = int(20 * settings['nt']/61)
data_dtype = settings['data_dtype']
device = self.device
save_preprocessed_copy = settings['save_preprocessed_copy']

ops = initialize_ops(settings, probe, data_dtype, do_CAR,
invert_sign, device, save_preprocessed_copy)
# Remove some stuff that doesn't need to be printed twice,
# then pretty-print format for log file.
ops_copy = ops.copy()
_ = ops_copy.pop('settings')
_ = ops_copy.pop('probe')
print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False)
logger.debug(f"Initial ops:\n{print_ops}\n")

# TODO: add support for file object through data conversion
# Set preprocessing and drift correction parameters
ops = compute_preprocessing(ops, self.device, tic0=tic0,
file_object=self.file_object)
np.random.seed(1)
torch.cuda.manual_seed_all(1)
torch.random.manual_seed(1)
ops, bfile, st0 = compute_drift_correction(
ops, self.device, tic0=tic0, progress_bar=self.progress_bar,
file_object=self.file_object
)

# Check scale of data for log file
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")

if save_preprocessed_copy:
save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)

# Will be None if nblocks = 0 (no drift correction)
if st0 is not None:
self.dshift = ops['dshift']
self.st0 = st0
self.plotDataReady.emit('drift')

# Sort spikes and save results
st, tF, Wall0, clu0 = detect_spikes(
ops, self.device, bfile, tic0=tic0,
progress_bar=self.progress_bar
)

self.Wall0 = Wall0
self.wPCA = torch.clone(ops['wPCA'].cpu()).numpy()
self.clu0 = clu0
self.plotDataReady.emit('diagnostics')

clu, Wall = cluster_spikes(
st, tF, ops, self.device, bfile, tic0=tic0,
progress_bar=self.progress_bar
)
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)

except:
# This makes sure the full traceback is written to log file.
logger.exception('Encountered error in `run_kilosort`:')
# Annoyingly, this will print the error message twice for console
# but I haven't found a good way around that.
raise

self.ops = ops
self.st = st[kept_spikes]
Expand Down
137 changes: 73 additions & 64 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,74 +112,82 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
filename, data_dir, results_dir, probe = \
set_files(settings, filename, probe, probe_name, data_dir, results_dir)
setup_logger(results_dir)
logger.info(f"Kilosort version {kilosort.__version__}")
logger.info(f"Sorting {filename}")
logger.info('-'*40)

if data_dtype is None:
logger.info(
"Interpreting binary file as default dtype='int16'. If data was "
"saved in a different format, specify `data_dtype`."
try:
logger.info(f"Kilosort version {kilosort.__version__}")
logger.info(f"Sorting {filename}")
logger.info('-'*40)

if data_dtype is None:
logger.info(
"Interpreting binary file as default dtype='int16'. If data was "
"saved in a different format, specify `data_dtype`."
)
data_dtype = 'int16'

if not do_CAR:
logger.info("Skipping common average reference.")

if device is None:
if torch.cuda.is_available():
logger.info('Using GPU for PyTorch computations. '
'Specify `device` to change this.')
device = torch.device('cuda')
else:
logger.info('Using CPU for PyTorch computations. '
'Specify `device` to change this.')
device = torch.device('cpu')

if probe['chanMap'].max() >= settings['n_chan_bin']:
raise ValueError(
f'Largest value of chanMap exceeds channel count of data, '
'make sure chanMap is 0-indexed.'
)
data_dtype = 'int16'

if not do_CAR:
logger.info("Skipping common average reference.")

if device is None:
if torch.cuda.is_available():
logger.info('Using GPU for PyTorch computations. '
'Specify `device` to change this.')
device = torch.device('cuda')
else:
logger.info('Using CPU for PyTorch computations. '
'Specify `device` to change this.')
device = torch.device('cpu')

if probe['chanMap'].max() >= settings['n_chan_bin']:
raise ValueError(
f'Largest value of chanMap exceeds channel count of data, '
'make sure chanMap is 0-indexed.'
)

tic0 = time.time()
ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign,
device, save_preprocessed_copy)
# Remove some stuff that doesn't need to be printed twice, then pretty-print
# format for log file.
ops_copy = ops.copy()
_ = ops_copy.pop('settings')
_ = ops_copy.pop('probe')
print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False)
logger.debug(f"Initial ops:\n{print_ops}\n")


# Set preprocessing and drift correction parameters
ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object)
np.random.seed(1)
torch.cuda.manual_seed_all(1)
torch.random.manual_seed(1)
ops, bfile, st0 = compute_drift_correction(
ops, device, tic0=tic0, progress_bar=progress_bar,
file_object=file_object
)

# Check scale of data for log file
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")

if save_preprocessed_copy:
io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)
tic0 = time.time()
ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign,
device, save_preprocessed_copy)
# Remove some stuff that doesn't need to be printed twice, then pretty-print
# format for log file.
ops_copy = ops.copy()
_ = ops_copy.pop('settings')
_ = ops_copy.pop('probe')
print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False)
logger.debug(f"Initial ops:\n{print_ops}\n")


# Set preprocessing and drift correction parameters
ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object)
np.random.seed(1)
torch.cuda.manual_seed_all(1)
torch.random.manual_seed(1)
ops, bfile, st0 = compute_drift_correction(
ops, device, tic0=tic0, progress_bar=progress_bar,
file_object=file_object
)

# Sort spikes and save results
st,tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0,
progress_bar=progress_bar)
clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0,
progress_bar=progress_bar)
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
save_extra_vars=save_extra_vars,
save_preprocessed_copy=save_preprocessed_copy)
# Check scale of data for log file
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")

if save_preprocessed_copy:
io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)

# Sort spikes and save results
st,tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0,
progress_bar=progress_bar)
clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0,
progress_bar=progress_bar)
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
save_extra_vars=save_extra_vars,
save_preprocessed_copy=save_preprocessed_copy)
except:
# This makes sure the full traceback is written to log file.
logger.exception('Encountered error in `run_kilosort`:')
# Annoyingly, this will print the error message twice for console, but
# I haven't found a good way around that.
raise

return ops, st, clu, tF, Wall, similar_templates, \
is_ref, est_contam_rate, kept_spikes
Expand Down Expand Up @@ -435,6 +443,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
Wrapped file object for handling data.
"""

tic = time.time()
logger.info(' ')
logger.info('Computing drift correction.')
Expand Down

0 comments on commit c6cd856

Please sign in to comment.