Skip to content

Commit

Permalink
Implemented learnable squeezers
Browse files Browse the repository at this point in the history
  • Loading branch information
= committed Sep 22, 2023
1 parent 2d6331f commit e567f76
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 65 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "2.2.5"
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JOLI = "bb331ad6-a1cf-11e9-23da-9bcb53c69f6f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 2 additions & 0 deletions src/InvertibleNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using LinearAlgebra, Random
using Statistics, Wavelets
using JOLI
using NNlib, Flux, ChainRulesCore
using ExponentialUtilities

# Overloads and reexports
import Base.size, Base.length, Base.getindex, Base.reverse, Base.reverse!, Base.getproperty
Expand Down Expand Up @@ -61,6 +62,7 @@ include("layers/invertible_layer_irim.jl")
include("layers/invertible_layer_glow.jl")
include("layers/invertible_layer_hyperbolic.jl")
include("layers/invertible_layer_hint.jl")
include("layers/learnable_squeezer.jl")

# Invertible network architectures
include("networks/invertible_network_hint_multiscale.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/layers/invertible_layer_hint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ function get_depth(n_in)
end

# Constructor for given coupling layer and 1 x 1 convolution
CouplingLayerHINT(CL::AbstractArray{CouplingLayerBasic, 1}, C::Union{Conv1x1, Nothing};
logdet=false, permute="none", activation::ActivationFunction=SigmoidLayer()) = CouplingLayerHINT(CL, C, logdet, permute, false)
CouplingLayerHINT(CL::AbstractArray{CouplingLayerBasic, 1}, C::Union{Conv1x1, Nothing}; logdet=false, permute="none") =
CouplingLayerHINT(CL, C, logdet, permute, false)

# 2D Constructor from input dimensions
function CouplingLayerHINT(n_in::Int64, n_hidden::Int64; logdet=false, permute="none",
Expand Down
146 changes: 146 additions & 0 deletions src/layers/learnable_squeezer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Learnable up-/down-sampling from Etmann et al., 2020, https://arxiv.org/abs/2005.05220
# The Frechet derivative of the matrix exponential is from Al-Mohy and Higham, 2009, https://epubs.siam.org/doi/10.1137/080716426

export LearnableSqueezer

mutable struct LearnableSqueezer <: InvertibleNetwork
stencil_pars::Parameter
pars2mat_idx
stencil_size
stencil::Union{AbstractArray,Nothing}
cdims::Union{DenseConvDims,Nothing}
logdet::Bool
reset::Bool
log_mat::Union{AbstractArray,Nothing}
end

@Flux.functor LearnableSqueezer


# Constructor

function LearnableSqueezer(stencil_size::Integer...; logdet::Bool=false)

σ = prod(stencil_size)
stencil_pars = vec2par(randn(Float32, div*-1), 2)), (div*-1), 2), ))
pars2mat_idx = _skew_symmetric_indices(σ)
return LearnableSqueezer(stencil_pars, pars2mat_idx, stencil_size, nothing, nothing, logdet, true, nothing)

end


# Forward/inverse/backward

function forward(X::AbstractArray{T,N}, C::LearnableSqueezer; logdet::Union{Nothing,Bool}=nothing) where {T,N}
isnothing(logdet) && (logdet = C.logdet)

# Compute exponential stencil
if C.reset
_compute_exponential_stencil!(C, size(X, N-1); set_log=true)
C.cdims = DenseConvDims(size(X), size(C.stencil); stride=C.stencil_size)
C.reset = false
end

# Convolution
X = conv(X, C.stencil, C.cdims)

return logdet ? (X, convert(T, 0)) : X

end

function inverse(Y::AbstractArray{T,N}, C::LearnableSqueezer; logdet::Union{Nothing,Bool}=nothing) where {T,N}
isnothing(logdet) && (logdet = C.logdet)
C.reset && throw(ArgumentError("The learnable squeezer must be evaluated forward first!"))

# Convolution (adjoint)
Y = ∇conv_data(Y, C.stencil, C.cdims)

return logdet ? (Y, convert(T, 0)) : Y

end

function backward(ΔY::AbstractArray{T,N}, Y::AbstractArray{T,N}, C::LearnableSqueezer; set_grad::Bool=true, trigger_recompute::Bool=true) where {T,N}
C.reset && throw(ArgumentError("The learnable squeezer must be evaluated forward first!"))

# Convolution (adjoint)
X = ∇conv_data(Y, C.stencil, C.cdims)
ΔX = ∇conv_data(ΔY, C.stencil, C.cdims)

# Parameter gradient
Δstencil = _mat2stencil_adjoint(∇conv_filter(X, ΔY, C.cdims), C.stencil_size, size(X, N-1))
ΔA = _Frechet_exponential(C.log_mat', Δstencil)
Δstencil_pars = ΔA[C.pars2mat_idx[1]]-ΔA[C.pars2mat_idx[2]]
set_grad && (C.stencil_pars.grad = Δstencil_pars)

# Trigger recomputation
trigger_recompute && (C.reset = true)

return set_grad ? (ΔX, X) : (ΔX, Δstencil_pars, X)

end


# Internal utilities for LearnableSqueezer

function _compute_exponential_stencil!(C::LearnableSqueezer, nc::Integer; set_log::Bool=false)
n = prod(C.stencil_size)
log_mat = _pars2skewsymm(C.stencil_pars.data, C.pars2mat_idx, n)
C.stencil = _mat2stencil(_exponential(log_mat), C.stencil_size, nc)
set_log && (C.log_mat = log_mat)
end

function _mat2stencil(A::AbstractMatrix{T}, k::NTuple{N,Integer}, nc::Integer) where {T,N}
stencil = similar(A, k..., nc, k..., nc); fill!(stencil, 0)
@inbounds for i = 1:nc
selectdim(selectdim(stencil, N+1, i), 2*N+1, i) .= reshape(A, k..., k...)
end
return reshape(stencil, k..., nc, :)
end

function _mat2stencil_adjoint(stencil::AbstractArray{T}, k::NTuple{N,Integer}, nc::Integer) where {T,N}
stencil = reshape(stencil, k..., nc, k..., nc)
A = similar(stencil, prod(k), prod(k)); fill!(A, 0)
@inbounds for i = 1:nc
A .+= reshape(selectdim(selectdim(stencil, N+1, i), 2*N+1, i), prod(k), prod(k))
end
return A
end

function _pars2skewsymm(Apars::AbstractVector{T}, pars2mat_idx::NTuple{2,AbstractVector{<:Integer}}, n::Integer) where T
A = similar(Apars, n, n)
A[pars2mat_idx[1]] .= Apars
A[pars2mat_idx[2]] .= -Apars
A[diagind(A)] .= 0
return A
end

function _exponential(A::AbstractMatrix{T}) where T
expA = copy(A)
exponential!(expA)
return expA
end

function _skew_symmetric_indices(nσ::Integer)
CIs = reshape(1:^2, nσ, nσ)
idx_u = Vector{Int}(undef, 0)
idx_l = Vector{Int}(undef, 0)
for i=1:nσ, j=i+1:# Indices related to (strictly) upper triangular part
push!(idx_u, CIs[i,j])
end
for j=1:nσ, i=j+1:# Indices related to (strictly) lower triangular part
push!(idx_l, CIs[i,j])
end
return idx_u, idx_l
end

function _Frechet_exponential(A::AbstractMatrix{T}, ΔA::AbstractMatrix{T}; niter::Int=40) where T
dA = copy(ΔA)
Mk = copy(ΔA)
Apowk = copy(A)
@inbounds for k = 2:niter
Mk .= Mk*A+Apowk*ΔA; Mk ./= k
dA .+= Mk
(k < niter) && (Apowk .= Apowk*A; Apowk ./= k)
end
return dA
end
2 changes: 1 addition & 1 deletion src/utils/neuralnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function get_params(I::InvertibleNetwork)
params
end

get_params(::Nothing) = Array{Parameter}(undef, 0)
get_params(::Any) = Array{Parameter}(undef, 0)
get_params(A::AbstractArray{T}) where {T <: Union{InvertibleNetwork, Nothing}} = vcat([get_params(A[i]) for i in 1:length(A)]...)
get_params(A::AbstractMatrix{T}) where {T <: Union{InvertibleNetwork, Nothing}} = vcat([get_params(A[i, j]) for i=1:size(A, 1) for j in 1:size(A, 2)]...)
get_params(RN::ReversedNetwork) = get_params(RN.I)
Expand Down
84 changes: 23 additions & 61 deletions src/utils/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ length(x::Parameter) = length(x.data)
or
clear_grad!(P::AbstractArray{Parameter, 1})
clear_grad!(P::AbstractAbstractVector{<:Parameter})
Set gradients of each `Parameter` in the network layer to `nothing`.
"""
function clear_grad!(P::AbstractArray{Parameter, 1})
function clear_grad!(P::AbstractVector{<:Parameter})
for j=1:length(P)
P[j].grad = nothing
end
Expand All @@ -60,8 +60,8 @@ function get_grads(p::Parameter)
return Parameter(p.grad)
end

function get_grads(pvec::Array{Parameter, 1})
g = Array{Parameter, 1}(undef, length(pvec))
function get_grads(pvec::AbstractVector{<:Parameter})
g = Vector{Parameter}(undef, length(pvec))
for i = 1:length(pvec)
g[i] = get_grads(pvec[i])
end
Expand All @@ -75,13 +75,13 @@ function set_params!(pold::Parameter, pnew::Parameter)
pold.grad = pnew.grad
end

function set_params!(pold::Array{Parameter, 1}, pnew::Array{Parameter, 1})
function set_params!(pold::AbstractVector{<:Parameter}, pnew::AbstractVector{<:Parameter})
for i = 1:length(pold)
set_params!(pold[i], pnew[i])
end
end

function set_params!(pold::Array{Parameter, 1}, pnew::Array{Any, 1})
function set_params!(pold::AbstractVector{<:Parameter}, pnew::AbstractVector{<:Any})
for i = 1:length(pold)
set_params!(pold[i], pnew[i])
end
Expand All @@ -91,75 +91,37 @@ end

## Algebraic utilities for parameters

function dot(p1::Parameter, p2::Parameter)
return dot(p1.data, p2.data)
end

function norm(p::Parameter)
return norm(p.data)
end

function +(p1::Parameter, p2::Parameter)
return Parameter(p1.data+p2.data)
end

function +(p1::Parameter, p2::T) where {T<:Real}
return Parameter(p1.data+p2)
end

function +(p1::T, p2::Parameter) where {T<:Real}
return p2+p1
end

function -(p1::Parameter, p2::Parameter)
return Parameter(p1.data-p2.data)
end

function -(p1::Parameter, p2::T) where {T<:Real}
return Parameter(p1.data-p2)
end

function -(p1::T, p2::Parameter) where {T<:Real}
return -(p2-p1)
end

function -(p::Parameter)
return Parameter(-p.data)
end

function *(p1::Parameter, p2::T) where {T<:Real}
return Parameter(p1.data*p2)
end

function *(p1::T, p2::Parameter) where {T<:Real}
return p2*p1
end

function /(p1::Parameter, p2::T) where {T<:Real}
return Parameter(p1.data/p2)
end

function /(p1::T, p2::Parameter) where {T<:Real}
return Parameter(p1/p2.data)
end
dot(p1::Parameter, p2::Parameter) = dot(p1.data, p2.data)
norm(p::Parameter) = norm(p.data)
+(p1::Parameter, p2::Parameter) = Parameter(p1.data+p2.data)
+(p1::Parameter, p2::T) where {T<:Real} = Parameter(p1.data+p2)
+(p1::T, p2::Parameter) where {T<:Real} = p2+p1
-(p1::Parameter, p2::Parameter) = Parameter(p1.data-p2.data)
-(p1::Parameter, p2::T) where {T<:Real} = Parameter(p1.data-p2)
-(p1::T, p2::Parameter) where {T<:Real} = -(p2-p1)
-(p::Parameter) = Parameter(-p.data)
*(p1::Parameter, p2::T) where {T<:Real} = Parameter(p1.data*p2)
*(p1::T, p2::Parameter) where {T<:Real} = p2*p1
/(p1::Parameter, p2::T) where {T<:Real} = Parameter(p1.data/p2)
/(p1::T, p2::Parameter) where {T<:Real} = Parameter(p1/p2.data)

# Shape manipulation

par2vec(x::Parameter) = vec(x.data), size(x.data)


function vec2par(x::AbstractArray{T, 1}, s::NTuple{N, Int64}) where {T, N}
function vec2par(x::AbstractVector{T}, s::NTuple{N, Integer}) where {T, N}
return Parameter(reshape(x, s))
end

function par2vec(x::Array{Parameter, 1})
function par2vec(x::AbstractVector{<:Parameter})
v = cat([vec(x[i].data) for i=1:length(x)]..., dims=1)
s = cat([size(x[i].data) for i=1:length(x)]..., dims=1)
return v, s
end

function vec2par(x::AbstractArray{T, 1}, s::Array{Any, 1}) where T
xpar = Array{Parameter, 1}(undef, length(s))
function vec2par(x::AbstractVector{T}, s::AbstractVector) where T
xpar = AbstractVector{<:Parameter}(undef, length(s))
idx_i = 0
for i = 1:length(s)
xpar[i] = vec2par(x[idx_i+1:idx_i+prod(s[i])], s[i])
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ layers = ["test_layers/test_residual_block.jl",
"test_layers/test_conditional_res_block.jl",
"test_layers/test_hyperbolic_layer.jl",
"test_layers/test_actnorm.jl",
"test_layers/test_layer_affine.jl"]
"test_layers/test_layer_affine.jl",
"test_layers/test_learnable_squeezer.jl"]

networks = ["test_networks/test_unrolled_loop.jl",
"test_networks/test_generator.jl",
Expand Down
Loading

0 comments on commit e567f76

Please sign in to comment.