Skip to content

Optionally Store Initial Scan Carry #14692

Discussion options

You must be logged in to vote

The return value of the scanned function is a tuple of (carry, out), where carry and out do not have to be the same type. I think in your case, keep_init=True only makes sense if they are the same type.

So I would write a helper function just for this purpose, if it occurs very often:

def unfold(f, initial, length):
    def step(x, _):
      nx = f(x)
      return nx, nx
    _,  xs = jax.lax.scan(step, initial, None, length=length-1)
    return jnp.concatenate([initial, xs])

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@zombie-einstein
Comment options

Answer selected by zombie-einstein
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Ideas
Labels
None yet
2 participants