diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 791737fd..30c5be53 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -121,6 +121,9 @@ def from_tfrecords( ) block_length = len([0 for _ in first_shard]) + if not n_volumes: + n_volumes = block_length * len(files) + dataset = dataset.interleave( map_func=lambda x: tf.data.TFRecordDataset( x, compression_type=compression_type diff --git a/nobrainer/distrubuted_learning/dwc.py b/nobrainer/distributed_learning/dwc.py similarity index 100% rename from nobrainer/distrubuted_learning/dwc.py rename to nobrainer/distributed_learning/dwc.py diff --git a/nobrainer/models/__init__.py b/nobrainer/models/__init__.py index 8bc4d125..a4842bc7 100644 --- a/nobrainer/models/__init__.py +++ b/nobrainer/models/__init__.py @@ -1,10 +1,32 @@ +from pprint import pprint + +from .attention_unet import attention_unet +from .attention_unet_with_inception import attention_unet_with_inception from .autoencoder import autoencoder +from .bayesian_meshnet import variational_meshnet from .dcgan import dcgan from .highresnet import highresnet from .meshnet import meshnet from .progressiveae import progressiveae from .progressivegan import progressivegan from .unet import unet +from .unetr import unetr + +__all__ = ["get", "list_available_models"] + +_models = { + "highresnet": highresnet, + "meshnet": meshnet, + "unet": unet, + "autoencoder": autoencoder, + "progressivegan": progressivegan, + "progressiveae": progressiveae, + "dcgan": dcgan, + "attention_unet": attention_unet, + "attention_unet_with_inception": attention_unet_with_inception, + "unetr": unetr, + "variational_meshnet": variational_meshnet, +} def get(name): @@ -21,20 +43,18 @@ def get(name): if not isinstance(name, str): raise ValueError("Model name must be a string.") - models = { - "highresnet": highresnet, - "meshnet": meshnet, - "unet": unet, - "autoencoder": autoencoder, - "progressivegan": progressivegan, - "progressiveae": progressiveae, - "dcgan": dcgan, - } - try: - return models[name.lower()] + return _models[name.lower()] except KeyError: - avail = ", ".join(models.keys()) + avail = ", ".join(_models.keys()) raise ValueError( "Unknown model: '{}'. Available models are {}.".format(name, avail) ) + + +def available_models(): + return list(_models) + + +def list_available_models(): + pprint(available_models()) diff --git a/nobrainer/models/attention_unet.py b/nobrainer/models/attention_unet.py new file mode 100644 index 00000000..7b155daf --- /dev/null +++ b/nobrainer/models/attention_unet.py @@ -0,0 +1,80 @@ +"""Model definition for Attention U-Net. +Adapted from https://github.com/nikhilroxtomar/Semantic-Segmentation-Architecture/blob/main/TensorFlow/attention-unet.py +""" # noqa: E501 + +from tensorflow.keras import layers +import tensorflow.keras.layers as L +from tensorflow.keras.models import Model + + +def conv_block(x, num_filters): + x = L.Conv3D(num_filters, 3, padding="same")(x) + x = L.BatchNormalization()(x) + x = L.Activation("relu")(x) + + x = L.Conv3D(num_filters, 3, padding="same")(x) + x = L.BatchNormalization()(x) + x = L.Activation("relu")(x) + + return x + + +def encoder_block(x, num_filters): + x = conv_block(x, num_filters) + p = L.MaxPool3D()(x) + return x, p + + +def attention_gate(g, s, num_filters): + Wg = L.Conv3D(num_filters, 1, padding="same")(g) + Wg = L.BatchNormalization()(Wg) + + Ws = L.Conv3D(num_filters, 1, padding="same")(s) + Ws = L.BatchNormalization()(Ws) + + out = L.Activation("relu")(Wg + Ws) + out = L.Conv3D(num_filters, 1, padding="same")(out) + out = L.Activation("sigmoid")(out) + + return out * s + + +def decoder_block(x, s, num_filters): + x = L.UpSampling3D()(x) + s = attention_gate(x, s, num_filters) + x = L.Concatenate()([x, s]) + x = conv_block(x, num_filters) + return x + + +def attention_unet(n_classes, input_shape): + """Inputs""" + inputs = L.Input(input_shape) + + """ Encoder """ + s1, p1 = encoder_block(inputs, 64) + s2, p2 = encoder_block(p1, 128) + s3, p3 = encoder_block(p2, 256) + + b1 = conv_block(p3, 512) + + """ Decoder """ + d1 = decoder_block(b1, s3, 256) + d2 = decoder_block(d1, s2, 128) + d3 = decoder_block(d2, s1, 64) + + """ Outputs """ + outputs = L.Conv3D(n_classes, 1, padding="same")(d3) + + final_activation = "sigmoid" if n_classes == 1 else "softmax" + outputs = layers.Activation(final_activation)(outputs) + + """ Model """ + return Model(inputs=inputs, outputs=outputs, name="Attention_U-Net") + + +if __name__ == "__main__": + n_classes = 50 + input_shape = (256, 256, 256, 3) + model = attention_unet(n_classes, input_shape) + model.summary() diff --git a/nobrainer/models/attention_unet_with_inception.py b/nobrainer/models/attention_unet_with_inception.py new file mode 100644 index 00000000..81809317 --- /dev/null +++ b/nobrainer/models/attention_unet_with_inception.py @@ -0,0 +1,322 @@ +"""Attention U-net with inception layers. +Adapted from https://github.com/robinvvinod/unet +""" + +from tensorflow.keras import layers +import tensorflow.keras.backend as K +from tensorflow.keras.models import Model + +K.set_image_data_format("channels_last") + + +def expend_as(tensor, rep): + # Anonymous lambda function to expand the specified axis by a factor of argument, rep. + # If tensor has shape (512,512,N), lambda will return a tensor of shape + # (512,512,N*rep), if specified axis=2 + + my_repeat = layers.Lambda( + lambda x, repnum: K.repeat_elements(x, repnum, axis=4), + arguments={"repnum": rep}, + )(tensor) + return my_repeat + + +def conv3d_block( + input_tensor, + n_filters, + kernel_size=3, + batchnorm=True, + strides=1, + dilation_rate=1, + recurrent=1, +): + # A wrapper of the Keras Conv3D block to serve as a building block for + # downsampling layers + # Includes options to use batch normalization, dilation and recurrence + + conv = layers.Conv3D( + filters=n_filters, + kernel_size=kernel_size, + strides=strides, + kernel_initializer="he_normal", + padding="same", + dilation_rate=dilation_rate, + )(input_tensor) + if batchnorm: + conv = layers.BatchNormalization()(conv) + output = layers.LeakyReLU(alpha=0.1)(conv) + + for _ in range(recurrent - 1): + conv = layers.Conv3D( + filters=n_filters, + kernel_size=kernel_size, + strides=1, + kernel_initializer="he_normal", + padding="same", + dilation_rate=dilation_rate, + )(output) + if batchnorm: + conv = layers.BatchNormalization()(conv) + res = layers.LeakyReLU(alpha=0.1)(conv) + output = layers.Add()([output, res]) + + return output + + +def AttnGatingBlock(x, g, inter_shape): + shape_x = K.int_shape(x) + shape_g = K.int_shape(g) + + # Getting the gating signal to the same number of filters as the inter_shape + phi_g = layers.Conv3D( + filters=inter_shape, kernel_size=1, strides=1, padding="same" + )(g) + + # Getting the x signal to the same shape as the gating signal + theta_x = layers.Conv3D( + filters=inter_shape, + kernel_size=3, + strides=( + shape_x[1] // shape_g[1], + shape_x[2] // shape_g[2], + shape_x[3] // shape_g[3], + ), + padding="same", + )(x) + + # Element-wise addition of the gating and x signals + add_xg = layers.add([phi_g, theta_x]) + add_xg = layers.Activation("relu")(add_xg) + + # 1x1x1 convolution + psi = layers.Conv3D(filters=1, kernel_size=1, padding="same")(add_xg) + psi = layers.Activation("sigmoid")(psi) + shape_sigmoid = K.int_shape(psi) + + # Upsampling psi back to the original dimensions of x signal + upsample_sigmoid_xg = layers.UpSampling3D( + size=( + shape_x[1] // shape_sigmoid[1], + shape_x[2] // shape_sigmoid[2], + shape_x[3] // shape_sigmoid[3], + ) + )(psi) + + # Expanding the filter axis to the number of filters in the original x signal + upsample_sigmoid_xg = expend_as(upsample_sigmoid_xg, shape_x[4]) + + # Element-wise multiplication of attention coefficients back onto original x signal + attn_coefficients = layers.multiply([upsample_sigmoid_xg, x]) + + # Final 1x1x1 convolution to consolidate attention signal to original x dimensions + output = layers.Conv3D( + filters=shape_x[4], kernel_size=1, strides=1, padding="same" + )(attn_coefficients) + output = layers.BatchNormalization()(output) + return output + + +def transpose_block( + input_tensor, + skip_tensor, + n_filters, + kernel_size=3, + strides=1, + batchnorm=True, + recurrent=1, +): + # A wrapper of the Keras Conv3DTranspose block to serve as a building block + # for upsampling layers + + shape_x = K.int_shape(input_tensor) + shape_xskip = K.int_shape(skip_tensor) + + conv = layers.Conv3DTranspose( + filters=n_filters, + kernel_size=kernel_size, + padding="same", + strides=( + shape_xskip[1] // shape_x[1], + shape_xskip[2] // shape_x[2], + shape_xskip[3] // shape_x[3], + ), + kernel_initializer="he_normal", + )(input_tensor) + conv = layers.LeakyReLU(alpha=0.1)(conv) + + act = conv3d_block( + conv, + n_filters=n_filters, + kernel_size=kernel_size, + strides=1, + batchnorm=batchnorm, + dilation_rate=1, + recurrent=recurrent, + ) + output = layers.Concatenate(axis=4)([act, skip_tensor]) + return output + + +# Use the functions provided in layers3D to build the network +def inception_block( + input_tensor, + n_filters, + kernel_size=3, + strides=1, + batchnorm=True, + recurrent=1, + layers_list=[], +): + # Inception-style convolutional block similar to InceptionNet + # The first convolution follows the function arguments, while subsequent + # inception convolutions follow the parameters in + # argument, layers + + # layers is a nested list containing the different secondary inceptions in + # the format of (kernel_size, dil_rate) + + # E.g => layers=[ [(3,1),(3,1)], [(5,1)], [(3,1),(3,2)] ] + # This will implement 3 sets of secondary convolutions + # Set 1 => 3x3 dil = 1 followed by another 3x3 dil = 1 + # Set 2 => 5x5 dil = 1 + # Set 3 => 3x3 dil = 1 followed by 3x3 dil = 2 + + res = conv3d_block( + input_tensor, + n_filters=n_filters, + kernel_size=kernel_size, + strides=strides, + batchnorm=batchnorm, + dilation_rate=1, + recurrent=recurrent, + ) + + temp = [] + for layer in layers_list: + local_res = res + for conv in layer: + incep_kernel_size = conv[0] + incep_dilation_rate = conv[1] + local_res = conv3d_block( + local_res, + n_filters=n_filters, + kernel_size=incep_kernel_size, + strides=1, + batchnorm=batchnorm, + dilation_rate=incep_dilation_rate, + recurrent=recurrent, + ) + temp.append(local_res) + + temp = layers.concatenate(temp) + res = conv3d_block( + temp, + n_filters=n_filters, + kernel_size=1, + strides=1, + batchnorm=batchnorm, + dilation_rate=1, + ) + + shortcut = conv3d_block( + input_tensor, + n_filters=n_filters, + kernel_size=1, + strides=strides, + batchnorm=batchnorm, + dilation_rate=1, + ) + if batchnorm: + shortcut = layers.BatchNormalization()(shortcut) + + output = layers.Add()([shortcut, res]) + return output + + +def attention_unet_with_inception( + n_classes, input_shape, batch_size=None, n_filters=16, batchnorm=True +): + # contracting path + + inputs = layers.Input(shape=input_shape, batch_size=batch_size) + + c0 = inception_block( + inputs, + n_filters=n_filters, + batchnorm=batchnorm, + strides=1, + recurrent=2, + layers_list=[[(3, 1), (3, 1)], [(3, 2)]], + ) # 512x512x512 + + c1 = inception_block( + c0, + n_filters=n_filters * 2, + batchnorm=batchnorm, + strides=2, + recurrent=2, + layers_list=[[(3, 1), (3, 1)], [(3, 2)]], + ) # 256x256x256 + + c2 = inception_block( + c1, + n_filters=n_filters * 4, + batchnorm=batchnorm, + strides=2, + recurrent=2, + layers_list=[[(3, 1), (3, 1)], [(3, 2)]], + ) # 128x128x128 + + c3 = inception_block( + c2, + n_filters=n_filters * 8, + batchnorm=batchnorm, + strides=2, + recurrent=2, + layers_list=[[(3, 1), (3, 1)], [(3, 2)]], + ) # 64x64x64 + + # bridge + + b0 = inception_block( + c3, + n_filters=n_filters * 16, + batchnorm=batchnorm, + strides=2, + recurrent=2, + layers_list=[[(3, 1), (3, 1)], [(3, 2)]], + ) # 32x32x32 + + # expansive path + + attn0 = AttnGatingBlock(c3, b0, n_filters * 16) + u0 = transpose_block( + b0, attn0, n_filters=n_filters * 8, batchnorm=batchnorm, recurrent=2 + ) # 64x64x64 + + attn1 = AttnGatingBlock(c2, u0, n_filters * 8) + u1 = transpose_block( + u0, attn1, n_filters=n_filters * 4, batchnorm=batchnorm, recurrent=2 + ) # 128x128x128 + + attn2 = AttnGatingBlock(c1, u1, n_filters * 4) + u2 = transpose_block( + u1, attn2, n_filters=n_filters * 2, batchnorm=batchnorm, recurrent=2 + ) # 256x256x256 + + u3 = transpose_block( + u2, c0, n_filters=n_filters, batchnorm=batchnorm, recurrent=2 + ) # 512x512x512 + + outputs = layers.Conv3D(filters=1, kernel_size=1, strides=1)(u3) + + final_activation = "sigmoid" if n_classes == 1 else "softmax" + outputs = layers.Activation(final_activation)(outputs) + + model = Model(inputs=[inputs], outputs=[outputs]) + return model + + +if __name__ == "__main__": + model = attention_unet_with_inception(n_classes=1, input_shape=(256, 256, 256, 1)) + model.summary() diff --git a/nobrainer/models/tests/models_test.py b/nobrainer/models/tests/models_test.py index 68f4b985..45cffbdc 100644 --- a/nobrainer/models/tests/models_test.py +++ b/nobrainer/models/tests/models_test.py @@ -1,10 +1,15 @@ +import os + import numpy as np import pytest import tensorflow as tf from nobrainer.bayesian_utils import default_mean_field_normal_fn +from ..attention_unet import attention_unet +from ..attention_unet_with_inception import attention_unet_with_inception from ..autoencoder import autoencoder +from ..bayesian_meshnet import variational_meshnet from ..bayesian_vnet import bayesian_vnet from ..bayesian_vnet_semi import bayesian_vnet_semi from ..brainsiam import brainsiam @@ -14,9 +19,12 @@ from ..progressivegan import progressivegan from ..unet import unet from ..unet_lstm import unet_lstm +from ..unetr import unetr from ..vnet import vnet from ..vox2vox import Vox_ensembler, vox_gan +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + def model_test(model_cls, n_classes, input_shape, kwds={}): """Tests for models.""" @@ -241,3 +249,27 @@ def test_vox2vox(): pred_shape = (1, 2, 2, 2, 1) out = vox_discriminator(inputs=[y, x]) assert out.shape == pred_shape + + +def test_attention_unet(): + model_test(attention_unet, n_classes=1, input_shape=(1, 64, 64, 64, 1)) + + +def test_attention_unet_with_inception(): + model_test( + attention_unet_with_inception, n_classes=1, input_shape=(1, 64, 64, 64, 1) + ) + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Cannot test in GitHub Actions") +def test_unetr(): + model_test(unetr, n_classes=1, input_shape=(1, 96, 96, 96, 1)) + + +def test_variational_meshnet(): + model_test( + variational_meshnet, + n_classes=1, + input_shape=(1, 128, 128, 128, 1), + kwds={"filters": 4}, + ) diff --git a/nobrainer/models/unetr.py b/nobrainer/models/unetr.py new file mode 100644 index 00000000..30861617 --- /dev/null +++ b/nobrainer/models/unetr.py @@ -0,0 +1,384 @@ +"""UNETR implementation in Tensorflow 2.0. + +Adapted from https://www.kaggle.com/code/usharengaraju/tensorflow-unetr-w-b +""" +import math + +import tensorflow as tf + + +class SingleDeconv3DBlock(tf.keras.layers.Layer): + def __init__(self, filters): + super(SingleDeconv3DBlock, self).__init__() + self.block = tf.keras.layers.Conv3DTranspose( + filters=filters, + kernel_size=2, + strides=2, + padding="valid", + output_padding=None, + ) + + def call(self, inputs): + return self.block(inputs) + + +class SingleConv3DBlock(tf.keras.layers.Layer): + def __init__(self, filters, kernel_size): + super(SingleConv3DBlock, self).__init__() + self.kernel = kernel_size + self.res = tuple(map(lambda i: (i - 1) // 2, self.kernel)) + self.block = tf.keras.layers.Conv3D( + filters=filters, kernel_size=kernel_size, strides=1, padding="same" + ) + + def call(self, inputs): + return self.block(inputs) + + +class Conv3DBlock(tf.keras.layers.Layer): + def __init__(self, filters, kernel_size=(3, 3, 3)): + super(Conv3DBlock, self).__init__() + self.a = tf.keras.Sequential( + [ + SingleConv3DBlock(filters, kernel_size=kernel_size), + tf.keras.layers.BatchNormalization(), + tf.keras.layers.Activation("relu"), + ] + ) + + def call(self, inputs): + return self.a(inputs) + + +class Deconv3DBlock(tf.keras.layers.Layer): + def __init__(self, filters, kernel_size=(3, 3, 3)): + super(Deconv3DBlock, self).__init__() + self.a = tf.keras.Sequential( + [ + SingleDeconv3DBlock(filters=filters), + SingleConv3DBlock(filters=filters, kernel_size=kernel_size), + tf.keras.layers.BatchNormalization(), + tf.keras.layers.Activation("relu"), + ] + ) + + def call(self, inputs): + return self.a(inputs) + + +class SelfAttention(tf.keras.layers.Layer): + def __init__(self, num_heads, embed_dim, dropout): + super(SelfAttention, self).__init__() + + self.num_attention_heads = num_heads + self.attention_head_size = int(embed_dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = tf.keras.layers.Dense(self.all_head_size) + self.key = tf.keras.layers.Dense(self.all_head_size) + self.value = tf.keras.layers.Dense(self.all_head_size) + + self.out = tf.keras.layers.Dense(embed_dim) + self.attn_dropout = tf.keras.layers.Dropout(dropout) + self.proj_dropout = tf.keras.layers.Dropout(dropout) + + self.softmax = tf.keras.layers.Softmax() + + self.vis = False + + def transpose_for_scores(self, x): + new_x_shape = list( + x.shape[:-1] + (self.num_attention_heads, self.attention_head_size) + ) + new_x_shape[0] = -1 + y = tf.reshape(x, new_x_shape) + return tf.transpose(y, perm=[0, 2, 1, 3]) + + def call(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + attention_scores = query_layer @ tf.transpose(key_layer, perm=[0, 1, 3, 2]) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = attention_probs @ value_layer + context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) + new_context_layer_shape = list(context_layer.shape[:-2] + (self.all_head_size,)) + new_context_layer_shape[0] = -1 + context_layer = tf.reshape(context_layer, new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Mlp(tf.keras.layers.Layer): + def __init__(self, output_features, drop=0.0): + super(Mlp, self).__init__() + self.a = tf.keras.layers.Dense(units=output_features, activation=tf.nn.gelu) + self.b = tf.keras.layers.Dropout(drop) + + def call(self, inputs): + x = self.a(inputs) + return self.b(x) + + +class PositionwiseFeedForward(tf.keras.layers.Layer): + def __init__(self, d_model=768, d_ff=2048, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.a = tf.keras.layers.Dense(units=d_ff) + self.b = tf.keras.layers.Dense(units=d_model) + self.c = tf.keras.layers.Dropout(dropout) + + def call(self, inputs): + return self.b(self.c(tf.nn.relu(self.a(inputs)))) + + +# embeddings, projection_dim=embed_dim +class PatchEmbedding(tf.keras.layers.Layer): + def __init__(self, cube_size, patch_size, embed_dim): + super(PatchEmbedding, self).__init__() + self.num_of_patches = int( + (cube_size[0] * cube_size[1] * cube_size[2]) + / (patch_size * patch_size * patch_size) + ) + self.patch_size = patch_size + self.size = patch_size + self.embed_dim = embed_dim + + self.projection = tf.keras.layers.Dense(embed_dim) + + self.clsToken = tf.Variable( + tf.keras.initializers.GlorotNormal()(shape=(1, 512, embed_dim)), + trainable=True, + ) + + self.positionalEmbedding = tf.keras.layers.Embedding( + self.num_of_patches, embed_dim + ) + self.patches = None + self.lyer = tf.keras.layers.Conv3D( + filters=self.embed_dim, + kernel_size=self.patch_size, + strides=self.patch_size, + padding="valid", + ) + # embedding - basically is adding numerical embedding to the layer along with an extra dim + + def call(self, inputs): + patches = self.lyer(inputs) + patches = tf.reshape( + patches, (tf.shape(inputs)[0], -1, self.size * self.size * 3) + ) + patches = self.projection(patches) + positions = tf.range(0, self.num_of_patches, 1)[tf.newaxis, ...] + positionalEmbedding = self.positionalEmbedding(positions) + patches = patches + positionalEmbedding + + return patches, positionalEmbedding + + +# transformerblock +class TransformerLayer(tf.keras.layers.Layer): + def __init__(self, embed_dim, num_heads, dropout, cube_size, patch_size): + super(TransformerLayer, self).__init__() + + self.attention_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6) + + self.mlp_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6) + + # embed_dim/no-of_heads + self.mlp_dim = int( + (cube_size[0] * cube_size[1] * cube_size[2]) + / (patch_size * patch_size * patch_size) + ) + + self.mlp = PositionwiseFeedForward(embed_dim, 2048) + self.attn = SelfAttention(num_heads, embed_dim, dropout) + + def call(self, x, training=True): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + h = x + + x = self.mlp_norm(x) + x = self.mlp(x) + + x = x + h + + return x, weights + + +class TransformerEncoder(tf.keras.layers.Layer): + def __init__( + self, + embed_dim, + num_heads, + cube_size, + patch_size, + num_layers=12, + dropout=0.1, + extract_layers=[3, 6, 9, 12], + ): + super(TransformerEncoder, self).__init__() + # embed_dim, num_heads ,dropout, cube_size, patch_size + self.embeddings = PatchEmbedding(cube_size, patch_size, embed_dim) + self.extract_layers = extract_layers + self.encoders = [ + TransformerLayer(embed_dim, num_heads, dropout, cube_size, patch_size) + for _ in range(num_layers) + ] + + def call(self, inputs, training=True): + extract_layers = [] + x = inputs + x, _ = self.embeddings(x) + + for depth, layer in enumerate(self.encoders): + x, _ = layer(x, training=training) + if depth + 1 in self.extract_layers: + extract_layers.append(x) + + return extract_layers + + +class UNETR(tf.keras.Model): + def __init__( + self, + img_shape=(96, 96, 96), + input_dim=3, + output_dim=3, + embed_dim=768, + patch_size=16, + num_heads=12, + dropout=0.1, + ): + super(UNETR, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.embed_dim = embed_dim + self.img_shape = img_shape + self.patch_size = patch_size + self.num_heads = num_heads + self.dropout = dropout + self.num_layers = 12 + self.ext_layers = [3, 6, 9, 12] + + self.patch_dim = [int(x / patch_size) for x in img_shape] + self.transformer = TransformerEncoder( + self.embed_dim, + self.num_heads, + self.img_shape, + self.patch_size, + self.num_layers, + self.dropout, + self.ext_layers, + ) + + # U-Net Decoder + self.decoder0 = tf.keras.Sequential( + [Conv3DBlock(32, (3, 3, 3)), Conv3DBlock(64, (3, 3, 3))] + ) + + self.decoder3 = tf.keras.Sequential( + [Deconv3DBlock(512), Deconv3DBlock(256), Deconv3DBlock(128)] + ) + + self.decoder6 = tf.keras.Sequential([Deconv3DBlock(512), Deconv3DBlock(256)]) + + self.decoder9 = Deconv3DBlock(512) + + self.decoder12_upsampler = SingleDeconv3DBlock(512) + + self.decoder9_upsampler = tf.keras.Sequential( + [ + Conv3DBlock(512), + Conv3DBlock(512), + Conv3DBlock(512), + SingleDeconv3DBlock(256), + ] + ) + + self.decoder6_upsampler = tf.keras.Sequential( + [Conv3DBlock(256), Conv3DBlock(256), SingleDeconv3DBlock(128)] + ) + + self.decoder3_upsampler = tf.keras.Sequential( + [Conv3DBlock(128), Conv3DBlock(128), SingleDeconv3DBlock(64)] + ) + + self.decoder0_header = tf.keras.Sequential( + [Conv3DBlock(64), Conv3DBlock(64), SingleConv3DBlock(output_dim, (1, 1, 1))] + ) + + def call(self, x): + z = self.transformer(x) + z0, z3, z6, z9, z12 = x, z[0], z[1], z[2], z[3] + z3 = tf.reshape( + tf.transpose(z3, perm=[0, 2, 1]), [-1, *self.patch_dim, self.embed_dim] + ) + z6 = tf.reshape( + tf.transpose(z6, perm=[0, 2, 1]), [-1, *self.patch_dim, self.embed_dim] + ) + z9 = tf.reshape( + tf.transpose(z9, perm=[0, 2, 1]), [-1, *self.patch_dim, self.embed_dim] + ) + z12 = tf.reshape( + tf.transpose(z12, perm=[0, 2, 1]), [-1, *self.patch_dim, self.embed_dim] + ) + z12 = self.decoder12_upsampler(z12) + z9 = self.decoder9(z9) + z9 = self.decoder9_upsampler(tf.concat([z9, z12], 4)) + z6 = self.decoder6(z6) + z6 = self.decoder6_upsampler(tf.concat([z6, z9], 4)) + z3 = self.decoder3(z3) + z3 = self.decoder3_upsampler(tf.concat([z3, z6], 4)) + z0 = self.decoder0(z0) + output = self.decoder0_header(tf.concat([z0, z3], 4)) + return output + + # def model(self): + # x = tf.keras.layers.Input(shape=(96, 96, 96, 3)) + # return tf.keras.Model(inputs=[x], outputs=self.call(x)) + + +def unetr( + n_classes=1, + input_shape=(96, 96, 96, 3), + embed_dim=768, + patch_size=16, + num_heads=12, + dropout=0.1, +): + *img_shape, input_dim = input_shape + + input = tf.keras.layers.Input([*img_shape, input_dim], name="input_image") + + z = UNETR( + img_shape=img_shape, + input_dim=input_dim, + output_dim=n_classes, + embed_dim=embed_dim, + patch_size=patch_size, + num_heads=num_heads, + dropout=dropout, + )(input) + + final_activation = "sigmoid" if n_classes == 1 else "softmax" + output = tf.keras.layers.Activation(final_activation)(z) + + return tf.keras.Model(inputs=[input], outputs=[output]) + + +if __name__ == "__main__": + input_shape = (96, 96, 96, 3) + sub1 = unetr(input_shape=input_shape, n_classes=1) + sub1.summary() diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 66e1768d..7ba9a32a 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -5,6 +5,7 @@ from .base import BaseEstimator from .. import losses, metrics +from ..models import available_models, list_available_models logging.getLogger().setLevel(logging.INFO) @@ -23,12 +24,25 @@ def __init__( self.base_model = base_model.__name__ else: self.base_model = base_model + + if self.base_model and self.base_model not in available_models(): + raise ValueError( + "Unknown model: '{}'. Available models are {}.".format( + self.base_model, available_models() + ) + ) + self.model_ = None self.model_args = model_args or {} self.block_shape_ = None self.volume_shape_ = None self.scalar_labels_ = None + def add_model(self, base_model, model_args=None): + """Add a segmentation model""" + self.base_model = base_model + self.model_args = model_args or {} + def fit( self, dataset_train, @@ -39,6 +53,8 @@ def fit( opt_args=None, loss=losses.dice, metrics=metrics.dice, + callbacks=None, + verbose=1, ): """Train a segmentation model""" # TODO: check validity of datasets @@ -82,7 +98,12 @@ def _compile(): _compile() self.model_.summary() - callbacks = [] + if callbacks is not None and not isinstance(callbacks, list): + raise AttributeError("Callbacks must be either of type list or None") + + if callbacks is None: + callbacks = [] + if self.checkpoint_tracker: callbacks.append(self.checkpoint_tracker) self.model_.fit( @@ -90,10 +111,11 @@ def _compile(): epochs=epochs, steps_per_epoch=dataset_train.get_steps_per_epoch(), validation_data=dataset_validate.dataset if dataset_validate else None, - validation_steps=dataset_validate.get_steps_per_epoch() - if dataset_validate - else None, + validation_steps=( + dataset_validate.get_steps_per_epoch() if dataset_validate else None + ), callbacks=callbacks, + verbose=verbose, ) return self @@ -111,3 +133,7 @@ def predict(self, x, batch_size=1, normalizer=None): batch_size=batch_size, normalizer=normalizer, ) + + @classmethod + def list_available_models(cls): + list_available_models()