diff --git a/ephyviewer/epochencoder.py b/ephyviewer/epochencoder.py index c140ce2..31ba3ba 100644 --- a/ephyviewer/epochencoder.py +++ b/ephyviewer/epochencoder.py @@ -89,6 +89,10 @@ def __init__(self, **kargs): self.refresh_table() + # IMPORTANT: Any time the contents of self.source are changed (e.g., by + # calling self.source.add_epoch), this flag must be set to True! + self.has_unsaved_changes = False + def make_params(self): # Create parameters @@ -298,12 +302,13 @@ def make_param_controller(self): def closeEvent(self, event): - text = 'Do you want to save epoch encoder changes before closing?' - title = 'Save?' - mb = QT.QMessageBox.question(self, title,text, - QT.QMessageBox.Ok , QT.QMessageBox.Discard) - if mb==QT.QMessageBox.Ok: - self.source.save() + if self.has_unsaved_changes: + text = 'Do you want to save epoch encoder changes before closing?' + title = 'Save?' + mb = QT.QMessageBox.question(self, title,text, + QT.QMessageBox.Ok , QT.QMessageBox.Discard) + if mb==QT.QMessageBox.Ok: + self.source.save() self.thread.quit() self.thread.wait() @@ -460,6 +465,8 @@ def on_label_shortcut(self): label, modifier_used = self.label_shortcuts.get(self.sender(), None) if label is None: return + self.has_unsaved_changes = True + range_selection_is_enabled = self.but_range.isChecked() if range_selection_is_enabled: @@ -488,6 +495,7 @@ def on_label_shortcut(self): self.refresh_table() def on_merge_neighbors(self): + self.has_unsaved_changes = True self.source.merge_neighbors() self.refresh() self.refresh_table() @@ -497,6 +505,7 @@ def on_fill_blank(self): dia = tools.ParamDialog(params, title='Fill blank method', parent=self) dia.resize(300, 100) if dia.exec_(): + self.has_unsaved_changes = True d = dia.get() method = d['method'] self.source.fill_blank(method=method) @@ -506,6 +515,7 @@ def on_fill_blank(self): def on_save(self): self.source.save() + self.has_unsaved_changes = False def on_spin_limit_changed(self, v): self.region.blockSignals(True) @@ -523,6 +533,8 @@ def on_region_changed(self): self.spin_limit2.blockSignals(False) def apply_region(self): + self.has_unsaved_changes = True + rgn = self.region.getRegion() t = rgn[0] duration = rgn[1] - rgn[0] @@ -539,6 +551,8 @@ def apply_region(self): self.refresh_table() def delete_region(self): + self.has_unsaved_changes = True + rgn = self.region.getRegion() self.source.delete_in_between(rgn[0], rgn[1]) @@ -628,6 +642,8 @@ def on_seek_table(self): def on_change_label(self, id, new_label): + self.has_unsaved_changes = True + # get index corresponding to epoch id ind = self.source.id_to_ind[id] @@ -642,6 +658,7 @@ def delete_selected_epoch(self): selected_ind = self.table_widget.selectedIndexes() if len(selected_ind)==0: return + self.has_unsaved_changes = True ind = selected_ind[0].row() self.source.delete_epoch(ind) self.refresh() @@ -653,6 +670,7 @@ def duplicate_selected_epoch(self): selected_ind = self.table_widget.selectedIndexes() if len(selected_ind)==0: return + self.has_unsaved_changes = True ind = selected_ind[0].row() self.source.add_epoch(self.source.ep_times[ind], self.source.ep_durations[ind], self.source.ep_labels[ind]) self.refresh() @@ -667,6 +685,7 @@ def split_selected_epoch(self): ind = selected_ind[0].row() if self.t <= self.source.ep_times[ind] or self.source.ep_stops[ind] <= self.t: return + self.has_unsaved_changes = True self.source.split_epoch(ind, self.t) self.refresh() self.refresh_table()