diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 68e9a9cfe..37fc08e5f 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -1,6 +1,6 @@ """Proxy functions in front of the Jax RNG API or a compatible Numpy RNG API.""" -from typing import Any, List, Union +from typing import Union from absl import flags from absl import logging @@ -21,6 +21,12 @@ MAX_INT32 = 2**31 MIN_INT32 = -MAX_INT32 +# SALT constants +_SALT1 = np.random.RandomState(seed=5).randint( + MIN_INT32, MAX_INT32, dtype=np.int32) +_SALT2 = np.random.RandomState(seed=6).randint( + MIN_INT32, MAX_INT32, dtype=np.int32) + SeedType = Union[int, list, np.ndarray] @@ -33,15 +39,19 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: return np.array([s + 2**32 if s < 0 else s for s in seed.tolist()]) -def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: - rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) - return [new_seed, data] +def _fold_in(seed: SeedType, data: int) -> SeedType: + a = np.random.RandomState(seed=_signed_to_unsigned(seed ^ _SALT1)).randint( + MIN_INT32, MAX_INT32, dtype=np.int32) + b = np.random.RandomState(seed=_signed_to_unsigned(data ^ _SALT2)).randint( + MIN_INT32, MAX_INT32, dtype=np.int32) + c = np.random.RandomState(seed=_signed_to_unsigned(a ^ b)).randint( + MIN_INT32, MAX_INT32, dtype=np.int32) + return c def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name @@ -58,7 +68,11 @@ def _check_jax_install() -> None: '--framework=pytorch to use the Numpy version instead.') -def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: +def _bits(seed: SeedType) -> int: + return seed + + +def fold_in(seed: SeedType, data: int) -> SeedType: if FLAGS.framework == 'jax': _check_jax_install() return jax_rng.fold_in(seed, data) @@ -77,3 +91,10 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name _check_jax_install() return jax_rng.PRNGKey(seed) return _PRNGKey(seed) + + +def bits(seed: SeedType) -> int: + if FLAGS.framework == 'jax': + _check_jax_install() + return jax_rng.bits(seed) + return _bits(seed) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py index af86c212e..e2d655e9b 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py @@ -65,7 +65,7 @@ def _build_dataset( } if split == 'eval_train': train_indices = indices_split['train'] - random.Random(data_rng[0]).shuffle(train_indices) + random.Random(data_rng).shuffle(train_indices) indices_split['eval_train'] = train_indices[:self.num_eval_train_examples] if split in indices_split: dataset = torch.utils.data.Subset(dataset, indices_split[split]) @@ -111,7 +111,7 @@ def init_model_fn( self._model.reset_parameters() return self._model, None - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) self._model = resnet18(num_classes=self._num_classes) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 85bb602d1..bff5fa837 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -72,7 +72,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """Only dropout is used.""" del aux_dropout_rate - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False if self.use_resnet: diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index 74f6aa13d..0ad1b3eeb 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -113,7 +113,7 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) model = UNet( num_pool_layers=self.num_pool_layers, num_channels=self.num_channels, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 6727054c9..ba2012644 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -103,7 +103,7 @@ def _build_dataset( if split == 'eval_train': indices = list(range(self.num_train_examples)) - random.Random(data_rng[0]).shuffle(indices) + random.Random(data_rng).shuffle(indices) dataset = torch.utils.data.Subset(dataset, indices[:self.num_eval_train_examples]) @@ -147,7 +147,7 @@ def init_model_fn( """Dropout is unused.""" del dropout_rate del aux_dropout_rate - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) if self.use_silu and self.use_gelu: raise RuntimeError('Cannot use both GELU and SiLU activations.') diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index e672e8d22..aec3f1aaf 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -30,7 +30,7 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) model = models.ViT( dropout_rate=dropout_rate, num_classes=self._num_classes, diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 20f27b150..9f0a6f841 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -58,7 +58,7 @@ def init_model_fn( Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as input_dropout_rate. """ - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) # Configure torch backends to avoid OOM errors. torch.backends.cudnn.benchmark = False torch.backends.cuda.enable_flash_sdp(False) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index bcdd78fb5..c968b528d 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -32,7 +32,7 @@ def init_model_fn( Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate as input_dropout_rate. """ - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) model = DeepspeechEncoderDecoder( DeepspeechConfig( feed_forward_dropout_rate=dropout_rate, diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index e638df078..a60e6040e 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -133,7 +133,7 @@ def init_model_fn( self._model.reset_parameters() return self._model, None - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) self._model = _Model() self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index dcc195170..2b593948c 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -55,7 +55,7 @@ def _build_mnist_dataset( if shuffle: ds = ds.repeat() - ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0]) + ds = ds.shuffle(16 * global_batch_size, seed=prng.bits(data_rng)) ds = ds.batch(global_batch_size, drop_remainder=is_train) if repeat_final_dataset: diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index d4817226d..84a445c4b 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -143,7 +143,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is unused.""" del aux_dropout_rate - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) model = GNN( num_outputs=self._num_outputs, dropout_rate=dropout_rate, diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 9f6d817f4..9ee959a4f 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -171,7 +171,7 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is used as attention_dropout_rate.""" - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) if self.activation == 'relu': activation = F.relu diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 077ce8d4f..83d7a5f65 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -148,7 +148,7 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: run_key = prng.fold_in(rng_subkey, hash(workload)) - run_seed = run_key[0] # arbitrary + run_seed = prng.bits(run_key) base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() os.system( diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index b67625213..afa752cb5 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -113,7 +113,7 @@ def get_workload(workload): else: raise ValueError(f'Workload {workload} is not available.') _ = jax_workload.init_model_fn(jax.random.PRNGKey(0)) - _ = pytorch_workload.init_model_fn([0]) + _ = pytorch_workload.init_model_fn(0) return jax_workload, pytorch_workload diff --git a/tests/test_param_types.py b/tests/test_param_types.py index 7cf8f63c3..639c7372d 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -221,7 +221,7 @@ def get_workload(workload_name): else: raise ValueError(f'Workload {workload_name} is not available.') _ = jax_workload.init_model_fn(jax.random.PRNGKey(0)) - _ = pytorch_workload.init_model_fn([0]) + _ = pytorch_workload.init_model_fn(0) return jax_workload, pytorch_workload