More Efficient Way to Apply a List of Functions in JAX? #23306
-
The code below is a minimized version of my problem. (The actual function content is different.) I want to obtain the result of sequentially applying a list (or tuple) of functions to a variable Initially, I tried using Afterward, I used @jax.jit
def get_result(x):
fn_list = (
lambda x: x + 1,
lambda x: x + 2,
lambda x: x + 3,
)
def add_number(fn_index, x):
fn = jnp.select([fn_index == 0, fn_index == 1, fn_index == 2], fn_list)
return fn(x)
result = jax.lax.fori_loop(0, len(fn_list), add_number, x)
return result However, this did not work. (You don't need to explain the cause of this error, but I would appreciate it if you could point me to any relevant documentation for reference.) So, I used @jax.jit
def get_result(x):
fn_list = (
lambda x: x + 1,
lambda x: x + 2,
lambda x: x + 3,
)
def add_number(fn_index, x):
return jax.lax.switch(fn_index, fn_list, x)
result = jax.lax.fori_loop(0, len(fn_list), add_number, x)
return result This works well. However, in my case, the function corresponding to the index in the add_number part is used multiple times. The input for In other words, I have to write Is there any other approach I could use in this situation? I want to store the function corresponding to the index during each iteration so that it can be reused in |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
If your goal is to unconditionally apply all functions in sequence, the most efficient way to do so would be with a normal for func in func_list:
x = func(x) If you're in a more complicated context where you need to refer to the function via a traced index (such as If you want to apply the dynamically looked-up function multiple times within a loop step, the most convenient way to do so would be to partially evaluate func = functools.partial(jax.lax.switch, fn_index, fn_list)
result1 = func(x)
result2 = func(y) Then you can use |
Beta Was this translation helpful? Give feedback.
If your goal is to unconditionally apply all functions in sequence, the most efficient way to do so would be with a normal
for
loop:If you're in a more complicated context where you need to refer to the function via a traced index (such as
fn_index
in yourfori_loop
), then thelax.switch
method you mentioned is the most efficient approach. Your attempt to usejnp.select
didn't work, becauseselect
can only select between arrays, not between arbitrary objects like functions (since you asked about documentation, you can see that thejnp.select
docs mention thatchoicelist
must be a "sequence of array-like values").If you want to apply the dynamically l…