Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reaction tangent controller #138

Merged
merged 42 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b6aaeec
init
AbdAlazezAhmed Aug 13, 2024
83d4b6b
now it stops
AbdAlazezAhmed Aug 14, 2024
93d1eba
.
AbdAlazezAhmed Aug 14, 2024
ec3bd11
quick quick
AbdAlazezAhmed Aug 14, 2024
85bc9ba
restructuring
AbdAlazezAhmed Aug 14, 2024
a5ccab2
test?
AbdAlazezAhmed Aug 14, 2024
0bf9c9e
error and a couple of returns
AbdAlazezAhmed Aug 14, 2024
aaf7eec
some docs
AbdAlazezAhmed Aug 14, 2024
03827ed
ugh zero default
AbdAlazezAhmed Aug 14, 2024
6451030
.
AbdAlazezAhmed Aug 14, 2024
7a39637
Now it doesn't allocate
AbdAlazezAhmed Aug 15, 2024
5d78f91
some formatting
AbdAlazezAhmed Aug 15, 2024
f909309
example formatting
AbdAlazezAhmed Aug 15, 2024
ca0cd7c
retcode
AbdAlazezAhmed Aug 15, 2024
e5dbc21
add \pi so it doesn't align with tstops hehe
AbdAlazezAhmed Aug 15, 2024
d2af0f9
nans but commented out
AbdAlazezAhmed Aug 15, 2024
071ecd4
assume only one problem
AbdAlazezAhmed Aug 15, 2024
921ab74
docs oopsie
AbdAlazezAhmed Aug 15, 2024
c98c348
adaptive test UwU
AbdAlazezAhmed Aug 16, 2024
bad3da1
Merge branch 'main' of https://github.com/AbdAlazezAhmed/Thunderbolt.…
AbdAlazezAhmed Aug 16, 2024
b5b4195
Better?
AbdAlazezAhmed Aug 19, 2024
0ccbd02
Docs
AbdAlazezAhmed Aug 19, 2024
b505899
oopsie
AbdAlazezAhmed Aug 19, 2024
6ef8ba3
Merge branch 'main' into RTC
AbdAlazezAhmed Aug 20, 2024
f3ad8d3
Next: Error checking for NaNs
AbdAlazezAhmed Aug 20, 2024
f1a72de
Merge branch 'RTC' of https://github.com/AbdAlazezAhmed/Thunderbolt.j…
AbdAlazezAhmed Aug 20, 2024
c685bf3
Try mabe CI works now?
AbdAlazezAhmed Aug 20, 2024
06a541b
Don't merge, homeoffice push
AbdAlazezAhmed Aug 21, 2024
ffe2128
nans test
AbdAlazezAhmed Aug 26, 2024
61ecbc9
multiple pwodef
AbdAlazezAhmed Aug 26, 2024
fbeed40
Inf
AbdAlazezAhmed Aug 26, 2024
e39a773
indentation
AbdAlazezAhmed Aug 26, 2024
e19fd50
access u from caches directly
AbdAlazezAhmed Aug 26, 2024
c981cf7
Merge branch 'main' into RTC
termi-official Aug 28, 2024
09b7993
Merge branch 'main' into RTC
termi-official Aug 28, 2024
68bd74c
remove unroll filter
AbdAlazezAhmed Aug 30, 2024
b7f6773
adaptivity -> time/rtc
AbdAlazezAhmed Aug 30, 2024
934a4d3
remove unused type
AbdAlazezAhmed Aug 30, 2024
0843020
yes
AbdAlazezAhmed Aug 30, 2024
c912e1a
or not to do
AbdAlazezAhmed Aug 30, 2024
3a64dcc
use only current R
AbdAlazezAhmed Aug 30, 2024
5023e38
inf
AbdAlazezAhmed Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/api-reference/solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,9 @@ Thunderbolt.OS.LieTrotterGodunov
Thunderbolt.OS.GenericSplitFunction
Thunderbolt.OS.OperatorSplittingIntegrator
```

## Operator Splitting Adaptivity

```@docs
Thunderbolt.ReactionTangentController
```
10 changes: 10 additions & 0 deletions docs/src/assets/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,13 @@ @article{PotDubRicVinGul:2006:cmb
pages={2425-2435},
doi={10.1109/TBME.2006.880875}
}
@article{OgiBalPer:2023:seats,
author = {Ogiermann, Dennis and Perotti, Luigi E. and Balzani, Daniel},
journal = {International Journal for Numerical Methods in Biomedical Engineering},
title = {A simple and efficient adaptive time stepping technique for low-order operator splitting schemes applied to cardiac electrophysiology},
year = {2023}
volume = {39},
number = {2},
pages = {e3670},
doi = {https://doi.org/10.1002/cnm.3670},
}
26 changes: 14 additions & 12 deletions examples/conduction-velocity-benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,20 @@ steady_state_initializer!(u₀, odeform)

# io = ParaViewWriter("spiral-wave-test")


timestepper = OS.LieTrotterGodunov((
BackwardEulerSolver(
solution_vector_type=Vector{Float32},
system_matrix_type=Thunderbolt.ThreadedSparseMatrixCSR{Float32, Int32},
inner_solver=LinearSolve.KrylovJL_CG(atol=1.0f-6, rtol=1.0f-5),
),
AdaptiveForwardEulerSubstepper(
solution_vector_type=Vector{Float32},
reaction_threshold=0.1f0,
),
))
timestepper = Thunderbolt.ReactionTangentController(
OS.LieTrotterGodunov((
BackwardEulerSolver(
solution_vector_type=Vector{Float32},
system_matrix_type=Thunderbolt.ThreadedSparseMatrixCSR{Float32, Int32},
inner_solver=LinearSolve.KrylovJL_CG(atol=1.0f-6, rtol=1.0f-5),
),
AdaptiveForwardEulerSubstepper(
solution_vector_type=Vector{Float32},
reaction_threshold=0.1f0,
)
)),
0.5, 1.0, (0.01, 0.3)
)

problem = OS.OperatorSplittingProblem(odeform, u₀, tspan)

Expand Down
1 change: 1 addition & 0 deletions src/Thunderbolt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ include("solver/linear.jl")
include("solver/nonlinear.jl")
include("solver/time_integration.jl")


include("processing/ecg.jl")

include("io.jl")
Expand Down
68 changes: 59 additions & 9 deletions src/solver/operator_splitting/integrator.jl
termi-official marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function DiffEqBase.__init(
callback = nothing,
advance_to_tstop = false,
save_func = (u, t) -> copy(u), # custom kwarg
dtchangeable = true, # custom kwarg
dtchangeable = DiffEqBase.isadaptive(alg), # custom kwarg
kwargs...,
)
(; u0, p) = prob
Expand All @@ -73,14 +73,17 @@ function DiffEqBase.__init(
callback = DiffEqBase.CallbackSet(callback)

cache = init_cache(prob, alg; dt, kwargs...)

u = cache.u
uprev = cache.uprev

subintegrators = build_subintegrators_recursive(prob.f, prob.f.synchronizers, p, cache, cache.u, cache.uprev, t0, dt, 1:length(u0), cache.u, tstops, _tstops, saveat, _saveat)
subintegrators = build_subintegrators_recursive(prob.f, prob.f.synchronizers, p, cache, u, uprev, t0, dt, 1:length(u0), u, tstops, _tstops, saveat, _saveat)

integrator = OperatorSplittingIntegrator(
prob.f,
alg,
cache.u,
cache.uprev,
u,
uprev,
p,
t0,
copy(t0),
Expand Down Expand Up @@ -111,6 +114,7 @@ function DiffEqBase.reinit!(
tstops = integrator._tstops,
saveat = integrator._saveat,
reinit_callbacks = true,
reinit_retcode = true
)
(t0,tf) = tspan
integrator.u .= u0
Expand All @@ -126,6 +130,9 @@ function DiffEqBase.reinit!(
saving_callback = integrator.callback.discrete_callbacks[end]
DiffEqBase.initialize!(saving_callback, u0, t0, integrator)
end
if reinit_retcode
integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, DiffEqBase.ReturnCode.Default)
end
end

# called by DiffEqBase.solve
Expand All @@ -137,37 +144,50 @@ end
# either called directly (after init), or by DiffEqBase.solve (via __solve)
function DiffEqBase.solve!(integrator::OperatorSplittingIntegrator)
while !isempty(integrator.tstops)
DiffEqBase.check_error!(integrator) ∉ (DiffEqBase.ReturnCode.Success, DiffEqBase.ReturnCode.Default) && return
__step!(integrator)
end
DiffEqBase.finalize!(integrator.callback, integrator.u, integrator.t, integrator)
integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, DiffEqBase.ReturnCode.Success)
if DiffEqBase.NAN_CHECK(integrator.u)
integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, DiffEqBase.ReturnCode.Failure)
else
integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, DiffEqBase.ReturnCode.Success)
end
return integrator.sol
end

function DiffEqBase.step!(integrator::OperatorSplittingIntegrator)
if integrator.advance_to_tstop
tstop = first(integrator.tstops)
while !reached_tstop(integrator, tstop)
DiffEqBase.check_error!(integrator) ∉ (DiffEqBase.ReturnCode.Success, DiffEqBase.ReturnCode.Default) && return
__step!(integrator)
end
else
DiffEqBase.check_error!(integrator) ∉ (DiffEqBase.ReturnCode.Success, DiffEqBase.ReturnCode.Default) && return
__step!(integrator)
end
end

function DiffEqBase.check_error!(integrator::OperatorSplittingIntegrator)
if DiffEqBase.NAN_CHECK(integrator._dt) # replace with https://github.com/SciML/OrdinaryDiffEq.jl/blob/373a8eec8024ef1acc6c5f0c87f479aa0cf128c3/lib/OrdinaryDiffEqCore/src/iterator_interface.jl#L5-L6 after moving to sciml integrators
integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, DiffEqBase.ReturnCode.Failure)
end
return integrator.sol.retcode
end

function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_tdt = false)
# OridinaryDiffEq lets dt be negative if tdir is -1, but that's inconsistent
dt <= zero(dt) && error("dt must be positive")
stop_at_tdt && !integrator.dtchangeable && error("Cannot stop at t + dt if dtchangeable is false")
tnext = integrator.t + tdir(integrator) * dt
stop_at_tdt && DiffEqBase.add_tstop!(integrator, tnext)
while !reached_tstop(integrator, tnext, stop_at_tdt)
DiffEqBase.check_error!(integrator) ∉ (DiffEqBase.ReturnCode.Success, DiffEqBase.ReturnCode.Default) && return
__step!(integrator)
end
end



# TimeChoiceIterator API
@inline function DiffEqBase.get_tmp_cache(integrator::OperatorSplittingIntegrator)
DiffEqBase.get_tmp_cache(integrator, integrator.alg, integrator.cache)
Expand All @@ -184,7 +204,35 @@ function (integrator::OperatorSplittingIntegrator)(tmp, t)
linear_interpolation!(tmp, t, integrator.uprev, integrator.u, integrator.tprev, integrator.t)
end

"""
stepsize_controller!(::OperatorSplittingIntegrator)
Updates the controller using the current state of the integrator if the operator splitting algorithm is adaptive.
"""
@inline function stepsize_controller!(integrator::OperatorSplittingIntegrator)
algorithm = integrator.alg
DiffEqBase.isadaptive(algorithm) || return nothing
stepsize_controller!(integrator, algorithm)
end

"""
step_accept_controller!(::OperatorSplittingIntegrator)
Updates `_dt` of the integrator if the step is accepted and the operator splitting algorithm is adaptive.
"""
@inline function step_accept_controller!(integrator::OperatorSplittingIntegrator)
algorithm = integrator.alg
DiffEqBase.isadaptive(algorithm) || return nothing
step_accept_controller!(integrator, algorithm, nothing)
end

"""
step_reject_controller!(::OperatorSplittingIntegrator)
Updates `_dt` of the integrator if the step is rejected and the the operator splitting algorithm is adaptive.
"""
@inline function step_reject_controller!(integrator::OperatorSplittingIntegrator)
algorithm = integrator.alg
DiffEqBase.isadaptive(algorithm) || return nothing
step_reject_controller!(integrator, algorithm, nothing)
end

# helper functions for dealing with time-reversed integrators in the same way
# that OrdinaryDiffEq.jl does
Expand Down Expand Up @@ -228,10 +276,13 @@ function __step!(integrator)
!isempty(tstops) && dtchangeable ? tdir(integrator) * min(_dt, abs(first(tstops) - integrator.t)) :
tdir(integrator) * _dt

# Propagate information down into the subintegrators
synchronize_subintegrators!(integrator)
termi-official marked this conversation as resolved.
Show resolved Hide resolved
tnext = integrator.t + integrator.dt

# Solve inner problems
advance_solution_to!(integrator, tnext)
stepsize_controller!(integrator)

# Update integrator
# increment t by dt, rounding to the first tstop if that is roughly
Expand All @@ -242,8 +293,7 @@ function __step!(integrator)
integrator.tprev = integrator.t
integrator.t = !isempty(tstops) && abs(first(tstops) - tnext) < max_t_error ? first(tstops) : tnext

# Propagate information down into the subintegrators
synchronize_subintegrators!(integrator)
step_accept_controller!(integrator)

# remove tstops that were just reached
while !isempty(tstops) && reached_tstop(integrator, first(tstops))
Expand Down
6 changes: 3 additions & 3 deletions src/solver/operator_splitting/solver.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Lie-Trotter-Godunov Splitting Implementation
"""
LieTrotterGodunov <: AbstractOperatorSplittingAlgorithm
Expand All @@ -9,6 +8,8 @@ struct LieTrotterGodunov{AlgTupleType} <: AbstractOperatorSplittingAlgorithm
# transfer_algs::TransferTupleType # Tuple of transfer algorithms from the master solution into the individual ones
end

@inline DiffEqBase.isadaptive(::AbstractOperatorSplittingAlgorithm) = false

struct LieTrotterGodunovCache{uType, tmpType, iiType} <: AbstractOperatorSplittingCache
u::uType
uprev::uType # True previous solution
Expand Down Expand Up @@ -54,5 +55,4 @@ end
advance_solution_to!(subinteg, inner_caches[i], tnext)
finalize_local_step!(subinteg)
end
end

end
107 changes: 107 additions & 0 deletions src/solver/time/rtc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
ReactionTangentController{LTG <: OS.LieTrotterGodunov, T <: Real} <: OS.AbstractOperatorSplittingAlgorithm
A timestep length controller for [`LieTrotterGodunov`](@ref) [Lie:1880:tti,Tro:1959:psg,God:1959:dmn](@cite)
operator splitting using the reaction tangent as proposed in [OgiBalPer:2023:seats](@cite)
The next timestep length is calculated as
```math
\\sigma\\left(R_{\\max }\\right):=\\left(1.0-\\frac{1}{1+\\exp \\left(\\left(\\sigma_{\\mathrm{c}}-R_{\\max }\\right) \\cdot \\sigma_{\\mathrm{s}}\\right)}\\right) \\cdot\\left(\\Delta t_{\\max }-\\Delta t_{\\min }\\right)+\\Delta t_{\\min }
```
# Fields
- `ltg`::`LTG`: `LieTrotterGodunov` algorithm
- `σ_s::T`: steepness
- `σ_c::T`: offset in R axis
- `Δt_bounds::NTuple{2,T}`: lower and upper timestep length bounds
"""
struct ReactionTangentController{LTG <: OS.LieTrotterGodunov, T <: Real} <: OS.AbstractOperatorSplittingAlgorithm
ltg::LTG
σ_s::T
σ_c::T
Δt_bounds::NTuple{2,T}
end

mutable struct ReactionTangentControllerCache{T <: Real, LTGCache <: OS.LieTrotterGodunovCache, uType} <: OS.AbstractOperatorSplittingCache
const ltg_cache::LTGCache
u::uType
uprev::uType # True previous solution
R::T
function ReactionTangentControllerCache(ltg_cache::LTGCache, R::T) where {T, LTGCache <: OS.LieTrotterGodunovCache}
uType = typeof(ltg_cache.u)
return new{T, LTGCache, uType}(ltg_cache, ltg_cache.u, ltg_cache.uprev, R)
end
end

@inline DiffEqBase.get_tmp_cache(integrator::OS.OperatorSplittingIntegrator, alg::OS.AbstractOperatorSplittingAlgorithm, cache::ReactionTangentControllerCache) = DiffEqBase.get_tmp_cache(integrator, alg, cache.ltg_cache)

@inline function OS.advance_solution_to!(subintegrators::Tuple, cache::ReactionTangentControllerCache, tnext)
OS.advance_solution_to!(subintegrators, cache.ltg_cache, tnext)
end

@inline DiffEqBase.isadaptive(::ReactionTangentController) = true

"""
get_reaction_tangent(integrator::OS.OperatorSplittingIntegrator)
Returns the maximal reaction magnitude using the [`PointwiseODEFunction`](@ref) of an operator splitting integrator that uses [`LieTrotterGodunov`](@ref) [Lie:1880:tti,Tro:1959:psg,God:1959:dmn](@cite).
It is assumed that the problem containing the reaction tangent is a [`PointwiseODEFunction`](@ref).
"""
@inline function get_reaction_tangent(integrator::OS.OperatorSplittingIntegrator)
R, _ = _get_reaction_tangent(integrator.subintegrators)
return R
end

@inline @unroll function _get_reaction_tangent(subintegrators, n_reaction_tangents::Int = 0)
R = 0.0
@unroll for subintegrator in subintegrators
if subintegrator isa Tuple
R, n_reaction_tangents = _get_reaction_tangent(subintegrator, n_reaction_tangents)
elseif subintegrator.f isa PointwiseODEFunction
n_reaction_tangents += 1
φₘidx = transmembranepotential_index(subintegrator.f.ode)
R = max(R, maximum(@view subintegrator.cache.dumat[:, φₘidx]))
end
end
@assert n_reaction_tangents == 1 "No or multiple integrators using PointwiseODEFunction found"
return (R, n_reaction_tangents)
end

@inline function OS.stepsize_controller!(integrator::OS.OperatorSplittingIntegrator, alg::ReactionTangentController)
integrator.cache.R = get_reaction_tangent(integrator)
return nothing
end

@inline function OS.step_accept_controller!(integrator::OS.OperatorSplittingIntegrator, alg::ReactionTangentController, q)
@unpack R = integrator.cache
@unpack σ_s, σ_c, Δt_bounds = alg

if isinf(σ_s)
integrator._dt = R > σ_c ? Δt_bounds[1] : Δt_bounds[2]
else
integrator._dt = (1 - 1/(1+exp((σ_c - R)*σ_s)))*(Δt_bounds[2] - Δt_bounds[1]) + Δt_bounds[1]
end
return nothing
end

@inline function OS.step_reject_controller!(integrator::OS.OperatorSplittingIntegrator, alg::ReactionTangentController, q)
return nothing # Do nothing
end

# Dispatch for outer construction
function OS.init_cache(prob::OS.OperatorSplittingProblem, alg::ReactionTangentController; dt, kwargs...)
@unpack f = prob
@assert f isa GenericSplitFunction

u = copy(prob.u0)
uprev = copy(prob.u0)

# Build inner integrator
return OS.construct_inner_cache(f, alg, u, uprev)
end

# Dispatch for recursive construction
function OS.construct_inner_cache(f::OS.AbstractOperatorSplitFunction, alg::ReactionTangentController, u::AbstractArray{T}, uprev::AbstractArray) where T <: Number
ltg_cache = OS.construct_inner_cache(f, alg.ltg, u, uprev)
return ReactionTangentControllerCache(ltg_cache, zero(T))
end

function OS.build_subintegrators_recursive(f::GenericSplitFunction, synchronizers::Tuple, p::Tuple, cache::ReactionTangentControllerCache, u::AbstractArray, uprev::AbstractArray, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat)
OS.build_subintegrators_recursive(f, synchronizers, p, cache.ltg_cache, u, uprev, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat)
end
2 changes: 2 additions & 0 deletions src/solver/time_integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ include("time/time_integrator.jl")
include("time/euler.jl")
include("time/load_stepping.jl")
include("time/partitioned_solver.jl")
include("time/rtc.jl")

2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ end
# Transfer the element data into a vector
function ea_collapse!(b::Vector, bes::EAVector)
ndofs = size(b, 1)
@batch minbatch= max(1, ndofs÷Threads.nthreads()) for dof ∈ 1:ndofs
@batch minbatch=max(1, ndofs÷Threads.nthreads()) for dof ∈ 1:ndofs
_ea_collapse_kernel!(b, dof, bes)
end
end
Expand Down
Loading