Skip to content

More Efficient Way to Apply a List of Functions in JAX? #23306

Answered by jakevdp
helpingstar asked this question in Q&A
Discussion options

You must be logged in to vote

If your goal is to unconditionally apply all functions in sequence, the most efficient way to do so would be with a normal for loop:

for func in func_list:
  x = func(x)

If you're in a more complicated context where you need to refer to the function via a traced index (such as fn_index in your fori_loop), then the lax.switch method you mentioned is the most efficient approach. Your attempt to use jnp.select didn't work, because select can only select between arrays, not between arbitrary objects like functions (since you asked about documentation, you can see that the jnp.select docs mention that choicelist must be a "sequence of array-like values").

If you want to apply the dynamically l…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@helpingstar
Comment options

Answer selected by helpingstar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants