diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 381d52f32..945ad5e36 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -79,7 +79,7 @@ def __eq__(self, other): nn.Module] ParameterTypeTree = Dict[ParameterKey, Dict[ParameterKey, ParameterType]] -RandomState = Any # Union[jax.random.PRNGKey, int, bytes, ...] +RandomState = Any # Union[jax.random.key, int, bytes, ...] OptimizerState = Union[Dict[str, Any], Tuple[Any, Any]] Hyperparameters = Any diff --git a/algorithmic_efficiency/workloads/fastmri/input_pipeline.py b/algorithmic_efficiency/workloads/fastmri/input_pipeline.py index 8f6ddafd1..3ae59fa8d 100644 --- a/algorithmic_efficiency/workloads/fastmri/input_pipeline.py +++ b/algorithmic_efficiency/workloads/fastmri/input_pipeline.py @@ -204,7 +204,7 @@ def process_example(example_index, example): process_rng, example_index) else: # NOTE(dsuo): we use fixed randomness for eval. - process_rng = tf.cast(jax.random.PRNGKey(_EVAL_SEED), tf.int64) + process_rng = tf.cast(jax.random.key(_EVAL_SEED), tf.int64) return _process_example(*example, process_rng) ds = ds.enumerate().map(process_example, num_parallel_calls=16) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index dcc195170..869d55882 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -23,7 +23,7 @@ def _normalize(image: spec.Tensor, mean: float, stddev: float) -> spec.Tensor: def _build_mnist_dataset( - data_rng: jax.random.PRNGKey, + data_rng: jax.random.key, num_train_examples: int, num_validation_examples: int, train_mean: float, diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index d4817226d..6440951c9 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -69,7 +69,7 @@ def loss_fn( return loss_dict def _build_input_queue(self, - data_rng: jax.random.PRNGKey, + data_rng: jax.random.key, split: str, data_dir: str, global_batch_size: int): diff --git a/algorithmic_efficiency/workloads/ogbg/workload.py b/algorithmic_efficiency/workloads/ogbg/workload.py index a32f385cb..beea35883 100644 --- a/algorithmic_efficiency/workloads/ogbg/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/workload.py @@ -95,7 +95,7 @@ def eval_period_time_sec(self) -> int: return 4 * 60 def _build_input_queue(self, - data_rng: jax.random.PRNGKey, + data_rng: jax.random.key, split: str, data_dir: str, global_batch_size: int): diff --git a/algorithmic_efficiency/workloads/utils.py b/algorithmic_efficiency/workloads/utils.py index 7719f91fb..286a03447 100644 --- a/algorithmic_efficiency/workloads/utils.py +++ b/algorithmic_efficiency/workloads/utils.py @@ -6,7 +6,7 @@ def print_jax_model_summary(model, fake_inputs): """Prints a summary of the jax module.""" tabulate_fn = nn.tabulate( model, - jax.random.PRNGKey(0), + jax.random.key(0), console_kwargs={ 'force_terminal': False, 'force_jupyter': False, 'width': 240 }, diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 046d5e469..2c6ee6c0c 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -105,7 +105,7 @@ def initialize_cache(self, config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.Transformer(config).init( - jax.random.PRNGKey(0), + jax.random.key(0), jnp.ones(inputs.shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) return initial_variables['cache'] diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 0ba49c2f6..3f3be6380 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -238,7 +238,7 @@ def model_fn( return logits_batch, None def _build_input_queue(self, - data_rng: jax.random.PRNGKey, + data_rng: jax.random.key, split: str, data_dir: str, global_batch_size: int, diff --git a/algorithmic_efficiency/workloads/wmt/workload.py b/algorithmic_efficiency/workloads/wmt/workload.py index 68ebdc94b..1a2626bd2 100644 --- a/algorithmic_efficiency/workloads/wmt/workload.py +++ b/algorithmic_efficiency/workloads/wmt/workload.py @@ -116,7 +116,7 @@ def glu(self) -> bool: return False def _build_input_queue(self, - data_rng: jax.random.PRNGKey, + data_rng: jax.random.key, split: str, data_dir: str, global_batch_size: int, diff --git a/docker/scripts/check_gpu.py b/docker/scripts/check_gpu.py index 08aa6cd81..740da41e7 100644 --- a/docker/scripts/check_gpu.py +++ b/docker/scripts/check_gpu.py @@ -2,7 +2,7 @@ print('JAX identified %d GPU devices' % jax.local_device_count()) print('Generating RNG seed for CUDA sanity check ... ') -rng = jax.random.PRNGKey(0) +rng = jax.random.key(0) data_rng, shuffle_rng = jax.random.split(rng, 2) if jax.local_device_count() == 8 and data_rng is not None: diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index adbade983..387166bcb 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -71,7 +71,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py index 0748e2d71..5ba3bdebb 100644 --- a/tests/modeldiffs/criteo1tb_embed_init/compare.py +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -70,7 +70,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 0a6e5c5ac..676c236bd 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -82,7 +82,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, # mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 288442594..e572db16a 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -89,7 +89,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index f96fa672b..9cc0c1b57 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -13,7 +13,7 @@ def torch2jax(jax_workload, key_transform=None, sd_transform=None, init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), + jax_params, model_state = jax_workload.init_model_fn(jax.random.key(0), **init_kwargs) pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) jax_params = jax_utils.unreplicate(jax_params).unfreeze() diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py index 56b74b32d..654b30dc7 100644 --- a/tests/modeldiffs/fastmri/compare.py +++ b/tests/modeldiffs/fastmri/compare.py @@ -73,7 +73,7 @@ def sort_key(k): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index 23ccf26d7..407fb51c2 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -80,7 +80,7 @@ def sort_key(k): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py index b61516c29..0bd4f7b9f 100644 --- a/tests/modeldiffs/fastmri_model_size/compare.py +++ b/tests/modeldiffs/fastmri_model_size/compare.py @@ -73,7 +73,7 @@ def sort_key(k): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py index 0f455387c..f5b9dbfba 100644 --- a/tests/modeldiffs/fastmri_tanh/compare.py +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -73,7 +73,7 @@ def sort_key(k): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py index fb730f1bf..2d77bcbbc 100644 --- a/tests/modeldiffs/imagenet_resnet/compare.py +++ b/tests/modeldiffs/imagenet_resnet/compare.py @@ -90,7 +90,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py index 6c8adbec2..0e070d08c 100644 --- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py @@ -37,7 +37,7 @@ jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py index 7668cdbd9..fcaef588a 100644 --- a/tests/modeldiffs/imagenet_resnet/silu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py @@ -37,7 +37,7 @@ jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index ba21e63da..fffccf45c 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -103,7 +103,7 @@ def key_transform(k): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/imagenet_vit_glu/compare.py b/tests/modeldiffs/imagenet_vit_glu/compare.py index 2c0aa546d..834272a19 100644 --- a/tests/modeldiffs/imagenet_vit_glu/compare.py +++ b/tests/modeldiffs/imagenet_vit_glu/compare.py @@ -39,7 +39,7 @@ jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/imagenet_vit_map/compare.py b/tests/modeldiffs/imagenet_vit_map/compare.py index e7c4c2ee8..78b7fc420 100644 --- a/tests/modeldiffs/imagenet_vit_map/compare.py +++ b/tests/modeldiffs/imagenet_vit_map/compare.py @@ -50,7 +50,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/imagenet_vit_postln/compare.py b/tests/modeldiffs/imagenet_vit_postln/compare.py index 8a9063cac..fa3dbeb75 100644 --- a/tests/modeldiffs/imagenet_vit_postln/compare.py +++ b/tests/modeldiffs/imagenet_vit_postln/compare.py @@ -39,7 +39,7 @@ jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index cfe6c7381..6c061ad26 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -78,7 +78,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py index 8480fca02..8ad181a58 100644 --- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py +++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py @@ -78,7 +78,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py index caa9b09b9..c7c925e68 100644 --- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py +++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py @@ -78,7 +78,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py index 1a94d3c77..9a575cfe8 100644 --- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py +++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py @@ -78,7 +78,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py index edcc3ba87..2a9054be2 100644 --- a/tests/modeldiffs/librispeech_deepspeech/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech/compare.py @@ -103,7 +103,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py index 6c00bdf69..0ac734cb9 100644 --- a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py @@ -39,7 +39,7 @@ jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py index c68d6adf9..53c80011f 100644 --- a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py @@ -39,7 +39,7 @@ jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py index 4cfdf4f21..40fed27a1 100644 --- a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py @@ -39,7 +39,7 @@ jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 56316ba12..552f4a176 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -107,7 +107,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index b58bcd461..c9fcfe1be 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -107,7 +107,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 62443bbb5..747d2420f 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -107,7 +107,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 2922b7046..2e016f2d8 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -107,7 +107,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 41fc5ee17..64951ad83 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -128,7 +128,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 92ce4eb44..f381b12e0 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -128,7 +128,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index b8d860479..e2f63dc3e 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -128,7 +128,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 3f5469d8d..0c4f65e6e 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -128,7 +128,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), + rng=jax.random.key(0), update_batch_norm=False) out_diff( diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index 4ad56c873..82e5029c1 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -116,7 +116,7 @@ def get_workload(workload): pytorch_workload = PyTorchWmtWorkload() else: raise ValueError(f'Workload {workload} is not available.') - _ = jax_workload.init_model_fn(jax.random.PRNGKey(0)) + _ = jax_workload.init_model_fn(jax.random.key(0)) _ = pytorch_workload.init_model_fn([0]) return jax_workload, pytorch_workload diff --git a/tests/test_param_types.py b/tests/test_param_types.py index 7cf8f63c3..192b4a9a7 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -220,7 +220,7 @@ def get_workload(workload_name): pytorch_workload = PyTorchWmtWorkload() else: raise ValueError(f'Workload {workload_name} is not available.') - _ = jax_workload.init_model_fn(jax.random.PRNGKey(0)) + _ = jax_workload.init_model_fn(jax.random.key(0)) _ = pytorch_workload.init_model_fn([0]) return jax_workload, pytorch_workload diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py index 6a85c2196..828167944 100644 --- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py +++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py @@ -20,7 +20,7 @@ class ModelsTest(absltest.TestCase): def test_forward_pass(self): batch_size = 11 - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) rng, model_init_rng, *data_rngs = jax.random.split(rng, 4) workload = ImagenetResNetWorkload() model_params, batch_stats = workload.init_model_fn(model_init_rng)