diff --git a/.github/workflows/draft-pdf.yml b/.github/workflows/draft-pdf.yml deleted file mode 100644 index 51a805e..0000000 --- a/.github/workflows/draft-pdf.yml +++ /dev/null @@ -1,23 +0,0 @@ -on: [push] - -jobs: - paper: - runs-on: ubuntu-latest - name: Paper Draft - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Build draft PDF - uses: openjournals/openjournals-draft-action@master - with: - journal: joss - # This should be the path to the paper within your repo. - paper-path: paper/paper.md - - name: Upload - uses: actions/upload-artifact@v1 - with: - name: paper - # This is the output path where Pandoc will write the compiled - # PDF. Note, this should be the same directory as the input - # paper.md - path: paper/paper.pdf \ No newline at end of file diff --git a/README.md b/README.md index 2843405..b7e1e5b 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,9 @@ The `ikpls` software package provides fast and efficient tools for PLS (Partial Least Squares) modeling. This package is designed to help researchers and practitioners handle PLS modeling faster than previously possible - particularly on large datasets. +## Citation +If you use the `ikpls` software package for your work, please cite [this Journal of Open Source Software article](https://joss.theoj.org/papers/10.21105/joss.06533). If you use the fast cross-validation algorithm implemented in `ikpls.fast_cross_validation.numpy_ikpls`, please also cite [this arXiv preprint](https://arxiv.org/abs/2401.13185). + ## Unlock the Power of Fast and Stable Partial Least Squares Modeling with IKPLS Dive into cutting-edge Python implementations of the IKPLS (Improved Kernel Partial Least Squares) Algorithms #1 and #2 [[1]](#references) for CPUs, GPUs, and TPUs. IKPLS is both fast [[2]](#references) and numerically stable [[3]](#references) making it optimal for PLS modeling. @@ -49,9 +52,10 @@ and scaling can be enabled or disabled independently from eachother and for X an by setting the parameters `center_X`, `center_Y`, `scale_X`, and `scale_Y`, respectively. In addition to correctly handling (column-wise) centering and scaling, the fast cross-validation algorithm **correctly handles row-wise preprocessing** -such as (row-wise) centering and scaling of the X and Y input matrices, -convolution, or other preprocessing. Row-wise preprocessing can safely be -applied before passing the data to the fast cross-validation algorithm. +that operates independently on each sample such as (row-wise) centering and scaling +of the X and Y input matrices, convolution, or other preprocessing. Row-wise +preprocessing can safely be applied before passing the data to the fast +cross-validation algorithm. ## Prerequisites diff --git a/ikpls/__init__.py b/ikpls/__init__.py index b3f9ac7..b7e1990 100644 --- a/ikpls/__init__.py +++ b/ikpls/__init__.py @@ -1 +1 @@ -__version__ = "1.2.4" +__version__ = "1.2.5" diff --git a/ikpls/fast_cross_validation/numpy_ikpls.py b/ikpls/fast_cross_validation/numpy_ikpls.py index 9c4208c..d5b651c 100644 --- a/ikpls/fast_cross_validation/numpy_ikpls.py +++ b/ikpls/fast_cross_validation/numpy_ikpls.py @@ -84,7 +84,7 @@ def __init__( scale_X: bool = True, scale_Y: bool = True, algorithm: int = 1, - dtype: np.float_ = np.float64, + dtype: np.floating = np.float64, ) -> None: self.center_X = center_X self.center_Y = center_Y @@ -139,27 +139,27 @@ def _stateless_fit( validation_indices: npt.NDArray[np.int_], ) -> Union[ tuple[ - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], ], tuple[ - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], - npt.NDArray[np.float_], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], + npt.NDArray[np.floating], ], ]: """ @@ -433,13 +433,13 @@ def _stateless_fit( def _stateless_predict( self, indices: npt.NDArray[np.int_], - B: npt.NDArray[np.float_], - training_X_mean: npt.NDArray[np.float_], - training_Y_mean: npt.NDArray[np.float_], - training_X_std: npt.NDArray[np.float_], - training_Y_std: npt.NDArray[np.float_], + B: npt.NDArray[np.floating], + training_X_mean: npt.NDArray[np.floating], + training_Y_mean: npt.NDArray[np.floating], + training_X_std: npt.NDArray[np.floating], + training_Y_std: npt.NDArray[np.floating], n_components: Union[None, int] = None, - ) -> npt.NDArray[np.float_]: + ) -> npt.NDArray[np.floating]: """ Predicts with Improved Kernel PLS Algorithm #1 on `X` with `B` using `n_components` components. If `n_components` is None, then predictions are @@ -503,7 +503,7 @@ def _stateless_fit_predict_eval( self, validation_indices: npt.NDArray[np.int_], metric_function: Callable[ - [npt.NDArray[np.float_], npt.NDArray[np.float_]], Any + [npt.NDArray[np.floating], npt.NDArray[np.floating]], Any ], ) -> Any: """ diff --git a/ikpls/jax_ikpls_base.py b/ikpls/jax_ikpls_base.py index f645dfd..39ded28 100644 --- a/ikpls/jax_ikpls_base.py +++ b/ikpls/jax_ikpls_base.py @@ -115,7 +115,7 @@ def __init__( self.X_std = None self.Y_std = None - def _weight_warning(self, arg: Tuple[npt.NDArray[np.int_], npt.NDArray[np.float_]]): + def _weight_warning(self, arg: Tuple[npt.NDArray[np.int_], npt.NDArray[np.floating]]): """ Display a warning message if the weight is close to zero. diff --git a/ikpls/numpy_ikpls.py b/ikpls/numpy_ikpls.py index 0ff7109..679577a 100644 --- a/ikpls/numpy_ikpls.py +++ b/ikpls/numpy_ikpls.py @@ -79,7 +79,7 @@ def __init__( scale_X: bool = True, scale_Y: bool = True, copy: bool = True, - dtype: np.float_ = np.float64, + dtype: np.floating = np.float64, ) -> None: self.algorithm = algorithm self.center_X = center_X @@ -300,7 +300,7 @@ def fit(self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int) -> None: def predict( self, X: npt.ArrayLike, n_components: Union[None, int] = None - ) -> npt.NDArray[np.float_]: + ) -> npt.NDArray[np.floating]: """ Predicts with Improved Kernel PLS Algorithm #1 on `X` with `B` using `n_components` components. If `n_components` is None, then predictions are diff --git a/pyproject.toml b/pyproject.toml index 00b72b1..3c33fe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ikpls" -version = "1.2.4" +version = "1.2.5" description = "Improved Kernel PLS and Fast Cross-Validation." authors = ["Sm00thix "] maintainers = ["Sm00thix "] @@ -11,7 +11,7 @@ repository = "https://github.com/Sm00thix/IKPLS" [tool.poetry.dependencies] python = ">=3.9, <3.13" -numpy = "^1.26.3" +numpy = ">=1.26.3" jax = "^0.4.20" jaxlib = "^0.4.20" scikit-learn = "^1.5.0" diff --git a/tests/load_data.py b/tests/load_data.py index f1b6c4d..8572d9e 100644 --- a/tests/load_data.py +++ b/tests/load_data.py @@ -64,7 +64,7 @@ def load_spectra(): resp_byte_array = resp.read() byte_contents = io.BytesIO(resp_byte_array) npz_arr = np.load(byte_contents) - spectra = np.row_stack([npz_arr[k] for k in npz_arr.keys()]) + spectra = np.vstack([npz_arr[k] for k in npz_arr.keys()]) spectra = spectra.astype(np.float64) spectra = -np.log10(spectra) return spectra diff --git a/tests/test_ikpls.py b/tests/test_ikpls.py index 98e7395..cae8f47 100644 --- a/tests/test_ikpls.py +++ b/tests/test_ikpls.py @@ -47,14 +47,14 @@ class TestClass: csv : DataFrame The CSV data containing target values. - raw_spectra : NDArray[float] + raw_spectra : npt.NDArray[np.float64] The raw spectral data. """ csv = load_data.load_csv() raw_spectra = load_data.load_spectra() - def load_X(self) -> npt.NDArray[np.float_]: + def load_X(self) -> npt.NDArray[np.float64]: """ Description ----------- @@ -62,12 +62,12 @@ def load_X(self) -> npt.NDArray[np.float_]: Returns ------- - npt.NDArray[np.float_] + npt.NDArray[np.float64] The raw spectral data. """ return np.copy(self.raw_spectra) - def load_Y(self, values: list[str]) -> npt.NDArray[np.float_]: + def load_Y(self, values: list[str]) -> npt.NDArray[np.float64]: """ Description ----------- @@ -80,7 +80,7 @@ def load_Y(self, values: list[str]) -> npt.NDArray[np.float_]: Returns ------- - NDArray[float] + npt.NDArray[np.float64] Target values as a NumPy array. """ target_values = self.csv[values].to_numpy() @@ -895,8 +895,8 @@ def test_pls_1(self) -> None: jax_pls_alg_2=jax_pls_alg_2, diff_jax_pls_alg_1=diff_jax_pls_alg_1, diff_jax_pls_alg_2=diff_jax_pls_alg_2, - atol=1e-8, - rtol=6e-5, + atol=3e-8, + rtol=2e-4, ) self.check_predictions( @@ -963,8 +963,8 @@ def test_pls_1(self) -> None: jax_pls_alg_2=jax_pls_alg_2, diff_jax_pls_alg_1=diff_jax_pls_alg_1, diff_jax_pls_alg_2=diff_jax_pls_alg_2, - atol=1e-8, - rtol=6e-5, + atol=3e-8, + rtol=2e-4, ) self.check_predictions( @@ -3178,26 +3178,26 @@ def test_fast_cross_val_pls_1(self): splits = self.load_Y(["split"]) assert Y.shape[1] == 1 self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=False, scale=False, atol=0, rtol=1e-7 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=False, atol=0, rtol=1e-7 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=True, atol=0, rtol=1e-7 ) # Remove the singleton dimension and check that the predictions are consistent. Y = Y.squeeze() assert Y.ndim == 1 self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=False, scale=False, atol=0, rtol=1e-7 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=False, atol=0, rtol=1e-7 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=True, atol=0, rtol=1e-7 ) # JAX will issue a warning if os.fork() is called as JAX is incompatible with @@ -3243,13 +3243,13 @@ def test_fast_cross_val_pls_2_m_less_k(self): assert Y.shape[1] > 1 assert Y.shape[1] < X.shape[1] self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=0, rtol=1e-7 + X, Y, splits, center=False, scale=False, atol=0, rtol=1e-6 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=0, rtol=1e-7 + X, Y, splits, center=True, scale=False, atol=0, rtol=1e-6 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=True, atol=0, rtol=1e-7 + X, Y, splits, center=True, scale=True, atol=0, rtol=1e-6 ) # JAX will issue a warning if os.fork() is called as JAX is incompatible with @@ -3296,13 +3296,13 @@ def test_fast_cross_val_pls_2_m_eq_k(self): assert Y.shape[1] > 1 assert Y.shape[1] == X.shape[1] self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=False, scale=False, atol=0, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=False, atol=0, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8 + X, Y, splits, center=True, scale=True, atol=0, rtol=2e-8 ) # JAX will issue a warning if os.fork() is called as JAX is incompatible with @@ -3448,13 +3448,13 @@ def test_fast_cross_val_pls_2_m_less_k_loocv(self): assert Y.shape[1] > 1 assert Y.shape[1] < X.shape[1] self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=2e-6, rtol=1e-8 + X, Y, splits, center=False, scale=False, atol=1e-4, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=5e-6, rtol=1e-8 + X, Y, splits, center=True, scale=False, atol=1e-4, rtol=2e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=True, atol=3e-6, rtol=1e-8 + X, Y, splits, center=True, scale=True, atol=1e-4, rtol=2e-8 ) def test_fast_cross_val_pls_2_m_eq_k_loocv(self): @@ -3494,10 +3494,10 @@ def test_fast_cross_val_pls_2_m_eq_k_loocv(self): assert Y.shape[1] > 1 assert Y.shape[1] == X.shape[1] self.check_fast_cross_val_pls( - X, Y, splits, center=False, scale=False, atol=1e-7, rtol=1e-8 + X, Y, splits, center=False, scale=False, atol=2e-7, rtol=1e-8 ) self.check_fast_cross_val_pls( - X, Y, splits, center=True, scale=False, atol=1e-7, rtol=1e-8 + X, Y, splits, center=True, scale=False, atol=2e-7, rtol=1e-8 ) self.check_fast_cross_val_pls( X, Y, splits, center=True, scale=True, atol=1e-7, rtol=1e-8 @@ -4111,7 +4111,7 @@ def test_center_scale_combinations_pls_2_m_eq_k(self): splits = self.load_Y(["split"]) # Contains 3 splits of different sizes assert Y.shape[1] > 1 assert Y.shape[1] == X.shape[1] - self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=1e-8) + self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=3e-8) # JAX will issue a warning if os.fork() is called as JAX is incompatible with # multi-threaded code. os.fork() is called by the other cross-validation