Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix requirements, clean up small test issues #117

Merged
merged 8 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ jobs:
steps:
- uses: actions/checkout@v2
with:
python-version: '3.10'
python-version: '3.11'
- name: Black Check
uses: psf/black@22.8.0
8 changes: 4 additions & 4 deletions .github/workflows/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ jobs:
- name: Setup Python environment
uses: actions/setup-python@v2
with:
python-version: '3.10'
python-version: '3.11'
- name: Install dependencies
run: |
sudo apt install pandoc
pip3 install -r requirements.txt
pip install -r dev-requirements.txt
pip install -r requirements.txt
pip install .
pip3 install h5py --upgrade --no-dependencies
pip3 install cached-property

- name: Build documentation
run: |
cd docs
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/flake8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.10" ]
python-version: [ "3.11" ]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/isort.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.10'
python-version: '3.11'
- name: Install isort
run: |
pip install isort==5.10.1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nbtest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: "3.10"
python-version: "3.11"
- name: Install dev requirements
run: |
pip3 install nbmake
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
python-version: ["3.11"]

steps:
- uses: actions/checkout@v2
Expand All @@ -22,8 +22,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -r dev-requirements.txt
pip install -r requirements.txt
- name: Install package
run: |
pip install .
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def setup_class(cls):
Create a model and data for the tests.
The resnet config has a 1 dimensional input and a 2 dimensional output.
"""

resnet_config = ResNetConfig(
num_channels=2,
embedding_size=64,
Expand Down Expand Up @@ -88,3 +87,11 @@ def test_infinite_failure(self):
"""
with pytest.raises(NotImplementedError):
self.model.compute_ntk(self.x, infinite=True)


if __name__ == "__main__":
test_class = TestFlaxHFModule()
test_class.setup_class()

# test_class.test_infinite_failure()
test_class.test_ntk_shape()
2 changes: 1 addition & 1 deletion CI/unit_tests/utils/test_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_unscaled_eigenvalues(self):

values, vectors = compute_eigensystem(matrix, normalize=False)

assert_array_equal(np.real(values), [1, 1])
assert_array_equal(np.real(values), [1.0, 1.0])

def test_scaled_eigenvalues(self):
"""
Expand Down
10 changes: 10 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
isort>=5.13.2
black>=24.4.0
sphinx>=7.3.7
sphinx_copybutton>=0.5.2
sphinx_rtd_theme>=2.0.0
nbsphinx>=0.9.3
pytest>=8.1.1
numpydoc>=1.7.0
flake8>=7.0.0
pre_commit>=3.7.0
43 changes: 16 additions & 27 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,28 +1,17 @@
numpy
matplotlib
sphinx
flake8
black
ipython
numpydoc
optax
sphinx_copybutton
sphinx_rtd_theme
nbsphinx
tensorflow_probability
scipy
scikit-learn
# Temp fix of version of jax and jaxlib until the next release
jax<=0.4.25
jaxlib<=0.4.25
plotly
flax
tqdm
pandas
numpy>=1.26.4
matplotlib>=3.8.4
optax>=0.2.2
tensorflow_probability>=0.24.0
scipy>=1.13.0
scikit-learn>=1.4.2
plotly>=5.21.0
flax>=0.8.2
tqdm>=4.66.2
pandas>=2.2.2
neural-tangents>=0.6.5
tensorflow-datasets
isort
tensorflow
pyyaml
jupyter
transformers
tensorflow-datasets>=4.9.4
tensorflow>=2.16.1
jupyter>=1.0.0
transformers>=4.40.0
jax>=0.4.26
jaxlib>=0.4.26
1 change: 0 additions & 1 deletion znnl/models/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
-------
"""

from functools import partial
from typing import Any, Callable, Optional, Sequence, Union

import jax
Expand Down
2 changes: 1 addition & 1 deletion znnl/training_strategies/simple_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def train_model(
state = self.model.model_state

loading_bar = trange(
1, epochs + 1, ncols=100, unit="batch", disable=self.disable_loading_bar
0, epochs, ncols=100, unit="batch", disable=self.disable_loading_bar
SamTov marked this conversation as resolved.
Show resolved Hide resolved
)

train_losses = []
Expand Down
Loading