diff --git a/Dockerfile b/Dockerfile index 79437d7..1765940 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,7 +27,7 @@ COPY requirements.txt . # Install the required packages RUN python3.11 -m pip install -r requirements.txt -# Install JAX with CUDA support. HPC is on CUDA 11, and JAX 0.2.25 is the latest version for that +# Install JAX with CUDA support. RUN python3.11 -m pip install --upgrade \ pip install -U "jax[cuda12]" \ optax diff --git a/config.yaml b/config.yaml index 1950f90..4f580dd 100644 --- a/config.yaml +++ b/config.yaml @@ -1,11 +1,11 @@ base: 12 -n: 1024 -emb: 32 +n: 16384 +emb: 64 lr: 0.001 -depth: 3 -heads: 4 -epochs: 400 +depth: 2 +heads: 8 +epochs: 10000 l2: 0.00001 # lambda block: vaswani -dropout: 0.15 +dropout: 0.2 gamma: 2 diff --git a/poetry.lock b/poetry.lock index 7ac26d2..0c931c3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "absl-py" +version = "2.1.0" +description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +optional = false +python-versions = ">=3.7" +files = [ + {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, + {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, +] + [[package]] name = "altair" version = "5.3.0" @@ -175,6 +186,26 @@ files = [ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] +[[package]] +name = "chex" +version = "0.1.86" +description = "Chex: Testing made fun, in JAX!" +optional = false +python-versions = ">=3.9" +files = [ + {file = "chex-0.1.86-py3-none-any.whl", hash = "sha256:251c20821092323a3d9c28e1cf80e4a58180978bec368f531949bd9847eee568"}, + {file = "chex-0.1.86.tar.gz", hash = "sha256:e8b0f96330eba4144659e1617c0f7a57b161e8cbb021e55c6d5056c7378091d1"}, +] + +[package.dependencies] +absl-py = ">=0.9.0" +jax = ">=0.4.16" +jaxlib = ">=0.1.37" +numpy = ">=1.24.1" +setuptools = {version = "*", markers = "python_version >= \"3.12\""} +toolz = ">=0.9.0" +typing-extensions = ">=4.2.0" + [[package]] name = "click" version = "8.1.7" @@ -425,6 +456,80 @@ files = [ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] +[[package]] +name = "jax" +version = "0.4.29" +description = "Differentiate, compile, and transform Numpy code." +optional = false +python-versions = ">=3.9" +files = [ + {file = "jax-0.4.29-py3-none-any.whl", hash = "sha256:cfdc594d133d7dfba2ec19bc10742ffaa0e8827ead29be00d3ec4215a3f7892e"}, + {file = "jax-0.4.29.tar.gz", hash = "sha256:12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186"}, +] + +[package.dependencies] +ml-dtypes = ">=0.4.0" +numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, +] +opt-einsum = "*" +scipy = [ + {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, + {version = ">=1.9", markers = "python_version < \"3.12\""}, +] + +[package.extras] +ci = ["jaxlib (==0.4.28)"] +cpu = ["jaxlib (==0.4.29)"] +cuda = ["jaxlib (==0.4.29+cuda12.cudnn91)"] +cuda12 = ["jax-cuda12-plugin (==0.4.29)", "jaxlib (==0.4.29)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=9.0,<10.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] +cuda12-cudnn91 = ["jaxlib (==0.4.29+cuda12.cudnn91)"] +cuda12-local = ["jaxlib (==0.4.29+cuda12.cudnn91)"] +cuda12-pip = ["jaxlib (==0.4.29+cuda12.cudnn91)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=9.0,<10.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] +minimum-jaxlib = ["jaxlib (==0.4.27)"] +tpu = ["jaxlib (==0.4.29)", "libtpu-nightly (==0.1.dev20240609)", "requests"] + +[[package]] +name = "jaxlib" +version = "0.4.29" +description = "XLA library for JAX" +optional = false +python-versions = ">=3.9" +files = [ + {file = "jaxlib-0.4.29-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:60ec8aa2ba133a0615b0fce8e084c90c179c019793551641dd3da6526d036953"}, + {file = "jaxlib-0.4.29-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:adb37f9c01a0fbdf97ab4afc7b60939c1694a5c056d8224c3d292cc253a3dc55"}, + {file = "jaxlib-0.4.29-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:8b1062b804d95ddb8dbb039c48316cbaed1d6866f973ef39e03f39e452ac8a1c"}, + {file = "jaxlib-0.4.29-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:ae21b84dd08c015bf2bab9ba97fa6a1da30f9a51c35902e23c7ffe959ad2e86c"}, + {file = "jaxlib-0.4.29-cp310-cp310-win_amd64.whl", hash = "sha256:a4993ab2f91c8ee213cacc4ed341539a7980b1b231e9dac69d68b508118dc19d"}, + {file = "jaxlib-0.4.29-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:1da0716c45c5b0e177d334938a09953915f8da2080fffee9366ad8a9a988f484"}, + {file = "jaxlib-0.4.29-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5da76e760be790896f7149eccafff64b800f129a282de7f7a1edc138d56ac997"}, + {file = "jaxlib-0.4.29-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:aec76c416657e25884ee1364e98e40fcedbd2235c79691026d9babf05d850ede"}, + {file = "jaxlib-0.4.29-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:5a313e94c3ae87f147d561bc61e923d75e18e2448ac8fbf150d526ecd24404b0"}, + {file = "jaxlib-0.4.29-cp311-cp311-win_amd64.whl", hash = "sha256:7bfcc35a2991c2973489333f5c07dbb1f5d0ec78ef889097534c2c5f0d0149a7"}, + {file = "jaxlib-0.4.29-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:7d7eabdb21814386cfa32e09ddaee76e374126c3709989b6de5f1e49a04f5f36"}, + {file = "jaxlib-0.4.29-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ac11efc5eb7d0d25dc853efd68c18b204d071e78f762583dc9ebed84db272bf2"}, + {file = "jaxlib-0.4.29-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:08cb5b24f481f62ff432b0bbedc7e35c5d561dc42c1c8138bbf8514ea91ab17e"}, + {file = "jaxlib-0.4.29-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:a7c884c5651e5d1cc0fc57f2cf0abee4223b1f38c3e8cbd00fb142e6551cfe47"}, + {file = "jaxlib-0.4.29-cp312-cp312-win_amd64.whl", hash = "sha256:91058f1606312c42621b0a9979d0f14c0db9da6341ffd714ac5eb5e3be79c59a"}, + {file = "jaxlib-0.4.29-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:e59afd8026f43688fd0bf4f3dbcf913157c7f144d000850aaa7a88b1228a87ab"}, + {file = "jaxlib-0.4.29-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2801327384be3edab5f3adb38262206e488135d7c8e27d928ac3c6ffb71f7718"}, + {file = "jaxlib-0.4.29-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:dfce109290dbbd27176750b931bedbabb56a4f5e955f83eead2e3cdefe5108b8"}, + {file = "jaxlib-0.4.29-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:84be918201a7f06b73074ed154fc3344b02aa8736597b39397e7093807dcfc3c"}, + {file = "jaxlib-0.4.29-cp39-cp39-win_amd64.whl", hash = "sha256:9b0efd3ba45a7ee03fb91a4118099ea97aa02a96e50f4d91cc910554e226c5b9"}, +] + +[package.dependencies] +ml-dtypes = ">=0.4.0" +numpy = ">=1.22" +scipy = [ + {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, + {version = ">=1.9", markers = "python_version < \"3.12\""}, +] + +[package.extras] +cuda12-pip = ["nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=9.0,<10.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] + [[package]] name = "jinja2" version = "3.1.4" @@ -757,6 +862,41 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "ml-dtypes" +version = "0.4.0" +description = "" +optional = false +python-versions = ">=3.9" +files = [ + {file = "ml_dtypes-0.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81"}, + {file = "ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d"}, + {file = "ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5"}, + {file = "ml_dtypes-0.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364"}, + {file = "ml_dtypes-0.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1"}, + {file = "ml_dtypes-0.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c"}, + {file = "ml_dtypes-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6"}, + {file = "ml_dtypes-0.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e"}, + {file = "ml_dtypes-0.4.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06"}, + {file = "ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49"}, + {file = "ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259"}, + {file = "ml_dtypes-0.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675"}, + {file = "ml_dtypes-0.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368"}, + {file = "ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0"}, + {file = "ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1"}, + {file = "ml_dtypes-0.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e"}, + {file = "ml_dtypes-0.4.0.tar.gz", hash = "sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, +] + +[package.extras] +dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] + [[package]] name = "mpmath" version = "1.3.0" @@ -844,6 +984,48 @@ files = [ matplotlib = ">=3.1,<4" sympy = ">=1.2,<2" +[[package]] +name = "opt-einsum" +version = "3.3.0" +description = "Optimizing numpys einsum function" +optional = false +python-versions = ">=3.5" +files = [ + {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"}, + {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"}, +] + +[package.dependencies] +numpy = ">=1.7" + +[package.extras] +docs = ["numpydoc", "sphinx (==1.2.3)", "sphinx-rtd-theme", "sphinxcontrib-napoleon"] +tests = ["pytest", "pytest-cov", "pytest-pep8"] + +[[package]] +name = "optax" +version = "0.2.2" +description = "A gradient processing and optimisation library in JAX." +optional = false +python-versions = ">=3.9" +files = [ + {file = "optax-0.2.2-py3-none-any.whl", hash = "sha256:411c414a76aae259f4191a60b712663968741a5163ca92fc250b5d5c7d36fb57"}, + {file = "optax-0.2.2.tar.gz", hash = "sha256:f09bf790ef4b09fb9c35f79a07594c6196a719919985f542dc84b0bf97812e0e"}, +] + +[package.dependencies] +absl-py = ">=0.7.1" +chex = ">=0.1.86" +jax = ">=0.1.55" +jaxlib = ">=0.1.37" +numpy = ">=1.18.0" + +[package.extras] +docs = ["flax", "ipython (>=8.8.0)", "matplotlib (>=3.5.0)", "myst-nb (>=1.0.0)", "sphinx (>=6.0.0)", "sphinx-autodoc-typehints", "sphinx-book-theme (>=1.0.1)", "sphinx-collections (>=0.0.1)", "sphinx-gallery (>=0.14.0)", "sphinx_contributors", "sphinxcontrib-katex", "tensorflow (>=2.4.0)", "tensorflow-datasets (>=4.2.0)"] +dp-accounting = ["absl-py (>=1.0.0)", "attrs (>=21.4.0)", "mpmath (>=1.2.1)", "numpy (>=1.21.4)", "scipy (>=1.7.1)"] +examples = ["dp_accounting (>=0.4)", "flax", "tensorflow (>=2.4.0)", "tensorflow-datasets (>=4.2.0)"] +test = ["dm-tree (>=0.1.7)", "flax (>=0.5.3)"] + [[package]] name = "packaging" version = "24.0" @@ -895,8 +1077,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -1968,4 +2150,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "44b08e99f2f33906ec403ed9dfc377d57d426a3ca210c2f67a881dcf424b4c25" +content-hash = "c77b597a7bc539d5a74a54e2cb17da49b7371f5b15b0634298d1d8ec75497bd4" diff --git a/pyproject.toml b/pyproject.toml index 570471b..4a3412d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,8 @@ numpy-hilbert-curve = "^1.0.1" wandb = "^0.17.0" seaborn = "^0.13.2" scikit-learn = "^1.5.0" +jax = "^0.4.29" +optax = "^0.2.2" [build-system] diff --git a/src/plots.py b/src/plots.py index f58fbaf..6bc9c39 100644 --- a/src/plots.py +++ b/src/plots.py @@ -28,6 +28,13 @@ plt.rcParams["font.family"] = "Monospace" +def fname_fn(conf, fname): + return ( + "_".join([f"{k}_{v}" for k, v in conf.items() if k not in ["in_d", "out_d"]]) + + f"_{fname}" + ) + + # functions def polar_plot(gold, pred, conf, fname, offset=0): # maps v to a polar plot conf = conf.__dict__ @@ -74,8 +81,8 @@ def polar_plot(gold, pred, conf, fname, offset=0): # maps v to a polar plot ] ) ax.set_xlabel(xlabel, color=ink) - if darkdetect.isLight(): - plt.savefig(f"figs/{fname}", dpi=100) + fname = fname_fn(conf, fname) + plt.savefig(f"figs/{fname}", dpi=100) def curve_plot( @@ -110,8 +117,10 @@ def curve_plot( color=ink, ) ax.legend(info["legend"], frameon=False, labelcolor=ink) + # make fname contain conf + fname = fname_fn(conf, "curves") if darkdetect.isLight(): - plt.savefig(f"figs/curves.pdf") + plt.savefig(f"figs/{fname}.pdf", dpi=100) ############################################