Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to random utils #625

Merged
merged 19 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Proxy functions in front of the Jax RNG API or a compatible Numpy RNG API."""

from typing import Any, List, Union
from typing import Union

from absl import flags
from absl import logging
Expand All @@ -21,6 +21,12 @@
MAX_INT32 = 2**31
MIN_INT32 = -MAX_INT32

# SALT constants
_SALT1 = np.random.RandomState(seed=5).randint(
MIN_INT32, MAX_INT32, dtype=np.int32)
_SALT2 = np.random.RandomState(seed=6).randint(
MIN_INT32, MAX_INT32, dtype=np.int32)

SeedType = Union[int, list, np.ndarray]


Expand All @@ -33,15 +39,19 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType:
return np.array([s + 2**32 if s < 0 else s for s in seed.tolist()])


def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
return [new_seed, data]
def _fold_in(seed: SeedType, data: int) -> SeedType:
a = np.random.RandomState(seed=_signed_to_unsigned(seed ^ _SALT1)).randint(
MIN_INT32, MAX_INT32, dtype=np.int32)
b = np.random.RandomState(seed=_signed_to_unsigned(data ^ _SALT2)).randint(
MIN_INT32, MAX_INT32, dtype=np.int32)
c = np.random.RandomState(seed=_signed_to_unsigned(a ^ b)).randint(
MIN_INT32, MAX_INT32, dtype=np.int32)
return c


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand All @@ -58,7 +68,13 @@ def _check_jax_install() -> None:
'--framework=pytorch to use the Numpy version instead.')


def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
def _bits(seed: SeedType) -> int:
rng = np.random.RandomState(_signed_to_unsigned(seed))
b = rng.bytes(4)
return int.from_bytes(b, byteorder='little')
priyakasimbeg marked this conversation as resolved.
Show resolved Hide resolved


def fold_in(seed: SeedType, data: int) -> SeedType:
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.fold_in(seed, data)
Expand All @@ -77,3 +93,10 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
_check_jax_install()
return jax_rng.PRNGKey(seed)
return _PRNGKey(seed)


def bits(seed: SeedType) -> int:
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.bits(seed)
return _bits(seed)
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _build_dataset(
}
if split == 'eval_train':
train_indices = indices_split['train']
random.Random(data_rng[0]).shuffle(train_indices)
random.Random(data_rng).shuffle(train_indices)
indices_split['eval_train'] = train_indices[:self.num_eval_train_examples]
if split in indices_split:
dataset = torch.utils.data.Subset(dataset, indices_split[split])
Expand Down Expand Up @@ -111,7 +111,7 @@ def init_model_fn(
self._model.reset_parameters()
return self._model, None

torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
self._model = resnet18(num_classes=self._num_classes)
self._param_shapes = param_utils.pytorch_param_shapes(self._model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def init_model_fn(
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""Only dropout is used."""
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
# Disable cudnn benchmark to avoid OOM errors.
torch.backends.cudnn.benchmark = False
if self.use_resnet:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def init_model_fn(
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = UNet(
num_pool_layers=self.num_pool_layers,
num_channels=self.num_channels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _build_dataset(

if split == 'eval_train':
indices = list(range(self.num_train_examples))
random.Random(data_rng[0]).shuffle(indices)
random.Random(data_rng).shuffle(indices)
dataset = torch.utils.data.Subset(dataset,
indices[:self.num_eval_train_examples])

Expand Down Expand Up @@ -147,7 +147,7 @@ def init_model_fn(
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)

if self.use_silu and self.use_gelu:
raise RuntimeError('Cannot use both GELU and SiLU activations.')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def init_model_fn(
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = models.ViT(
dropout_rate=dropout_rate,
num_classes=self._num_classes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def init_model_fn(
Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as
input_dropout_rate.
"""
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
# Configure torch backends to avoid OOM errors.
torch.backends.cudnn.benchmark = False
torch.backends.cuda.enable_flash_sdp(False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def init_model_fn(
Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate
as input_dropout_rate.
"""
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = DeepspeechEncoderDecoder(
DeepspeechConfig(
feed_forward_dropout_rate=dropout_rate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def init_model_fn(
self._model.reset_parameters()
return self._model, None

torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
self._model = _Model()
self._param_shapes = param_utils.pytorch_param_shapes(self._model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/mnist/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _build_mnist_dataset(

if shuffle:
ds = ds.repeat()
ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])
ds = ds.shuffle(16 * global_batch_size, seed=prng.bits(data_rng))
ds = ds.batch(global_batch_size, drop_remainder=is_train)

if repeat_final_dataset:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def init_model_fn(
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""aux_dropout_rate is unused."""
del aux_dropout_rate
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)
model = GNN(
num_outputs=self._num_outputs,
dropout_rate=dropout_rate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def init_model_fn(
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""aux_dropout_rate is used as attention_dropout_rate."""
torch.random.manual_seed(rng[0])
torch.random.manual_seed(rng)

if self.activation == 'relu':
activation = F.relu
Expand Down
2 changes: 1 addition & 1 deletion scoring/run_workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def main(_):
# For each runnable workload check if there are any containers running and if not launch next container command
for workload in workloads:
run_key = prng.fold_in(rng_subkey, hash(workload))
run_seed = run_key[0] # arbitrary
run_seed = prng.bits(run_key)
base_workload_name = get_base_workload_name(workload)
wait_until_container_not_running()
os.system(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_param_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_workload(workload):
else:
raise ValueError(f'Workload {workload} is not available.')
_ = jax_workload.init_model_fn(jax.random.PRNGKey(0))
_ = pytorch_workload.init_model_fn([0])
_ = pytorch_workload.init_model_fn(0)
return jax_workload, pytorch_workload


Expand Down
2 changes: 1 addition & 1 deletion tests/test_param_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def get_workload(workload_name):
else:
raise ValueError(f'Workload {workload_name} is not available.')
_ = jax_workload.init_model_fn(jax.random.PRNGKey(0))
_ = pytorch_workload.init_model_fn([0])
_ = pytorch_workload.init_model_fn(0)
return jax_workload, pytorch_workload


Expand Down
Loading