Optionally Store Initial Scan Carry #14692
-
which produces a series
Would it be worth adding a flag to the scan to retain the initial state as the first element of its stacked output? Something like
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The return value of the scanned function is a tuple of So I would write a helper function just for this purpose, if it occurs very often: def unfold(f, initial, length):
def step(x, _):
nx = f(x)
return nx, nx
_, xs = jax.lax.scan(step, initial, None, length=length-1)
return jnp.concatenate([initial, xs]) |
Beta Was this translation helpful? Give feedback.
The return value of the scanned function is a tuple of
(carry, out)
, wherecarry
andout
do not have to be the same type. I think in your case,keep_init=True
only makes sense if they are the same type.So I would write a helper function just for this purpose, if it occurs very often: