Skip to content

Commit

Permalink
Merge pull request #59 from leofang/dev
Browse files Browse the repository at this point in the history
Support batch cropping
  • Loading branch information
leofang authored Aug 5, 2019
2 parents b4f63b8 + 835727e commit 211041e
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 21 deletions.
2 changes: 1 addition & 1 deletion nsls2ptycho/core/ptycho
2 changes: 1 addition & 1 deletion nsls2ptycho/core/ptycho_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self):
self.pha_min = -1.0 #

self.gpu_flag = True # whether to use GPU
self.gpus = [1, 2, 3] # should be a list of gpu numbers, ex: [0, 2, 3]
self.gpus = [0, 1] # should be a list of gpu numbers, ex: [0, 2, 3]
self.gpu_batch_size = 256 # should be 4^n, ex: 4, 16, 64, 256, 1024, 4096, ...
self.use_NCCL = False
self.use_CUDA_MPI = False
Expand Down
2 changes: 1 addition & 1 deletion nsls2ptycho/core/widgets/mplcanvastool.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def reset(self):
self.canvas.draw()

def draw_image(self, image, cmap='gray', init_roi=False, use_log=False):
print(cmap, init_roi, use_log)
#print(cmap, init_roi, use_log)
if use_log:
print('log scale')
image_data = np.nan_to_num(np.log(image + 1.))
Expand Down
107 changes: 89 additions & 18 deletions nsls2ptycho/ptycho_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import h5py
import numpy as np
from numpy import pi
import matplotlib.pyplot as plt
import traceback

# for frontend-backend communication
Expand All @@ -46,6 +45,8 @@


class MainWindow(QtWidgets.QMainWindow, ui_ptycho.Ui_MainWindow):
_mainwindow_signal = QtCore.pyqtSignal()

def __init__(self, parent=None, param:Param=None):
super().__init__(parent)
self.setupUi(self)
Expand Down Expand Up @@ -76,6 +77,8 @@ def __init__(self, parent=None, param:Param=None):
self.ck_position_correction_flag.clicked.connect(self.updateCorrFlg)
self.ck_refine_data_flag.clicked.connect(self.updateRefineDataFlg)
self.ck_postprocessing_flag.clicked.connect(self.showNoPostProcessingWarning)
self.ck_batch_crop_flag.clicked.connect(self.updateBatchCropDataFlg)
self.cb_dataloader.currentTextChanged.connect(self.updateBatchCropDataFlg)

self.btn_recon_start.clicked.connect(self.start)
self.btn_recon_stop.clicked.connect(self.stop)
Expand Down Expand Up @@ -144,6 +147,7 @@ def __init__(self, parent=None, param:Param=None):
self.updatePcFlg()
self.updateCorrFlg()
self.updateRefineDataFlg()
self.updateBatchCropDataFlg()
self.checkGpuAvail()
self.updateGpuFlg()
self.resetExperimentalParameters() # probably not necessary
Expand Down Expand Up @@ -180,8 +184,6 @@ def resetButtons(self):
self.btn_recon_batch_start.setEnabled(True)
self.btn_recon_batch_stop.setEnabled(False)
self.recon_bar.setValue(0)
#plt.ioff()
plt.close('all')
# close the mmap arrays
# removing these arrays, can be changed later if needed
if self._prb is not None:
Expand Down Expand Up @@ -573,11 +575,12 @@ def start(self, batch_mode=False):


def stop(self, batch_mode=False):
if self._ptycho_gpu_thread is not None and self._ptycho_gpu_thread.isRunning():
if self._ptycho_gpu_thread is not None:
if batch_mode:
self._ptycho_gpu_thread.finished.disconnect(self._batch_manager)
self._ptycho_gpu_thread.kill() # first kill the mpi processes
self._ptycho_gpu_thread.quit() # then quit QThread gracefully
if self._ptycho_gpu_thread.isRunning():
self._ptycho_gpu_thread.kill() # first kill the mpi processes
self._ptycho_gpu_thread.quit() # then quit QThread gracefully
self._ptycho_gpu_thread = None
self.resetButtons()
if self.reconStepWindow is not None:
Expand Down Expand Up @@ -883,6 +886,20 @@ def updateRefineDataFlg(self):
self.param.refine_data_flag = flag


def updateBatchCropDataFlg(self):
if self.cb_dataloader.currentText() != "Load from databroker":
flag = False
self.ck_batch_crop_flag.setChecked(flag)
self.ck_batch_crop_flag.setEnabled(flag)
else:
flag = self.ck_batch_crop_flag.isChecked()
self.ck_batch_crop_flag.setEnabled(True)
self.sp_batch_x0.setEnabled(flag)
self.sp_batch_y0.setEnabled(flag)
self.sp_batch_width.setEnabled(flag)
self.sp_batch_height.setEnabled(flag)


def showNoPostProcessingWarning(self):
if not self.ck_postprocessing_flag.isChecked():
print("[WARNING] Post-processing is turned off. No result will be written to disk!", file=sys.stderr)
Expand Down Expand Up @@ -915,12 +932,14 @@ def resetMPIFlg(self):


def batchStart(self):
'''
Currently only support load from h5.
'''
if self.cb_dataloader.currentText() == "Load from databroker":
print("[WARNING] Batch mode with databroker is not yet supported. Abort.", file=sys.stderr)
if not self.ck_batch_crop_flag.isChecked() and not self.ck_batch_run_flag.isChecked():
print("[WARNING] Choose least one action (Crop or Run). Stop.", file=sys.stderr)
return

if self.cb_dataloader.currentText() == "Load from databroker":
if not self.ck_batch_crop_flag.isChecked():
print("[WARNING] Batch mode with databroker is set, but \"Crop data\" is not.\n"
"[WARNING] Will attempt to load h5 from working directory", file=sys.stderr)

try:
self._scan_numbers = parse_range(self.le_batch_items.text(), self.sp_batch_step.value())
Expand All @@ -944,10 +963,19 @@ def batchStop(self):
'''
Brute-force abortion of the entire batch. No resumption is possible.
'''
#self._ptycho_gpu_thread.finished.disconnect(self._batch_manager)
self._scan_numbers = None
self.le_scan_num.textChanged.connect(self.forceLoad)
self.stop(True)
if self.roiWindow is not None:
if self.roiWindow._worker_thread is not None:
self.roiWindow._worker_thread.disconnect()
## thread.terminate() freezes the whole GUI -- why?
#if self.roiWindow._worker_thread.isRunning():
# self.roiWindow._worker_thread.terminate()
# self.roiWindow._worker_thread.wait()
self.roiWindow._worker_thread = None
self.roiWindow = None
self.resetButtons()


def _batch_manager(self):
Expand All @@ -958,19 +986,61 @@ def _batch_manager(self):
is not helping.
'''
# TODO: think what if anything goes wrong in the middle. Is this robust?
if self._scan_numbers is None:
return

if len(self._scan_numbers) > 0:
scan_num = self._scan_numbers.pop()
print("begin processing scan " + str(scan_num) + "...")
print("[BATCH] begin processing scan " + str(scan_num) + "...")
self.le_scan_num.setText(str(scan_num))
self.loadExpParam()
self.start(True)
self.btn_recon_batch_start.setEnabled(False)
self.btn_recon_batch_stop.setEnabled(True)

if self.ck_batch_crop_flag.isChecked():
self._batch_crop() # also handles "Run" if needed
elif self.ck_batch_run_flag.isChecked():
self._batch_run() # h5 exists, just "Run"
else:
raise
else:
print("batch processing complete!")
print("[BATCH] batch processing complete!")
self._scan_numbers = None
self.le_scan_num.textChanged.connect(self.forceLoad)
self.resetButtons()
if self.roiWindow is not None:
self.roiWindow = None


def _batch_crop(self):
# ugly hack: pretend the ROI window exists, take the first frame for finding bad pixels,
# mimic human input, and run the reconstruction (if checked)

# first get params from databroker
eventloop = self._batch_eventloop = QtCore.QEventLoop()
self._mainwindow_signal.connect(eventloop.quit)
self.loadExpParam()
eventloop.exec()

# then invoke the h5 worker in RoiWindow
if self.roiWindow is not None:
self.roiWindow.close()
img = self._viewDataFrameBroker(0)
self.roiWindow = RoiWindow(image=img, main_window=self)
#self.roiWindow.roi_changed.connect(self._get_roi_slot)
self.roiWindow.canvas._eventHandler.set_curr_roi(self.roiWindow.canvas.ax,
(self.sp_batch_x0.value(), self.sp_batch_y0.value()),
self.sp_batch_width.value(), self.sp_batch_height.value())
#print("ROI:", self.roiWindow.canvas.get_red_roi())
self.roiWindow.save_to_h5()
#self.btn_recon_batch_stop.clicked.connect(self.roiWindow._worker_thread.terminate)
if not self.ck_batch_run_flag.isChecked():
self.roiWindow._worker_thread.finished.connect(self._batch_manager)
else:
self.roiWindow._worker_thread.finished.connect(self._batch_run)


def _batch_run(self):
self.start(True)


def switchProbeBatch(self):
Expand Down Expand Up @@ -1081,7 +1151,7 @@ def _get_roi_slot(self, x0, y0, width, height):
print(x0, y0, width, height)


def loadExpParam(self):
def loadExpParam(self):
scan_num = self.le_scan_num.text()

try:
Expand Down Expand Up @@ -1126,7 +1196,7 @@ def _loadExpParamBroker(self, scan_id:int):
thread.start()


def _setExpParamBroker(self, it, metadata:dict):
def _setExpParamBroker(self, it, metadata:dict):
'''
Notes:
1. The parameter "it" is just a placeholder for the signal
Expand Down Expand Up @@ -1155,6 +1225,7 @@ def _setExpParamBroker(self, it, metadata:dict):
self.cb_scan_type.setCurrentText(metadata['scan_type'])
self._scan_points = metadata['points']
print("done")
self._mainwindow_signal.emit()


def setLoadButton(self):
Expand Down
1 change: 1 addition & 0 deletions nsls2ptycho/roi_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def save_to_h5(self):
self.cx, self.cy, threshold, badpixels, blue_rois)
thread.finished.connect(lambda: self.btn_save_to_h5.setEnabled(True))
thread.exception_handler = master.exception_handler
thread.setTerminationEnabled()
self.btn_save_to_h5.setEnabled(False)
thread.start()

Expand Down
50 changes: 50 additions & 0 deletions nsls2ptycho/ui/ui_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,49 @@ def setupUi(self, MainWindow):
self.btn_recon_batch_stop.setMaximumSize(QtCore.QSize(80, 16777215))
self.btn_recon_batch_stop.setObjectName("btn_recon_batch_stop")
self.horizontalLayout_11.addWidget(self.btn_recon_batch_stop)
self.ck_batch_run_flag = QtWidgets.QCheckBox(self.tab_3)
self.ck_batch_run_flag.setGeometry(QtCore.QRect(20, 220, 141, 20))
self.ck_batch_run_flag.setChecked(True)
self.ck_batch_run_flag.setObjectName("ck_batch_run_flag")
self.widget = QtWidgets.QWidget(self.tab_3)
self.widget.setGeometry(QtCore.QRect(20, 190, 644, 31))
self.widget.setObjectName("widget")
self.horizontalLayout_16 = QtWidgets.QHBoxLayout(self.widget)
self.horizontalLayout_16.setContentsMargins(0, 0, 0, 0)
self.horizontalLayout_16.setObjectName("horizontalLayout_16")
self.ck_batch_crop_flag = QtWidgets.QCheckBox(self.widget)
self.ck_batch_crop_flag.setObjectName("ck_batch_crop_flag")
self.horizontalLayout_16.addWidget(self.ck_batch_crop_flag)
spacerItem16 = QtWidgets.QSpacerItem(28, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
self.horizontalLayout_16.addItem(spacerItem16)
self.label_62 = QtWidgets.QLabel(self.widget)
self.label_62.setObjectName("label_62")
self.horizontalLayout_16.addWidget(self.label_62)
self.sp_batch_x0 = QtWidgets.QSpinBox(self.widget)
self.sp_batch_x0.setMaximum(100000000)
self.sp_batch_x0.setObjectName("sp_batch_x0")
self.horizontalLayout_16.addWidget(self.sp_batch_x0)
self.label_63 = QtWidgets.QLabel(self.widget)
self.label_63.setObjectName("label_63")
self.horizontalLayout_16.addWidget(self.label_63)
self.sp_batch_y0 = QtWidgets.QSpinBox(self.widget)
self.sp_batch_y0.setMaximum(100000000)
self.sp_batch_y0.setObjectName("sp_batch_y0")
self.horizontalLayout_16.addWidget(self.sp_batch_y0)
self.label_64 = QtWidgets.QLabel(self.widget)
self.label_64.setObjectName("label_64")
self.horizontalLayout_16.addWidget(self.label_64)
self.sp_batch_width = QtWidgets.QSpinBox(self.widget)
self.sp_batch_width.setMaximum(100000000)
self.sp_batch_width.setObjectName("sp_batch_width")
self.horizontalLayout_16.addWidget(self.sp_batch_width)
self.label_65 = QtWidgets.QLabel(self.widget)
self.label_65.setObjectName("label_65")
self.horizontalLayout_16.addWidget(self.label_65)
self.sp_batch_height = QtWidgets.QSpinBox(self.widget)
self.sp_batch_height.setMaximum(100000000)
self.sp_batch_height.setObjectName("sp_batch_height")
self.horizontalLayout_16.addWidget(self.sp_batch_height)
self.tabWidget.addTab(self.tab_3, "")
self.verticalLayout_5.addWidget(self.tabWidget)
self.console_info = QtWidgets.QTextEdit(self.centralwidget)
Expand Down Expand Up @@ -1168,6 +1211,13 @@ def retranslateUi(self, MainWindow):
self.le_prb_path_batch.setToolTip(_translate("MainWindow", "Set probe filename template. Ex: \"recon_*_t1_probe_ave.npy\", where \"*\" will be replaced by the scan number."))
self.btn_recon_batch_start.setText(_translate("MainWindow", "start"))
self.btn_recon_batch_stop.setText(_translate("MainWindow", "stop"))
self.ck_batch_run_flag.setText(_translate("MainWindow", "Run reconstruction"))
self.ck_batch_crop_flag.setToolTip(_translate("MainWindow", "This is effective only when \"Load from databroker\" is set."))
self.ck_batch_crop_flag.setText(_translate("MainWindow", "Crop data:"))
self.label_62.setText(_translate("MainWindow", "x0"))
self.label_63.setText(_translate("MainWindow", "y0"))
self.label_64.setText(_translate("MainWindow", "w"))
self.label_65.setText(_translate("MainWindow", "h"))
self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_3), _translate("MainWindow", "Batch mode"))
self.menuFile.setTitle(_translate("MainWindow", "File"))
self.menuWindows.setTitle(_translate("MainWindow", "Windows"))
Expand Down
Loading

0 comments on commit 211041e

Please sign in to comment.