Skip to content

Commit

Permalink
Allow Non-Vector Input for Infinite Parameter Supports (#358)
Browse files Browse the repository at this point in the history
* initial changes

* more changes

* finalize generalization of supports specification

* doctest update
  • Loading branch information
pulsipher authored Jul 30, 2024
1 parent c1f2545 commit c64333f
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 127 deletions.
4 changes: 2 additions & 2 deletions docs/src/guide/derivative.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
```@meta
DocTestFilters = [r"≥|>=", r" == | = ", r" ∈ | in ", r" for all | ∀ ", r"d|∂",
r"integral|∫", r".*scalar_parameters.jl:785"]
r"integral|∫", r".*scalar_parameters.jl:813"]
```

# [Derivative Operators](@id deriv_docs)
Expand Down Expand Up @@ -503,7 +503,7 @@ julia> derivative_constraints(d1)
julia> add_supports(t, 0.2)
┌ Warning: Support/method changes will invalidate existing derivative evaluation constraints that have been added to the InfiniteModel. Thus, these are being deleted.
└ @ InfiniteOpt ~/work/infiniteopt/InfiniteOpt.jl/src/scalar_parameters.jl:785
└ @ InfiniteOpt ~/work/infiniteopt/InfiniteOpt.jl/src/scalar_parameters.jl:813
julia> has_derivative_constraints(d1)
false
Expand Down
135 changes: 98 additions & 37 deletions src/array_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,23 +156,46 @@ end

## Process the supports format via dispatch
# Vector{<:Real}
function _process_supports(
function _process_array_supports(
_error::Function,
supps::Vector{<:Real},
domain,
sig_digits
)
supps = round.(supps, sigdigits = sig_digits)
if !supports_in_domain(reshape(supps, length(supps), 1), domain)
_error("Support violates the infinite domain.")
end
supps = round.(supps, sigdigits = sig_digits)
return DataStructures.OrderedDict{Vector{Float64}, Set{DataType}}(
supps => Set([UserDefined])
)
end

# Vector{Matrix{<:Real}}
function _process_array_supports(
_error::Function,
vect_supps::Vector{<:Matrix{<:Real}},
domain,
sig_digits
)
supps = first(vect_supps)
if any(arr != supps for arr in vect_supps)
_error("Cannot specify a matrix of supports for individual infinite parameters.")
elseif size(supps, 1) != length(vect_supps)
_error("Matrix of supports does not match the dimension of infinite parameters. " *
"Ensure the number of rows is equal to the number of parameters.")
end
rounded_supps = round.(supps, sigdigits = sig_digits)
if !supports_in_domain(rounded_supps, domain)
_error("Supports violate the infinite domain.")
end
return DataStructures.OrderedDict{Vector{Float64}, Set{DataType}}(
s => Set([UserDefined]) for s in eachcol(rounded_supps)
)
end

# Vector{Vector{<:Real}}
function _process_supports(
function _process_array_supports(
_error::Function,
vect_supps::Vector{<:Vector{<:Real}},
domain,
Expand All @@ -183,15 +206,35 @@ function _process_supports(
_error("Inconsistent support dimensions.")
end
supps = permutedims(reduce(hcat, vect_supps))
supps = round.(supps, sigdigits = sig_digits)
if !supports_in_domain(supps, domain)
_error("Supports violate the infinite domain.")
end
supps = round.(supps, sigdigits = sig_digits)
return DataStructures.OrderedDict{Vector{Float64}, Set{DataType}}(
s => Set([UserDefined]) for s in eachcol(supps)
)
end

# Vector of other collections
function _process_array_supports(
_error::Function,
vect_supps::Vector{<:Union{UnitRange{T}, StepRange{T}, StepRangeLen{T}, NTuple{N, T}, Base.Generator}},
domain,
sig_digits
) where {N, T <: Real}
return _process_array_supports(_error, collect.(vect_supps), domain, sig_digits)
end

# Fallback
function _process_array_supports(
_error::Function,
vect_supps,
domain,
sig_digits
)
error("Unrecognized support input. Please consult the docs.")
end

## Use dispatch to make the formatting of the derivative method vector
# Valid vector
function _process_derivative_methods(
Expand Down Expand Up @@ -222,7 +265,7 @@ function _build_parameters(
orig_inds::Collections.ContainerIndices;
num_supports::Int = 0,
sig_digits::Int = DefaultSigDigits,
supports::Union{Vector{<:Real}, Vector{<:Vector{<:Real}}} = Float64[],
supports = Float64[],
derivative_method::Vector = [],
extra_kwargs...
)
Expand All @@ -235,7 +278,7 @@ function _build_parameters(
domain = round_domain(domain, sig_digits)
# we have supports
if !isempty(supports)
supp_dict = _process_supports(_error, supports, domain, sig_digits)
supp_dict = _process_array_supports(_error, supports, domain, sig_digits)
# we want to generate supports
elseif !iszero(num_supports)
supps, label = generate_support_values(
Expand All @@ -257,10 +300,11 @@ function _build_parameters(
end

"""
add_parameters(model::InfiniteModel,
params::DependentParameters,
names::Vector{String}
)::Vector{GeneralVariableRef}
add_parameters(
model::InfiniteModel,
params::DependentParameters,
names::Vector{String}
)::Vector{GeneralVariableRef}
Add `params` to `model` and return an appropriate container of the dependent
infinite parameter references. This is intended as an internal method for use
Expand Down Expand Up @@ -988,8 +1032,10 @@ function significant_digits(pref::DependentParameterRef)
end

"""
num_supports(pref::DependentParameterRef;
[label::Type{<:AbstractSupportLabel} = PublicLabel])::Int
num_supports(
pref::DependentParameterRef;
[label::Type{<:AbstractSupportLabel} = PublicLabel]
)::Int
Return the number of support points associated with a single dependent infinite
parameter `pref`. Specify a subset of supports via `label` to only count the
Expand Down Expand Up @@ -1018,8 +1064,10 @@ function num_supports(
end

"""
num_supports(prefs::AbstractArray{<:DependentParameterRef};
[label::Type{<:AbstractSupportLabel} = PublicLabel])::Int
num_supports(
prefs::AbstractArray{<:DependentParameterRef};
[label::Type{<:AbstractSupportLabel} = PublicLabel]
)::Int
Return the number of support points associated with dependent infinite
parameters `prefs`. Errors if not all from the same underlying object.
Expand Down Expand Up @@ -1072,8 +1120,10 @@ function has_supports(prefs::AbstractArray{<:DependentParameterRef})
end

"""
supports(pref::DependentParameterRef;
[label::Type{<:AbstractSupportLabel} = PublicLabel])::Vector{Float64}
supports(
pref::DependentParameterRef;
[label::Type{<:AbstractSupportLabel} = PublicLabel]
)::Vector{Float64}
Return the support points associated with `pref`. A subset of supports can be
returned via `label` to return just the supports associated with `label`. By
Expand Down Expand Up @@ -1103,9 +1153,10 @@ function supports(
end

"""
supports(prefs::AbstractArray{<:DependentParameterRef};
[label::Type{<:AbstractSupportLabel} = PublicLabel]
)::Union{Vector{<:AbstractArray{<:Real}}, Array{Float64, 2}}
supports(
prefs::AbstractArray{<:DependentParameterRef};
[label::Type{<:AbstractSupportLabel} = PublicLabel]
)::Union{Vector{<:AbstractArray{<:Real}}, Array{Float64, 2}}
Return the support points associated with `prefs`. Errors if not all of the
infinite dependent parameters are from the same object. This will return a
Expand Down Expand Up @@ -1204,10 +1255,12 @@ function _make_support_matrix(
end

"""
set_supports(prefs::AbstractArray{<:DependentParameterRef},
supports::Vector{<:AbstractArray{<:Real}};
[force::Bool = false,
label::Type{<:AbstractSupportLabel} = UserDefined])::Nothing
set_supports(
prefs::AbstractArray{<:DependentParameterRef},
supports::Vector{<:AbstractArray{<:Real}};
[force::Bool = false,
label::Type{<:AbstractSupportLabel} = UserDefined]
)::Nothing
Specify the support points for `prefs`. Errors if the supports violate the domain
of the infinite domain, if the dimensions don't match up properly,
Expand All @@ -1217,10 +1270,12 @@ supports and `force = false`. Note that it is strongly preferred to use
`add_supports` if possible to avoid destroying measure dependencies.
```julia
set_supports(prefs::Vector{DependentParameterRef},
supports::Array{<:Real, 2};
[force::Bool = false,
label::Type{<:AbstractSupportLabel} = UserDefined])::Nothing
set_supports(
prefs::Vector{DependentParameterRef},
supports::Array{<:Real, 2};
[force::Bool = false,
label::Type{<:AbstractSupportLabel} = UserDefined]
)::Nothing
```
Specify the supports for a vector `prefs` of dependent infinite parameters.
Here rows of `supports` correspond to `prefs` and the columns correspond to the
Expand Down Expand Up @@ -1260,14 +1315,14 @@ function set_supports(
label::Type{<:AbstractSupportLabel} = UserDefined
)
domain = infinite_domain(prefs) # this does a check on prefs
supports = round.(supports, sigdigits = significant_digits(first(prefs)))
if has_supports(prefs) && !force
error("Unable set supports for $prefs since they already have supports." *
" Consider using `add_supports` or use `force = true` to " *
"overwrite the existing supports.")
elseif !supports_in_domain(supports, domain)
error("Supports violate the domain of the infinite domain.")
end
supports = round.(supports, sigdigits = significant_digits(first(prefs)))
_update_parameter_supports(prefs, supports, label)
return
end
Expand All @@ -1278,19 +1333,23 @@ function set_supports(pref::DependentParameterRef, supports; kwargs...)
end

"""
add_supports(prefs::AbstractArray{<:DependentParameterRef},
supports::Vector{<:AbstractArray{<:Real}};
[label::Type{<:AbstractSupportLabel} = UserDefined])::Nothing
add_supports(
prefs::AbstractArray{<:DependentParameterRef},
supports::Vector{<:AbstractArray{<:Real}};
[label::Type{<:AbstractSupportLabel} = UserDefined]
)::Nothing
Add additional support points for `prefs`. Errors if the supports violate the domain
of the infinite domain, if the dimensions don't match up properly,
if `prefs` and `supports` have different indices, or not all of the `prefs` are
from the same dependent infinite parameter container.
```julia
add_supports(prefs::Vector{DependentParameterRef},
supports::Array{<:Real, 2};
[label::Type{<:AbstractSupportLabel} = UserDefined])::Nothing
add_supports(
prefs::Vector{DependentParameterRef},
supports::Array{<:Real, 2};
[label::Type{<:AbstractSupportLabel} = UserDefined]
)::Nothing
```
Specify the supports for a vector `prefs` of dependent infinite parameters.
Here rows of `supports` correspond to `prefs` and the columns correspond to the
Expand Down Expand Up @@ -1335,10 +1394,10 @@ function add_supports(
check::Bool = true
)
domain = infinite_domain(prefs) # this does a check on prefs
supports = round.(supports, sigdigits = significant_digits(first(prefs)))
if check && !supports_in_domain(supports, domain)
error("Supports violate the domain of the infinite domain.")
end
supports = round.(supports, sigdigits = significant_digits(first(prefs)))
current_supports = _parameter_supports(first(prefs))
added_new_support = false
for i in 1:size(supports, 2)
Expand Down Expand Up @@ -1370,8 +1429,10 @@ function add_supports(pref::DependentParameterRef, supports; kwargs...)
end

"""
delete_supports(prefs::AbstractArray{<:DependentParameterRef};
[label::Type{<:AbstractSupportLabel} = All])::Nothing
delete_supports(
prefs::AbstractArray{<:DependentParameterRef};
[label::Type{<:AbstractSupportLabel} = All]
)::Nothing
Delete the support points for `prefs`. Errors if any of the parameters are
used by a measure or if not all belong to the same set of dependent parameters.
Expand Down
Loading

0 comments on commit c64333f

Please sign in to comment.