From 7b50239299059f94a2733a7e6c5e5ecf311b2e8c Mon Sep 17 00:00:00 2001 From: bischtob Date: Thu, 22 Feb 2024 16:48:14 -0800 Subject: [PATCH] basic control net infra --- src/CliMAgen.jl | 3 +- src/networks.jl | 259 ++++++++++++++++++++++++++++++++++++++++++ test/test_networks.jl | 31 +++++ 3 files changed, 292 insertions(+), 1 deletion(-) diff --git a/src/CliMAgen.jl b/src/CliMAgen.jl index 13fb0f61..18f741f2 100644 --- a/src/CliMAgen.jl +++ b/src/CliMAgen.jl @@ -35,9 +35,10 @@ export struct2dict, dict2nt export VarianceExplodingSDE export drift, diffusion, marginal_prob, score export vanilla_score_matching_loss, score_matching_loss -export NoiseConditionalScoreNetwork, DenoisingDiffusionNetwork, ResnetBlockDDN, ResnetBlockNCSN, AttentionBlock, CircularConv +export NoiseConditionalScoreNetwork, DenoisingDiffusionNetwork, ResnetBlockDDN, ResnetBlockNCSN, AttentionBlock, CircularConv, ControlledNoiseConditionalScoreNetwork export WarmupSchedule, ExponentialMovingAverage export train!, load_model_and_optimizer, save_model_and_optimizer export setup_sampler, Euler_Maruyama_sampler, Euler_Maruyama_ld_sampler, predictor_corrector_sampler export MeanSpatialScaling, StandardScaling, apply_preprocessing, invert_preprocessing + end diff --git a/src/networks.jl b/src/networks.jl index db21c3b7..a4a3f18d 100644 --- a/src/networks.jl +++ b/src/networks.jl @@ -994,3 +994,262 @@ Carries on the spatial convolution respecting periodicity at the boundaries. function (layer::CircularConv)(x) layer.conv(NNlib.pad_circular(x, layer.pad)) end + +""" + CliMAgen.ControlNet + +The struct containing the parameters and layers of the ControlNet architecture. +""" +struct ControlNet{N} + net::N + trainable::Bool # whether the network is trainable +end + +""" + ControlNet(net::N; trainable::Bool=false) + +Creates a ControlNet with the given neural network `net` and whether it is trainable. +""" +function ControlNet(net::N; trainable::Bool=false) where N + return ControlNet{N}(net, trainable) +end + +@functor ControlNet + +""" + (net::ControlNet)(x) + +Evaluates the neural network of the ControlNet model on `x`. +""" +function (c::ControlNet)(x) + return c.net(x) +end + +""" + Flux.params(::ControlNet) +Returns the trainable parameters of the ControlNet). +""" +Flux.params(c::ControlNet) = c.trainable ? Flux.params(c.net) : nothing + +""" + CliMAgen.ControlledNoiseConditionalScoreNetwork + +The struct containing the parameters and layers +of the Noise Conditional Score Network architecture, +with the option to include a mean-bypass layer. + +# References +Unet: https://arxiv.org/abs/1505.04597 +""" +struct ControlledNoiseConditionalScoreNetwork{N} + "The layers of the network" + layers::NamedTuple + "A control network to condition the output of the U-net" + control_net::N + "A boolean indicating if a mean-bypass layer should be used" + mean_bypass::Bool + "A boolean indicating if the output of the mean-bypass layer should be scaled" + scale_mean_bypass::Bool + "A boolean indicating if the input is demeaned before being passed to the U-net" + shift_input::Bool + "A boolean indicating if the output of the Unet is demeaned" + shift_output::Bool + "A boolean indicating if a groupnorm should be used in the mean-bypass layer" + gnorm::Bool +end + +function ControlledNoiseConditionalScoreNetwork(; control_net, + mean_bypass=false, + scale_mean_bypass=false, + shift_input=false, + shift_output=false, + gnorm=false, + nspatial=2, + dropout_p=0.0f0, + num_residual=8, + noised_channels=1, + channels=[32, 64, 128, 256], + embed_dim=256, + scale=30.0f0, + periodic=false, + proj_kernelsize=3, + outer_kernelsize=3, + middle_kernelsize=3, + inner_kernelsize=3) + if scale_mean_bypass & !mean_bypass + @error("Attempting to scale the mean bypass term without adding in a mean bypass connection.") + end + if gnorm & !mean_bypass + @error("Attempting to gnorm without adding in a mean bypass connection.") + end + inchannels = noised_channels + outchannels = noised_channels + + # Mean processing as indicated by boolean mean_bypass + if mean_bypass + if gnorm + mean_bypass_layers = ( + mean_skip_1 = Conv((1, 1), inchannels => embed_dim), + mean_skip_2 = Conv((1, 1), embed_dim => embed_dim), + mean_skip_3 = Conv((1, 1), embed_dim => outchannels), + mean_gnorm_1 = GroupNorm(embed_dim, 32, swish), + mean_gnorm_2 = GroupNorm(embed_dim, 32, swish), + mean_dense_1 = Dense(embed_dim, embed_dim), + mean_dense_2 = Dense(embed_dim, embed_dim), + ) + else + mean_bypass_layers = ( + mean_skip_1 = Conv((1, 1), inchannels => embed_dim), + mean_skip_2 = Conv((1, 1), embed_dim => embed_dim), + mean_skip_3 = Conv((1, 1), embed_dim => outchannels), + mean_dense_1 = Dense(embed_dim, embed_dim), + mean_dense_2 = Dense(embed_dim, embed_dim), + ) + end + else + mean_bypass_layers = () + end + + # Lifting/Projection layers depend on periodicity of data + if periodic + conv1 = CircularConv(3, nspatial, inchannels => channels[1] ; stride=1) + tconv1 = CircularConv(proj_kernelsize, nspatial, channels[1] + channels[1] => outchannels; stride=1) + else + conv1=Conv((3, 3), inchannels => channels[1], stride=1, pad=SamePad()) + tconv1=Conv((proj_kernelsize, proj_kernelsize), channels[1] + channels[1] => outchannels, stride=1, pad=SamePad()) + end + + layers = (gaussfourierproj=GaussianFourierProjection(embed_dim, scale), + linear=Dense(embed_dim, embed_dim, swish), + + # Lifting + conv1=conv1, + dense1=Dense(embed_dim, channels[1]), + control_dense1=Dense(embed_dim, channels[1]), + gnorm1=GroupNorm(channels[1], 4, swish), + + # Encoding + conv2=Downsampling(channels[1] => channels[2], nspatial, kernel_size=3, periodic=periodic), + dense2=Dense(embed_dim, channels[2]), + control_dense2=Dense(embed_dim, channels[2]), + gnorm2=GroupNorm(channels[2], 32, swish), + + conv3=Downsampling(channels[2] => channels[3], nspatial, kernel_size=3, periodic=periodic), + dense3=Dense(embed_dim, channels[3]), + control_dense3=Dense(embed_dim, channels[3]), + gnorm3=GroupNorm(channels[3], 32, swish), + + conv4=Downsampling(channels[3] => channels[4], nspatial, kernel_size=3, periodic=periodic), + dense4=Dense(embed_dim, channels[4]), + control_dense4=Dense(embed_dim, channels[4]), + + # Residual Blocks + resnet_blocks = + [ResnetBlockNCSN(channels[end], nspatial, embed_dim; p = dropout_p, periodic=periodic) for _ in range(1, length=num_residual)], + + # Decoding + gnorm4=GroupNorm(channels[4], 32, swish), + tconv4=Upsampling(channels[4] => channels[3], nspatial, kernel_size=inner_kernelsize, periodic=periodic), + denset4=Dense(embed_dim, channels[3]), + control_denset4=Dense(embed_dim, channels[3]), + tgnorm4=GroupNorm(channels[3], 32, swish), + + tconv3=Upsampling(channels[3]+channels[3] => channels[2], nspatial, kernel_size=middle_kernelsize, periodic=periodic), + denset3=Dense(embed_dim, channels[2]), + control_denset3=Dense(embed_dim, channels[2]), + tgnorm3=GroupNorm(channels[2], 32, swish), + + tconv2=Upsampling(channels[2]+channels[2] => channels[1], nspatial, kernel_size=outer_kernelsize, periodic=periodic), + denset2=Dense(embed_dim, channels[1]), + control_denset2=Dense(embed_dim, channels[1]), + tgnorm2=GroupNorm(channels[1], 32, swish), + + # Projection + tconv1=tconv1, + mean_bypass_layers... + ) + + return ControlledNoiseConditionalScoreNetwork(layers, control_net, mean_bypass, scale_mean_bypass, shift_input, shift_output, gnorm) +end + +@functor ControlledNoiseConditionalScoreNetwork + +""" + (net::ControlledNoiseConditionalScoreNetwork)(x, c, t) + +Evaluates the neural network of the NoiseConditionalScoreNetwork +model on (x,c,t), where `x` is the tensor of noised input, +`c` is the tensor of contextual input, and `t` is a tensor of times. +""" +function (net::ControlledNoiseConditionalScoreNetwork)(x, c, t) + # Get size of spatial dimensions + nspatial = ndims(x) - 2 + + # Embeddings + embed = net.layers.gaussfourierproj(t) + embed = net.layers.linear(embed) + control_embed = net.control_net(c) + + # Encoder + if net.shift_input + h1 = x .- mean(x, dims=(1:nspatial)) # remove mean of noised variables before input + else + h1 = x + end + h1 = net.layers.conv1(h1) + h1 = h1 .+ expand_dims(net.layers.dense1(embed) .+ net.layers.control_dense1(control_embed), nspatial) + h1 = net.layers.gnorm1(h1) + h2 = net.layers.conv2(h1) + h2 = h2 .+ expand_dims(net.layers.dense2(embed) .+ net.layers.control_dense2(control_embed), nspatial) + h2 = net.layers.gnorm2(h2) + h3 = net.layers.conv3(h2) + h3 = h3 .+ expand_dims(net.layers.dense3(embed) .+ net.layers.control_dense3(control_embed), nspatial) + h3 = net.layers.gnorm3(h3) + h4 = net.layers.conv4(h3) + h4 = h4 .+ expand_dims(net.layers.dense4(embed) .+ net.layers.control_dense4(control_embed), nspatial) + + # middle + h = h4 + for block in net.layers.resnet_blocks + h = block(h, embed .+ control_embed) # add in control embedding, can perhaps be done better. + end + + # Decoder + h = net.layers.gnorm4(h) + h = net.layers.tconv4(h) + h = h .+ expand_dims(net.layers.denset4(embed) .+ net.layers.control_denset4(control_embed), nspatial) + h = net.layers.tgnorm4(h) + h = net.layers.tconv3(cat(h, h3; dims=nspatial+1)) + h = h .+ expand_dims(net.layers.denset3(embed) .+ net.layers.control_denset3(control_embed), nspatial) + h = net.layers.tgnorm3(h) + h = net.layers.tconv2(cat(h, h2, dims=nspatial+1)) + h = h .+ expand_dims(net.layers.denset2(embed) .+ net.layers.control_denset2(control_embed), nspatial) + h = net.layers.tgnorm2(h) + h = net.layers.tconv1(cat(h, h1, dims=nspatial+1)) + if net.shift_output + h = h .- mean(h, dims=(1:nspatial)) # remove mean after output + end + + # Mean processing of noised variable channels + if net.mean_bypass + hm = net.layers.mean_skip_1(mean(x, dims=(1:nspatial))) + hm = hm .+ expand_dims(net.layers.mean_dense_1(embed), nspatial) + if net.gnorm + hm = net.layers.mean_gnorm_1(hm) + end + hm = net.layers.mean_skip_2(hm) + hm = hm .+ expand_dims(net.layers.mean_dense_2(embed), nspatial) + if net.gnorm + hm = net.layers.mean_gnorm_2(hm) + end + hm = net.layers.mean_skip_3(hm) + if net.scale_mean_bypass + scale = convert(eltype(x), sqrt(prod(size(x)[1:nspatial]))) + hm = hm ./ scale + end + # Add back in noised channel mean to noised channel spatial variatons + return h .+ hm + else + return h + end +end diff --git a/test/test_networks.jl b/test/test_networks.jl index 868bc6fa..29f29fe4 100644 --- a/test/test_networks.jl +++ b/test/test_networks.jl @@ -193,4 +193,35 @@ end sum(net(x, c, t) .^ 2) end @test loss isa Real +end + +@testset "ControlNet" begin + # constructor + net = Dense(10, 5) + controlnet = CliMAgen.ControlNet(net, trainable=true) + @test controlnet.net == net + @test controlnet.trainable == true + + x = randn(10) + @test controlnet(x) == net(x) +end + +@testset "ControlledNoiseConditionalScoreNetwork" begin + # with controlnet + control_net = CliMAgen.ControlNet(Dense(11, 256), trainable=true) + net = CliMAgen.ControlledNoiseConditionalScoreNetwork(control_net=control_net, noised_channels=2) + ps = Flux.params(net) + k = 5 + x = rand(Float32, 2^k, 2^k, 2, 11) + c = rand(Float32, 11) + t = rand(Float32) + + # forward pass + @test net(x, c, t) |> size == (2^k, 2^k, 2, 11) + + # backward pass of dummy loss + loss, grad = Flux.withgradient(ps) do + sum(net(x, c, t) .^ 2) + end + @test loss isa Real end \ No newline at end of file