Skip to content

Commit

Permalink
synthseg in nobrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
hvgazula committed Jun 11, 2024
1 parent f0396cb commit a2b857a
Show file tree
Hide file tree
Showing 9 changed files with 7,185 additions and 0 deletions.
6 changes: 6 additions & 0 deletions nobrainer/ext/lab2im/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import edit_tensors
from . import edit_volumes
from . import image_generator
from . import lab2im_model
from . import layers
from . import utils
346 changes: 346 additions & 0 deletions nobrainer/ext/lab2im/edit_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
"""
This file contains functions to handle keras/tensorflow tensors.
- blurring_sigma_for_downsampling
- gaussian_kernel
- resample_tensor
- expand_dims
If you use this code, please cite the first SynthSeg paper:
https://github.com/BBillot/lab2im/blob/master/bibtex.bib
Copyright 2020 Benjamin Billot
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied. See the License for the specific language governing permissions and limitations under the
License.
"""


# python imports
import numpy as np
import tensorflow as tf
import keras.layers as KL
import keras.backend as K
from itertools import combinations

# project imports
from ext.lab2im import utils

# third-party imports
import ext.neuron.layers as nrn_layers
from ext.neuron.utils import volshape_to_meshgrid


def blurring_sigma_for_downsampling(current_res, downsample_res, mult_coef=None, thickness=None):
"""Compute standard deviations of 1d gaussian masks for image blurring before downsampling.
:param downsample_res: resolution to downsample to. Can be a 1d numpy array or list, or a tensor.
:param current_res: resolution of the volume before downsampling.
Can be a 1d numpy array or list or tensor of the same length as downsample res.
:param mult_coef: (optional) multiplicative coefficient for the blurring kernel. Default is 0.75.
:param thickness: (optional) slice thickness in each dimension. Must be the same type as downsample_res.
:return: standard deviation of the blurring masks given as the same type as downsample_res (list or tensor).
"""

if not tf.is_tensor(downsample_res):

# get blurring resolution (min between downsample_res and thickness)
current_res = np.array(current_res)
downsample_res = np.array(downsample_res)
if thickness is not None:
downsample_res = np.minimum(downsample_res, np.array(thickness))

# get std deviation for blurring kernels
if mult_coef is None:
sigma = 0.75 * downsample_res / current_res
sigma[downsample_res == current_res] = 0.5
else:
sigma = mult_coef * downsample_res / current_res
sigma[downsample_res == 0] = 0

else:

# reformat data resolution at which we blur
if thickness is not None:
down_res = KL.Lambda(lambda x: tf.math.minimum(x[0], x[1]))([downsample_res, thickness])
else:
down_res = downsample_res

# get std deviation for blurring kernels
if mult_coef is None:
sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x, tf.convert_to_tensor(current_res, dtype='float32')),
0.5, 0.75 * x / tf.convert_to_tensor(current_res, dtype='float32')))(down_res)
else:
sigma = KL.Lambda(lambda x: mult_coef * x / tf.convert_to_tensor(current_res, dtype='float32'))(down_res)
sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x[0], 0.), 0., x[1]))([down_res, sigma])

return sigma


def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True):
"""Build gaussian kernels of the specified standard deviation. The outputs are given as tensorflow tensors.
:param sigma: standard deviation of the tensors. Can be given as a list/numpy array or as tensors. In each case,
sigma must have the same length as the number of dimensions of the volume that will be blurred with the output
tensors (e.g. sigma must have 3 values for 3D volumes).
:param max_sigma:
:param blur_range:
:param separable:
:return:
"""
# convert sigma into a tensor
if not tf.is_tensor(sigma):
sigma_tens = tf.convert_to_tensor(utils.reformat_to_list(sigma), dtype='float32')
else:
assert max_sigma is not None, 'max_sigma must be provided when sigma is given as a tensor'
sigma_tens = sigma
shape = sigma_tens.get_shape().as_list()

# get n_dims and batchsize
if shape[0] is not None:
n_dims = shape[0]
batchsize = None
else:
n_dims = shape[1]
batchsize = tf.split(tf.shape(sigma_tens), [1, -1])[0]

# reformat max_sigma
if max_sigma is not None: # dynamic blurring
max_sigma = np.array(utils.reformat_to_list(max_sigma, length=n_dims))
else: # sigma is fixed
max_sigma = np.array(utils.reformat_to_list(sigma, length=n_dims))

# randomise the burring std dev and/or split it between dimensions
if blur_range is not None:
if blur_range != 1:
sigma_tens = sigma_tens * tf.random.uniform(tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range)

# get size of blurring kernels
windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1

if separable:

split_sigma = tf.split(sigma_tens, [1] * n_dims, axis=-1)

kernels = list()
comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1])
for (i, wsize) in enumerate(windowsize):

if wsize > 1:

# build meshgrid and replicate it along batch dim if dynamic blurring
locations = tf.cast(tf.range(0, wsize), 'float32') - (wsize - 1) / 2
if batchsize is not None:
locations = tf.tile(tf.expand_dims(locations, axis=0),
tf.concat([batchsize, tf.ones(tf.shape(tf.shape(locations)), dtype='int32')],
axis=0))
comb[i] += 1

# compute gaussians
exp_term = -K.square(locations) / (2 * split_sigma[i] ** 2)
g = tf.exp(exp_term - tf.math.log(np.sqrt(2 * np.pi) * split_sigma[i]))
g = g / tf.reduce_sum(g)

for axis in comb[i]:
g = tf.expand_dims(g, axis=axis)
kernels.append(tf.expand_dims(tf.expand_dims(g, -1), -1))

else:
kernels.append(None)

else:

# build meshgrid
mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')]
diff = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1)

# replicate meshgrid to batch size and reshape sigma_tens
if batchsize is not None:
diff = tf.tile(tf.expand_dims(diff, axis=0),
tf.concat([batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype='int32')], axis=0))
for i in range(n_dims):
sigma_tens = tf.expand_dims(sigma_tens, axis=1)
else:
for i in range(n_dims):
sigma_tens = tf.expand_dims(sigma_tens, axis=0)

# compute gaussians
sigma_is_0 = tf.equal(sigma_tens, 0)
exp_term = -K.square(diff) / (2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens)**2)
norms = exp_term - tf.math.log(tf.where(sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens))
kernels = K.sum(norms, -1)
kernels = tf.exp(kernels)
kernels /= tf.reduce_sum(kernels)
kernels = tf.expand_dims(tf.expand_dims(kernels, -1), -1)

return kernels


def sobel_kernels(n_dims):
"""Returns sobel kernels to compute spatial derivative on image of n dimensions."""

in_dir = tf.convert_to_tensor([1, 0, -1], dtype='float32')
orthogonal_dir = tf.convert_to_tensor([1, 2, 1], dtype='float32')
comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1])

list_kernels = list()
for dim in range(n_dims):

sublist_kernels = list()
for axis in range(n_dims):

kernel = in_dir if axis == dim else orthogonal_dir
for i in comb[axis]:
kernel = tf.expand_dims(kernel, axis=i)
sublist_kernels.append(tf.expand_dims(tf.expand_dims(kernel, -1), -1))

list_kernels.append(sublist_kernels)

return list_kernels


def unit_kernel(dist_threshold, n_dims, max_dist_threshold=None):
"""Build kernel with values of 1 for voxel at a distance < dist_threshold from the center, and 0 otherwise.
The outputs are given as tensorflow tensors.
:param dist_threshold: maximum distance from the center until voxel will have a value of 1. Can be a tensor of size
(batch_size, 1), or a float.
:param n_dims: dimension of the kernel to return (excluding batch and channel dimensions).
:param max_dist_threshold: if distance_threshold is a tensor, max_dist_threshold must be given. It represents the
maximum value that will be passed to dist_threshold. Must be a float.
"""

# convert dist_threshold into a tensor
if not tf.is_tensor(dist_threshold):
dist_threshold_tens = tf.convert_to_tensor(utils.reformat_to_list(dist_threshold), dtype='float32')
else:
assert max_dist_threshold is not None, 'max_sigma must be provided when dist_threshold is given as a tensor'
dist_threshold_tens = tf.cast(dist_threshold, 'float32')
shape = dist_threshold_tens.get_shape().as_list()

# get batchsize
batchsize = None if shape[0] is not None else tf.split(tf.shape(dist_threshold_tens), [1, -1])[0]

# set max_dist_threshold into an array
if max_dist_threshold is None: # dist_threshold is fixed (i.e. dist_threshold will not change at each mini-batch)
max_dist_threshold = dist_threshold

# get size of blurring kernels
windowsize = np.array([max_dist_threshold * 2 + 1]*n_dims, dtype='int32')

# build tensor representing the distance from the centre
mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')]
dist = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1)
dist = tf.sqrt(tf.reduce_sum(tf.square(dist), axis=-1))

# replicate distance to batch size and reshape sigma_tens
if batchsize is not None:
dist = tf.tile(tf.expand_dims(dist, axis=0),
tf.concat([batchsize, tf.ones(tf.shape(tf.shape(dist)), dtype='int32')], axis=0))
for i in range(n_dims - 1):
dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=1)
else:
for i in range(n_dims - 1):
dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=0)

# build final kernel by thresholding distance tensor
kernel = tf.where(tf.less_equal(dist, dist_threshold_tens), tf.ones_like(dist), tf.zeros_like(dist))
kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1)

return kernel


def resample_tensor(tensor,
resample_shape,
interp_method='linear',
subsample_res=None,
volume_res=None,
build_reliability_map=False):
"""This function resamples a volume to resample_shape. It does not apply any pre-filtering.
A prior downsampling step can be added if subsample_res is specified. In this case, volume_res should also be
specified, in order to calculate the downsampling ratio. A reliability map can also be returned to indicate which
slices were interpolated during resampling from the downsampled to final tensor.
:param tensor: tensor
:param resample_shape: list or numpy array of size (n_dims,)
:param interp_method: (optional) interpolation method for resampling, 'linear' (default) or 'nearest'
:param subsample_res: (optional) if not None, this triggers a downsampling of the volume, prior to the resampling
step. List or numpy array of size (n_dims,). Default si None.
:param volume_res: (optional) if subsample_res is not None, this should be provided to compute downsampling ratio.
list or numpy array of size (n_dims,). Default is None.
:param build_reliability_map: whether to return reliability map along with the resampled tensor. This map indicates
which slices of the resampled tensor are interpolated (0=interpolated, 1=real slice, in between=degree of realness).
:return: resampled volume, with reliability map if necessary.
"""

# reformat resolutions to lists
subsample_res = utils.reformat_to_list(subsample_res)
volume_res = utils.reformat_to_list(volume_res)
n_dims = len(resample_shape)

# downsample image
tensor_shape = tensor.get_shape().as_list()[1:-1]
downsample_shape = tensor_shape # will be modified if we actually downsample

if subsample_res is not None:
assert volume_res is not None, 'volume_res must be given when providing a subsampling resolution.'
assert len(subsample_res) == len(volume_res), 'subsample_res and volume_res must have the same length, ' \
'had {0}, and {1}'.format(len(subsample_res), len(volume_res))
if subsample_res != volume_res:

# get shape at which we downsample
downsample_shape = [int(tensor_shape[i] * volume_res[i] / subsample_res[i]) for i in range(n_dims)]

# downsample volume
tensor._keras_shape = tuple(tensor.get_shape().as_list())
tensor = nrn_layers.Resize(size=downsample_shape, interp_method='nearest')(tensor)

# resample image at target resolution
if resample_shape != downsample_shape: # if we didn't downsample downsample_shape = tensor_shape
tensor._keras_shape = tuple(tensor.get_shape().as_list())
tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)(tensor)

# compute reliability maps if necessary and return results
if build_reliability_map:

# compute maps only if we downsampled
if downsample_shape != tensor_shape:

# compute upsampling factors
upsampling_factors = np.array(resample_shape) / np.array(downsample_shape)

# build reliability map
reliability_map = 1
for i in range(n_dims):
loc_float = np.arange(0, resample_shape[i], upsampling_factors[i])
loc_floor = np.int32(np.floor(loc_float))
loc_ceil = np.int32(np.clip(loc_floor + 1, 0, resample_shape[i] - 1))
tmp_reliability_map = np.zeros(resample_shape[i])
tmp_reliability_map[loc_floor] = 1 - (loc_float - loc_floor)
tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + (loc_float - loc_floor)
shape = [1, 1, 1]
shape[i] = resample_shape[i]
reliability_map = reliability_map * np.reshape(tmp_reliability_map, shape)
shape = KL.Lambda(lambda x: tf.shape(x))(tensor)
mask = KL.Lambda(lambda x: tf.reshape(tf.convert_to_tensor(reliability_map, dtype='float32'),
shape=x))(shape)

# otherwise just return an all-one tensor
else:
mask = KL.Lambda(lambda x: tf.ones_like(x))(tensor)

return tensor, mask

else:
return tensor


def expand_dims(tensor, axis=0):
"""Expand the dimensions of the input tensor along the provided axes (given as an integer or a list)."""
axis = utils.reformat_to_list(axis)
for ax in axis:
tensor = tf.expand_dims(tensor, axis=ax)
return tensor
Loading

0 comments on commit a2b857a

Please sign in to comment.