Skip to content

Commit

Permalink
Merge pull request #82 from slimgroup/summ_net
Browse files Browse the repository at this point in the history
Summ net
  • Loading branch information
rafaelorozco authored Jun 9, 2023
2 parents 671e9c1 + 7fb07d9 commit cc8a408
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 38 deletions.
2 changes: 2 additions & 0 deletions src/InvertibleNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ include("networks/invertible_network_conditional_glow.jl")
include("networks/invertible_network_conditional_hint.jl")
include("networks/invertible_network_conditional_hint_multiscale.jl")

include("networks/summarized_net.jl")

# Jacobians
include("utils/jacobian.jl")

Expand Down
14 changes: 7 additions & 7 deletions src/layers/layer_resnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
export ResNet


function ResNet(n_in::Int64, n_hidden::Int64, nblocks::Int64; k::Int64=3, p::Int64=1, s::Int64=1, norm::Union{Nothing, String}="batch", n_out::Union{Nothing, Int64}=nothing)

function ResNet(n_in::Int64, n_hidden::Int64, nblocks::Int64; k::Int64=3, p::Int64=1, s::Int64=1, norm::Union{Nothing, String}="batch", n_out::Union{Nothing, Int64}=nothing,ndims=2)
k1 = Tuple(k for i=1:ndims)
resnet_blocks = Array{Any, 1}(undef, nblocks)
for i = 1:nblocks-1
# Normalization layer
(norm == "batch") && (NormLayer = BatchNorm(n_hidden))
(norm === nothing) && (NormLayer = identity)

# Skip-connection
resnet_blocks[i] = SkipConnection(Chain(Conv((k, k), n_in => n_hidden; stride = s, pad = p),
resnet_blocks[i] = SkipConnection(Chain(Conv(k1, n_in => n_hidden; stride = s, pad = p),
NormLayer, x->relu.(x),
Conv((k, k), n_hidden => n_in; stride = s, pad = p)), +)
Conv(k1, n_hidden => n_in; stride = s, pad = p)), +)
end

# Last layer
Expand All @@ -26,12 +26,12 @@ function ResNet(n_in::Int64, n_hidden::Int64, nblocks::Int64; k::Int64=3, p::Int
(norm === nothing) && (NormLayer = identity)

# Skip-connection
resnet_blocks[end] = SkipConnection(Chain(Conv((k, k), n_in => n_hidden; stride = s, pad = p),
resnet_blocks[end] = SkipConnection(Chain(Conv(k1, n_in => n_hidden; stride = s, pad = p),
NormLayer, x->relu.(x),
Conv((k, k), n_hidden => n_in; stride = s, pad = p)), +)
Conv(k1, n_hidden => n_in; stride = s, pad = p)), +)
else
# Simple convolution
resnet_blocks[end] = Conv((k, k), n_in => n_out; stride = s, pad = p)
resnet_blocks[end] = Conv(k1, n_in => n_out; stride = s, pad = p)
end
return FluxBlock(Chain(resnet_blocks...))

Expand Down
34 changes: 16 additions & 18 deletions src/networks/invertible_network_conditional_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
export NetworkConditionalGlow, NetworkConditionalGlow3D

"""
G = NetworkGlow(n_in, n_hidden, L, K; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1)
G = NetworkGlow(n_in, n_cond, n_hidden, L, K; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1)
G = NetworkGlow3D(n_in, n_hidden, L, K; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1)
G = NetworkGlow3D(n_in, n_cond, n_hidden, L, K; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1)
Create an invertible network based on the Glow architecture. Each flow step in the inner loop
Create a conditional invertible network based on the Glow architecture. Each flow step in the inner loop
consists of an activation normalization layer, followed by an invertible coupling layer with
1x1 convolutions and a residual block. The outer loop performs a squeezing operation prior
to the inner loop, and a splitting operation afterwards.
*Input*:
- 'n_in': number of input channels
- 'n_in': number of input channels of variable to sample
- 'n_cond': number of input channels of condition
- `n_hidden`: number of hidden units in residual blocks
Expand Down Expand Up @@ -46,9 +48,9 @@ export NetworkConditionalGlow, NetworkConditionalGlow3D
*Usage:*
- Forward mode: `Y, logdet = G.forward(X)`
- Forward mode: `ZX, ZC logdet = G.forward(X, C)`
- Backward mode: `ΔX, X = G.backward(ΔY, Y)`
- Backward mode: `ΔX, X, ΔC = G.backward(ΔZX, ZX, ZC)`
*Trainable parameters:*
Expand All @@ -73,7 +75,7 @@ end
@Flux.functor NetworkConditionalGlow

# Constructor
function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;freeze_conv=false, split_scales=false, rb_activation::ActivationFunction=ReLUlayer(), k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; freeze_conv=false, split_scales=false, rb_activation::ActivationFunction=ReLUlayer(), k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
AN = Array{ActNorm}(undef, L, K) # activation normalization
AN_C = ActNorm(n_cond; logdet=false) # activation normalization for condition
CL = Array{ConditionalLayerGlow}(undef, L, K) # coupling layers w/ 1x1 convolution and residual block
Expand Down Expand Up @@ -106,7 +108,6 @@ function forward(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkCondi
G.split_scales && (Z_save = array_of_array(X, G.L-1))
orig_shape = size(X)

# Dont need logdet for condition
C = G.AN_C.forward(C)

logdet = 0
Expand Down Expand Up @@ -147,37 +148,34 @@ function inverse(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkCondi
end

# Backward pass and compute gradients
function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow) where {T, N}

function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow;) where {T, N}
# Split data and gradients
if G.split_scales
ΔZ_save, ΔX = split_states(ΔX[:], G.Z_dims)
Z_save, X = split_states(X[:], G.Z_dims)
end

ΔC_total = T(0) .* C

ΔC = T(0) .* C
for i=G.L:-1:1
if G.split_scales && i < G.L
X = tensor_cat(X, Z_save[i])
ΔX = tensor_cat(ΔX, ΔZ_save[i])
end
for j=G.K:-1:1
ΔX, X, ΔC = G.CL[i, j].backward(ΔX, X, C)
ΔX, X, ΔC_ = G.CL[i, j].backward(ΔX, X, C)
ΔX, X = G.AN[i, j].backward(ΔX, X)
ΔC_total += ΔC
ΔC += ΔC_
end

if G.split_scales
C = G.squeezer.inverse(C)
ΔC_total = G.squeezer.inverse(ΔC_total)
ΔC = G.squeezer.inverse(ΔC)
X = G.squeezer.inverse(X)
ΔX = G.squeezer.inverse(ΔX)

end
end

ΔC_total, C = G.AN_C.backward(ΔC_total, C)

return ΔX, X
ΔC, C = G.AN_C.backward(ΔC, C)
return ΔX, X, ΔC
end
56 changes: 56 additions & 0 deletions src/networks/summarized_net.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
export SummarizedNet

"""
G = SummarizedNet(cond_net, sum_net)
Create a summarized neural conditional approximator from conditional approximator cond_net and summary network sum_net.
*Input*:
- 'cond_net': invertible conditional distribution approximator
- 'sum_net': Should be flux layer. summary network. Should be invariant to a dimension of interest.
*Output*:
- `G`: summarized network.
*Usage:*
- Forward mode: `ZX, ZY, logdet = G.forward(X, Y)`
- Backward mode: `ΔX, X, ΔY = G.backward(ΔZX, ZX, ZY; Y_save=Y)`
- inverse mode: `ZX, ZY logdet = G.inverse(ZX, ZY)`
*Trainable parameters:*
- None in `G` itself
- Trainable parameters in conditional approximator `G.cond_net` and smmary network `G.sum_net`,
See also: [`ActNorm`](@ref), [`CouplingLayerGlow!`](@ref), [`get_params`](@ref), [`clear_grad!`](@ref)
"""
struct SummarizedNet <: InvertibleNetwork
cond_net::InvertibleNetwork
sum_net
end

@Flux.functor SummarizedNet

# Forward pass
function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNet) where {T, N}
S.cond_net(X, S.sum_net(Y))
end

# Inverse pass
function inverse(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNet) where {T, N}
S.cond_net.inverse(X, Y)
end

# Backward pass and compute gradients
function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNet; Y_save=nothing) where {T, N}
ΔX, X, ΔY = S.cond_net.backward(ΔX,X,Y)
ΔY = S.sum_net.backward(ΔY, Y_save)
return ΔX, X, ΔY
end
6 changes: 5 additions & 1 deletion src/utils/neuralnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,8 @@ end

# Make invertible nets callable objects
(net::Invertible)(X::AbstractArray{T,N} where {T, N}) = forward_net(net, X, getfield.(get_params(net), :data))
forward_net(net::Invertible, X::AbstractArray{T,N}, ::Any) where {T, N} = net.forward(X)
forward_net(net::Invertible, X::AbstractArray{T,N}, ::Any) where {T, N} = net.forward(X)

# Make conditional invertible nets callable objects
(net::Invertible)(X::AbstractArray{T,N}, Y::AbstractArray{T,N}) where {T, N} = forward_net(net, X, Y, getfield.(get_params(net), :data))
forward_net(net::Invertible, X::AbstractArray{T,N}, Y::AbstractArray{T,N}, ::Any) where {T, N} = net.forward(X,Y)
Loading

0 comments on commit cc8a408

Please sign in to comment.