Skip to content

Commit

Permalink
max sites refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
gottacatchenall committed May 7, 2024
1 parent 9961cc0 commit db55da1
Show file tree
Hide file tree
Showing 19 changed files with 197 additions and 166 deletions.
2 changes: 1 addition & 1 deletion docs/src/vignettes/entropize.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ pixel scale:
```@example 1
U = entropize(measurements)
locations =
seed(BalancedAcceptance(; numpoints = 100, uncertainty=U))
seed(BalancedAcceptance(; numsites = 100, uncertainty=U))
```
8 changes: 4 additions & 4 deletions docs/src/vignettes/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ less) uncertainty. To start with, we will extract 200 candidate points, *i.e.*


```@example 1
candidates = seed(BalancedAcceptance(; numpoints = 200));
candidates = seed(BalancedAcceptance(; numsites = 200));
```

We can have a look at the first five points:
Expand All @@ -47,7 +47,7 @@ case, `AdaptiveSpatial`, which performs adaptive spatial sampling (maximizing
the distribution of entropy while minimizing spatial auto-correlation).

```@example 1
locations = refine(candidates, AdaptiveSpatial(; numpoints = 50, uncertainty=U))
locations = refine(candidates, AdaptiveSpatial(; numsites = 50, uncertainty=U))
locations[1:5]
```

Expand All @@ -64,8 +64,8 @@ functions have a curried version that allows chaining them together using pipes

```@example 1
locations =
seed(BalancedAcceptance(; numpoints = 200)) |>
refine(AdaptiveSpatial(; numpoints = 50, uncertainty=U))
seed(BalancedAcceptance(; numsites = 200)) |>
refine(AdaptiveSpatial(; numsites = 50, uncertainty=U))
```

This works because `seed` and `refine` have curried versions that can be used
Expand Down
4 changes: 2 additions & 2 deletions docs/src/vignettes/uniqueness.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ heatmap(uncert)
Now we'll get a set of candidate points from a BalancedAcceptance seeder that has no bias toward higher uncertainty values.

```@example 1
candpts = seed(BalancedAcceptance(numpoints=100));
candpts = seed(BalancedAcceptance(numsites=100));
```

Now we'll `refine` our `100` candidate points down to the 30 most environmentally unique.

```@example 1
finalpts = refine(candpts, Uniqueness(;numpoints=30, layers=layers))
finalpts = refine(candpts, Uniqueness(;numsites=30, layers=layers))
heatmap(uncert)
scatter!([p[1] for p in candpts], [p[2] for p in candpts], color=:white)
scatter!([p[1] for p in finalpts], [p[2] for p in finalpts], color=:dodgerblue, msc=:white)
Expand Down
49 changes: 36 additions & 13 deletions src/adaptivespatial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,32 @@
...
**numpoints**, an Integer (def. 50), specifying the number of points to use.
**numsites**, an Integer (def. 50), specifying the number of points to use.
"""
Base.@kwdef mutable struct AdaptiveSpatial{T <: Integer, F<: AbstractFloat} <: BONRefiner
numpoints::T = 30
numsites::T = 30
uncertainty::Array{F,2} = rand(50,50)
function AdaptiveSpatial(numpoints, uncertainty)
if numpoints < one(numpoints)
throw(
ArgumentError(
"You cannot have an AdaptiveSpatial with fewer than one point",
),
)
end
return new{typeof(numpoints), typeof(uncertainty[begin])}(numpoints, uncertainty)
function AdaptiveSpatial(numsites, uncertainty)
as = new{typeof(numsites), typeof(uncertainty[begin])}(numsites, uncertainty)
check_arguments(as)
return as
end
end

function check_arguments(as::AdaptiveSpatial)
check(TooFewSites, as)

max_num_sites = prod(size(as.uncertainty))
check(TooManySites, as, max_num_sites)
end

function _generate!(
coords::Vector{CartesianIndex},
pool::Vector{CartesianIndex},
sampler::AdaptiveSpatial,
)
# Distance matrix (inlined)
d = zeros(Float64, Int((sampler.numpoints * (sampler.numpoints - 1)) / 2))
d = zeros(Float64, Int((sampler.numsites * (sampler.numsites - 1)) / 2))

# Start with the point with maximum entropy
imax = last(findmax([uncertainty[i] for i in pool]))
Expand All @@ -36,7 +38,7 @@ function _generate!(
best_score = 0.0
best_s = 1

for i in 2:(sampler.numpoints)
for i in 2:(sampler.numsites)
for (ci, cs) in enumerate(pool)
coords[i] = cs
# Distance update
Expand Down Expand Up @@ -77,3 +79,24 @@ function _D(a1::T, a2::T) where {T <: CartesianIndex{2}}
return sqrt((x1 - x2)^2.0 + (y1 - y2)^2.0)
end


# ====================================================
#
# Tests
#
# =====================================================

@testitem "AdaptiveSpatial default constructor works" begin
@test typeof(AdaptiveSpatial()) <: AdaptiveSpatial
end

@testitem "AdaptiveSpatial has right subtypes" begin
@test AdaptiveSpatial <: BONRefiner
@test AdaptiveSpatial <: BONSampler
end

@testitem "AdaptiveSpatial requires positive number of sites" begin
@test_throws TooFewSites AdaptiveSpatial(numsites = 1)
@test_throws TooFewSites AdaptiveSpatial(numsites = 0)
@test_throws TooFewSites AdaptiveSpatial(numsites = -1)
end
27 changes: 12 additions & 15 deletions src/balancedacceptance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,20 @@ A `BONSeeder` that uses Balanced-Acceptance Sampling (Van-dem-Bates et al. 2017
https://doi.org/10.1111/2041-210X.13003)
"""
Base.@kwdef struct BalancedAcceptance{I <: Integer} <: BONSeeder
numpoints::I = 30
numsites::I = 30
dims::Tuple{I, I} = (50, 50)
function BalancedAcceptance(numpoints, dims)
bas = new{typeof(numpoints)}(numpoints, dims)
function BalancedAcceptance(numsites, dims)
bas = new{typeof(numsites)}(numsites, dims)
check_arguments(bas)
return bas
end
end

maxsites(bas::BalancedAcceptance) = prod(bas.dims)

function check_arguments(bas::BalancedAcceptance)
check(TooFewSites, bas)
max_num_sites = prod(bas.dims)
return max_num_sites >= bas.numpoints || throw(
TooManySites(
"Number of sites to select $(bas.numpoints) is greater than number of possible sites $(max_num_sites)",
),
)
check(TooManySites, bas, maxsites(bas))
end

function _generate!(
Expand All @@ -48,9 +45,9 @@ end
end

@testitem "BalancedAcceptance requires positive number of sites" begin
@test_throws TooFewSites BalancedAcceptance(numpoints = 1)
@test_throws TooFewSites BalancedAcceptance(numpoints = 0)
@test_throws TooFewSites BalancedAcceptance(numpoints = -1)
@test_throws TooFewSites BalancedAcceptance(numsites = 1)
@test_throws TooFewSites BalancedAcceptance(numsites = 0)
@test_throws TooFewSites BalancedAcceptance(numsites = -1)
end

@testitem "BalancedAcceptance can't be run with too many sites" begin
Expand All @@ -65,7 +62,7 @@ end
coords = seed(bas)

@test typeof(coords) <: Vector{CartesianIndex}
@test length(coords) == bas.numpoints
@test length(coords) == bas.numsites
end

@testitem "BalancedAcceptance can generate a custom number of points as positional argument" begin
Expand All @@ -78,6 +75,6 @@ end

@testitem "BalancedAcceptance can take number of points as keyword argument" begin
N = 40
bas = BalancedAcceptance(; numpoints = N)
@test bas.numpoints == N
bas = BalancedAcceptance(; numsites = N)
@test bas.numsites == N
end
33 changes: 19 additions & 14 deletions src/cubesampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,36 @@ A `BONRefiner` that uses Cube Sampling (Tillé 2011)
...
**numpoints**, an Integer (def. 50), specifying the number of points to use.
**numsites**, an Integer (def. 50), specifying the number of points to use.
**fast**, a Boolean (def. true) indicating whether to use the fast flight algorithm.
**x**, a Matrix of auxillary variables for the candidate points, with one row for each variable and one column for each candidate point.
**πₖ**, a Float Vector indicating the probabilities of inclusion for each candidate point; should sum to numpoints value.
**πₖ**, a Float Vector indicating the probabilities of inclusion for each candidate point; should sum to numsites value.
"""

Base.@kwdef struct CubeSampling{I <: Integer, M <: Matrix, V <: Vector} <: BONRefiner
numpoints::I = 50
numsites::I = 50
fast::Bool = true
x::M = rand(0:4, 3, 50)
πₖ::V = zeros(size(x, 2))
function CubeSampling(numpoints, fast, x, πₖ)
cs = new{typeof(numpoints), typeof(x), typeof(πₖ)}(numpoints, fast, x, πₖ)
function CubeSampling(numsites, fast, x, πₖ)
cs = new{typeof(numsites), typeof(x), typeof(πₖ)}(numsites, fast, x, πₖ)
_check_arguments(cs)

return cs
end
end

function check_arguments(cs::CubeSampling)
check(TooFewSites, cs)
if numpoints > length(πₖ)
numsites(cubesampling::CubeSampling) = cubesampling.numsites
maxsites(cubesampling::CubeSampling) = size(cubesampling.x, 2)

function check_arguments(cubesampling::CubeSampling)
check(TooFewSites, cubesampling)
check(TooManySites, maxsites(cubesampling))

if numsites > length(πₖ)
throw(
ArgumentError(
"You cannot select more points than the number of candidate points.",
Expand All @@ -49,10 +54,10 @@ function check_conditions(coords, pool, sampler)
πₖ = sampler.πₖ
if sum(sampler.πₖ) == 0
@info "Probabilities of inclusion were not provided, so we assume equal probability design."
πₖ = [sampler.numpoints / length(pool) for _ in eachindex(pool)]
πₖ = [sampler.numsites / length(pool) for _ in eachindex(pool)]
end
if round(Int, sum(πₖ)) != sampler.numpoints
@warn "The inclusion probabilities sum to $(round(Int, sum(πₖ))), which will be your sample size instead of numpoints."
if round(Int, sum(πₖ)) != sampler.numsites
@warn "The inclusion probabilities sum to $(round(Int, sum(πₖ))), which will be your sample size instead of numsites."
end
if length(pool) != length(πₖ)
throw(
Expand Down Expand Up @@ -442,7 +447,7 @@ end
# =====================================================

@testitem "CubeSampling throws exception with too few points" begin
@test_throws TooFewSites CubeSampling(numpoints = -1)
@test_throws TooFewSites CubeSampling(numpoints = 0)
@test_throws TooFewSites CubeSampling(numpoints = 1)
@test_throws TooFewSites CubeSampling(numsites = -1)
@test_throws TooFewSites CubeSampling(numsites = 0)
@test_throws TooFewSites CubeSampling(numsites = 1)
end
12 changes: 7 additions & 5 deletions src/exceptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@ Base.showerror(io::IO, e::E) where {E <: BONException} =
)

function _check_arguments(sampler::S) where {S <: Union{BONSeeder, BONRefiner}}
return sampler.numpoints > 1 || throw(TooFewSites(sampler.numpoints))
end

function check(TooFewSites, sampler)
return sampler.numpoints > 1 || throw(TooFewSites(sampler.numpoints))
return sampler.numsites > 1 || throw(TooFewSites(sampler.numsites))
end

@kwdef struct TooFewSites <: BONException
message = "Number of sites to select must be at least two."
end
function check(TooFewSites, sampler)
return sampler.numsites > 1 || throw(TooFewSites())
end

@kwdef struct TooManySites <: BONException
message = "Cannot select more sites than there are candidates."
end
function check(TooManySites, sampler, max_sites)
return sampler.numsites <= max_sites || throw(TooManySites())
end
12 changes: 7 additions & 5 deletions src/fractaltriad.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
Base.@kwdef struct FractalTriad{IT <: Integer, FT <: AbstractFloat} <: SpatialSampler
numpoints::IT = 50
padding::FT = 0.1
Base.@kwdef struct FractalTriad{I <: Integer, F <: AbstractFloat} <: SpatialSampler
numsites::I = 50
horizontal_padding::F = 0.1
vetical_padding::F = 0.1
dims::Tuple{I, I}
end

function _generate!(ft::FractalTriad, sdm::M) where {M <: AbstractMatrix}
response = zeros(ft.numpoints, 2)
function _generate!(ft::FractalTriad)
response = zeros(ft.numsites, 2)

return response
end
14 changes: 7 additions & 7 deletions src/refine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ function refine!(
pool::Vector{CartesianIndex},
sampler::ST,
) where {ST <: BONRefiner}
if length(coords) != sampler.numpoints
if length(coords) != sampler.numsites
throw(
DimensionMismatch(
"The length of the coordinate vector must match the `numpoints` fields of the sampler",
"The length of the coordinate vector must match the `numsites` fields of the sampler",
),
)
end
Expand All @@ -33,10 +33,10 @@ The curried version of `refine!`, which returns a function that acts on the inpu
coordinate pool passed to the curried function (`p` below).
"""
function refine!(coords::Vector{CartesianIndex}, sampler::ST) where {ST <: BONRefiner}
if length(coords) != sampler.numpoints
if length(coords) != sampler.numsites
throw(
DimensionMismatch(
"The length of the coordinate vector must match the `numpoints` fields of the sampler",
"The length of the coordinate vector must match the `numsites` fields of the sampler",
),
)
end
Expand All @@ -46,14 +46,14 @@ end
"""
refine(pool::Vector{CartesianIndex}, sampler::ST)
Refines a set of candidate sampling locations and returns a vector `coords` of length numpoints
Refines a set of candidate sampling locations and returns a vector `coords` of length numsites
from a vector of coordinates `pool` using `sampler`, where `sampler` is a [`BONRefiner`](@ref).
"""
function refine(
pool::Vector{CartesianIndex},
sampler::ST,
) where {ST <: BONRefiner}
coords = Vector{CartesianIndex}(undef, sampler.numpoints)
coords = Vector{CartesianIndex}(undef, sampler.numsites)
return refine!(coords, copy(pool), sampler)
end

Expand All @@ -63,7 +63,7 @@ end
Returns a curried function of `refine`
"""
function refine(sampler::ST) where {ST <: BONRefiner}
coords = Vector{CartesianIndex}(undef, sampler.numpoints)
coords = Vector{CartesianIndex}(undef, sampler.numsites)
_inner(p) = refine!(coords, first(p), sampler)
return _inner
end
8 changes: 4 additions & 4 deletions src/seed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ function seed!(
coords::Vector{CartesianIndex},
sampler::ST,
) where {ST <: BONSeeder}
length(coords) != sampler.numpoints &&
length(coords) != sampler.numsites &&
throw(
DimensionMismatch(
"The length of the coordinate vector must match the `numpoints` fiel s of the sampler",
"The length of the coordinate vector must match the `numsites` fiel s of the sampler",
),
)
return _generate!(coords, sampler)
Expand All @@ -22,10 +22,10 @@ end
"""
seed(sampler::ST)
Produces a set of candidate sampling locations in a vector `coords` of length numpoints
Produces a set of candidate sampling locations in a vector `coords` of length numsites
from a raster using `sampler`, where `sampler` is a [`BONSeeder`](@ref).
"""
function seed(sampler::ST) where {ST <: BONSeeder}
coords = Vector{CartesianIndex}(undef, sampler.numpoints)
coords = Vector{CartesianIndex}(undef, sampler.numsites)
return seed!(coords, sampler)
end
Loading

0 comments on commit db55da1

Please sign in to comment.