Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't collect observed ArraySymbolics #81

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

hexaeder
Copy link

@hexaeder hexaeder commented May 30, 2024

If is_observed returns true for ArraySymbolic, forward directly to observed instead of collecting first.

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

If `ArraySymbolic` `is_observed` forward directly to `observed` instead
of collecting first.
Copy link
Member

@AayushSabharwal AayushSabharwal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just needs a minor change and some tests. Thanks for the PR

src/state_indexing.jl Outdated Show resolved Hide resolved
@hexaeder
Copy link
Author

I removed the tuple stuff and added some tests to ensure that ArraySymbolics are forwarede "as is" without collecting if is_(parameter/variable/observed) returns true.

@AayushSabharwal
Copy link
Member

I just ran into some edge case that is tested in SciMLBase downstream in my own work. It will be broken by the changes in this PR. Could you refactor the check in _getp so that instead of checking is_observed(sys, p) it checks any(x -> is_observed(sys, x), collect(p))? Also, it looks like the downstream environment needs to cap SymbolicUtils to <1.6. Please add that here as well so that CI runs.

Copy link

codecov bot commented Jun 3, 2024

Codecov Report

Attention: Patch coverage is 0% with 5 lines in your changes are missing coverage. Please review.

Project coverage is 27.79%. Comparing base (71408bb) to head (47c85fe).
Report is 4 commits behind head on master.

Current head 47c85fe differs from pull request most recent head d0c98ba

Please upload reports for the commit d0c98ba to get more accurate results.

Files Patch % Lines
src/state_indexing.jl 0.00% 5 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master      #81       +/-   ##
===========================================
- Coverage   86.15%   27.79%   -58.37%     
===========================================
  Files          11       11               
  Lines         513      518        +5     
===========================================
- Hits          442      144      -298     
- Misses         71      374      +303     
Flag Coverage Δ
docs 27.79% <0.00%> (-0.72%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@hexaeder
Copy link
Author

hexaeder commented Jun 3, 2024

Could you refactor the check in _getp so that instead of checking is_observed(sys, p) it checks any(x -> is_observed(sys, x), collect(p))?

I guess you're referring to _getu as I am not altering _getp at all? I think changing to any(x -> is_observed(sys, x), collect(p)) is kinda the opposite of what this PR is trying to achieve. Without this PR:

  • user requests ArraySymbolic sym,
    • if is_variable(sym) -> call variable_index(sym) without collecting
    • if is_parameter(sym) -> call parameter_index(sym) without collecting
    • else collect and try again on the ::NotSymbolic Array of ScalarSymbolic

My PR tries to apply the same behavior to observed states, meaning that is_observed(sym) returns true, call observed(sym) directly without collecting sym. If is_observed(sym) is false, it will collect and call it again on the collected array, which will then hit the code path which does more or less exactly what you suggested anyway

for (t1, t2) in [
(ScalarSymbolic, Any),
(ArraySymbolic, Any),
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
]
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
num_observed = count(x -> is_observed(sys, x), sym)
if num_observed <= 1
getters = getu.((sys,), sym)
return MultipleGetters(getters)
else
obs = observed(sys, sym isa Tuple ? collect(sym) : sym)
getter = if is_time_dependent(sys)
TimeDependentObservedFunction(obs)
else
TimeIndependentObservedFunction(obs)
end
if sym isa Tuple
getter = AsTupleWrapper(getter)
end
return getter
end
end
end

So I am not sure what happens in your edge cases, but maybe is_observed does not fall back to some is_observed(_::ArraySymbolic) = false even though observed cannot handle the ArraySymbolic?

@AayushSabharwal
Copy link
Member

The problem is that if you have an ODESystem with parameters p[1], p[2], p[3] (all scalarized, so they're stored as 3 scalars instead of a single vector) and try to do getu(sys, p[2:3]), is_parameter(sys, p[2:3]) is false, but is_observed(sys, p[2:3]) is true, which means that it will generate a function for p[2:3]. These generated functions break type inference, so the operation becomes type-unstable. It's also unnecessarily expensive when in reality this should just be a couple of getindex calls.

I'll check this case with your PR shortly, since it might just not break inference in which case it's mostly fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants