Replies: 1 comment
-
You could pass the constants by closure instead; modifying your example, it might look something like this: def main_func(a_const, b_const, c_const, d_const, e_const, variable, array:jnp.ndarray)
def compute_variables(variable, x):
# add a dozen lines of logic here
return (x+a_const-b_const*variable, variable)
scan(compute_variable, init=variable), xs=array) The general approach is: you can reference |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Sup,
I recently had to write a module where we had 25
scan
s in a single file... Well, it was not so simple because we had 5 variables in carry - and most of them were constants. Minimal reproducible example here:I believe in ML there are also hyperparameters which need to be passed and needelessly optimized for by
jit
in every line. If there is a constants, XLA may not bother optimizing functions for them in the same way asstatic_argnums
makes functions faster..Would
scan
maybe get an optionalconsts
parameter in which the constant parameters would be passed?It's a QOL feature which also may speed up
scan
and other loops computation a little bit.Is it big enough to be implemeneted?
(same goes for
fori_loop
andwhile_loop
s)Beta Was this translation helpful? Give feedback.
All reactions