diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index de2d6ca..f1fac3f 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -609,7 +609,7 @@ struct MultipleSetters{S} <: AbstractSetIndexer end function (ms::MultipleSetters)(prob, val) - map((s!, v) -> s!(prob, v), ms.setters, val) + broadcast((s!, v) -> s!(prob, v), ms.setters, val) end for (t1, t2) in [ diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 5b392ab..f71a157 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -54,7 +54,8 @@ for sys in [ ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true) + ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), + ([:a, :b], p[1:2], 42, true) ] get = getp(sys, sym) set! = setp(sys, sym) @@ -77,13 +78,13 @@ for sys in [ end @test fi.counter[] == 1 - @test get(fi) == newval + @test all(get(fi) .== newval) set!(fi, oldval) @test get(fi) == oldval @test fi.counter[] == 2 fi.ps[sym] = newval - @test get(fi) == newval + @test all(get(fi) .== newval) @test fi.counter[] == 3 fi.ps[sym] = oldval @test get(fi) == oldval @@ -98,7 +99,7 @@ for sys in [ else set!(p, newval) end - @test get(p) == newval + @test all(get(p) .== newval) set!(p, oldval) @test get(p) == oldval @test fi.counter[] == 4 @@ -150,6 +151,12 @@ end Base.getindex(mpo::MyParameterObject, i) = mpo.p[i] +# check throws if setp dimensions do not match +sys = SymbolCache([:x, :y, :z], [:a, :b, :c, :d], [:t]) +fi = FakeIntegrator(sys, [1.0, 2.0, 3.0], 0.0, Ref(0)) +@test_throws DimensionMismatch setp(fi, 1:2)(fi, [-1.0, -2.0, -3.0]) +@test_throws DimensionMismatch setp(fi, 1:3)(fi, [-1.0, -2.0]) + struct FakeSolution sys::SymbolCache u::Vector{Vector{Float64}}