From e567f764139ec1045984387785c7389cf4c970ce Mon Sep 17 00:00:00 2001 From: = <=> Date: Fri, 22 Sep 2023 18:54:58 +0200 Subject: [PATCH] Implemented learnable squeezers --- Project.toml | 1 + src/InvertibleNetworks.jl | 2 + src/layers/invertible_layer_hint.jl | 4 +- src/layers/learnable_squeezer.jl | 146 ++++++++++++++++++++ src/utils/neuralnet.jl | 2 +- src/utils/parameter.jl | 84 +++-------- test/runtests.jl | 3 +- test/test_layers/test_learnable_squeezer.jl | 64 +++++++++ 8 files changed, 241 insertions(+), 65 deletions(-) create mode 100644 src/layers/learnable_squeezer.jl create mode 100644 test/test_layers/test_learnable_squeezer.jl diff --git a/Project.toml b/Project.toml index 31afe013..db8b155c 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "2.2.5" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JOLI = "bb331ad6-a1cf-11e9-23da-9bcb53c69f6f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/InvertibleNetworks.jl b/src/InvertibleNetworks.jl index 6a65ab93..472a49a6 100644 --- a/src/InvertibleNetworks.jl +++ b/src/InvertibleNetworks.jl @@ -9,6 +9,7 @@ using LinearAlgebra, Random using Statistics, Wavelets using JOLI using NNlib, Flux, ChainRulesCore +using ExponentialUtilities # Overloads and reexports import Base.size, Base.length, Base.getindex, Base.reverse, Base.reverse!, Base.getproperty @@ -61,6 +62,7 @@ include("layers/invertible_layer_irim.jl") include("layers/invertible_layer_glow.jl") include("layers/invertible_layer_hyperbolic.jl") include("layers/invertible_layer_hint.jl") +include("layers/learnable_squeezer.jl") # Invertible network architectures include("networks/invertible_network_hint_multiscale.jl") diff --git a/src/layers/invertible_layer_hint.jl b/src/layers/invertible_layer_hint.jl index 40f3eb59..7c977ad6 100644 --- a/src/layers/invertible_layer_hint.jl +++ b/src/layers/invertible_layer_hint.jl @@ -72,8 +72,8 @@ function get_depth(n_in) end # Constructor for given coupling layer and 1 x 1 convolution -CouplingLayerHINT(CL::AbstractArray{CouplingLayerBasic, 1}, C::Union{Conv1x1, Nothing}; - logdet=false, permute="none", activation::ActivationFunction=SigmoidLayer()) = CouplingLayerHINT(CL, C, logdet, permute, false) +CouplingLayerHINT(CL::AbstractArray{CouplingLayerBasic, 1}, C::Union{Conv1x1, Nothing}; logdet=false, permute="none") = + CouplingLayerHINT(CL, C, logdet, permute, false) # 2D Constructor from input dimensions function CouplingLayerHINT(n_in::Int64, n_hidden::Int64; logdet=false, permute="none", diff --git a/src/layers/learnable_squeezer.jl b/src/layers/learnable_squeezer.jl new file mode 100644 index 00000000..ca6572e6 --- /dev/null +++ b/src/layers/learnable_squeezer.jl @@ -0,0 +1,146 @@ +# Learnable up-/down-sampling from Etmann et al., 2020, https://arxiv.org/abs/2005.05220 +# The Frechet derivative of the matrix exponential is from Al-Mohy and Higham, 2009, https://epubs.siam.org/doi/10.1137/080716426 + +export LearnableSqueezer + +mutable struct LearnableSqueezer <: InvertibleNetwork + stencil_pars::Parameter + pars2mat_idx + stencil_size + stencil::Union{AbstractArray,Nothing} + cdims::Union{DenseConvDims,Nothing} + logdet::Bool + reset::Bool + log_mat::Union{AbstractArray,Nothing} +end + +@Flux.functor LearnableSqueezer + + +# Constructor + +function LearnableSqueezer(stencil_size::Integer...; logdet::Bool=false) + + σ = prod(stencil_size) + stencil_pars = vec2par(randn(Float32, div(σ*(σ-1), 2)), (div(σ*(σ-1), 2), )) + pars2mat_idx = _skew_symmetric_indices(σ) + return LearnableSqueezer(stencil_pars, pars2mat_idx, stencil_size, nothing, nothing, logdet, true, nothing) + +end + + +# Forward/inverse/backward + +function forward(X::AbstractArray{T,N}, C::LearnableSqueezer; logdet::Union{Nothing,Bool}=nothing) where {T,N} + isnothing(logdet) && (logdet = C.logdet) + + # Compute exponential stencil + if C.reset + _compute_exponential_stencil!(C, size(X, N-1); set_log=true) + C.cdims = DenseConvDims(size(X), size(C.stencil); stride=C.stencil_size) + C.reset = false + end + + # Convolution + X = conv(X, C.stencil, C.cdims) + + return logdet ? (X, convert(T, 0)) : X + +end + +function inverse(Y::AbstractArray{T,N}, C::LearnableSqueezer; logdet::Union{Nothing,Bool}=nothing) where {T,N} + isnothing(logdet) && (logdet = C.logdet) + C.reset && throw(ArgumentError("The learnable squeezer must be evaluated forward first!")) + + # Convolution (adjoint) + Y = ∇conv_data(Y, C.stencil, C.cdims) + + return logdet ? (Y, convert(T, 0)) : Y + +end + +function backward(ΔY::AbstractArray{T,N}, Y::AbstractArray{T,N}, C::LearnableSqueezer; set_grad::Bool=true, trigger_recompute::Bool=true) where {T,N} + C.reset && throw(ArgumentError("The learnable squeezer must be evaluated forward first!")) + + # Convolution (adjoint) + X = ∇conv_data(Y, C.stencil, C.cdims) + ΔX = ∇conv_data(ΔY, C.stencil, C.cdims) + + # Parameter gradient + Δstencil = _mat2stencil_adjoint(∇conv_filter(X, ΔY, C.cdims), C.stencil_size, size(X, N-1)) + ΔA = _Frechet_exponential(C.log_mat', Δstencil) + Δstencil_pars = ΔA[C.pars2mat_idx[1]]-ΔA[C.pars2mat_idx[2]] + set_grad && (C.stencil_pars.grad = Δstencil_pars) + + # Trigger recomputation + trigger_recompute && (C.reset = true) + + return set_grad ? (ΔX, X) : (ΔX, Δstencil_pars, X) + +end + + +# Internal utilities for LearnableSqueezer + +function _compute_exponential_stencil!(C::LearnableSqueezer, nc::Integer; set_log::Bool=false) + n = prod(C.stencil_size) + log_mat = _pars2skewsymm(C.stencil_pars.data, C.pars2mat_idx, n) + C.stencil = _mat2stencil(_exponential(log_mat), C.stencil_size, nc) + set_log && (C.log_mat = log_mat) +end + +function _mat2stencil(A::AbstractMatrix{T}, k::NTuple{N,Integer}, nc::Integer) where {T,N} + stencil = similar(A, k..., nc, k..., nc); fill!(stencil, 0) + @inbounds for i = 1:nc + selectdim(selectdim(stencil, N+1, i), 2*N+1, i) .= reshape(A, k..., k...) + end + return reshape(stencil, k..., nc, :) +end + +function _mat2stencil_adjoint(stencil::AbstractArray{T}, k::NTuple{N,Integer}, nc::Integer) where {T,N} + stencil = reshape(stencil, k..., nc, k..., nc) + A = similar(stencil, prod(k), prod(k)); fill!(A, 0) + @inbounds for i = 1:nc + A .+= reshape(selectdim(selectdim(stencil, N+1, i), 2*N+1, i), prod(k), prod(k)) + end + return A +end + +function _pars2skewsymm(Apars::AbstractVector{T}, pars2mat_idx::NTuple{2,AbstractVector{<:Integer}}, n::Integer) where T + A = similar(Apars, n, n) + A[pars2mat_idx[1]] .= Apars + A[pars2mat_idx[2]] .= -Apars + A[diagind(A)] .= 0 + return A +end + +function _exponential(A::AbstractMatrix{T}) where T + expA = copy(A) + exponential!(expA) + return expA +end + +function _skew_symmetric_indices(nσ::Integer) + CIs = reshape(1:nσ^2, nσ, nσ) + idx_u = Vector{Int}(undef, 0) + idx_l = Vector{Int}(undef, 0) + for i=1:nσ, j=i+1:nσ # Indices related to (strictly) upper triangular part + push!(idx_u, CIs[i,j]) + end + for j=1:nσ, i=j+1:nσ # Indices related to (strictly) lower triangular part + push!(idx_l, CIs[i,j]) + end + return idx_u, idx_l +end + +function _Frechet_exponential(A::AbstractMatrix{T}, ΔA::AbstractMatrix{T}; niter::Int=40) where T + dA = copy(ΔA) + Mk = copy(ΔA) + Apowk = copy(A) + @inbounds for k = 2:niter + Mk .= Mk*A+Apowk*ΔA; Mk ./= k + dA .+= Mk + (k < niter) && (Apowk .= Apowk*A; Apowk ./= k) + end + return dA +end \ No newline at end of file diff --git a/src/utils/neuralnet.jl b/src/utils/neuralnet.jl index 48b2ec3f..bd1d165f 100644 --- a/src/utils/neuralnet.jl +++ b/src/utils/neuralnet.jl @@ -79,7 +79,7 @@ function get_params(I::InvertibleNetwork) params end -get_params(::Nothing) = Array{Parameter}(undef, 0) +get_params(::Any) = Array{Parameter}(undef, 0) get_params(A::AbstractArray{T}) where {T <: Union{InvertibleNetwork, Nothing}} = vcat([get_params(A[i]) for i in 1:length(A)]...) get_params(A::AbstractMatrix{T}) where {T <: Union{InvertibleNetwork, Nothing}} = vcat([get_params(A[i, j]) for i=1:size(A, 1) for j in 1:size(A, 2)]...) get_params(RN::ReversedNetwork) = get_params(RN.I) diff --git a/src/utils/parameter.jl b/src/utils/parameter.jl index 9aac1625..8955ce8e 100644 --- a/src/utils/parameter.jl +++ b/src/utils/parameter.jl @@ -46,11 +46,11 @@ length(x::Parameter) = length(x.data) or - clear_grad!(P::AbstractArray{Parameter, 1}) + clear_grad!(P::AbstractAbstractVector{<:Parameter}) Set gradients of each `Parameter` in the network layer to `nothing`. """ -function clear_grad!(P::AbstractArray{Parameter, 1}) +function clear_grad!(P::AbstractVector{<:Parameter}) for j=1:length(P) P[j].grad = nothing end @@ -60,8 +60,8 @@ function get_grads(p::Parameter) return Parameter(p.grad) end -function get_grads(pvec::Array{Parameter, 1}) - g = Array{Parameter, 1}(undef, length(pvec)) +function get_grads(pvec::AbstractVector{<:Parameter}) + g = Vector{Parameter}(undef, length(pvec)) for i = 1:length(pvec) g[i] = get_grads(pvec[i]) end @@ -75,13 +75,13 @@ function set_params!(pold::Parameter, pnew::Parameter) pold.grad = pnew.grad end -function set_params!(pold::Array{Parameter, 1}, pnew::Array{Parameter, 1}) +function set_params!(pold::AbstractVector{<:Parameter}, pnew::AbstractVector{<:Parameter}) for i = 1:length(pold) set_params!(pold[i], pnew[i]) end end -function set_params!(pold::Array{Parameter, 1}, pnew::Array{Any, 1}) +function set_params!(pold::AbstractVector{<:Parameter}, pnew::AbstractVector{<:Any}) for i = 1:length(pold) set_params!(pold[i], pnew[i]) end @@ -91,75 +91,37 @@ end ## Algebraic utilities for parameters -function dot(p1::Parameter, p2::Parameter) - return dot(p1.data, p2.data) -end - -function norm(p::Parameter) - return norm(p.data) -end - -function +(p1::Parameter, p2::Parameter) - return Parameter(p1.data+p2.data) -end - -function +(p1::Parameter, p2::T) where {T<:Real} - return Parameter(p1.data+p2) -end - -function +(p1::T, p2::Parameter) where {T<:Real} - return p2+p1 -end - -function -(p1::Parameter, p2::Parameter) - return Parameter(p1.data-p2.data) -end - -function -(p1::Parameter, p2::T) where {T<:Real} - return Parameter(p1.data-p2) -end - -function -(p1::T, p2::Parameter) where {T<:Real} - return -(p2-p1) -end - -function -(p::Parameter) - return Parameter(-p.data) -end - -function *(p1::Parameter, p2::T) where {T<:Real} - return Parameter(p1.data*p2) -end - -function *(p1::T, p2::Parameter) where {T<:Real} - return p2*p1 -end - -function /(p1::Parameter, p2::T) where {T<:Real} - return Parameter(p1.data/p2) -end - -function /(p1::T, p2::Parameter) where {T<:Real} - return Parameter(p1/p2.data) -end +dot(p1::Parameter, p2::Parameter) = dot(p1.data, p2.data) +norm(p::Parameter) = norm(p.data) ++(p1::Parameter, p2::Parameter) = Parameter(p1.data+p2.data) ++(p1::Parameter, p2::T) where {T<:Real} = Parameter(p1.data+p2) ++(p1::T, p2::Parameter) where {T<:Real} = p2+p1 +-(p1::Parameter, p2::Parameter) = Parameter(p1.data-p2.data) +-(p1::Parameter, p2::T) where {T<:Real} = Parameter(p1.data-p2) +-(p1::T, p2::Parameter) where {T<:Real} = -(p2-p1) +-(p::Parameter) = Parameter(-p.data) +*(p1::Parameter, p2::T) where {T<:Real} = Parameter(p1.data*p2) +*(p1::T, p2::Parameter) where {T<:Real} = p2*p1 +/(p1::Parameter, p2::T) where {T<:Real} = Parameter(p1.data/p2) +/(p1::T, p2::Parameter) where {T<:Real} = Parameter(p1/p2.data) # Shape manipulation par2vec(x::Parameter) = vec(x.data), size(x.data) -function vec2par(x::AbstractArray{T, 1}, s::NTuple{N, Int64}) where {T, N} +function vec2par(x::AbstractVector{T}, s::NTuple{N, Integer}) where {T, N} return Parameter(reshape(x, s)) end -function par2vec(x::Array{Parameter, 1}) +function par2vec(x::AbstractVector{<:Parameter}) v = cat([vec(x[i].data) for i=1:length(x)]..., dims=1) s = cat([size(x[i].data) for i=1:length(x)]..., dims=1) return v, s end -function vec2par(x::AbstractArray{T, 1}, s::Array{Any, 1}) where T - xpar = Array{Parameter, 1}(undef, length(s)) +function vec2par(x::AbstractVector{T}, s::AbstractVector) where T + xpar = AbstractVector{<:Parameter}(undef, length(s)) idx_i = 0 for i = 1:length(s) xpar[i] = vec2par(x[idx_i+1:idx_i+prod(s[i])], s[i]) diff --git a/test/runtests.jl b/test/runtests.jl index fb23ebee..dd2411d9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,7 +34,8 @@ layers = ["test_layers/test_residual_block.jl", "test_layers/test_conditional_res_block.jl", "test_layers/test_hyperbolic_layer.jl", "test_layers/test_actnorm.jl", - "test_layers/test_layer_affine.jl"] + "test_layers/test_layer_affine.jl", + "test_layers/test_learnable_squeezer.jl"] networks = ["test_networks/test_unrolled_loop.jl", "test_networks/test_generator.jl", diff --git a/test/test_layers/test_learnable_squeezer.jl b/test/test_layers/test_learnable_squeezer.jl new file mode 100644 index 00000000..42e0c711 --- /dev/null +++ b/test/test_layers/test_learnable_squeezer.jl @@ -0,0 +1,64 @@ +using InvertibleNetworks, LinearAlgebra, Test, Flux, Random +device = InvertibleNetworks.CUDA.functional() ? gpu : cpu +Random.seed!(11) + + +# Dimensions +n = (2*17, 3*11, 4*7) +nc = 4 +batchsize = 3 +k = (2, 3, 4) + +for N = 1:3 + + # Initialize operator + C = LearnableSqueezer(k[1:N]...) |> device + + + # Test invertibility + X = randn(Float32, n[1:N]..., nc, batchsize) |> device + Y = randn(Float32, div.(n, k)[1:N]..., prod(k[1:N])*nc, batchsize) |> device + @test X ≈ C.inverse(C.forward(X)) rtol=1f-6 + @test Y ≈ C.forward(C.inverse(Y)) rtol=1f-6 + + + # Test backward/inverse coherence + ΔY = randn(Float32, div.(n, k)[1:N]..., prod(k[1:N])*nc, batchsize) |> device + Y = randn(Float32, div.(n, k)[1:N]..., prod(k[1:N])*nc, batchsize) |> device + X_ = C.inverse(Y) + _, X = C.backward(ΔY, Y) + @test X ≈ X_ rtol=1f-6 + + + # Gradient test (input) + ΔY = randn(Float32, div.(n, k)[1:N]..., prod(k[1:N])*nc, batchsize) |> device + ΔX = randn(Float32, n[1:N]..., nc, batchsize) |> device + X = randn(Float32, n[1:N]..., nc, batchsize) |> device + Y = C.forward(X) + ΔX_, _ = C.backward(ΔY, Y) + @test dot(ΔX, ΔX_) ≈ dot(C.forward(ΔX), ΔY) rtol=1f-4 + + + # Gradient test (parameters) + using CUDA + T = Float64 + C = LearnableSqueezer(k[1:N]...) |> device; C.stencil_pars.data = cu(C.stencil_pars.data) + X = CUDA.randn(T, n[1:N]..., nc, batchsize) + ΔY_ = CUDA.randn(T, div.(n, k)[1:N]..., prod(k[1:N])*nc, batchsize) + θ = copy(C.stencil_pars.data) + Δθ = CUDA.randn(T, size(θ)); Δθ *= norm(θ)/norm(Δθ) + + t = T(1e-5) + C.stencil_pars.data = θ+t*Δθ/2; C.reset = true + Yp1 = C.forward(X) + C.stencil_pars.data = θ-t*Δθ/2; C.reset = true + Ym1 = C.forward(X) + ΔY = (Yp1-Ym1)/t + C.stencil_pars.data = θ; C.reset = true + Y = C.forward(X) + C.backward(ΔY_, Y) + Δθ_ = C.stencil_pars.grad + + @test dot(ΔY, ΔY_) ≈ dot(Δθ, Δθ_) rtol=T(1e-4) + +end \ No newline at end of file