From da49e555674143f18a0b5eda878d20988bb9cc3e Mon Sep 17 00:00:00 2001 From: Robert David Stein Date: Wed, 15 May 2024 10:22:03 -0700 Subject: [PATCH] RB for WINTER (#893) --- .github/workflows/continuous_integration.yml | 24 + mirar/paths.py | 7 +- mirar/pipelines/winter/blocks.py | 19 +- mirar/pipelines/winter/generator/__init__.py | 1 + .../pipelines/winter/generator/candidates.py | 5 +- mirar/pipelines/winter/generator/realbogus.py | 33 ++ mirar/pipelines/winter/models/_candidates.py | 4 +- .../sources/machine_learning/__init__.py | 5 + .../sources/machine_learning/pytorch.py | 132 +++++ poetry.lock | 505 ++++++++++++++++-- pyproject.toml | 2 + 11 files changed, 673 insertions(+), 64 deletions(-) create mode 100644 mirar/pipelines/winter/generator/realbogus.py create mode 100644 mirar/processors/sources/machine_learning/__init__.py create mode 100644 mirar/processors/sources/machine_learning/pytorch.py diff --git a/.github/workflows/continuous_integration.yml b/.github/workflows/continuous_integration.yml index 60728680a..7676be5be 100644 --- a/.github/workflows/continuous_integration.yml +++ b/.github/workflows/continuous_integration.yml @@ -31,6 +31,27 @@ jobs: # Steps represent a sequence of tasks that will be executed as part of the job steps: # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - name: Print disk space + run: df -h + + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + # this might remove tools that are actually needed, + # if set to "true" but frees about 6 GB + tool-cache: false + + # all of these default to true, but feel free to set to + # "false" if necessary for your workflow + android: true + dotnet: true + haskell: true + large-packages: false + docker-images: true + swap-storage: true + + - name: Print disk space + run: df -h - uses: actions/checkout@v3 @@ -89,6 +110,9 @@ jobs: make -C q3c sudo make -C q3c install + - name: Print disk space + run: df -h + # First make sure the doc tests are up to date - name: Run doc tests shell: bash -el {0} diff --git a/mirar/paths.py b/mirar/paths.py index ca94a319c..daa5bd135 100644 --- a/mirar/paths.py +++ b/mirar/paths.py @@ -45,7 +45,7 @@ "The raw data directory will need to be specified manually for path function." f"The raw directory is being set to {default_dir}." ) - logger.warning(warning) + logger.info(warning) base_raw_dir: Path = default_dir else: base_raw_dir = Path(_base_raw_dir) @@ -58,7 +58,7 @@ f"Run 'export OUTPUT_DATA_DIR=/path/to/data' to set this. " f"The output directory is being set to {default_dir}." ) - logger.warning(warning) + logger.info(warning) base_output_dir = default_dir else: base_output_dir = Path(_base_output_dir) @@ -70,6 +70,9 @@ RAW_IMG_SUB_DIR = "raw" CAL_OUTPUT_SUB_DIR = "calibration" +ml_models_dir = base_output_dir.joinpath("ml_models") +ml_models_dir.mkdir(exist_ok=True) + def raw_img_dir( sub_dir: str = "", raw_dir: Path = base_raw_dir, img_sub_dir: str = RAW_IMG_SUB_DIR diff --git a/mirar/pipelines/winter/blocks.py b/mirar/pipelines/winter/blocks.py index adf3b133a..8b5889341 100644 --- a/mirar/pipelines/winter/blocks.py +++ b/mirar/pipelines/winter/blocks.py @@ -5,6 +5,8 @@ # pylint: disable=duplicate-code import os +from winterrb.model import WINTERNet + from mirar.catalog.kowalski import PS1, TMASS, Gaia, GaiaBright, PS1SGSc from mirar.downloader.get_test_data import get_test_data_dir from mirar.paths import ( @@ -39,6 +41,7 @@ ) from mirar.pipelines.winter.constants import NXSPLIT, NYSPLIT from mirar.pipelines.winter.generator import ( + apply_rb_to_table, mask_stamps_around_bright_stars, select_winter_sky_flat_images, winter_anet_sextractor_config_path_generator, @@ -127,10 +130,12 @@ CustomSourceTableModifier, ForcedPhotometryDetector, SourceBatcher, + SourceDebatcher, SourceLoader, SourceWriter, ZOGYSourceDetector, ) +from mirar.processors.sources.machine_learning import Pytorch from mirar.processors.split import SUB_ID_KEY, SplitImage, SwarpImageSplitter from mirar.processors.utils import ( CustomImageBatchModifier, @@ -691,9 +696,21 @@ load_sources = [ SourceLoader(input_dir_name="candidates"), + SourceBatcher(BASE_NAME_KEY), +] + +ml_classify = [ + Pytorch( + model=WINTERNet(), + model_weights_url="https://github.com/winter-telescope/winterrb/raw/" + "v1.0.0/models/winterrb_v1_0_0_weights.pth", + apply_to_table=apply_rb_to_table, + ), + HeaderEditor(edit_keys="rbversion", values="v1.0.0"), ] crossmatch_candidates = [ + SourceDebatcher(), XMatch(catalog=TMASS(num_sources=3, search_radius_arcmin=0.5)), XMatch(catalog=PS1(num_sources=3, search_radius_arcmin=0.5)), XMatch(catalog=PS1SGSc(num_sources=3, search_radius_arcmin=0.5)), @@ -786,7 +803,7 @@ avro_export = avro_write + avro_broadcast -process_candidates = crossmatch_candidates + name_candidates + avro_write +process_candidates = ml_classify + crossmatch_candidates + name_candidates + avro_write load_avro = [SourceLoader(input_dir_name="preavro")] diff --git a/mirar/pipelines/winter/generator/__init__.py b/mirar/pipelines/winter/generator/__init__.py index 78924e0fa..9946278ab 100644 --- a/mirar/pipelines/winter/generator/__init__.py +++ b/mirar/pipelines/winter/generator/__init__.py @@ -22,6 +22,7 @@ winter_ref_photometric_catalogs_purifier, winter_reference_phot_calibrator, ) +from mirar.pipelines.winter.generator.realbogus import apply_rb_to_table from mirar.pipelines.winter.generator.reduce import ( mask_stamps_around_bright_stars, select_winter_dome_flats_images, diff --git a/mirar/pipelines/winter/generator/candidates.py b/mirar/pipelines/winter/generator/candidates.py index 8cefa2c3e..34898d1cf 100644 --- a/mirar/pipelines/winter/generator/candidates.py +++ b/mirar/pipelines/winter/generator/candidates.py @@ -282,13 +282,14 @@ def winter_candidate_quality_filterer(source_table: SourceBatch) -> SourceBatch: mask = ( (src_df["nbad"] < 2) & (src_df["ndethist"] > 0) - # & (src_df["chipsf"] < 3.0) - & (src_df["sumrat"] > 0.7) + & ((src_df["rb"] > 0.5) | pd.isnull(src_df["rb"])) + & (src_df["sumrat"] > 0.6) & (src_df["fwhm"] < 10.0) & (src_df["magdiff"] < 1.6) & (src_df["magdiff"] > -1.0) & (src_df["mindtoedge"] > 50.0) & (src_df["isdiffpos"]) + & ((src_df["sgscore1"] < 0.5) | pd.isnull(src_df["sgscore1"])) ) filtered_df = src_df[mask].reset_index(drop=True) diff --git a/mirar/pipelines/winter/generator/realbogus.py b/mirar/pipelines/winter/generator/realbogus.py new file mode 100644 index 000000000..4d48d10de --- /dev/null +++ b/mirar/pipelines/winter/generator/realbogus.py @@ -0,0 +1,33 @@ +""" +Functions to apply rbscore +""" + +import numpy as np +import pandas as pd +import torch +from torch import nn +from winterrb.utils import make_triplet + + +def apply_rb_to_table(model: nn.Module, table: pd.DataFrame) -> pd.DataFrame: + """ + Apply the realbogus score to a table of sources + + :param model: Pytorch model + :param table: Table of sources + :return: Table of sources with realbogus score + """ + + rb_scores = [] + + for _, row in table.iterrows(): + triplet = make_triplet(row, normalize=True) + triplet_reshaped = np.transpose(np.expand_dims(triplet, axis=0), (0, 3, 1, 2)) + with torch.no_grad(): + outputs = model(torch.from_numpy(triplet_reshaped)) + + rb_scores.append(float(outputs[0])) + + table["rb"] = rb_scores + + return table diff --git a/mirar/pipelines/winter/models/_candidates.py b/mirar/pipelines/winter/models/_candidates.py index 9e09d5cad..6fd67235b 100644 --- a/mirar/pipelines/winter/models/_candidates.py +++ b/mirar/pipelines/winter/models/_candidates.py @@ -143,7 +143,7 @@ class CandidatesTable(WinterBase): # pylint: disable=too-few-public-methods # Real/bogus properties rb = Column(Float, nullable=True) - rbversion = Column(Float, nullable=True) + rbversion = Column(VARCHAR(10), nullable=True) # Solar system properties @@ -284,7 +284,7 @@ class Candidate(BaseDB): scorr: float = Field(ge=0) rb: float | None = Field(ge=0, default=None) - rbversion: float | None = Field(ge=0, default=None) + rbversion: str | None = Field(default=None, max_length=10) ssdistnr: float | None = Field(ge=0, default=None) ssmagnr: float | None = Field(ge=0, default=None) diff --git a/mirar/processors/sources/machine_learning/__init__.py b/mirar/processors/sources/machine_learning/__init__.py new file mode 100644 index 000000000..f2e02c438 --- /dev/null +++ b/mirar/processors/sources/machine_learning/__init__.py @@ -0,0 +1,5 @@ +""" +Module for machine learning models +""" + +from mirar.processors.sources.machine_learning.pytorch import Pytorch diff --git a/mirar/processors/sources/machine_learning/pytorch.py b/mirar/processors/sources/machine_learning/pytorch.py new file mode 100644 index 000000000..294810cd1 --- /dev/null +++ b/mirar/processors/sources/machine_learning/pytorch.py @@ -0,0 +1,132 @@ +""" +Module with classes to use apply an ML score from pytorch +""" + +import logging +from pathlib import Path +from typing import Callable + +import pandas as pd +import requests +import torch +from torch import nn + +from mirar.data import SourceBatch +from mirar.paths import ml_models_dir +from mirar.processors.base_processor import BaseSourceProcessor + +logger = logging.getLogger(__name__) + + +class Pytorch(BaseSourceProcessor): + """ + Class to apply a pytorch model to a source table + """ + + base_key = "pytorch" + + def __init__( + self, + model: nn.Module, + model_weights_url: str, + apply_to_table: Callable[[nn.Module, pd.DataFrame], pd.DataFrame], + ): + super().__init__() + self._model = model + self.model_weights_url = model_weights_url + self.model_name = Path(self.model_weights_url).name + self.apply_to_table = apply_to_table + + self.model = None + + def __str__(self) -> str: + return f"Processor to use Pytorch model {self.model_name} to score sources" + + def get_ml_path(self) -> Path: + """ + Get the path to the ML model + + :return: Path to the ML model + """ + return ml_models_dir.joinpath(self.model_name) + + def download_model(self): + """ + Download the ML model + """ + + url = self.model_weights_url + local_path = self.get_ml_path() + + logger.info( + f"Downloading model {self.model_name} " f"from {url} to {local_path}" + ) + + with requests.get(url, stream=True, timeout=120.0) as r: + r.raise_for_status() + with open(local_path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + # If you have chunk encoded response uncomment if + # and set chunk_size parameter to None. + # if chunk: + f.write(chunk) + + if not local_path.exists(): + err = f"Model {self.model_name} not downloaded" + logger.error(err) + raise FileNotFoundError(err) + + @staticmethod + def load_model(path): + """ + Function to load a pytorch model dict from a path + + :param path: Path to the model + :return: Pytorch model dict + """ + if not path.exists(): + err = f"Model {path} not found" + logger.error(err) + raise FileNotFoundError(err) + + if path.suffix in [".pth", ".pt"]: + return torch.load(path) + + raise ValueError(f"Unknown model type {path.suffix}") + + def get_model(self): + """ + Load the ML model weights. Download it if it doesn't exist. + + :return: ML model + """ + + if self.model is None: + + model = self._model + + local_path = self.get_ml_path() + + if not local_path.exists(): + self.download_model() + + model.load_state_dict(torch.load(local_path)) + model.eval() + + self.model = model + + return self.model + + def _apply_to_sources( + self, + batch: SourceBatch, + ) -> SourceBatch: + + model = self.get_model() + + for source_table in batch: + sources = source_table.get_data() + new = self.apply_to_table(model, sources) + source_table.set_data(new) + + return batch diff --git a/poetry.lock b/poetry.lock index 1db60d0ab..a2a2d4c9c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -133,13 +133,13 @@ test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock [[package]] name = "astroid" -version = "3.1.0" +version = "3.2.0" description = "An abstract syntax tree for Python with inference support." optional = false python-versions = ">=3.8.0" files = [ - {file = "astroid-3.1.0-py3-none-any.whl", hash = "sha256:951798f922990137ac090c53af473db7ab4e70c770e6d7fae0cec59f74411819"}, - {file = "astroid-3.1.0.tar.gz", hash = "sha256:ac248253bfa4bd924a0de213707e7ebeeb3138abeb48d798784ead1e56d419d4"}, + {file = "astroid-3.2.0-py3-none-any.whl", hash = "sha256:16ee8ca5c75ac828783028cc1f967777f0e507c6886a295ad143e0f405b975a2"}, + {file = "astroid-3.2.0.tar.gz", hash = "sha256:f7f829f8506ade59f1b3c6c93d8fac5b1ebc721685fa9af23e9794daf1d450a3"}, ] [package.dependencies] @@ -215,13 +215,13 @@ typing = ["typing-extensions (>=4.0.0)"] [[package]] name = "astropy-iers-data" -version = "0.2024.5.6.0.29.28" +version = "0.2024.5.13.0.30.12" description = "IERS Earth Rotation and Leap Second tables for the astropy core package" optional = false python-versions = ">=3.8" files = [ - {file = "astropy_iers_data-0.2024.5.6.0.29.28-py3-none-any.whl", hash = "sha256:ad992722705af68b4fe3b79983c05c4bbed01b3ed31466fe6f574ddfb986b620"}, - {file = "astropy_iers_data-0.2024.5.6.0.29.28.tar.gz", hash = "sha256:7f02a20d4bc72c22533123734cc48d694cdd32adc471c5d6659218d788b611da"}, + {file = "astropy_iers_data-0.2024.5.13.0.30.12-py3-none-any.whl", hash = "sha256:ee6c25ceaaf1d0e88171e74f70bef6927584f923e3920eb537d0c1a72169ff65"}, + {file = "astropy_iers_data-0.2024.5.13.0.30.12.tar.gz", hash = "sha256:774894a1bc4cd9f36437354a85d2016539bbf02dd01013d95ca4ac0a1e5643ef"}, ] [package.extras] @@ -1473,6 +1473,41 @@ files = [ {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] +[[package]] +name = "fsspec" +version = "2024.3.1" +description = "File-system specification" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, + {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"}, +] + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +devel = ["pytest", "pytest-cov"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +tqdm = ["tqdm"] + [[package]] name = "greenlet" version = "3.0.3" @@ -1912,6 +1947,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + [[package]] name = "json5" version = "0.9.25" @@ -2239,13 +2285,13 @@ files = [ [[package]] name = "keyring" -version = "25.2.0" +version = "25.2.1" description = "Store and access your passwords safely." optional = false python-versions = ">=3.8" files = [ - {file = "keyring-25.2.0-py3-none-any.whl", hash = "sha256:19f17d40335444aab84b19a0d16a77ec0758a9c384e3446ae2ed8bd6d53b67a5"}, - {file = "keyring-25.2.0.tar.gz", hash = "sha256:7045f367268ce42dba44745050164b431e46f6e92f99ef2937dfadaef368d8cf"}, + {file = "keyring-25.2.1-py3-none-any.whl", hash = "sha256:2458681cdefc0dbc0b7eb6cf75d0b98e59f9ad9b2d4edd319d18f68bdca95e50"}, + {file = "keyring-25.2.1.tar.gz", hash = "sha256:daaffd42dbda25ddafb1ad5fec4024e5bbcfe424597ca1ca452b299861e49f1b"}, ] [package.dependencies] @@ -2620,6 +2666,23 @@ files = [ {file = "more_itertools-10.2.0-py3-none-any.whl", hash = "sha256:686b06abe565edfab151cb8fd385a05651e1fdf8f0a14191e4439283421f8684"}, ] +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +optional = false +python-versions = "*" +files = [ + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, +] + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + [[package]] name = "mypy" version = "1.10.0" @@ -2886,6 +2949,147 @@ files = [ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.1.3.1" +description = "CUBLAS native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" +description = "CUDA profiling tools runtime libs." +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" +description = "NVRTC native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" +description = "CUDA Runtime native Libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "8.9.2.26" +description = "cuDNN runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.0.2.54" +description = "CUFFT native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.2.106" +description = "CURAND native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" +description = "CUDA solver native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" +description = "CUSPARSE native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.19.3" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:a9734707a2c96443331c1e48c717024aa6678a0e2a4cb66b2c364d18cee6b48d"}, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.4.127" +description = "Nvidia JIT LTO Library" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +description = "NVIDIA Tools Extension" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, +] + [[package]] name = "overrides" version = "7.7.0" @@ -2980,13 +3184,13 @@ xml = ["lxml (>=4.9.2)"] [[package]] name = "pandas-stubs" -version = "2.2.1.240316" +version = "2.2.2.240514" description = "Type annotations for pandas" optional = false python-versions = ">=3.9" files = [ - {file = "pandas_stubs-2.2.1.240316-py3-none-any.whl", hash = "sha256:0126a26451a37cb893ea62357ca87ba3d181bd999ec8ba2ca5602e20207d6682"}, - {file = "pandas_stubs-2.2.1.240316.tar.gz", hash = "sha256:236a4f812fb6b1922e9607ff09e427f6d8540c421c9e5a40e3e4ddf7adac7f05"}, + {file = "pandas_stubs-2.2.2.240514-py3-none-any.whl", hash = "sha256:5d6f64d45a98bc94152a0f76fa648e598cd2b9ba72302fd34602479f0c391a53"}, + {file = "pandas_stubs-2.2.2.240514.tar.gz", hash = "sha256:85b20da44a62c80eb8389bcf4cbfe31cce1cafa8cca4bf1fc75ec45892e72ce8"}, ] [package.dependencies] @@ -3405,47 +3609,47 @@ tests = ["pytest"] [[package]] name = "pyarrow" -version = "16.0.0" +version = "16.1.0" description = "Python library for Apache Arrow" optional = false python-versions = ">=3.8" files = [ - {file = "pyarrow-16.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:22a1fdb1254e5095d629e29cd1ea98ed04b4bbfd8e42cc670a6b639ccc208b60"}, - {file = "pyarrow-16.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:574a00260a4ed9d118a14770edbd440b848fcae5a3024128be9d0274dbcaf858"}, - {file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0815d0ddb733b8c1b53a05827a91f1b8bde6240f3b20bf9ba5d650eb9b89cdf"}, - {file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df0080339387b5d30de31e0a149c0c11a827a10c82f0c67d9afae3981d1aabb7"}, - {file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:edf38cce0bf0dcf726e074159c60516447e4474904c0033f018c1f33d7dac6c5"}, - {file = "pyarrow-16.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:91d28f9a40f1264eab2af7905a4d95320ac2f287891e9c8b0035f264fe3c3a4b"}, - {file = "pyarrow-16.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:99af421ee451a78884d7faea23816c429e263bd3618b22d38e7992c9ce2a7ad9"}, - {file = "pyarrow-16.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d22d0941e6c7bafddf5f4c0662e46f2075850f1c044bf1a03150dd9e189427ce"}, - {file = "pyarrow-16.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:266ddb7e823f03733c15adc8b5078db2df6980f9aa93d6bb57ece615df4e0ba7"}, - {file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cc23090224b6594f5a92d26ad47465af47c1d9c079dd4a0061ae39551889efe"}, - {file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56850a0afe9ef37249d5387355449c0f94d12ff7994af88f16803a26d38f2016"}, - {file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:705db70d3e2293c2f6f8e84874b5b775f690465798f66e94bb2c07bab0a6bb55"}, - {file = "pyarrow-16.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:5448564754c154997bc09e95a44b81b9e31ae918a86c0fcb35c4aa4922756f55"}, - {file = "pyarrow-16.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:729f7b262aa620c9df8b9967db96c1575e4cfc8c25d078a06968e527b8d6ec05"}, - {file = "pyarrow-16.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:fb8065dbc0d051bf2ae2453af0484d99a43135cadabacf0af588a3be81fbbb9b"}, - {file = "pyarrow-16.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:20ce707d9aa390593ea93218b19d0eadab56390311cb87aad32c9a869b0e958c"}, - {file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5823275c8addbbb50cd4e6a6839952682a33255b447277e37a6f518d6972f4e1"}, - {file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ab8b9050752b16a8b53fcd9853bf07d8daf19093533e990085168f40c64d978"}, - {file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:42e56557bc7c5c10d3e42c3b32f6cff649a29d637e8f4e8b311d334cc4326730"}, - {file = "pyarrow-16.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:2a7abdee4a4a7cfa239e2e8d721224c4b34ffe69a0ca7981354fe03c1328789b"}, - {file = "pyarrow-16.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:ef2f309b68396bcc5a354106741d333494d6a0d3e1951271849787109f0229a6"}, - {file = "pyarrow-16.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:ed66e5217b4526fa3585b5e39b0b82f501b88a10d36bd0d2a4d8aa7b5a48e2df"}, - {file = "pyarrow-16.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc8814310486f2a73c661ba8354540f17eef51e1b6dd090b93e3419d3a097b3a"}, - {file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c2f5e239db7ed43e0ad2baf46a6465f89c824cc703f38ef0fde927d8e0955f7"}, - {file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f293e92d1db251447cb028ae12f7bc47526e4649c3a9924c8376cab4ad6b98bd"}, - {file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:dd9334a07b6dc21afe0857aa31842365a62eca664e415a3f9536e3a8bb832c07"}, - {file = "pyarrow-16.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d91073d1e2fef2c121154680e2ba7e35ecf8d4969cc0af1fa6f14a8675858159"}, - {file = "pyarrow-16.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:71d52561cd7aefd22cf52538f262850b0cc9e4ec50af2aaa601da3a16ef48877"}, - {file = "pyarrow-16.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:b93c9a50b965ee0bf4fef65e53b758a7e8dcc0c2d86cebcc037aaaf1b306ecc0"}, - {file = "pyarrow-16.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d831690844706e374c455fba2fb8cfcb7b797bfe53ceda4b54334316e1ac4fa4"}, - {file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35692ce8ad0b8c666aa60f83950957096d92f2a9d8d7deda93fb835e6053307e"}, - {file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dd3151d098e56f16a8389c1247137f9e4c22720b01c6f3aa6dec29a99b74d80"}, - {file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:bd40467bdb3cbaf2044ed7a6f7f251c8f941c8b31275aaaf88e746c4f3ca4a7a"}, - {file = "pyarrow-16.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:00a1dcb22ad4ceb8af87f7bd30cc3354788776c417f493089e0a0af981bc8d80"}, - {file = "pyarrow-16.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:fda9a7cebd1b1d46c97b511f60f73a5b766a6de4c5236f144f41a5d5afec1f35"}, - {file = "pyarrow-16.0.0.tar.gz", hash = "sha256:59bb1f1edbbf4114c72415f039f1359f1a57d166a331c3229788ccbfbb31689a"}, + {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, + {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, + {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, + {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, + {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, + {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, + {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, + {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, + {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, + {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, + {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, + {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, + {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, + {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, + {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, + {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, ] [package.dependencies] @@ -3688,17 +3892,17 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pylint" -version = "3.1.0" +version = "3.2.0" description = "python code static checker" optional = false python-versions = ">=3.8.0" files = [ - {file = "pylint-3.1.0-py3-none-any.whl", hash = "sha256:507a5b60953874766d8a366e8e8c7af63e058b26345cfcb5f91f89d987fd6b74"}, - {file = "pylint-3.1.0.tar.gz", hash = "sha256:6a69beb4a6f63debebaab0a3477ecd0f559aa726af4954fc948c51f7a2549e23"}, + {file = "pylint-3.2.0-py3-none-any.whl", hash = "sha256:9f20c05398520474dac03d7abb21ab93181f91d4c110e1e0b32bc0d016c34fa4"}, + {file = "pylint-3.2.0.tar.gz", hash = "sha256:ad8baf17c8ea5502f23ae38d7c1b7ec78bd865ce34af9a0b986282e2611a8ff2"}, ] [package.dependencies] -astroid = ">=3.1.0,<=3.2.0-dev0" +astroid = ">=3.2.0,<=3.3.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, @@ -4350,6 +4554,48 @@ files = [ {file = "rpds_py-0.18.1.tar.gz", hash = "sha256:dc48b479d540770c811fbd1eb9ba2bb66951863e448efec2e2c102625328e92f"}, ] +[[package]] +name = "scikit-learn" +version = "1.4.2" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit-learn-1.4.2.tar.gz", hash = "sha256:daa1c471d95bad080c6e44b4946c9390a4842adc3082572c20e4f8884e39e959"}, + {file = "scikit_learn-1.4.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8539a41b3d6d1af82eb629f9c57f37428ff1481c1e34dddb3b9d7af8ede67ac5"}, + {file = "scikit_learn-1.4.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:68b8404841f944a4a1459b07198fa2edd41a82f189b44f3e1d55c104dbc2e40c"}, + {file = "scikit_learn-1.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81bf5d8bbe87643103334032dd82f7419bc8c8d02a763643a6b9a5c7288c5054"}, + {file = "scikit_learn-1.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f0ea5d0f693cb247a073d21a4123bdf4172e470e6d163c12b74cbb1536cf38"}, + {file = "scikit_learn-1.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:87440e2e188c87db80ea4023440923dccbd56fbc2d557b18ced00fef79da0727"}, + {file = "scikit_learn-1.4.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:45dee87ac5309bb82e3ea633955030df9bbcb8d2cdb30383c6cd483691c546cc"}, + {file = "scikit_learn-1.4.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1d0b25d9c651fd050555aadd57431b53d4cf664e749069da77f3d52c5ad14b3b"}, + {file = "scikit_learn-1.4.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0203c368058ab92efc6168a1507d388d41469c873e96ec220ca8e74079bf62e"}, + {file = "scikit_learn-1.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44c62f2b124848a28fd695db5bc4da019287abf390bfce602ddc8aa1ec186aae"}, + {file = "scikit_learn-1.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:5cd7b524115499b18b63f0c96f4224eb885564937a0b3477531b2b63ce331904"}, + {file = "scikit_learn-1.4.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90378e1747949f90c8f385898fff35d73193dfcaec3dd75d6b542f90c4e89755"}, + {file = "scikit_learn-1.4.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ff4effe5a1d4e8fed260a83a163f7dbf4f6087b54528d8880bab1d1377bd78be"}, + {file = "scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:671e2f0c3f2c15409dae4f282a3a619601fa824d2c820e5b608d9d775f91780c"}, + {file = "scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d36d0bc983336bbc1be22f9b686b50c964f593c8a9a913a792442af9bf4f5e68"}, + {file = "scikit_learn-1.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:d762070980c17ba3e9a4a1e043ba0518ce4c55152032f1af0ca6f39b376b5928"}, + {file = "scikit_learn-1.4.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d9993d5e78a8148b1d0fdf5b15ed92452af5581734129998c26f481c46586d68"}, + {file = "scikit_learn-1.4.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:426d258fddac674fdf33f3cb2d54d26f49406e2599dbf9a32b4d1696091d4256"}, + {file = "scikit_learn-1.4.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5460a1a5b043ae5ae4596b3126a4ec33ccba1b51e7ca2c5d36dac2169f62ab1d"}, + {file = "scikit_learn-1.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49d64ef6cb8c093d883e5a36c4766548d974898d378e395ba41a806d0e824db8"}, + {file = "scikit_learn-1.4.2-cp39-cp39-win_amd64.whl", hash = "sha256:c97a50b05c194be9146d61fe87dbf8eac62b203d9e87a3ccc6ae9aed2dfaf361"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=2.0.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory-profiler (>=0.57.0)", "pandas (>=1.1.5)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-gallery (>=0.15.0)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +tests = ["black (>=23.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.3)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.19.12)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.0.272)", "scikit-image (>=0.17.2)"] + [[package]] name = "scipy" version = "1.13.0" @@ -4798,6 +5044,20 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "sympy" +version = "1.12" +description = "Computer algebra system (CAS) in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, + {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, +] + +[package.dependencies] +mpmath = ">=0.19" + [[package]] name = "terminado" version = "0.18.1" @@ -4819,6 +5079,17 @@ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] typing = ["mypy (>=1.6,<2.0)", "traitlets (>=5.11.1)"] +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] + [[package]] name = "tinycss2" version = "1.3.0" @@ -4870,6 +5141,75 @@ files = [ {file = "tomlkit-0.12.5.tar.gz", hash = "sha256:eef34fba39834d4d6b73c9ba7f3e4d1c417a4e56f89a7e96e090dd0d24b8fb3c"}, ] +[[package]] +name = "torch" +version = "2.2.2" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bc889d311a855dd2dfd164daf8cc903a6b7273a747189cebafdd89106e4ad585"}, + {file = "torch-2.2.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15dffa4cc3261fa73d02f0ed25f5fa49ecc9e12bf1ae0a4c1e7a88bbfaad9030"}, + {file = "torch-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:11e8fe261233aeabd67696d6b993eeb0896faa175c6b41b9a6c9f0334bdad1c5"}, + {file = "torch-2.2.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b2e2200b245bd9f263a0d41b6a2dab69c4aca635a01b30cca78064b0ef5b109e"}, + {file = "torch-2.2.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:877b3e6593b5e00b35bbe111b7057464e76a7dd186a287280d941b564b0563c2"}, + {file = "torch-2.2.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:ad4c03b786e074f46606f4151c0a1e3740268bcf29fbd2fdf6666d66341c1dcb"}, + {file = "torch-2.2.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:32827fa1fbe5da8851686256b4cd94cc7b11be962862c2293811c94eea9457bf"}, + {file = "torch-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:f9ef0a648310435511e76905f9b89612e45ef2c8b023bee294f5e6f7e73a3e7c"}, + {file = "torch-2.2.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:95b9b44f3bcebd8b6cd8d37ec802048c872d9c567ba52c894bba90863a439059"}, + {file = "torch-2.2.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:49aa4126ede714c5aeef7ae92969b4b0bbe67f19665106463c39f22e0a1860d1"}, + {file = "torch-2.2.2-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:cf12cdb66c9c940227ad647bc9cf5dba7e8640772ae10dfe7569a0c1e2a28aca"}, + {file = "torch-2.2.2-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:89ddac2a8c1fb6569b90890955de0c34e1724f87431cacff4c1979b5f769203c"}, + {file = "torch-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:451331406b760f4b1ab298ddd536486ab3cfb1312614cfe0532133535be60bea"}, + {file = "torch-2.2.2-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:eb4d6e9d3663e26cd27dc3ad266b34445a16b54908e74725adb241aa56987533"}, + {file = "torch-2.2.2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:bf9558da7d2bf7463390b3b2a61a6a3dbb0b45b161ee1dd5ec640bf579d479fc"}, + {file = "torch-2.2.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd2bf7697c9e95fb5d97cc1d525486d8cf11a084c6af1345c2c2c22a6b0029d0"}, + {file = "torch-2.2.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:b421448d194496e1114d87a8b8d6506bce949544e513742b097e2ab8f7efef32"}, + {file = "torch-2.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:3dbcd563a9b792161640c0cffe17e3270d85e8f4243b1f1ed19cca43d28d235b"}, + {file = "torch-2.2.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:31f4310210e7dda49f1fb52b0ec9e59382cfcb938693f6d5378f25b43d7c1d29"}, + {file = "torch-2.2.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:c795feb7e8ce2e0ef63f75f8e1ab52e7fd5e1a4d7d0c31367ade1e3de35c9e95"}, + {file = "torch-2.2.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a6e5770d68158d07456bfcb5318b173886f579fdfbf747543901ce718ea94782"}, + {file = "torch-2.2.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:67dcd726edff108e2cd6c51ff0e416fd260c869904de95750e80051358680d24"}, + {file = "torch-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:539d5ef6c4ce15bd3bd47a7b4a6e7c10d49d4d21c0baaa87c7d2ef8698632dfb"}, + {file = "torch-2.2.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:dff696de90d6f6d1e8200e9892861fd4677306d0ef604cb18f2134186f719f82"}, + {file = "torch-2.2.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:3a4dd910663fd7a124c056c878a52c2b0be4a5a424188058fe97109d4436ee42"}, +] + +[package.dependencies] +filelock = "*" +fsspec = "*" +jinja2 = "*" +networkx = "*" +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +sympy = "*" +triton = {version = "2.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} +typing-extensions = ">=4.8.0" + +[package.extras] +opt-einsum = ["opt-einsum (>=3.3)"] +optree = ["optree (>=0.9.1)"] + +[[package]] +name = "torchinfo" +version = "1.8.0" +description = "Model summary in PyTorch, based off of the original torchsummary." +optional = false +python-versions = ">=3.7" +files = [ + {file = "torchinfo-1.8.0-py3-none-any.whl", hash = "sha256:2e911c2918603f945c26ff21a3a838d12709223dc4ccf243407bce8b6e897b46"}, + {file = "torchinfo-1.8.0.tar.gz", hash = "sha256:72e94b0e9a3e64dc583a8e5b7940b8938a1ac0f033f795457f27e6f4e7afa2e9"}, +] + [[package]] name = "tornado" version = "6.4" @@ -4925,6 +5265,29 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "triton" +version = "2.2.0" +description = "A language and compiler for custom Deep Learning operations" +optional = false +python-versions = "*" +files = [ + {file = "triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2294514340cfe4e8f4f9e5c66c702744c4a117d25e618bd08469d0bfed1e2e5"}, + {file = "triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da58a152bddb62cafa9a857dd2bc1f886dbf9f9c90a2b5da82157cd2b34392b0"}, + {file = "triton-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af58716e721460a61886668b205963dc4d1e4ac20508cc3f623aef0d70283d5"}, + {file = "triton-2.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8fe46d3ab94a8103e291bd44c741cc294b91d1d81c1a2888254cbf7ff846dab"}, + {file = "triton-2.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8ce26093e539d727e7cf6f6f0d932b1ab0574dc02567e684377630d86723ace"}, + {file = "triton-2.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:227cc6f357c5efcb357f3867ac2a8e7ecea2298cd4606a8ba1e931d1d5a947df"}, +] + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] +tutorials = ["matplotlib", "pandas", "tabulate", "torch"] + [[package]] name = "types-python-dateutil" version = "2.9.0.20240316" @@ -5027,13 +5390,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "virtualenv" -version = "20.26.1" +version = "20.26.2" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.26.1-py3-none-any.whl", hash = "sha256:7aa9982a728ae5892558bff6a2839c00b9ed145523ece2274fad6f414690ae75"}, - {file = "virtualenv-20.26.1.tar.gz", hash = "sha256:604bfdceaeece392802e6ae48e69cec49168b9c5f4a44e483963f9242eb0e78b"}, + {file = "virtualenv-20.26.2-py3-none-any.whl", hash = "sha256:a624db5e94f01ad993d476b9ee5346fdf7b9de43ccaee0e0197012dc838a0e9b"}, + {file = "virtualenv-20.26.2.tar.gz", hash = "sha256:82bf0f4eebbb78d36ddaee0283d43fe5736b53880b8a8cdcd37390a07ac3741c"}, ] [package.dependencies] @@ -5160,6 +5523,34 @@ files = [ {file = "widgetsnbextension-4.0.10.tar.gz", hash = "sha256:64196c5ff3b9a9183a8e699a4227fb0b7002f252c814098e66c4d1cd0644688f"}, ] +[[package]] +name = "winterrb" +version = "1.0.0" +description = "" +optional = false +python-versions = ">=3.10" +files = [ + {file = "winterrb-1.0.0-py3-none-any.whl", hash = "sha256:699839ad9c8ebede59df5b58f75e80197bb599748de8e48c160bc011730ee881"}, + {file = "winterrb-1.0.0.tar.gz", hash = "sha256:a6ec6f5d775d436e3c0f4966f545c745cb7c31dad80f1fa178b2fca187c58355"}, +] + +[package.dependencies] +astropy = "*" +fastavro = "*" +ipykernel = "*" +jupyter = "*" +matplotlib = "*" +numpy = "*" +pandas = "*" +scikit-learn = "*" +scipy = "*" +torch = "*" +torchinfo = "*" +tqdm = "*" + +[package.extras] +dev = ["black", "isort"] + [[package]] name = "wintertoo" version = "1.5.1" @@ -5207,4 +5598,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10.0,<3.12" -content-hash = "68bfa5fb97c02da99900557f92dfb1db803c7afb97500ed478854a7d6d507afe" +content-hash = "1b3360b3c5181e023110c7749a878a1563a60d5b6f8ad526377466a1d7752c67" diff --git a/pyproject.toml b/pyproject.toml index 433f2c622..b4451aa6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,8 @@ wintertoo = "^1.5.0" scipy = "^1.12.0" python-dotenv = "^1.0.1" pyarrow = ">=15.0.2,<17.0.0" +torch = "=2.2.2" +winterrb = "^1.0.0" [tool.poetry.group.docs.dependencies] sphinx = "^7.0.1"