Skip to content

Commit

Permalink
support CUDA-aware MPI; replace GPU buttons by text; close all window…
Browse files Browse the repository at this point in the history
…s at once

1. closes #55
2. closes #56
3. closes #57
4. parse_scan_range() is moved from the main window's method to
   nsls2ptycho.core.utils.parse_range() with minor changes
  • Loading branch information
leofang committed Jul 31, 2019
1 parent a526063 commit 218f8e4
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 304 deletions.
2 changes: 1 addition & 1 deletion nsls2ptycho/core/ptycho
1 change: 1 addition & 0 deletions nsls2ptycho/core/ptycho_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self):
self.gpus = [1, 2, 3] # should be a list of gpu numbers, ex: [0, 2, 3]
self.gpu_batch_size = 256 # should be 4^n, ex: 4, 16, 64, 256, 1024, 4096, ...
self.use_NCCL = False
self.use_CUDA_MPI = False
self.mpi_file_path = '' # full path to a valid MPI machine file

### [adv param group] ###
Expand Down
32 changes: 32 additions & 0 deletions nsls2ptycho/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,35 @@ def set_flush_early(mpirun_command):
if 'MPICH' in MPI.get_vendor()[0] or 'MVAPICH' in MPI.get_vendor()[0]:
mpirun_command.insert(-2, "-u") # force flush asap (MPICH is weird...)
return mpirun_command


def parse_range(batch_items, every_nth_item = 1, reverse_sort = True):
'''
Note the range is inclusive on both ends.
Ex: 1238 - 1242 with step size 2 --> [1238, 1240, 1242]
'''
scan_range = []
scan_numbers = []

if batch_items == '':
raise ValueError("No item list is given for batch processing.")

# first parse items and separate them into two catogories
slist = batch_items.split(',')
for item in slist:
if '-' in item:
sublist = item.split('-')
scan_range.append((int(sublist[0].strip()), int(sublist[1].strip())))
elif len(item) == 0: # for empty string
continue
else:
scan_numbers.append(int(item.strip()))

# next generate all legit items from the chosen ranges and make a sorted item list
for item in scan_range:
scan_numbers = scan_numbers + list(range(item[0], item[1]+1, every_nth_item))

if reverse_sort:
scan_numbers.sort(reverse=True)

return scan_numbers
90 changes: 36 additions & 54 deletions nsls2ptycho/ptycho_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PyQt5.QtWidgets import QFileDialog, QAction

from nsls2ptycho.ui import ui_ptycho
from nsls2ptycho.core.utils import clean_shared_memory, get_mpi_num_processes
from nsls2ptycho.core.utils import clean_shared_memory, get_mpi_num_processes, parse_range
from nsls2ptycho.core.ptycho_param import Param
from nsls2ptycho.core.ptycho_recon import PtychoReconWorker, PtychoReconFakeWorker, HardWorker
from nsls2ptycho.core.ptycho_qt_utils import PtychoStream
Expand Down Expand Up @@ -91,9 +91,7 @@ def __init__(self, parent=None, param:Param=None):
self.actionClear_shared_memory.triggered.connect(self.clearSharedMemory)

self.btn_MPI_file.clicked.connect(self.setMPIfile)
self.btn_gpu_all = [self.btn_gpu_0, self.btn_gpu_1, self.btn_gpu_2, self.btn_gpu_3]
for btn in self.btn_gpu_all:
btn.clicked.connect(self.resetMPIFlg)
self.le_gpus.textChanged.connect(self.resetMPIFlg)

# setup
self.sp_pha_max.setMaximum(pi)
Expand Down Expand Up @@ -274,11 +272,7 @@ def update_param_from_gui(self):
p.pha_max = float(self.sp_pha_max.value())

p.gpu_flag = self.ck_gpu_flag.isChecked()
gpus = []
for btn_gpu, id in zip(self.btn_gpu_all, range(len(self.btn_gpu_all))):
if btn_gpu.isChecked():
gpus.append(id)
p.gpus = gpus
p.gpus = parse_range(self.le_gpus.text(), reverse_sort=False)
p.gpu_batch_size = int(self.cb_gpu_batch_size.currentText())

# adv param group
Expand Down Expand Up @@ -333,6 +327,7 @@ def update_param_from_gui(self):
p.profiler_flag = self.ck_profiler_flag.isChecked()
p.postprocessing_flag = self.ck_postprocessing_flag.isChecked()
p.use_NCCL = self.rb_nccl.isChecked()
p.use_CUDA_MPI = self.rb_cuda_mpi.isChecked()

# TODO: organize them
#self.ck_init_obj_dpc_flag.setChecked(p.init_obj_dpc_flag)
Expand Down Expand Up @@ -392,16 +387,20 @@ def update_gui_from_param(self):
self.sp_pha_min.setValue(float(p.pha_min))

self.ck_gpu_flag.setChecked(p.gpu_flag)
for btn_gpu, id in zip(self.btn_gpu_all, range(len(self.btn_gpu_all))):
btn_gpu.setChecked(id in p.gpus)
gpu_str = ''
for i, dev_id in enumerate(p.gpus):
gpu_str += str(dev_id)
if i != len(p.gpus) - 1:
gpu_str += ', '
self.le_gpus.setText(gpu_str)
self.cb_gpu_batch_size.setCurrentIndex(p.get_gpu_batch_index())

# set MPI file path from param
if p.mpi_file_path != '':
mpi_filename = os.path.basename(p.mpi_file_path)
self.le_MPI_file_path.setText(mpi_filename)
for btn in self.btn_gpu_all:
btn.setChecked(False)
# TODO: does this make sense?
self.le_gpus.setText('')

# adv param group
self.sp_ccd_pixel_um.setValue(p.ccd_pixel_um)
Expand Down Expand Up @@ -459,6 +458,7 @@ def update_gui_from_param(self):
self.ck_profiler_flag.setChecked(p.profiler_flag)
self.ck_postprocessing_flag.setChecked(p.postprocessing_flag)
self.rb_nccl.setChecked(p.use_NCCL)
self.rb_cuda_mpi.setChecked(p.use_CUDA_MPI)

# batch param group, necessary?

Expand Down Expand Up @@ -839,17 +839,14 @@ def checkGpuAvail(self):
self.ck_gpu_flag.setChecked(False)
self.ck_gpu_flag.setEnabled(False)
self.param.gpu_flag = False
for button in self.btn_gpu_all:
button.setChecked(False)
self.le_gpus.setText('')
self.le_gpus.setEnabled(False)
self.cb_gpu_batch_size.setEnabled(False)


def updateGpuFlg(self):
flag = self.ck_gpu_flag.isChecked()
self.btn_gpu_0.setEnabled(flag)
self.btn_gpu_1.setEnabled(flag)
self.btn_gpu_2.setEnabled(flag)
self.btn_gpu_3.setEnabled(flag)
self.le_gpus.setEnabled(flag)
self.rb_nccl.setEnabled(flag)
if not flag and self.rb_nccl.isChecked():
self.rb_mpi.setChecked(True)
Expand Down Expand Up @@ -908,8 +905,7 @@ def setMPIfile(self):
self.param.mpi_file_path = filename
#print(filename)
self.le_MPI_file_path.setText(mpi_filename)
for btn in self.btn_gpu_all:
btn.setChecked(False)
self.le_gpus.setText('')


def resetMPIFlg(self):
Expand All @@ -918,38 +914,6 @@ def resetMPIFlg(self):
self.le_MPI_file_path.setText('')


# adapted from dpc_batch.py
def parse_scan_range(self):
'''
Note the range is inclusive on both ends.
Ex: 1238 - 1242 with step size 2 --> [1238, 1240, 1242]
'''
scan_range = []
scan_numbers = []
batch_items = self.le_batch_items.text()
every_nth_scan = self.sp_batch_step.value()

if batch_items == '':
raise ValueError("No item list is given for batch processing.")

# first parse items and separate them into two catogories
slist = batch_items.split(',')
for item in slist:
if '-' in item:
sublist = item.split('-')
scan_range.append((int(sublist[0].strip()), int(sublist[1].strip())))
else:
scan_numbers.append(int(item.strip()))

# next generate all legit items from the chosen ranges and make a sorted item list
for item in scan_range:
scan_numbers = scan_numbers + list(range(item[0], item[1]+1, every_nth_scan))
scan_numbers.sort(reverse=True)
print(scan_numbers)

return scan_numbers


def batchStart(self):
'''
Currently only support load from h5.
Expand All @@ -959,7 +923,8 @@ def batchStart(self):
return

try:
self._scan_numbers = self.parse_scan_range()
self._scan_numbers = parse_range(self.le_batch_items.text(), self.sp_batch_step.value())
print(self._scan_numbers)
# TODO: is there a way to lock all widgets to prevent accidental parameter changes in the middle?

# fire up
Expand Down Expand Up @@ -1309,6 +1274,23 @@ def exception_handler(self, ex):
print("[ERROR] " + str(ex), file=sys.stderr)


def closeEvent(self, event):
# Overwrite the class's default
message = "Are you sure you want to quit the app?"
ans = QtWidgets.QMessageBox.question(self, "Warning", message, QtWidgets.QMessageBox.Yes, QtWidgets.QMessageBox.No)
if ans == QtWidgets.QMessageBox.Yes:
self.stop()
if self.reconStepWindow is not None:
self.reconStepWindow.close()
if self.roiWindow is not None:
self.roiWindow.close()
if self.scanWindow is not None:
self.scanWindow.close()
event.accept()
else:
event.ignore()


# reimplementing __del__ is useless, so use the signal QApplication.aboutToQuit
def destructor(self):
try:
Expand Down
Loading

0 comments on commit 218f8e4

Please sign in to comment.