diff --git a/sml/pca/BUILD.bazel b/sml/pca/BUILD.bazel new file mode 100644 index 00000000..994ea56c --- /dev/null +++ b/sml/pca/BUILD.bazel @@ -0,0 +1,49 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "simple_pca", + srcs = ["simple_pca.py"], + deps = [ + "//sml/utils:fxp_approx", + ], +) + +py_binary( + name = "simple_pca_emul", + srcs = ["simple_pca_emul.py"], + deps = [ + ":simple_pca", + "//examples/python/utils:dataset_utils", # FIXME: remove examples dependency + "//sml/utils:emulation", + ], +) + +py_test( + name = "simple_pca_test", + srcs = ["simple_pca_test.py"], + data = [ + "//examples/python/conf", # FIXME: remove examples dependency + ], + deps = [ + ":simple_pca", + "//examples/python/utils:dataset_utils", # FIXME: remove examples dependency + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/pca/simple_pca.py b/sml/pca/simple_pca.py new file mode 100644 index 00000000..dd226d91 --- /dev/null +++ b/sml/pca/simple_pca.py @@ -0,0 +1,151 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +from enum import Enum + +class Method(Enum): + PCA = 'power_iteration' + + +class SimplePCA: + def __init__( + self, + method: str, + n_components: int, + max_iter: int = 100, + ): + """A PCA estimator implemented with Power Iteration. + + Parameters + ---------- + method : str + The method to compute the principal components. + 'power_iteration' uses Power Iteration to compute the eigenvalues and eigenvectors. + + n_components : int + Number of components to keep. + + max_iter : int, default=100 + Maximum number of iterations for Power Iteration. + + References + ---------- + Power Iteration: https://en.wikipedia.org/wiki/Power_iteration + """ + # parameter check. + assert n_components > 0, f"n_components should >0" + assert method in [ + e.value for e in Method + ], f"method should in {[e.value for e in Method]}, but got {method}" + + self._n_components = n_components + self._max_iter = max_iter + self._mean = None + self._components = None + self._variances = None + self._method = Method(method) + + def fit(self, X): + """Fit the estimator to the data. + + In the 'power_iteration' method, we use the Power Iteration algorithm to compute the eigenvalues and eigenvectors. + The Power Iteration algorithm works by repeatedly multiplying a vector by the matrix to inflate the largest eigenvalue, + and then normalizing to keep numerical stability. + After finding the largest eigenvalue and eigenvector, we deflate the matrix by subtracting the outer product of the + eigenvector and itself, scaled by the eigenvalue. This leaves a matrix with the same eigenvectors, but the largest + eigenvalue is replaced by zero. + + Parameters + ---------- + X : {array-like}, shape (n_samples, n_features) + Training data. + + Returns + ------- + self : object + Returns an instance of self. + """ + assert len(X.shape) == 2, f"Expected X to be 2 dimensional array, got {X.shape}" + + self._mean = jnp.mean(X, axis=0) + X_centered = X - self._mean + + # The covariance matrix + cov_matrix = jnp.cov(X_centered, rowvar=False) + + # Initialization + components = [] + variances = [] + + for _ in range(self._n_components): + # Initialize a random vector + vec = jnp.ones((X_centered.shape[1],)) + + for _ in range(self._max_iter): # Max iterations + # Power iteration + vec = jnp.dot(cov_matrix, vec) + vec /= jnp.linalg.norm(vec) + + # Compute the corresponding eigenvalue + eigval = jnp.dot(vec.T, jnp.dot(cov_matrix, vec)) + + components.append(vec) + variances.append(eigval) + + # Remove the component from the covariance matrix + cov_matrix -= eigval * jnp.outer(vec, vec) + + self._components = jnp.column_stack(components) + self._variances = jnp.array(variances) + + return self + + def transform(self, X): + """Transform the data to the first `n_components` principal components. + + Parameters + ---------- + X : {array-like}, shape (n_samples, n_features) + Data to be transformed. + + Returns + ------- + X_transformed : array, shape (n_samples, n_components) + Transformed data. + """ + assert len(X.shape) == 2, f"Expected X to be 2 dimensional array, got {X.shape}" + + X = X - self._mean + return jnp.dot(X, self._components) + + def inverse_transform(self, X_transformed): + """Transform the data back to the original space. + + Parameters + ---------- + X_transformed : {array-like}, shape (n_samples, n_components) + Data in the transformed space. + + Returns + ------- + X_original : array, shape (n_samples, n_features) + Data in the original space. + """ + assert len(X_transformed.shape) == 2, f"Expected X_transformed to be 2 dimensional array, got {X_transformed.shape}" + + X_original = jnp.dot(X_transformed, self._components.T) + self._mean + + return X_original diff --git a/sml/pca/simple_pca_emul.py b/sml/pca/simple_pca_emul.py new file mode 100644 index 00000000..ff6e29a0 --- /dev/null +++ b/sml/pca/simple_pca_emul.py @@ -0,0 +1,100 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os + +import jax.numpy as jnp +import jax.random as random +import numpy as np +from sklearn.decomposition import PCA as SklearnPCA +# from sklearn.metrics import roc_auc_score, explained_variance_score + +# Add the library directory to the path +sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) + +import sml.utils.emulation as emulation +from sml.pca.simple_pca import SimplePCA + + +# TODO: design the enumation framework, just like py.unittest +# all emulation action should begin with `emul_` (for reflection) +def emul_SimplePCA(mode: emulation.Mode.MULTIPROCESS): + def proc(X): + model = SimplePCA( + method='power_iteration', + n_components=2, + ) + + model.fit(X) + X_transformed = model.transform(X) + X_variances = model._variances + + return X_transformed, X_variances + + def proc_reconstruct(X): + model = SimplePCA( + method='power_iteration', + n_components=2, + ) + + model.fit(X) + X_reconstructed = model.inverse_transform(model.transform(X)) + + return X_reconstructed + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 + ) + emulator.up() + # Create a simple dataset + X = random.normal(random.PRNGKey(0), (15, 100)) + result = emulator.run(proc)(X) + print("X_transformed_jax: ", result[0]) + print("X_transformed_jax: ", result[1]) + # The transformed data should have 2 dimensions + assert result[0].shape[1] == 2 + # The mean of the transformed data should be approximately 0 + assert jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-3) + + # Compare with sklearn + model = SklearnPCA(n_components=2) + model.fit(X) + X_transformed = model.transform(X) + X_variances = model.explained_variance_ + + print("X_transformed_sklearn: ", X_transformed) + print("X_variances_sklearn: ", X_variances) + + result = emulator.run(proc_reconstruct)(X) + + print("X_reconstructed_jax: ", result) + + # Compare with sklearn + model = SklearnPCA(n_components=2) + model.fit(X) + X_reconstructed = model.inverse_transform(model.transform(X)) + + print("X_reconstructed_sklearn: ", X_reconstructed) + + assert np.allclose(X_reconstructed, result, atol=1e-3) + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_SimplePCA(emulation.Mode.MULTIPROCESS) diff --git a/sml/pca/simple_pca_test.py b/sml/pca/simple_pca_test.py new file mode 100644 index 00000000..66d119e4 --- /dev/null +++ b/sml/pca/simple_pca_test.py @@ -0,0 +1,103 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import unittest + +import jax.numpy as jnp +import numpy as np +from jax import random +from sklearn.decomposition import PCA as SklearnPCA +import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as spsim + +# Add the sml directory to the path +sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) + +from sml.pca.simple_pca import SimplePCA + + +class UnitTests(unittest.TestCase): + def test_simple(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + # Test fit_transform + def proc_transform(X): + model = SimplePCA( + method='power_iteration', + n_components=2, + ) + + model.fit(X) + X_transformed = model.transform(X) + X_variances = model._variances + + return X_transformed, X_variances + + # Create a simple dataset + X = random.normal(random.PRNGKey(0), (15, 100)) + + # Run the simulation + result = spsim.sim_jax(sim, proc_transform)(X) + + # The transformed data should have 2 dimensions + self.assertEqual(result[0].shape[1], 2) + + # The mean of the transformed data should be approximately 0 + self.assertTrue(jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-3)) + + X_np = np.array(X) + + # Run fit_transform using sklearn + sklearn_pca = SklearnPCA(n_components=2) + X_transformed_sklearn = sklearn_pca.fit_transform(X_np) + + # Compare the transform results + print("X_transformed_sklearn: ", X_transformed_sklearn) + print("X_transformed_jax", result[0]) + + # Compare the variance results + print("X_transformed_sklearn.explained_variance_: ", sklearn_pca.explained_variance_) + print("X_transformed_jax.explained_variance_: ", result[1]) + + # Test inverse_transform + def proc_reconstruct(X): + model = SimplePCA( + method='power_iteration', + n_components=2, + ) + + model.fit(X) + X_reconstructed = model.inverse_transform(model.transform(X)) + + return X_reconstructed + + # Run the simulation + result = spsim.sim_jax(sim, proc_reconstruct)(X) + + # Run inverse_transform using sklearn + X_reconstructed_sklearn = sklearn_pca.inverse_transform(X_transformed_sklearn) + + # Compare the results + self.assertTrue(np.allclose(X_reconstructed_sklearn, result, atol=1e-3)) + + + + + + + +if __name__ == "__main__": + unittest.main()