Does JAX support PRNG manager to ease the usage of PRNG keys? #17999
-
It seems JAX code requires manually splitting keys each time we need a new key. Does JAX have any functionality like below, so that we can use PRNG keys more easily? # intended usage:
key = random.PRNGKey(0) # or any other initial seed
with PRNGManager(key) as manager:
output = jnp.array(random.normal(manager.new_key(), (3, 3)))
sequence = jnp.stack([random.normal(subkey) for subkey in manager.new_n_keys(5)]) Currently, we have to manually split keys and keep track of the current key: key = random.PRNGKey(0) # or any other initial seed
key, new_key = random.split(key, 2)
output = jnp.array(random.normal(new_key, (3, 3)))
key, *new_keys = random.split(key, 6)
sequence = jnp.stack([random.normal(subkey) for subkey in new_keys]) The implementation of from typing import List, Union
import jax.random as random
import jax.numpy as jnp
class PRNGManager:
def __init__(self, key: jnp.ndarray):
self.initial_key = key
self.current_key = None
def __enter__(self) -> 'PRNGManager':
self.current_key = self.initial_key
return self
def __exit__(self, exc_type, exc_value, traceback):
# You can add any cleanup or handling here if needed
pass
def new_key(self) -> jnp.ndarray:
keys = random.split(self.current_key, 2)
self.current_key = keys[0]
return keys[1]
def new_n_keys(self, n: int) -> List[jnp.ndarray]:
keys = random.split(self.current_key, n + 1)
self.current_key = keys[0]
return list(keys[1:]) |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Oct 7, 2023
Replies: 1 comment 1 reply
-
Seems to be a duplicate of #17998 |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
youkaichao
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Seems to be a duplicate of #17998