Skip to content

Commit

Permalink
Work on decon in experimental.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dmilkie committed Oct 16, 2023
1 parent bad68f9 commit 583cc02
Showing 1 changed file with 97 additions and 18 deletions.
115 changes: 97 additions & 18 deletions src/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from joblib import Parallel, delayed
from scipy.interpolate import NearestNDInterpolator
from scipy.ndimage import shift, generate_binary_structure, binary_dilation
from skimage.filters import window

from scipy.signal import correlate
from scipy.optimize import minimize_scalar, minimize

import utils
import vis
Expand All @@ -42,7 +42,7 @@
from wavefront import Wavefront
from preloaded import Preloadedmodelclass
from embeddings import remove_interference_pattern
from preprocessing import prep_sample, optimal_rolling_strides, find_roi, get_tiles
from preprocessing import prep_sample, optimal_rolling_strides, find_roi, get_tiles, round_to_even

import logging
logger = logging.getLogger('')
Expand Down Expand Up @@ -2109,6 +2109,28 @@ def overlap_tile(volume_shape, tile_shape, border, target, tile_index):
return ranges


def f_to_minimize(defocus: float,
w: Wavefront,
corrected_psf: np.ndarray,
samplepsfgen: SyntheticPSF) -> float:
"""
Args:
defocus: amount of lightsheet defocus in microns
w: wavefront aberration
corrected_psf: corrected empirical psf to match (3D array)
samplepsfgen: PSF generator (has voxel sizes, wavelength, etc..) which will make the defocus'ed 3D PSF.
Returns:
max correlation amount between 'corrected_psf' and 'defocused psf'
"""
if isinstance(defocus, np.ndarray):
defocus = defocus.item()
kernel = samplepsfgen.single_psf(w, normed=True, lls_defocus_offset=defocus)
# kernel /= np.max(kernel)
return np.max(correlate(corrected_psf, kernel, mode='same'))


@profile
def decon(
Expand Down Expand Up @@ -2140,13 +2162,21 @@ def decon(
lateral_voxel_size = predictions_settings['sample_voxel_size'][2]
window_size = predictions_settings['window_size']

psf_voxel_size = np.array([0.03, 0.03, 0.03])
tile_fov = np.array(window_size) * (axial_voxel_size, lateral_voxel_size, lateral_voxel_size)
psf_shape = np.full(shape=3, fill_value=round_to_even(np.min(tile_fov / psf_voxel_size) * 0.5), dtype=np.int32)
print(f" {axial_voxel_size=:0.03f}\n"
f"{lateral_voxel_size=:0.03f}\n"
f" psf_shape={psf_shape}\n"
f" psf_voxel_size={psf_voxel_size}um\n"
f" tile_fov={tile_fov}um\n")
samplepsfgen = SyntheticPSF(
psf_type=predictions_settings['psf_type'],
psf_shape=(64, 64, 64),
psf_shape=psf_shape,
lam_detection=wavelength,
x_voxel_size=lateral_voxel_size,
y_voxel_size=lateral_voxel_size,
z_voxel_size=axial_voxel_size
x_voxel_size=psf_voxel_size[0],
y_voxel_size=psf_voxel_size[1],
z_voxel_size=psf_voxel_size[2],
)

# tile id is the column header, rows are the predictions
Expand Down Expand Up @@ -2187,12 +2217,12 @@ def decon(
imwrite(savepath, decon_vol.astype(np.float32))

psfs = np.zeros(
(ztiles, ytiles*samplepsfgen.psf_shape[1], xtiles*samplepsfgen.psf_shape[2]),
(ztiles*samplepsfgen.psf_shape[0], ytiles*samplepsfgen.psf_shape[1], xtiles*samplepsfgen.psf_shape[2]),
dtype=np.float32
)

zw, yw, xw = window_size
kyw, kxw = samplepsfgen.psf_shape[1:]
kzw, kyw, kxw = samplepsfgen.psf_shape

logger.info(f"volume_size = {vol.shape}")
logger.info(f"window_size = {zw, yw, xw}")
Expand All @@ -2217,6 +2247,12 @@ def decon(
n_iters=iters,
skewed_decon=True,
deskew=0,
na=samplepsfgen.na_detection,
background=100,
wavelength=samplepsfgen.lam_detection * 1000, # wavelength in nm
nimm=samplepsfgen.refractive_index,
cleanup_otf=False,
dup_rev_z=True,
)
elif task == 'cocoa':
from experimental_benchmarks import predict_cocoa
Expand Down Expand Up @@ -2245,6 +2281,7 @@ def decon(
position=0
):
w = Wavefront(predictions.loc[z, y, x].values, lam_detection=wavelength)

kernel = samplepsfgen.single_psf(w, normed=False)
kernel /= np.max(kernel)

Expand All @@ -2259,32 +2296,70 @@ def decon(

kernel = out_k_m
elif task == 'decon':
stdout = silence(task == 'decon')
# stdout = silence(task == 'decon')
reconstructed = reconstruct_decon(tile, psf=kernel)
silence(False, stdout=stdout)
# silence(False, stdout=stdout)
else:
logger.error(f"Task of '{task}' is unknown")
return

decon_vol[tile_slice(target='dst', tile_index=(z, y, x))
] = reconstructed[tile_slice(target='extract', tile_index=(z, y, x))]

psfs[z, y * kyw:(y * kyw) + kyw, x * kxw:(x * kxw) + kxw] = np.max(kernel[:, 0:kyw, 0:kxw], axis=0) # mip view for later.
psfs[ z * kzw:(z * kzw) + kzw,
y * kyw:(y * kyw) + kyw,
x * kxw:(x * kxw) + kxw] = kernel[0:kzw, 0:kyw, 0:kxw]

imwrite(f"{model_pred.with_suffix('')}_{task}_psfs.tif", psfs.astype(np.float32))
imwrite(f"{model_pred.with_suffix('')}_{task}_psfs.tif", psfs.astype(np.float32), resolution=(xw, yw))
imwrite(savepath, decon_vol)
else:
# identify all the unique PSFs that we need to decconvolve with
# identify all the unique PSFs that we need to deconvolve with
predictions['psf_id'] = predictions.groupby(predictions.columns.values.tolist(), sort=False).grouper.group_info[0]
groups = predictions.groupby('psf_id')

# for each psf_id, deconvolve the volume
# for each psf_id, deconvolve the entire volume
for psf_id in tqdm(predictions['psf_id'].unique(), desc=f'Do {task} entire vol with each psf, {iters} RL iterations', unit='vols to decon', position=0):
df = groups.get_group(psf_id).drop(columns=['p2v', 'psf_id'])

zernikes = df.values[0] # all rows in this group should be equal. Take the first one as the wavefront.
w = Wavefront(zernikes, lam_detection=wavelength)
kernel = samplepsfgen.single_psf(w, normed=False)
defocus = 0
z, y, x = df.index[0]
if np.count_nonzero(zernikes) > 0:
corrected_psf_path = Path(
str(model_pred / f'z{z}-y{y}-x{x}_corrected_psf.tif').replace("_predictions.csv", ""))
corrected_psf = np.zeros_like(imread(corrected_psf_path))

for index, zernikes in df.iterrows():
z, y, x = index
corrected_psf_path = Path(str(model_pred / f'z{z}-y{y}-x{x}_corrected_psf.tif').replace("_predictions.csv", ""))
corrected_psf += imread(corrected_psf_path)

corrected_psf /= np.max(corrected_psf)
res = minimize(
f_to_minimize,
x0=0,
args=(w, corrected_psf, samplepsfgen),
tol=samplepsfgen.z_voxel_size,
bounds=[(-0.6, 0.6)],
method='Nelder-Mead',
options={
'disp': False,
'initial_simplex': [[-samplepsfgen.z_voxel_size], [samplepsfgen.z_voxel_size]], # need 1st step > one z, overrides x0
}
)
defocus = res.x[0]
# defocus_steps = np.linspace(-2, 2, 41)
# correlations = np.zeros_like(defocus_steps)
#
# for i, defocus in enumerate(defocus_steps):
# kernel = samplepsfgen.single_psf(w, normed=False, lls_defocus_offset=defocus)
# kernel /= np.max(kernel)
# correlations[i] = np.max(correlate(corrected_psf, kernel, mode='same'))
#
# defocus = defocus_steps[np.argmax(correlations)]
print(f'\t Defocus for ({z:2d}, {y:2d}, {x:2d}) is {defocus: 0.2f} um, p2V is {w.peak2valley(na=samplepsfgen.na_detection):0.02f}')
kernel = samplepsfgen.single_psf(w, normed=False, lls_defocus_offset=defocus)
kernel /= np.max(kernel)

if task == 'cocoa':
Expand All @@ -2309,9 +2384,13 @@ def decon(
y * yw :(y * yw) + yw,
x * xw :(x * xw) + xw,
]
psfs[
z * kzw:(z * kzw) + kzw,
y * kyw:(y * kyw) + kyw,
x * kxw:(x * kxw) + kxw] = kernel[0:kzw, 0:kyw, 0:kxw]

imwrite(savepath, decon_vol.astype(np.float32))

imwrite(savepath, decon_vol.astype(np.float32), resolution=(xw, yw))
imwrite(f"{model_pred.with_suffix('')}_{task}_psfs.tif", psfs.astype(np.float32), resolution=(psf_shape[1], psf_shape[2]))

imwrite(savepath, decon_vol.astype(np.float32))
logger.info(f"Decon image saved to : \n{savepath.resolve()}")
Expand Down

0 comments on commit 583cc02

Please sign in to comment.