Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 11, 2024
1 parent 9567980 commit 006e946
Show file tree
Hide file tree
Showing 16 changed files with 4,043 additions and 1,866 deletions.
7 changes: 1 addition & 6 deletions nobrainer/ext/lab2im/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
from . import edit_tensors
from . import edit_volumes
from . import image_generator
from . import lab2im_model
from . import layers
from . import utils
from . import edit_tensors, edit_volumes, image_generator, lab2im_model, layers, utils
197 changes: 144 additions & 53 deletions nobrainer/ext/lab2im/edit_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
"""

from itertools import combinations

import keras.backend as K
import keras.layers as KL

# python imports
import numpy as np
import tensorflow as tf
import keras.layers as KL
import keras.backend as K
from itertools import combinations

# project imports
from nobrainer.ext.lab2im import utils
Expand All @@ -38,7 +39,9 @@
from nobrainer.ext.neuron.utils import volshape_to_meshgrid


def blurring_sigma_for_downsampling(current_res, downsample_res, mult_coef=None, thickness=None):
def blurring_sigma_for_downsampling(
current_res, downsample_res, mult_coef=None, thickness=None
):
"""Compute standard deviations of 1d gaussian masks for image blurring before downsampling.
:param downsample_res: resolution to downsample to. Can be a 1d numpy array or list, or a tensor.
:param current_res: resolution of the volume before downsampling.
Expand Down Expand Up @@ -68,17 +71,32 @@ def blurring_sigma_for_downsampling(current_res, downsample_res, mult_coef=None,

# reformat data resolution at which we blur
if thickness is not None:
down_res = KL.Lambda(lambda x: tf.math.minimum(x[0], x[1]))([downsample_res, thickness])
down_res = KL.Lambda(lambda x: tf.math.minimum(x[0], x[1]))(
[downsample_res, thickness]
)
else:
down_res = downsample_res

# get std deviation for blurring kernels
if mult_coef is None:
sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x, tf.convert_to_tensor(current_res, dtype='float32')),
0.5, 0.75 * x / tf.convert_to_tensor(current_res, dtype='float32')))(down_res)
sigma = KL.Lambda(
lambda x: tf.where(
tf.math.equal(
x, tf.convert_to_tensor(current_res, dtype="float32")
),
0.5,
0.75 * x / tf.convert_to_tensor(current_res, dtype="float32"),
)
)(down_res)
else:
sigma = KL.Lambda(lambda x: mult_coef * x / tf.convert_to_tensor(current_res, dtype='float32'))(down_res)
sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x[0], 0.), 0., x[1]))([down_res, sigma])
sigma = KL.Lambda(
lambda x: mult_coef
* x
/ tf.convert_to_tensor(current_res, dtype="float32")
)(down_res)
sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x[0], 0.0), 0.0, x[1]))(
[down_res, sigma]
)

return sigma

Expand All @@ -95,9 +113,13 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):
"""
# convert sigma into a tensor
if not tf.is_tensor(sigma):
sigma_tens = tf.convert_to_tensor(utils.reformat_to_list(sigma), dtype='float32')
sigma_tens = tf.convert_to_tensor(
utils.reformat_to_list(sigma), dtype="float32"
)
else:
assert max_sigma is not None, 'max_sigma must be provided when sigma is given as a tensor'
assert (
max_sigma is not None
), "max_sigma must be provided when sigma is given as a tensor"
sigma_tens = sigma
shape = sigma_tens.get_shape().as_list()

Expand All @@ -118,7 +140,9 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):
# randomise the burring std dev and/or split it between dimensions
if blur_range is not None:
if blur_range != 1:
sigma_tens = sigma_tens * tf.random.uniform(tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range)
sigma_tens = sigma_tens * tf.random.uniform(
tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range
)

# get size of blurring kernels
windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1
Expand All @@ -129,16 +153,23 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):

kernels = list()
comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1])
for (i, wsize) in enumerate(windowsize):
for i, wsize in enumerate(windowsize):

if wsize > 1:

# build meshgrid and replicate it along batch dim if dynamic blurring
locations = tf.cast(tf.range(0, wsize), 'float32') - (wsize - 1) / 2
locations = tf.cast(tf.range(0, wsize), "float32") - (wsize - 1) / 2
if batchsize is not None:
locations = tf.tile(tf.expand_dims(locations, axis=0),
tf.concat([batchsize, tf.ones(tf.shape(tf.shape(locations)), dtype='int32')],
axis=0))
locations = tf.tile(
tf.expand_dims(locations, axis=0),
tf.concat(
[
batchsize,
tf.ones(tf.shape(tf.shape(locations)), dtype="int32"),
],
axis=0,
),
)
comb[i] += 1

# compute gaussians
Expand All @@ -156,13 +187,23 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):
else:

# build meshgrid
mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')]
diff = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1)
mesh = [
tf.cast(f, "float32")
for f in volshape_to_meshgrid(windowsize, indexing="ij")
]
diff = tf.stack(
[mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1
)

# replicate meshgrid to batch size and reshape sigma_tens
if batchsize is not None:
diff = tf.tile(tf.expand_dims(diff, axis=0),
tf.concat([batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype='int32')], axis=0))
diff = tf.tile(
tf.expand_dims(diff, axis=0),
tf.concat(
[batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype="int32")],
axis=0,
),
)
for i in range(n_dims):
sigma_tens = tf.expand_dims(sigma_tens, axis=1)
else:
Expand All @@ -171,8 +212,14 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):

# compute gaussians
sigma_is_0 = tf.equal(sigma_tens, 0)
exp_term = -K.square(diff) / (2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens)**2)
norms = exp_term - tf.math.log(tf.where(sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens))
exp_term = -K.square(diff) / (
2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens) ** 2
)
norms = exp_term - tf.math.log(
tf.where(
sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens
)
)
kernels = K.sum(norms, -1)
kernels = tf.exp(kernels)
kernels /= tf.reduce_sum(kernels)
Expand All @@ -184,8 +231,8 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):
def sobel_kernels(n_dims):
"""Returns sobel kernels to compute spatial derivative on image of n dimensions."""

in_dir = tf.convert_to_tensor([1, 0, -1], dtype='float32')
orthogonal_dir = tf.convert_to_tensor([1, 2, 1], dtype='float32')
in_dir = tf.convert_to_tensor([1, 0, -1], dtype="float32")
orthogonal_dir = tf.convert_to_tensor([1, 2, 1], dtype="float32")
comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1])

list_kernels = list()
Expand Down Expand Up @@ -216,50 +263,74 @@ def unit_kernel(dist_threshold, n_dims, max_dist_threshold=None):

# convert dist_threshold into a tensor
if not tf.is_tensor(dist_threshold):
dist_threshold_tens = tf.convert_to_tensor(utils.reformat_to_list(dist_threshold), dtype='float32')
dist_threshold_tens = tf.convert_to_tensor(
utils.reformat_to_list(dist_threshold), dtype="float32"
)
else:
assert max_dist_threshold is not None, 'max_sigma must be provided when dist_threshold is given as a tensor'
dist_threshold_tens = tf.cast(dist_threshold, 'float32')
assert (
max_dist_threshold is not None
), "max_sigma must be provided when dist_threshold is given as a tensor"
dist_threshold_tens = tf.cast(dist_threshold, "float32")
shape = dist_threshold_tens.get_shape().as_list()

# get batchsize
batchsize = None if shape[0] is not None else tf.split(tf.shape(dist_threshold_tens), [1, -1])[0]
batchsize = (
None
if shape[0] is not None
else tf.split(tf.shape(dist_threshold_tens), [1, -1])[0]
)

# set max_dist_threshold into an array
if max_dist_threshold is None: # dist_threshold is fixed (i.e. dist_threshold will not change at each mini-batch)
if (
max_dist_threshold is None
): # dist_threshold is fixed (i.e. dist_threshold will not change at each mini-batch)
max_dist_threshold = dist_threshold

# get size of blurring kernels
windowsize = np.array([max_dist_threshold * 2 + 1]*n_dims, dtype='int32')
windowsize = np.array([max_dist_threshold * 2 + 1] * n_dims, dtype="int32")

# build tensor representing the distance from the centre
mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')]
dist = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1)
mesh = [
tf.cast(f, "float32") for f in volshape_to_meshgrid(windowsize, indexing="ij")
]
dist = tf.stack(
[mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1
)
dist = tf.sqrt(tf.reduce_sum(tf.square(dist), axis=-1))

# replicate distance to batch size and reshape sigma_tens
if batchsize is not None:
dist = tf.tile(tf.expand_dims(dist, axis=0),
tf.concat([batchsize, tf.ones(tf.shape(tf.shape(dist)), dtype='int32')], axis=0))
dist = tf.tile(
tf.expand_dims(dist, axis=0),
tf.concat(
[batchsize, tf.ones(tf.shape(tf.shape(dist)), dtype="int32")], axis=0
),
)
for i in range(n_dims - 1):
dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=1)
else:
for i in range(n_dims - 1):
dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=0)

# build final kernel by thresholding distance tensor
kernel = tf.where(tf.less_equal(dist, dist_threshold_tens), tf.ones_like(dist), tf.zeros_like(dist))
kernel = tf.where(
tf.less_equal(dist, dist_threshold_tens),
tf.ones_like(dist),
tf.zeros_like(dist),
)
kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1)

return kernel


def resample_tensor(tensor,
resample_shape,
interp_method='linear',
subsample_res=None,
volume_res=None,
build_reliability_map=False):
def resample_tensor(
tensor,
resample_shape,
interp_method="linear",
subsample_res=None,
volume_res=None,
build_reliability_map=False,
):
"""This function resamples a volume to resample_shape. It does not apply any pre-filtering.
A prior downsampling step can be added if subsample_res is specified. In this case, volume_res should also be
specified, in order to calculate the downsampling ratio. A reliability map can also be returned to indicate which
Expand All @@ -286,22 +357,35 @@ def resample_tensor(tensor,
downsample_shape = tensor_shape # will be modified if we actually downsample

if subsample_res is not None:
assert volume_res is not None, 'volume_res must be given when providing a subsampling resolution.'
assert len(subsample_res) == len(volume_res), 'subsample_res and volume_res must have the same length, ' \
'had {0}, and {1}'.format(len(subsample_res), len(volume_res))
assert (
volume_res is not None
), "volume_res must be given when providing a subsampling resolution."
assert len(subsample_res) == len(volume_res), (
"subsample_res and volume_res must have the same length, "
"had {0}, and {1}".format(len(subsample_res), len(volume_res))
)
if subsample_res != volume_res:

# get shape at which we downsample
downsample_shape = [int(tensor_shape[i] * volume_res[i] / subsample_res[i]) for i in range(n_dims)]
downsample_shape = [
int(tensor_shape[i] * volume_res[i] / subsample_res[i])
for i in range(n_dims)
]

# downsample volume
tensor._keras_shape = tuple(tensor.get_shape().as_list())
tensor = nrn_layers.Resize(size=downsample_shape, interp_method='nearest')(tensor)
tensor = nrn_layers.Resize(size=downsample_shape, interp_method="nearest")(
tensor
)

# resample image at target resolution
if resample_shape != downsample_shape: # if we didn't downsample downsample_shape = tensor_shape
if (
resample_shape != downsample_shape
): # if we didn't downsample downsample_shape = tensor_shape
tensor._keras_shape = tuple(tensor.get_shape().as_list())
tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)(tensor)
tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)(
tensor
)

# compute reliability maps if necessary and return results
if build_reliability_map:
Expand All @@ -320,13 +404,20 @@ def resample_tensor(tensor,
loc_ceil = np.int32(np.clip(loc_floor + 1, 0, resample_shape[i] - 1))
tmp_reliability_map = np.zeros(resample_shape[i])
tmp_reliability_map[loc_floor] = 1 - (loc_float - loc_floor)
tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + (loc_float - loc_floor)
tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + (
loc_float - loc_floor
)
shape = [1, 1, 1]
shape[i] = resample_shape[i]
reliability_map = reliability_map * np.reshape(tmp_reliability_map, shape)
reliability_map = reliability_map * np.reshape(
tmp_reliability_map, shape
)
shape = KL.Lambda(lambda x: tf.shape(x))(tensor)
mask = KL.Lambda(lambda x: tf.reshape(tf.convert_to_tensor(reliability_map, dtype='float32'),
shape=x))(shape)
mask = KL.Lambda(
lambda x: tf.reshape(
tf.convert_to_tensor(reliability_map, dtype="float32"), shape=x
)
)(shape)

# otherwise just return an all-one tensor
else:
Expand Down
Loading

0 comments on commit 006e946

Please sign in to comment.