Skip to content

Commit

Permalink
Added more opts in learnable squeezers
Browse files Browse the repository at this point in the history
  • Loading branch information
= committed Sep 29, 2023
1 parent 72f0fd4 commit 89a1302
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 24 deletions.
46 changes: 32 additions & 14 deletions src/layers/learnable_squeezer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ mutable struct LearnableSqueezer <: InvertibleNetwork
logdet::Bool
reversed::Bool

# Intermediate computations related to the stencil exponential
_log_mat::Union{AbstractArray,Nothing}
_niter_exp_derivative::Integer
# Internal parameters related to the stencil exponential or derivative thereof
log_mat::Union{AbstractArray,Nothing}
niter_exp_derivative::Union{Nothing,Real}
tol_exp_derivative::Union{Nothing,Real}

end

Expand All @@ -28,15 +29,15 @@ end

# Constructor

function LearnableSqueezer(stencil_size::Integer...; logdet::Bool=false, zero_init::Bool=false, niter_exp_derivative::Integer=40, reversed::Bool=false)
function LearnableSqueezer(stencil_size::Integer...; logdet::Bool=false, zero_init::Bool=false, niter_exp_derivative::Union{Nothing,Integer}=nothing, tol_exp_derivative::Union{Nothing,Real}=nothing, reversed::Bool=false)

σ = prod(stencil_size)
zero_init ? (stencil_pars = vec2par(zeros(Float32, div*-1), 2)), (div*-1), 2), ))) :
(stencil_pars = vec2par(glorot_uniform(div*-1), 2)), (div*-1), 2), )))
pars2mat_idx = _skew_symmetric_indices(σ)
return LearnableSqueezer(stencil_pars, pars2mat_idx, stencil_size, nothing, nothing,
true, logdet, reversed,
nothing, niter_exp_derivative)
nothing, niter_exp_derivative, tol_exp_derivative)

end

Expand Down Expand Up @@ -89,12 +90,12 @@ function backward(ΔY::AbstractArray{T,N}, Y::AbstractArray{T,N}, C::LearnableSq

# Parameter gradient
Δstencil = _mat2stencil_adjoint(∇conv_filter(X, ΔY, C.cdims), C.stencil_size, size(X, N-1))
ΔA = _Frechet_derivative_exponential(C._log_mat', Δstencil; niter=C._niter_exp_derivative)
ΔA = _Frechet_derivative_exponential(C.log_mat', Δstencil; niter=C.niter_exp_derivative, tol=tol=isnothing(C.tol_exp_derivative) ? nothing : T(C.tol_exp_derivative))
Δstencil_pars = ΔA[C.pars2mat_idx[1]]-ΔA[C.pars2mat_idx[2]]
set_grad && (C.stencil_pars.grad = Δstencil_pars)

# Trigger recomputation
trigger_recompute && (C.trigger_recompute = true)
trigger_recompute && trigger_recompute!(C)

return set_grad ? (ΔX, X) : (ΔX, Δstencil_pars, X)

Expand All @@ -109,20 +110,22 @@ function backward_inv(ΔX::AbstractArray{T,N}, X::AbstractArray{T,N}, C::Learnab

# Parameter gradient
Δstencil = _mat2stencil_adjoint(∇conv_filter(X, ΔY, C.cdims), C.stencil_size, size(X, N-1))
ΔA = _Frechet_derivative_exponential(C._log_mat', Δstencil; niter=C._niter_exp_derivative)
ΔA = _Frechet_derivative_exponential(C.log_mat', Δstencil; niter=C.niter_exp_derivative, tol=isnothing(C.tol_exp_derivative) ? nothing : T(C.tol_exp_derivative))
Δstencil_pars = ΔA[C.pars2mat_idx[1]]-ΔA[C.pars2mat_idx[2]]
set_grad && (C.stencil_pars.grad = -Δstencil_pars)

# Trigger recomputation
trigger_recompute && (C.trigger_recompute = true)
trigger_recompute && trigger_recompute!(C)

return set_grad ? (ΔY, Y) : (ΔY, -Δstencil_pars, Y)

end

tag_as_reversed!(C::LearnableSqueezer, tag::Bool) = (C.reversed = tag; return C)

set_params!(C::LearnableSqueezer, θ::AbstractVector{<:Parameter}) = (C.stencil_pars = θ[1]; C.trigger_recompute = true)
set_params!(C::LearnableSqueezer, θ::AbstractVector{<:Parameter}) = (C.stencil_pars = θ[1]; trigger_recompute!(C))

trigger_recompute!(C::LearnableSqueezer) = (C.trigger_recompute = true)


# Internal utilities for LearnableSqueezer
Expand All @@ -131,7 +134,7 @@ function _compute_exponential_stencil!(C::LearnableSqueezer, nc::Integer; set_lo
n = prod(C.stencil_size)
log_mat = _pars2skewsymm(C.stencil_pars.data, C.pars2mat_idx, n)
C.stencil = _mat2stencil(_exponential(log_mat), C.stencil_size, nc)
set_log && (C._log_mat = log_mat)
set_log && (C.log_mat = log_mat)
end

function _mat2stencil(A::AbstractMatrix{T}, k::NTuple{N,Integer}, nc::Integer) where {T,N}
Expand Down Expand Up @@ -178,14 +181,29 @@ function _skew_symmetric_indices(σ::Integer)
return idx_u, idx_l
end

function _Frechet_derivative_exponential(A::AbstractMatrix{T}, ΔA::AbstractMatrix{T}; niter::Int=40) where T
function _Frechet_derivative_exponential(A::AbstractMatrix{T}, ΔA::AbstractMatrix{T}; niter::Union{Nothing,Integer}=nothing, tol::Union{Nothing,T}=nothing) where T

# Set default options
isnothing(niter) && (niter = 100)
isnothing(tol) && (tol = eps(T))

# Allocating arrays
dA = copy(ΔA)
Mk = copy(ΔA)
Apowk = copy(A)

@inbounds for k = 2:niter
Mk .= Mk*A+Apowk*ΔA; Mk ./= k

# Truncated series
Mk .= (Mk*A+Apowk*ΔA)/k
Apowk .= (Apowk*A)/k
dA .+= Mk
(k < niter) && (Apowk .= Apowk*A; Apowk ./= k)

# Convergence check
~isnothing(tol) && (norm(Mk)/norm(dA) < tol) && break

end

return dA

end
18 changes: 8 additions & 10 deletions test/test_layers/test_learnable_squeezer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,12 @@ for N = 1:3


# Gradient test (parameters)
using CUDA
T = Float64
C = LearnableSqueezer(k[1:N]...) |> device; C.stencil_pars.data = cu(C.stencil_pars.data)
X = CUDA.randn(T, n[1:N]..., nc, batchsize)
ΔY_ = CUDA.randn(T, div.(n, k)[1:N]..., prod(k[1:N])*nc, batchsize)
C = LearnableSqueezer(k[1:N]...) |> device; C.stencil_pars.data = InvertibleNetworks.CUDA.cu(C.stencil_pars.data)
X = InvertibleNetworks.CUDA.randn(T, n[1:N]..., nc, batchsize)
ΔY_ = InvertibleNetworks.CUDA.randn(T, div.(n, k)[1:N]..., prod(k[1:N])*nc, batchsize)
θ = copy(C.stencil_pars.data)
Δθ = CUDA.randn(T, size(θ)); Δθ *= norm(θ)/norm(Δθ)
Δθ = InvertibleNetworks.CUDA.randn(T, size(θ)); Δθ *= norm(θ)/norm(Δθ)

t = T(1e-5)
set_params!(C, [Parameter+t*Δθ/2)])
Expand All @@ -77,13 +76,12 @@ for N = 1:3


# Gradient test (parameters, inv)
using CUDA
T = Float64
Crev = reverse(LearnableSqueezer(k[1:N]...)) |> device; Crev.stencil_pars.data = cu(Crev.stencil_pars.data)
Y = CUDA.randn(T, div.(n, k)[1:N]..., prod(k[1:N])*nc, batchsize)
ΔX_ = CUDA.randn(T, n[1:N]..., nc, batchsize)
Crev = reverse(LearnableSqueezer(k[1:N]...)) |> device; Crev.stencil_pars.data = InvertibleNetworks.CUDA.cu(Crev.stencil_pars.data)
Y = InvertibleNetworks.CUDA.randn(T, div.(n, k)[1:N]..., prod(k[1:N])*nc, batchsize)
ΔX_ = InvertibleNetworks.CUDA.randn(T, n[1:N]..., nc, batchsize)
θ = deepcopy(Crev.stencil_pars.data)
Δθ = CUDA.randn(T, size(θ)); Δθ *= norm(θ)/norm(Δθ)
Δθ = InvertibleNetworks.CUDA.randn(T, size(θ)); Δθ *= norm(θ)/norm(Δθ)

t = T(1e-5)
set_params!(Crev, [Parameter+t*Δθ/2)])
Expand Down

0 comments on commit 89a1302

Please sign in to comment.