diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index 29c1a821e..04dad0eb7 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -231,7 +231,7 @@ def save_checkpoint(framework: str, target=checkpoint_state, step=global_step, overwrite=True, - keep=np.Inf if save_intermediate_checkpoints else 1) + keep=np.inf if save_intermediate_checkpoints else 1) else: if not save_intermediate_checkpoints: checkpoint_files = gfile.glob( diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 31317047e..93dc263bd 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,7 +18,7 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_UINT32 = 2**31 +MAX_UINT32 = 2**32-1 MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] @@ -26,11 +26,11 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed % 2**32 + return seed % MAX_UINT32 if isinstance(seed, list): - return [s % 2**32 for s in seed] + return [s % MAX_UINT32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s % 2**32 for s in seed.tolist()]) + return np.array([s % MAX_UINT32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType: def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.PRNGKey(seed) + return jax_rng.key(seed) return _PRNGKey(seed) diff --git a/setup.cfg b/setup.cfg index 078b694b8..6e6a1c957 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,9 +39,9 @@ install_requires = # Pin to avoid unpinned install in dependencies that requires Python>=3.9. networkx==3.2.1 docker==7.1.0 - numpy>=1.26.4 + numpy>=2.1.3 pandas==2.2.3 - tensorflow==2.17.0 + tensorflow==2.18.0 tensorflow-datasets==4.9.7 tensorflow-addons==0.23.0 gputil==1.4.0 @@ -100,12 +100,12 @@ ogbg = librispeech_conformer = sentencepiece==0.2.0 - tensorflow-text==2.17.0 + tensorflow-text==2.18.0 pydub==0.25.1 wmt = sentencepiece==0.2.0 - tensorflow-text==2.17.0 + tensorflow-text==2.18.0 sacrebleu==2.4.3 # Frameworks #