Skip to content

Commit

Permalink
fix: overflow error resolved and PRNGKey to key
Browse files Browse the repository at this point in the history
  • Loading branch information
init-22 committed Nov 16, 2024
1 parent d603ce9 commit be68f8c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion algorithmic_efficiency/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def save_checkpoint(framework: str,
target=checkpoint_state,
step=global_step,
overwrite=True,
keep=np.Inf if save_intermediate_checkpoints else 1)
keep=np.inf if save_intermediate_checkpoints else 1)
else:
if not save_intermediate_checkpoints:
checkpoint_files = gfile.glob(
Expand Down
10 changes: 5 additions & 5 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@

# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
# unsigned int), while RandomState.randint only accepts and returns signed ints.
MAX_UINT32 = 2**31
MAX_UINT32 = 2**32-1
MIN_UINT32 = 0

SeedType = Union[int, list, np.ndarray]


def _signed_to_unsigned(seed: SeedType) -> SeedType:
if isinstance(seed, int):
return seed % 2**32
return seed % MAX_UINT32
if isinstance(seed, list):
return [s % 2**32 for s in seed]
return [s % MAX_UINT32 for s in seed]
if isinstance(seed, np.ndarray):
return np.array([s % 2**32 for s in seed.tolist()])
return np.array([s % MAX_UINT32 for s in seed.tolist()])


def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
Expand Down Expand Up @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType:
def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.PRNGKey(seed)
return jax_rng.key(seed)
return _PRNGKey(seed)
8 changes: 4 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ install_requires =
# Pin to avoid unpinned install in dependencies that requires Python>=3.9.
networkx==3.2.1
docker==7.1.0
numpy>=1.26.4
numpy>=2.1.3
pandas==2.2.3
tensorflow==2.17.0
tensorflow==2.18.0
tensorflow-datasets==4.9.7
tensorflow-addons==0.23.0
gputil==1.4.0
Expand Down Expand Up @@ -100,12 +100,12 @@ ogbg =

librispeech_conformer =
sentencepiece==0.2.0
tensorflow-text==2.17.0
tensorflow-text==2.18.0
pydub==0.25.1

wmt =
sentencepiece==0.2.0
tensorflow-text==2.17.0
tensorflow-text==2.18.0
sacrebleu==2.4.3

# Frameworks #
Expand Down

0 comments on commit be68f8c

Please sign in to comment.