diff --git a/src/experimental.py b/src/experimental.py index 7e4aaee2..ea96f00e 100644 --- a/src/experimental.py +++ b/src/experimental.py @@ -31,8 +31,8 @@ from joblib import Parallel, delayed from scipy.interpolate import NearestNDInterpolator from scipy.ndimage import shift, generate_binary_structure, binary_dilation -from skimage.filters import window - +from scipy.signal import correlate +from scipy.optimize import minimize_scalar, minimize import utils import vis @@ -42,7 +42,7 @@ from wavefront import Wavefront from preloaded import Preloadedmodelclass from embeddings import remove_interference_pattern -from preprocessing import prep_sample, optimal_rolling_strides, find_roi, get_tiles +from preprocessing import prep_sample, optimal_rolling_strides, find_roi, get_tiles, round_to_even import logging logger = logging.getLogger('') @@ -2109,6 +2109,28 @@ def overlap_tile(volume_shape, tile_shape, border, target, tile_index): return ranges +def f_to_minimize(defocus: float, + w: Wavefront, + corrected_psf: np.ndarray, + samplepsfgen: SyntheticPSF) -> float: + """ + + Args: + defocus: amount of lightsheet defocus in microns + w: wavefront aberration + corrected_psf: corrected empirical psf to match (3D array) + samplepsfgen: PSF generator (has voxel sizes, wavelength, etc..) which will make the defocus'ed 3D PSF. + + Returns: + max correlation amount between 'corrected_psf' and 'defocused psf' + + """ + if isinstance(defocus, np.ndarray): + defocus = defocus.item() + kernel = samplepsfgen.single_psf(w, normed=True, lls_defocus_offset=defocus) + # kernel /= np.max(kernel) + return np.max(correlate(corrected_psf, kernel, mode='same')) + @profile def decon( @@ -2140,13 +2162,21 @@ def decon( lateral_voxel_size = predictions_settings['sample_voxel_size'][2] window_size = predictions_settings['window_size'] + psf_voxel_size = np.array([0.03, 0.03, 0.03]) + tile_fov = np.array(window_size) * (axial_voxel_size, lateral_voxel_size, lateral_voxel_size) + psf_shape = np.full(shape=3, fill_value=round_to_even(np.min(tile_fov / psf_voxel_size) * 0.5), dtype=np.int32) + print(f" {axial_voxel_size=:0.03f}\n" + f"{lateral_voxel_size=:0.03f}\n" + f" psf_shape={psf_shape}\n" + f" psf_voxel_size={psf_voxel_size}um\n" + f" tile_fov={tile_fov}um\n") samplepsfgen = SyntheticPSF( psf_type=predictions_settings['psf_type'], - psf_shape=(64, 64, 64), + psf_shape=psf_shape, lam_detection=wavelength, - x_voxel_size=lateral_voxel_size, - y_voxel_size=lateral_voxel_size, - z_voxel_size=axial_voxel_size + x_voxel_size=psf_voxel_size[0], + y_voxel_size=psf_voxel_size[1], + z_voxel_size=psf_voxel_size[2], ) # tile id is the column header, rows are the predictions @@ -2187,12 +2217,12 @@ def decon( imwrite(savepath, decon_vol.astype(np.float32)) psfs = np.zeros( - (ztiles, ytiles*samplepsfgen.psf_shape[1], xtiles*samplepsfgen.psf_shape[2]), + (ztiles*samplepsfgen.psf_shape[0], ytiles*samplepsfgen.psf_shape[1], xtiles*samplepsfgen.psf_shape[2]), dtype=np.float32 ) zw, yw, xw = window_size - kyw, kxw = samplepsfgen.psf_shape[1:] + kzw, kyw, kxw = samplepsfgen.psf_shape logger.info(f"volume_size = {vol.shape}") logger.info(f"window_size = {zw, yw, xw}") @@ -2217,6 +2247,12 @@ def decon( n_iters=iters, skewed_decon=True, deskew=0, + na=samplepsfgen.na_detection, + background=100, + wavelength=samplepsfgen.lam_detection * 1000, # wavelength in nm + nimm=samplepsfgen.refractive_index, + cleanup_otf=False, + dup_rev_z=True, ) elif task == 'cocoa': from experimental_benchmarks import predict_cocoa @@ -2245,6 +2281,7 @@ def decon( position=0 ): w = Wavefront(predictions.loc[z, y, x].values, lam_detection=wavelength) + kernel = samplepsfgen.single_psf(w, normed=False) kernel /= np.max(kernel) @@ -2259,9 +2296,9 @@ def decon( kernel = out_k_m elif task == 'decon': - stdout = silence(task == 'decon') + # stdout = silence(task == 'decon') reconstructed = reconstruct_decon(tile, psf=kernel) - silence(False, stdout=stdout) + # silence(False, stdout=stdout) else: logger.error(f"Task of '{task}' is unknown") return @@ -2269,22 +2306,60 @@ def decon( decon_vol[tile_slice(target='dst', tile_index=(z, y, x)) ] = reconstructed[tile_slice(target='extract', tile_index=(z, y, x))] - psfs[z, y * kyw:(y * kyw) + kyw, x * kxw:(x * kxw) + kxw] = np.max(kernel[:, 0:kyw, 0:kxw], axis=0) # mip view for later. + psfs[ z * kzw:(z * kzw) + kzw, + y * kyw:(y * kyw) + kyw, + x * kxw:(x * kxw) + kxw] = kernel[0:kzw, 0:kyw, 0:kxw] - imwrite(f"{model_pred.with_suffix('')}_{task}_psfs.tif", psfs.astype(np.float32)) + imwrite(f"{model_pred.with_suffix('')}_{task}_psfs.tif", psfs.astype(np.float32), resolution=(xw, yw)) imwrite(savepath, decon_vol) else: - # identify all the unique PSFs that we need to decconvolve with + # identify all the unique PSFs that we need to deconvolve with predictions['psf_id'] = predictions.groupby(predictions.columns.values.tolist(), sort=False).grouper.group_info[0] groups = predictions.groupby('psf_id') - # for each psf_id, deconvolve the volume + # for each psf_id, deconvolve the entire volume for psf_id in tqdm(predictions['psf_id'].unique(), desc=f'Do {task} entire vol with each psf, {iters} RL iterations', unit='vols to decon', position=0): df = groups.get_group(psf_id).drop(columns=['p2v', 'psf_id']) zernikes = df.values[0] # all rows in this group should be equal. Take the first one as the wavefront. w = Wavefront(zernikes, lam_detection=wavelength) - kernel = samplepsfgen.single_psf(w, normed=False) + defocus = 0 + z, y, x = df.index[0] + if np.count_nonzero(zernikes) > 0: + corrected_psf_path = Path( + str(model_pred / f'z{z}-y{y}-x{x}_corrected_psf.tif').replace("_predictions.csv", "")) + corrected_psf = np.zeros_like(imread(corrected_psf_path)) + + for index, zernikes in df.iterrows(): + z, y, x = index + corrected_psf_path = Path(str(model_pred / f'z{z}-y{y}-x{x}_corrected_psf.tif').replace("_predictions.csv", "")) + corrected_psf += imread(corrected_psf_path) + + corrected_psf /= np.max(corrected_psf) + res = minimize( + f_to_minimize, + x0=0, + args=(w, corrected_psf, samplepsfgen), + tol=samplepsfgen.z_voxel_size, + bounds=[(-0.6, 0.6)], + method='Nelder-Mead', + options={ + 'disp': False, + 'initial_simplex': [[-samplepsfgen.z_voxel_size], [samplepsfgen.z_voxel_size]], # need 1st step > one z, overrides x0 + } + ) + defocus = res.x[0] + # defocus_steps = np.linspace(-2, 2, 41) + # correlations = np.zeros_like(defocus_steps) + # + # for i, defocus in enumerate(defocus_steps): + # kernel = samplepsfgen.single_psf(w, normed=False, lls_defocus_offset=defocus) + # kernel /= np.max(kernel) + # correlations[i] = np.max(correlate(corrected_psf, kernel, mode='same')) + # + # defocus = defocus_steps[np.argmax(correlations)] + print(f'\t Defocus for ({z:2d}, {y:2d}, {x:2d}) is {defocus: 0.2f} um, p2V is {w.peak2valley(na=samplepsfgen.na_detection):0.02f}') + kernel = samplepsfgen.single_psf(w, normed=False, lls_defocus_offset=defocus) kernel /= np.max(kernel) if task == 'cocoa': @@ -2309,9 +2384,13 @@ def decon( y * yw :(y * yw) + yw, x * xw :(x * xw) + xw, ] + psfs[ + z * kzw:(z * kzw) + kzw, + y * kyw:(y * kyw) + kyw, + x * kxw:(x * kxw) + kxw] = kernel[0:kzw, 0:kyw, 0:kxw] - imwrite(savepath, decon_vol.astype(np.float32)) - + imwrite(savepath, decon_vol.astype(np.float32), resolution=(xw, yw)) + imwrite(f"{model_pred.with_suffix('')}_{task}_psfs.tif", psfs.astype(np.float32), resolution=(psf_shape[1], psf_shape[2])) imwrite(savepath, decon_vol.astype(np.float32)) logger.info(f"Decon image saved to : \n{savepath.resolve()}")