-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
=
committed
Sep 22, 2023
1 parent
2d6331f
commit e567f76
Showing
8 changed files
with
241 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Learnable up-/down-sampling from Etmann et al., 2020, https://arxiv.org/abs/2005.05220 | ||
# The Frechet derivative of the matrix exponential is from Al-Mohy and Higham, 2009, https://epubs.siam.org/doi/10.1137/080716426 | ||
|
||
export LearnableSqueezer | ||
|
||
mutable struct LearnableSqueezer <: InvertibleNetwork | ||
stencil_pars::Parameter | ||
pars2mat_idx | ||
stencil_size | ||
stencil::Union{AbstractArray,Nothing} | ||
cdims::Union{DenseConvDims,Nothing} | ||
logdet::Bool | ||
reset::Bool | ||
log_mat::Union{AbstractArray,Nothing} | ||
end | ||
|
||
@Flux.functor LearnableSqueezer | ||
|
||
|
||
# Constructor | ||
|
||
function LearnableSqueezer(stencil_size::Integer...; logdet::Bool=false) | ||
|
||
σ = prod(stencil_size) | ||
stencil_pars = vec2par(randn(Float32, div(σ*(σ-1), 2)), (div(σ*(σ-1), 2), )) | ||
pars2mat_idx = _skew_symmetric_indices(σ) | ||
return LearnableSqueezer(stencil_pars, pars2mat_idx, stencil_size, nothing, nothing, logdet, true, nothing) | ||
|
||
end | ||
|
||
|
||
# Forward/inverse/backward | ||
|
||
function forward(X::AbstractArray{T,N}, C::LearnableSqueezer; logdet::Union{Nothing,Bool}=nothing) where {T,N} | ||
isnothing(logdet) && (logdet = C.logdet) | ||
|
||
# Compute exponential stencil | ||
if C.reset | ||
_compute_exponential_stencil!(C, size(X, N-1); set_log=true) | ||
C.cdims = DenseConvDims(size(X), size(C.stencil); stride=C.stencil_size) | ||
C.reset = false | ||
end | ||
|
||
# Convolution | ||
X = conv(X, C.stencil, C.cdims) | ||
|
||
return logdet ? (X, convert(T, 0)) : X | ||
|
||
end | ||
|
||
function inverse(Y::AbstractArray{T,N}, C::LearnableSqueezer; logdet::Union{Nothing,Bool}=nothing) where {T,N} | ||
isnothing(logdet) && (logdet = C.logdet) | ||
C.reset && throw(ArgumentError("The learnable squeezer must be evaluated forward first!")) | ||
|
||
# Convolution (adjoint) | ||
Y = ∇conv_data(Y, C.stencil, C.cdims) | ||
|
||
return logdet ? (Y, convert(T, 0)) : Y | ||
|
||
end | ||
|
||
function backward(ΔY::AbstractArray{T,N}, Y::AbstractArray{T,N}, C::LearnableSqueezer; set_grad::Bool=true, trigger_recompute::Bool=true) where {T,N} | ||
C.reset && throw(ArgumentError("The learnable squeezer must be evaluated forward first!")) | ||
|
||
# Convolution (adjoint) | ||
X = ∇conv_data(Y, C.stencil, C.cdims) | ||
ΔX = ∇conv_data(ΔY, C.stencil, C.cdims) | ||
|
||
# Parameter gradient | ||
Δstencil = _mat2stencil_adjoint(∇conv_filter(X, ΔY, C.cdims), C.stencil_size, size(X, N-1)) | ||
ΔA = _Frechet_exponential(C.log_mat', Δstencil) | ||
Δstencil_pars = ΔA[C.pars2mat_idx[1]]-ΔA[C.pars2mat_idx[2]] | ||
set_grad && (C.stencil_pars.grad = Δstencil_pars) | ||
|
||
# Trigger recomputation | ||
trigger_recompute && (C.reset = true) | ||
|
||
return set_grad ? (ΔX, X) : (ΔX, Δstencil_pars, X) | ||
|
||
end | ||
|
||
|
||
# Internal utilities for LearnableSqueezer | ||
|
||
function _compute_exponential_stencil!(C::LearnableSqueezer, nc::Integer; set_log::Bool=false) | ||
n = prod(C.stencil_size) | ||
log_mat = _pars2skewsymm(C.stencil_pars.data, C.pars2mat_idx, n) | ||
C.stencil = _mat2stencil(_exponential(log_mat), C.stencil_size, nc) | ||
set_log && (C.log_mat = log_mat) | ||
end | ||
|
||
function _mat2stencil(A::AbstractMatrix{T}, k::NTuple{N,Integer}, nc::Integer) where {T,N} | ||
stencil = similar(A, k..., nc, k..., nc); fill!(stencil, 0) | ||
@inbounds for i = 1:nc | ||
selectdim(selectdim(stencil, N+1, i), 2*N+1, i) .= reshape(A, k..., k...) | ||
end | ||
return reshape(stencil, k..., nc, :) | ||
end | ||
|
||
function _mat2stencil_adjoint(stencil::AbstractArray{T}, k::NTuple{N,Integer}, nc::Integer) where {T,N} | ||
stencil = reshape(stencil, k..., nc, k..., nc) | ||
A = similar(stencil, prod(k), prod(k)); fill!(A, 0) | ||
@inbounds for i = 1:nc | ||
A .+= reshape(selectdim(selectdim(stencil, N+1, i), 2*N+1, i), prod(k), prod(k)) | ||
end | ||
return A | ||
end | ||
|
||
function _pars2skewsymm(Apars::AbstractVector{T}, pars2mat_idx::NTuple{2,AbstractVector{<:Integer}}, n::Integer) where T | ||
A = similar(Apars, n, n) | ||
A[pars2mat_idx[1]] .= Apars | ||
A[pars2mat_idx[2]] .= -Apars | ||
A[diagind(A)] .= 0 | ||
return A | ||
end | ||
|
||
function _exponential(A::AbstractMatrix{T}) where T | ||
expA = copy(A) | ||
exponential!(expA) | ||
return expA | ||
end | ||
|
||
function _skew_symmetric_indices(nσ::Integer) | ||
CIs = reshape(1:nσ^2, nσ, nσ) | ||
idx_u = Vector{Int}(undef, 0) | ||
idx_l = Vector{Int}(undef, 0) | ||
for i=1:nσ, j=i+1:nσ # Indices related to (strictly) upper triangular part | ||
push!(idx_u, CIs[i,j]) | ||
end | ||
for j=1:nσ, i=j+1:nσ # Indices related to (strictly) lower triangular part | ||
push!(idx_l, CIs[i,j]) | ||
end | ||
return idx_u, idx_l | ||
end | ||
|
||
function _Frechet_exponential(A::AbstractMatrix{T}, ΔA::AbstractMatrix{T}; niter::Int=40) where T | ||
dA = copy(ΔA) | ||
Mk = copy(ΔA) | ||
Apowk = copy(A) | ||
@inbounds for k = 2:niter | ||
Mk .= Mk*A+Apowk*ΔA; Mk ./= k | ||
dA .+= Mk | ||
(k < niter) && (Apowk .= Apowk*A; Apowk ./= k) | ||
end | ||
return dA | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.