From 366eab28fafebed59f50b9f75c08ac828e0abe92 Mon Sep 17 00:00:00 2001 From: SamTov Date: Wed, 24 Apr 2024 08:22:47 +0200 Subject: [PATCH 1/7] Fix requirements, clean up small test issues --- CI/unit_tests/utils/test_matrix_utils.py | 2 +- dev-requirements.txt | 10 +++++ requirements.txt | 41 +++++++-------------- znnl/models/jax_model.py | 3 +- znnl/training_strategies/simple_training.py | 2 +- 5 files changed, 27 insertions(+), 31 deletions(-) create mode 100644 dev-requirements.txt diff --git a/CI/unit_tests/utils/test_matrix_utils.py b/CI/unit_tests/utils/test_matrix_utils.py index c307d22..5fd9a92 100644 --- a/CI/unit_tests/utils/test_matrix_utils.py +++ b/CI/unit_tests/utils/test_matrix_utils.py @@ -53,7 +53,7 @@ def test_unscaled_eigenvalues(self): values, vectors = compute_eigensystem(matrix, normalize=False) - assert_array_equal(np.real(values), [1, 1]) + assert_array_equal(np.real(values), [1.0, 1.0]) def test_scaled_eigenvalues(self): """ diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 0000000..89184ee --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,10 @@ +isort>=5.13.2 +black>=24.4.0 +sphinx>=7.3.7 +sphinx_copybutton>=0.5.2 +sphinx_rtd_theme>=2.0.0 +nbsphinx>=0.9.3 +pytest>=8.1.1 +numpydoc>=1.7.0 +flake8>=7.0.0 +pre_commit>=3.7.0 diff --git a/requirements.txt b/requirements.txt index d31d588..dc383fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,28 +1,15 @@ -numpy -matplotlib -sphinx -flake8 -black -ipython -numpydoc -optax -sphinx_copybutton -sphinx_rtd_theme -nbsphinx -tensorflow_probability -scipy -scikit-learn -# Temp fix of version of jax and jaxlib until the next release -jax<=0.4.25 -jaxlib<=0.4.25 -plotly -flax -tqdm -pandas +numpy>=1.26.4 +matplotlib>=3.8.4 +optax>=0.2.2 +tensorflow_probability>0.24.0 +scipy>=1.13.0 +scikit-learn>=1.4.2 +plotly>=5.21.0 +flax>=0.8.2 +tqdm>4.66.2 +pandas>=2.2.2 neural-tangents>=0.6.5 -tensorflow-datasets -isort -tensorflow -pyyaml -jupyter -transformers \ No newline at end of file +tensorflow-datasets>=4.9.4 +tensorflow>=2.16.1 +jupyter>=1.0.0 +transformers>=4.40.0 diff --git a/znnl/models/jax_model.py b/znnl/models/jax_model.py index 9dee179..17bb3fb 100644 --- a/znnl/models/jax_model.py +++ b/znnl/models/jax_model.py @@ -25,7 +25,6 @@ ------- """ -from functools import partial from typing import Any, Callable, Optional, Sequence, Union import jax @@ -123,7 +122,7 @@ def __init__( # Prepare NTK calculation self.empirical_ntk = nt.batch( - nt.empirical_ntk_fn(f=self._ntk_apply_fn, trace_axes=trace_axes), + nt.empirical_ntk_fn(f=jax.jit(self._ntk_apply_fn), trace_axes=trace_axes), batch_size=ntk_batch_size, store_on_device=store_on_device, ) diff --git a/znnl/training_strategies/simple_training.py b/znnl/training_strategies/simple_training.py index 325bf7e..c0675e8 100644 --- a/znnl/training_strategies/simple_training.py +++ b/znnl/training_strategies/simple_training.py @@ -369,7 +369,7 @@ def train_model( state = self.model.model_state loading_bar = trange( - 1, epochs + 1, ncols=100, unit="batch", disable=self.disable_loading_bar + 0, epochs, ncols=100, unit="batch", disable=self.disable_loading_bar ) train_losses = [] From b200f7597af7ab3a689d44a904ececf425c82993 Mon Sep 17 00:00:00 2001 From: SamTov Date: Wed, 24 Apr 2024 08:28:31 +0200 Subject: [PATCH 2/7] Fix CI runner config --- .github/workflows/doc.yml | 6 +++--- .github/workflows/pytest.yml | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index ba5fb3b..9d14631 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -15,10 +15,10 @@ jobs: - name: Install dependencies run: | sudo apt install pandoc - pip3 install -r requirements.txt + pip install -r dev-requirements.txt + pip install -r requirements.txt pip install . - pip3 install h5py --upgrade --no-dependencies - pip3 install cached-property + - name: Build documentation run: | cd docs diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 12c5d1b..272ecf1 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -22,8 +22,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -r dev-requirements.txt + pip install -r requirements.txt - name: Install package run: | pip install . From 66e720f6eda27a0fa5df7be06e9bb6cb54890fef Mon Sep 17 00:00:00 2001 From: SamTov Date: Wed, 24 Apr 2024 08:30:39 +0200 Subject: [PATCH 3/7] Fix false requirement boundary --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index dc383fa..1ae9027 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ numpy>=1.26.4 matplotlib>=3.8.4 optax>=0.2.2 -tensorflow_probability>0.24.0 +tensorflow_probability>=0.24.0 scipy>=1.13.0 scikit-learn>=1.4.2 plotly>=5.21.0 flax>=0.8.2 -tqdm>4.66.2 +tqdm>=4.66.2 pandas>=2.2.2 neural-tangents>=0.6.5 tensorflow-datasets>=4.9.4 From 3151b04dd529d11b2447507b2059a7601fcd3c70 Mon Sep 17 00:00:00 2001 From: SamTov Date: Wed, 24 Apr 2024 08:42:25 +0200 Subject: [PATCH 4/7] Update python version --- .github/workflows/black.yml | 2 +- .github/workflows/doc.yml | 2 +- .github/workflows/flake8.yml | 2 +- .github/workflows/isort.yml | 2 +- .github/workflows/nbtest.yml | 2 +- .github/workflows/pytest.yml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index be1ae39..24af6fa 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -9,6 +9,6 @@ jobs: steps: - uses: actions/checkout@v2 with: - python-version: '3.10' + python-version: '3.11' - name: Black Check uses: psf/black@22.8.0 diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 9d14631..5d84512 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -11,7 +11,7 @@ jobs: - name: Setup Python environment uses: actions/setup-python@v2 with: - python-version: '3.10' + python-version: '3.11' - name: Install dependencies run: | sudo apt install pandoc diff --git a/.github/workflows/flake8.yml b/.github/workflows/flake8.yml index bb71a80..1e4a252 100644 --- a/.github/workflows/flake8.yml +++ b/.github/workflows/flake8.yml @@ -9,7 +9,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.10" ] + python-version: [ "3.11" ] steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml index 7a3d1e9..0a56f07 100644 --- a/.github/workflows/isort.yml +++ b/.github/workflows/isort.yml @@ -10,7 +10,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: - python-version: '3.10' + python-version: '3.11' - name: Install isort run: | pip install isort==5.10.1 diff --git a/.github/workflows/nbtest.yml b/.github/workflows/nbtest.yml index a482867..26370ff 100644 --- a/.github/workflows/nbtest.yml +++ b/.github/workflows/nbtest.yml @@ -12,7 +12,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: - python-version: "3.10" + python-version: "3.11" - name: Install dev requirements run: | pip3 install nbmake diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 272ecf1..cc6aff5 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] + python-version: ["3.11"] steps: - uses: actions/checkout@v2 From f670486e45e5d1d387ac7aa66686dd401bf3ba0d Mon Sep 17 00:00:00 2001 From: SamTov Date: Wed, 24 Apr 2024 13:46:19 +0200 Subject: [PATCH 5/7] add jax back --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 1ae9027..71bb1d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,5 @@ tensorflow-datasets>=4.9.4 tensorflow>=2.16.1 jupyter>=1.0.0 transformers>=4.40.0 +jax>=0.4.26 +jaxlib>=0.4.26 \ No newline at end of file From afb849f0963ff90ec0c8d9fdddc778430ae638d9 Mon Sep 17 00:00:00 2001 From: SamTov Date: Tue, 7 May 2024 11:28:29 +0200 Subject: [PATCH 6/7] Remove HF test --- ...ace_flax_model.py => _test_huggingface_flax_model.py} | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) rename CI/unit_tests/models/{test_huggingface_flax_model.py => _test_huggingface_flax_model.py} (93%) diff --git a/CI/unit_tests/models/test_huggingface_flax_model.py b/CI/unit_tests/models/_test_huggingface_flax_model.py similarity index 93% rename from CI/unit_tests/models/test_huggingface_flax_model.py rename to CI/unit_tests/models/_test_huggingface_flax_model.py index 7ef3867..7d77063 100644 --- a/CI/unit_tests/models/test_huggingface_flax_model.py +++ b/CI/unit_tests/models/_test_huggingface_flax_model.py @@ -46,7 +46,6 @@ def setup_class(cls): Create a model and data for the tests. The resnet config has a 1 dimensional input and a 2 dimensional output. """ - resnet_config = ResNetConfig( num_channels=2, embedding_size=64, @@ -88,3 +87,11 @@ def test_infinite_failure(self): """ with pytest.raises(NotImplementedError): self.model.compute_ntk(self.x, infinite=True) + + +if __name__ == "__main__": + test_class = TestFlaxHFModule() + test_class.setup_class() + + # test_class.test_infinite_failure() + test_class.test_ntk_shape() From c5e1170613ec80b29d8b1b0acd2b5801520a8104 Mon Sep 17 00:00:00 2001 From: SamTov Date: Tue, 7 May 2024 11:33:36 +0200 Subject: [PATCH 7/7] remove jit --- znnl/models/jax_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/znnl/models/jax_model.py b/znnl/models/jax_model.py index 17bb3fb..f6764e8 100644 --- a/znnl/models/jax_model.py +++ b/znnl/models/jax_model.py @@ -122,7 +122,7 @@ def __init__( # Prepare NTK calculation self.empirical_ntk = nt.batch( - nt.empirical_ntk_fn(f=jax.jit(self._ntk_apply_fn), trace_axes=trace_axes), + nt.empirical_ntk_fn(f=self._ntk_apply_fn, trace_axes=trace_axes), batch_size=ntk_batch_size, store_on_device=store_on_device, )