Skip to content

Commit

Permalink
Fixing Agents.jl integration
Browse files Browse the repository at this point in the history
  • Loading branch information
thevolatilebit committed Mar 17, 2024
1 parent 95058e6 commit 33aa338
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 84 deletions.
47 changes: 16 additions & 31 deletions src/integrations/AgentsIntegration/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ Initialize `ABMAgent`, incl. hierarchy of ABM's agents.
Configure the evolutionary step, logging, and step size by keyword arguments below.
# Arguments
- `agent_step!`, `model_step!`: same meaning as in `Agents.step!`
- in general, any kwarg accepted by `Agents.run!`, incl. `adata`, `mdata`
- any kwarg accepted by `Agents.run!`, incl. `adata`, `mdata`
- `when`, `when_model`: when to collect agents data, model data
true by default, and performs data collection at every step
if an `AbstractVector`, checks if `t ∈ when`; otherwise a function (model, t) -> ::Bool
Expand All @@ -34,8 +33,6 @@ mutable struct ABMAgent <: AbstractAlgebraicAgent

abm::Agents.AgentBasedModel

agent_step!::Any
model_step!::Any # evolutionary functions
kwargs::Any # kwargs propagated to `run!` (incl. `adata`, `mdata`)
when::Any
when_model::Any # when to collect agents data, model data
Expand All @@ -54,7 +51,6 @@ mutable struct ABMAgent <: AbstractAlgebraicAgent

## implement constructor
function ABMAgent(name::AbstractString, abm::Agents.AgentBasedModel;
agent_step! = Agents.dummystep, model_step! = Agents.dummystep,
when = true, when_model = when, step_size = 1.0,
tspan::NTuple{2, Float64} = (0.0, Inf), kwargs...)

Expand All @@ -63,8 +59,6 @@ mutable struct ABMAgent <: AbstractAlgebraicAgent
setup_agent!(i, name)

i.abm = abm
i.agent_step! = agent_step!
i.model_step! = model_step!
i.kwargs = kwargs
i.when = when
i.when_model = when_model
Expand Down Expand Up @@ -103,41 +97,32 @@ function _step!(a::ABMAgent)
a.when isa Bool ? a.when : a.when_model(a.abm, t)

df_agents, df_model = Agents.run!(a.abm, 1.0; a.kwargs...)

# append collected data
## df_agents
if collect_agents && ("step" names(df_agents))
if collect_agents && ("time" names(df_agents))
if a.t == a.tspan[1]
df_agents_0 = df_agents[df_agents.step .== 0.0, :]
df_agents_0[!, :step] = convert.(Float64, df_agents_0[!, :step])
df_agents_0[!, :step] .+= a.t
append!(a.df_agents, df_agents_0)
append!(a.df_agents, df_agents)
else
push!(a.df_agents, df_agents[end, :])
end
df_agents = df_agents[df_agents.step .== 1.0, :]
append!(a.df_agents, df_agents)
a.df_agents[(end - DataFrames.nrow(df_agents) + 1):end, :step] .+= a.t +
step_size - 1
end
## df_model
if collect_model && ("step" names(df_model))
if collect_model && ("time" names(df_model))
if a.t == a.tspan[1]
df_model_0 = df_model[df_model.step .== 0.0, :]
df_model_0[!, :step] = convert.(Float64, df_model_0[!, :step])
df_model_0[!, :step] .+= a.t
append!(a.df_model, df_model_0)
append!(a.df_model, df_model)
else
push!(a.df_model, df_model[end, :])
end
df_model = df_model[df_model.step .== 1.0, :]
append!(a.df_model, df_model)
a.df_model[(end - DataFrames.nrow(df_model) + 1):end, :step] .+= a.t +
step_size - 1
end

a.t += step_size
end

# if step is a float, need to retype the dataframe
function fix_float!(df, val)
if eltype(df[!, :step]) <: Int && !isa(val, Int)
df[!, :step] = convert.(Float64, df[!, :step])
if eltype(df[!, :time]) <: Int && !isa(val, Int)
df[!, :time] = convert.(Float64, df[!, :time])
end
end

Expand All @@ -149,10 +134,10 @@ end

function gettimeobservable(a::ABMAgent, t::Float64, obs)
df = a.df_model
@assert ("step" names(df)) && (string(obs) names(df))
@assert ("time" names(df)) && (string(obs) names(df))

# query dataframe
df[df.step .== Int(t), obs] |> first
df[df.time .== Int(t), obs] |> first
end

function _reinit!(a::ABMAgent)
Expand Down Expand Up @@ -189,10 +174,10 @@ end

function gettimeobservable(a::AAgent, t::Float64, obs)
df = getparent(a).df_agents
@assert ("step" names(df)) && (string(obs) names(df))
@assert ("time" names(df)) && (string(obs) names(df))

# query df
df[(df.step .== Int(t)) .& (df.id .== a.agent.id), obs] |> first
df[(df.time .== Int(t)) .& (df.id .== a.agent.id), obs] |> first
end

function print_custom(io::IO, mime::MIME"text/plain", a::ABMAgent)
Expand Down
148 changes: 148 additions & 0 deletions test/integrations/agents_sir.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Taken from https://github.com/JuliaDynamics/Agents.jl/blob/dc2ce2c8b9e805e7c0f6b2ead4d120f0b1590ef9/src/models/sir.jl

using LinearAlgebra
using StatsBase
using Random: Xoshiro

@agent struct PoorSoul(GraphAgent)
days_infected::Int # number of days since is infected
status::Symbol # 1: S, 2: I, 3:R
end

"""
```julia
sir(;
C = 8,
max_travel_rate = 0.01,
Ns = rand(50:5000, C),
β_und = rand(0.3:0.02:0.6, C),
β_det = β_und ./ 10,
infection_period = 30,
reinfection_probability = 0.05,
detection_time = 14,
death_rate = 0.02,
Is = [zeros(Int, length(Ns) - 1)..., 1],
seed = 19,
)
```
Same as in [SIR model for the spread of COVID-19](@ref).
"""
function sir(;
C = 8,
max_travel_rate = 0.01,
Ns = rand(50:5000, C),
β_und = rand(0.3:0.02:0.6, C),
β_det = β_und ./ 10,
infection_period = 30,
reinfection_probability = 0.05,
detection_time = 14,
death_rate = 0.02,
Is = [zeros(Int, length(Ns) - 1)..., 1],
seed = 19
)
rng = Xoshiro(seed)
migration_rates = zeros(C, C)
@assert length(Ns)==
length(Is)==
length(β_und)==
length(β_det)==
size(migration_rates, 1) "length of Ns, Is, and B, and number of rows/columns in migration_rates should be the same "
@assert size(migration_rates, 1)==size(migration_rates, 2) "migration_rates rates should be a square matrix"

for c in 1:C
for c2 in 1:C
migration_rates[c, c2] = (Ns[c] + Ns[c2]) / Ns[c]
end
end
maxM = maximum(migration_rates)
migration_rates = (migration_rates .* max_travel_rate) ./ maxM
migration_rates[diagind(migration_rates)] .= 1.0

## normalize migration_rates
migration_rates_sum = sum(migration_rates, dims = 2)
for c in 1:C
migration_rates[c, :] ./= migration_rates_sum[c]
end

properties = Dict(
:Ns => Ns,
:Is => Is,
:β_und => β_und,
:β_det => β_det,
:migration_rates => migration_rates,
:infection_period => infection_period,
:infection_period => infection_period,
:reinfection_probability => reinfection_probability,
:detection_time => detection_time,
:C => C,
:death_rate => death_rate
)

space = GraphSpace(Agents.Graphs.complete_graph(C))
model = ABM(PoorSoul, space; agent_step! = sir_agent_step!, properties, rng)

## Add initial individuals
for city in 1:C, n in 1:Ns[city]
ind = add_agent!(city, model, 0, :S) # Susceptible
end
## add infected individuals
for city in 1:C
inds = ids_in_position(city, model)
for n in 1:Is[city]
agent = model[inds[n]]
agent.status = :I # Infected
agent.days_infected = 1
end
end
return model, sir_agent_step!, dummystep
end

function sir_agent_step!(agent, model)
sir_migrate!(agent, model)
sir_transmit!(agent, model)
sir_update!(agent, model)
sir_recover_or_die!(agent, model)
end

function sir_migrate!(agent, model)
pid = agent.pos
m = sample(abmrng(model), 1:(model.C), Weights(model.migration_rates[pid, :]))
if m pid
move_agent!(agent, m, model)
end
end

function sir_transmit!(agent, model)
agent.status == :S && return
rate = if agent.days_infected < model.detection_time
model.β_und[agent.pos]
else
model.β_det[agent.pos]
end

n = rate * abs(randn(abmrng(model)))
n <= 0 && return

for contactID in ids_in_position(agent, model)
contact = model[contactID]
if contact.status == :S ||
(contact.status == :R && rand(abmrng(model)) model.reinfection_probability)
contact.status = :I
n -= 1
n <= 0 && return
end
end
end

sir_update!(agent, model) = agent.status == :I && (agent.days_infected += 1)

function sir_recover_or_die!(agent, model)
if agent.days_infected model.infection_period
if rand(abmrng(model)) model.death_rate
remove_agent!(agent, model)
else
agent.status = :R
agent.days_infected = 0
end
end
end
63 changes: 11 additions & 52 deletions test/integrations/agents_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ using Plots
import Random
import StatsBase: sample, Weights

include("agents_sir.jl")

# test pure Agents.jl solution vs AlgebraicAgents.jl wrap

# Agents.jl
Random.seed!(2023)

# use Agents.jl predefined model, in https://juliadynamics.github.io/Agents.jl/stable/models/#Predefined-Models-1
abm_agents, agent_step, _ = Agents.Models.sir()
abm_agents, agent_step, _ = sir()

# data to collect
infected(x) = count(i == :I for i in x)
Expand All @@ -20,70 +22,27 @@ to_collect = [(:status, f) for f in (infected, recovered, length)]
Random.seed!(2023)

# use Agents.jl predefined model, in https://juliadynamics.github.io/Agents.jl/stable/models/#Predefined-Models-1
abm_algebraic, _, _ = Agents.Models.sir()
abm_algebraic, _, _ = sir()

# modify stepping functions
function agent_step!(agent, model)
@get_model model
extract_agent(model, agent)
migrate!(agent, model)
transmit!(agent, model)
update!(agent, model)
recover_or_die!(agent, model)
end

function migrate!(agent, model)
pid = agent.pos
m = sample(model.rng, 1:(model.C), Weights(model.migration_rates[pid, :]))
if m pid
move_agent!(agent, m, model)
end
end

function transmit!(agent, model)
agent.status == :S && return
rate = if agent.days_infected < model.detection_time
model.β_und[agent.pos]
else
model.β_det[agent.pos]
end

n = rate * abs(randn(model.rng))
n <= 0 && return

for contactID in ids_in_position(agent, model)
contact = model[contactID]
if contact.status == :S ||
(contact.status == :R && rand(model.rng) model.reinfection_probability)
contact.status = :I
n -= 1
n <= 0 && return
end
end
end

update!(agent, model) = agent.status == :I && (agent.days_infected += 1)

function recover_or_die!(agent, model)
if agent.days_infected model.infection_period
if rand(model.rng) model.death_rate
@a kill_agent!(agent, model)
else
agent.status = :R
agent.days_infected = 0
end
end
sir_migrate!(agent, model)
sir_transmit!(agent, model)
sir_update!(agent, model)
sir_recover_or_die!(agent, model)
end

@testset "Agents.jl and AlgebraicAgents.jl solution are equal" begin
Random.seed!(1)
abm_algebraic_wrap = ABMAgent("sir_model", abm_algebraic; agent_step!,
abm_algebraic_wrap = ABMAgent("sir_model", abm_algebraic;
tspan = (0.0, 10.0), adata = to_collect)
simulate(abm_algebraic_wrap)
data_algebraic = abm_algebraic_wrap.df_agents

Random.seed!(1)
data_agent, _ = run!(abm_agents, agent_step, 10; adata = to_collect)
data_agent, _ = run!(abm_agents, 10; adata = to_collect)

@test abm_algebraic_wrap.t == 10.0
@test data_algebraic == data_agent
Expand All @@ -92,7 +51,7 @@ end
end

@testset "plotting for ABM wraps" begin
abm_algebraic_wrap = ABMAgent("sir_model", abm_algebraic; agent_step!,
abm_algebraic_wrap = ABMAgent("sir_model", abm_algebraic;
tspan = (0.0, 10.0), adata = to_collect)
simulate(abm_algebraic_wrap)

Expand Down
2 changes: 1 addition & 1 deletion tutorials/agents/agents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ to_collect = [(:status, f) for f in (infected, recovered, length)]

# We wrap the model as an agent:

m = ABMAgent("sir_model", abm; agent_step!, tspan=(0., 100.), adata=to_collect)
m = ABMAgent("sir_model", abm; tspan=(0., 100.), adata=to_collect)

# And we simulate the dynamics:

Expand Down

0 comments on commit 33aa338

Please sign in to comment.