Skip to content

Commit

Permalink
Merge pull request #295 from neuronets/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
satra committed Mar 22, 2024
2 parents 474f2a7 + d271fc9 commit c5263f0
Show file tree
Hide file tree
Showing 8 changed files with 883 additions and 16 deletions.
3 changes: 3 additions & 0 deletions nobrainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def from_tfrecords(
)
block_length = len([0 for _ in first_shard])

if not n_volumes:
n_volumes = block_length * len(files)

dataset = dataset.interleave(
map_func=lambda x: tf.data.TFRecordDataset(
x, compression_type=compression_type
Expand Down
File renamed without changes.
44 changes: 32 additions & 12 deletions nobrainer/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
from pprint import pprint

from .attention_unet import attention_unet
from .attention_unet_with_inception import attention_unet_with_inception
from .autoencoder import autoencoder
from .bayesian_meshnet import variational_meshnet
from .dcgan import dcgan
from .highresnet import highresnet
from .meshnet import meshnet
from .progressiveae import progressiveae
from .progressivegan import progressivegan
from .unet import unet
from .unetr import unetr

__all__ = ["get", "list_available_models"]

_models = {
"highresnet": highresnet,
"meshnet": meshnet,
"unet": unet,
"autoencoder": autoencoder,
"progressivegan": progressivegan,
"progressiveae": progressiveae,
"dcgan": dcgan,
"attention_unet": attention_unet,
"attention_unet_with_inception": attention_unet_with_inception,
"unetr": unetr,
"variational_meshnet": variational_meshnet,
}


def get(name):
Expand All @@ -21,20 +43,18 @@ def get(name):
if not isinstance(name, str):
raise ValueError("Model name must be a string.")

models = {
"highresnet": highresnet,
"meshnet": meshnet,
"unet": unet,
"autoencoder": autoencoder,
"progressivegan": progressivegan,
"progressiveae": progressiveae,
"dcgan": dcgan,
}

try:
return models[name.lower()]
return _models[name.lower()]
except KeyError:
avail = ", ".join(models.keys())
avail = ", ".join(_models.keys())
raise ValueError(
"Unknown model: '{}'. Available models are {}.".format(name, avail)
)


def available_models():
return list(_models)


def list_available_models():
pprint(available_models())
80 changes: 80 additions & 0 deletions nobrainer/models/attention_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Model definition for Attention U-Net.
Adapted from https://github.com/nikhilroxtomar/Semantic-Segmentation-Architecture/blob/main/TensorFlow/attention-unet.py
""" # noqa: E501

from tensorflow.keras import layers
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model


def conv_block(x, num_filters):
x = L.Conv3D(num_filters, 3, padding="same")(x)
x = L.BatchNormalization()(x)
x = L.Activation("relu")(x)

x = L.Conv3D(num_filters, 3, padding="same")(x)
x = L.BatchNormalization()(x)
x = L.Activation("relu")(x)

return x


def encoder_block(x, num_filters):
x = conv_block(x, num_filters)
p = L.MaxPool3D()(x)
return x, p


def attention_gate(g, s, num_filters):
Wg = L.Conv3D(num_filters, 1, padding="same")(g)
Wg = L.BatchNormalization()(Wg)

Ws = L.Conv3D(num_filters, 1, padding="same")(s)
Ws = L.BatchNormalization()(Ws)

out = L.Activation("relu")(Wg + Ws)
out = L.Conv3D(num_filters, 1, padding="same")(out)
out = L.Activation("sigmoid")(out)

return out * s


def decoder_block(x, s, num_filters):
x = L.UpSampling3D()(x)
s = attention_gate(x, s, num_filters)
x = L.Concatenate()([x, s])
x = conv_block(x, num_filters)
return x


def attention_unet(n_classes, input_shape):
"""Inputs"""
inputs = L.Input(input_shape)

""" Encoder """
s1, p1 = encoder_block(inputs, 64)
s2, p2 = encoder_block(p1, 128)
s3, p3 = encoder_block(p2, 256)

b1 = conv_block(p3, 512)

""" Decoder """
d1 = decoder_block(b1, s3, 256)
d2 = decoder_block(d1, s2, 128)
d3 = decoder_block(d2, s1, 64)

""" Outputs """
outputs = L.Conv3D(n_classes, 1, padding="same")(d3)

final_activation = "sigmoid" if n_classes == 1 else "softmax"
outputs = layers.Activation(final_activation)(outputs)

""" Model """
return Model(inputs=inputs, outputs=outputs, name="Attention_U-Net")


if __name__ == "__main__":
n_classes = 50
input_shape = (256, 256, 256, 3)
model = attention_unet(n_classes, input_shape)
model.summary()
Loading

0 comments on commit c5263f0

Please sign in to comment.