From 7f51e561dd801fd8407a8bbd8ab1323d431fccbb Mon Sep 17 00:00:00 2001 From: hacker-jerry <57279550+hacker-jerry@users.noreply.github.com> Date: Thu, 13 Jul 2023 22:44:20 +0800 Subject: [PATCH 1/4] [sml] add PCA in jax --- sml/pca/BUILD.bazel | 49 +++++++++++++++++++ sml/pca/simple_pca.py | 97 ++++++++++++++++++++++++++++++++++++++ sml/pca/simple_pca_emul.py | 76 +++++++++++++++++++++++++++++ sml/pca/simple_pca_test.py | 80 +++++++++++++++++++++++++++++++ 4 files changed, 302 insertions(+) create mode 100644 sml/pca/BUILD.bazel create mode 100644 sml/pca/simple_pca.py create mode 100644 sml/pca/simple_pca_emul.py create mode 100644 sml/pca/simple_pca_test.py diff --git a/sml/pca/BUILD.bazel b/sml/pca/BUILD.bazel new file mode 100644 index 00000000..994ea56c --- /dev/null +++ b/sml/pca/BUILD.bazel @@ -0,0 +1,49 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "simple_pca", + srcs = ["simple_pca.py"], + deps = [ + "//sml/utils:fxp_approx", + ], +) + +py_binary( + name = "simple_pca_emul", + srcs = ["simple_pca_emul.py"], + deps = [ + ":simple_pca", + "//examples/python/utils:dataset_utils", # FIXME: remove examples dependency + "//sml/utils:emulation", + ], +) + +py_test( + name = "simple_pca_test", + srcs = ["simple_pca_test.py"], + data = [ + "//examples/python/conf", # FIXME: remove examples dependency + ], + deps = [ + ":simple_pca", + "//examples/python/utils:dataset_utils", # FIXME: remove examples dependency + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/pca/simple_pca.py b/sml/pca/simple_pca.py new file mode 100644 index 00000000..0fd8091d --- /dev/null +++ b/sml/pca/simple_pca.py @@ -0,0 +1,97 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +from enum import Enum + +class Method(Enum): + PCA = 'full' + + +class SimplePCA: + def __init__( + self, + method: str, + n_components: int, + ): + # parameter check. + assert n_components > 0, f"n_components should >0" + assert method in [ + e.value for e in Method + ], f"method should in {[e.value for e in Method]}, but got {method}" + + self._n_components = n_components + self._mean = None + self._components = None + self._variances = None + self._method = Method(method) + + def fit(self, X): + """Fit the estimator to the data. + + Parameters + ---------- + X : {array-like}, shape (n_samples, n_features) + Training data. + + Returns + ------- + self : object + Returns an instance of self. + """ + assert len(X.shape) == 2, f"Expected X to be 2 dimensional array, got {X.shape}" + + self._mean = jnp.mean(X, axis=0) + X = X - self._mean + + # The covariance matrix + cov_matrix = jnp.cov(X, rowvar=False) + + # Cholesky decomposition + L = jnp.linalg.cholesky(cov_matrix) + + # QR decomposition on L + q, r = jnp.linalg.qr(L) + + # We get eigenvalues from r + eigvals = jnp.square(jnp.diag(r)) # Take square of diagonal elements + + # Get indices of the largest eigenvalues + idx = jnp.argsort(eigvals)[::-1][:self._n_components] + + # Get the sorted eigenvectors using indices + self._components = q[:, idx] + + # Save the variances of the principal components + self._variances = eigvals[idx] + return self + + def transform(self, X): + """Transform the data to the first `n_components` principal components. + + Parameters + ---------- + X : {array-like}, shape (n_samples, n_features) + Data to be transformed. + + Returns + ------- + X_transformed : array, shape (n_samples, n_components) + Transformed data. + """ + assert len(X.shape) == 2, f"Expected X to be 2 dimensional array, got {X.shape}" + + X = X - self._mean + return jnp.dot(X, self._components) diff --git a/sml/pca/simple_pca_emul.py b/sml/pca/simple_pca_emul.py new file mode 100644 index 00000000..511f7fa9 --- /dev/null +++ b/sml/pca/simple_pca_emul.py @@ -0,0 +1,76 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os + +import jax.numpy as jnp +import jax.random as random +import numpy as np +from sklearn.decomposition import PCA as SklearnPCA +# from sklearn.metrics import roc_auc_score, explained_variance_score + +# Add the library directory to the path +sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) + +import sml.utils.emulation as emulation +from sml.pca.simple_pca import SimplePCA + + +# TODO: design the enumation framework, just like py.unittest +# all emulation action should begin with `emul_` (for reflection) +def emul_SimplePCA(mode: emulation.Mode.MULTIPROCESS): + def proc(X): + model = SimplePCA( + method='full', + n_components=2, + ) + + model.fit(X) + X_transformed = model.transform(X) + X_variances = model._variances + + return X_transformed, X_variances + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 + ) + emulator.up() + # Create a simple dataset + X = random.normal(random.PRNGKey(0), (15, 5)) + result = emulator.run(proc)(X) + print("X_transformed_jax: ", result[0]) + print("X_transformed_jax: ", result[1]) + # The transformed data should have 2 dimensions + assert result[0].shape[1] == 2 + # The mean of the transformed data should be approximately 0 + assert jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-4) + + # Compare with sklearn + model = SklearnPCA(n_components=2) + model.fit(X) + X_transformed = model.transform(X) + X_variances = model.explained_variance_ + + print("X_transformed_sklearn: ", X_transformed) + print("X_variances_sklearn: ", X_variances) + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_SimplePCA(emulation.Mode.MULTIPROCESS) diff --git a/sml/pca/simple_pca_test.py b/sml/pca/simple_pca_test.py new file mode 100644 index 00000000..df5486bb --- /dev/null +++ b/sml/pca/simple_pca_test.py @@ -0,0 +1,80 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import unittest + +import jax.numpy as jnp +import numpy as np +from jax import random +from sklearn.decomposition import PCA as SklearnPCA +import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as spsim + +# Add the sml directory to the path +sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) + +from sml.pca.simple_pca import SimplePCA + + +class UnitTests(unittest.TestCase): + def test_simple(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def proc(X): + model = SimplePCA( + method='full', + n_components=2, + ) + + model.fit(X) + X_transformed = model.transform(X) + X_variances = model._variances + + return X_transformed, X_variances + + # Create a simple dataset + X = random.normal(random.PRNGKey(0), (15, 5)) + + # Run the simulation + result = spsim.sim_jax(sim, proc)(X) + + # The transformed data should have 2 dimensions + self.assertEqual(result[0].shape[1], 2) + + # The mean of the transformed data should be approximately 0 + self.assertTrue(jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-4)) + + X_np = np.array(X) + + # Run PCA using sklearn + sklearn_pca = SklearnPCA(n_components=2) + X_transformed_sklearn = sklearn_pca.fit_transform(X_np) + + # Compare the results + # Note: the signs of the components can be different between different PCA implementations, + # so we need to compare the absolute values + print("X_transformed_sklearn: ", X_transformed_sklearn) + print("X_transformed_jax", result[0]) + + # compare variance + print("X_transformed_sklearn.explained_variance_: ", sklearn_pca.explained_variance_) + print("X_transformed_jax.explained_variance_: ", result[1]) + + + +if __name__ == "__main__": + unittest.main() From f973eb81277c343a0d039b0abd155039703e1f08 Mon Sep 17 00:00:00 2001 From: hacker-jerry <57279550+hacker-jerry@users.noreply.github.com> Date: Fri, 14 Jul 2023 13:50:41 +0800 Subject: [PATCH 2/4] [sml] add PCA in jax --- sml/pca/simple_pca.py | 68 +++++++++++++++++++++++--------------- sml/pca/simple_pca_emul.py | 24 ++++++++++++++ sml/pca/simple_pca_test.py | 39 +++++++++++++++++----- 3 files changed, 97 insertions(+), 34 deletions(-) diff --git a/sml/pca/simple_pca.py b/sml/pca/simple_pca.py index 0fd8091d..e1a41ad3 100644 --- a/sml/pca/simple_pca.py +++ b/sml/pca/simple_pca.py @@ -39,43 +39,40 @@ def __init__( self._method = Method(method) def fit(self, X): - """Fit the estimator to the data. - - Parameters - ---------- - X : {array-like}, shape (n_samples, n_features) - Training data. - - Returns - ------- - self : object - Returns an instance of self. - """ + """Fit the estimator to the data.""" assert len(X.shape) == 2, f"Expected X to be 2 dimensional array, got {X.shape}" self._mean = jnp.mean(X, axis=0) - X = X - self._mean + X_centered = X - self._mean # The covariance matrix - cov_matrix = jnp.cov(X, rowvar=False) + cov_matrix = jnp.cov(X_centered, rowvar=False) + + # Initialization + components = [] + variances = [] + + for _ in range(self._n_components): + # Initialize a random vector + vec = jnp.ones((X_centered.shape[1],)) - # Cholesky decomposition - L = jnp.linalg.cholesky(cov_matrix) + for _ in range(100): # Max iterations + # Power iteration + vec = jnp.dot(cov_matrix, vec) + vec /= jnp.linalg.norm(vec) - # QR decomposition on L - q, r = jnp.linalg.qr(L) + # Compute the corresponding eigenvalue + eigval = jnp.dot(vec.T, jnp.dot(cov_matrix, vec)) - # We get eigenvalues from r - eigvals = jnp.square(jnp.diag(r)) # Take square of diagonal elements + components.append(vec) + variances.append(eigval) - # Get indices of the largest eigenvalues - idx = jnp.argsort(eigvals)[::-1][:self._n_components] + # Remove the component from the covariance matrix + cov_matrix -= eigval * jnp.outer(vec, vec) - # Get the sorted eigenvectors using indices - self._components = q[:, idx] + self._components = jnp.column_stack(components) + self._variances = jnp.array(variances) - # Save the variances of the principal components - self._variances = eigvals[idx] return self def transform(self, X): @@ -95,3 +92,22 @@ def transform(self, X): X = X - self._mean return jnp.dot(X, self._components) + + def inverse_transform(self, X_transformed): + """Transform the data back to the original space. + + Parameters + ---------- + X_transformed : {array-like}, shape (n_samples, n_components) + Data in the transformed space. + + Returns + ------- + X_original : array, shape (n_samples, n_features) + Data in the original space. + """ + assert len(X_transformed.shape) == 2, f"Expected X_transformed to be 2 dimensional array, got {X_transformed.shape}" + + X_original = jnp.dot(X_transformed, self._components.T) + self._mean + + return X_original diff --git a/sml/pca/simple_pca_emul.py b/sml/pca/simple_pca_emul.py index 511f7fa9..81dc9d9e 100644 --- a/sml/pca/simple_pca_emul.py +++ b/sml/pca/simple_pca_emul.py @@ -42,6 +42,17 @@ def proc(X): X_variances = model._variances return X_transformed, X_variances + + def proc_reconstruct(X): + model = SimplePCA( + method='full', + n_components=2, + ) + + model.fit(X) + X_reconstructed = model.inverse_transform(model.transform(X)) + + return X_reconstructed try: # bandwidth and latency only work for docker mode @@ -68,6 +79,19 @@ def proc(X): print("X_transformed_sklearn: ", X_transformed) print("X_variances_sklearn: ", X_variances) + result = emulator.run(proc_reconstruct)(X) + + print("X_reconstructed_jax: ", result) + + # Compare with sklearn + model = SklearnPCA(n_components=2) + model.fit(X) + X_reconstructed = model.inverse_transform(model.transform(X)) + + print("X_reconstructed_sklearn: ", X_reconstructed) + + assert np.allclose(X_reconstructed, result, atol=1e-4) + finally: emulator.down() diff --git a/sml/pca/simple_pca_test.py b/sml/pca/simple_pca_test.py index df5486bb..1c879755 100644 --- a/sml/pca/simple_pca_test.py +++ b/sml/pca/simple_pca_test.py @@ -33,8 +33,8 @@ def test_simple(self): sim = spsim.Simulator.simple( 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 ) - - def proc(X): + # Test fit_transform + def proc_transform(X): model = SimplePCA( method='full', n_components=2, @@ -50,7 +50,7 @@ def proc(X): X = random.normal(random.PRNGKey(0), (15, 5)) # Run the simulation - result = spsim.sim_jax(sim, proc)(X) + result = spsim.sim_jax(sim, proc_transform)(X) # The transformed data should have 2 dimensions self.assertEqual(result[0].shape[1], 2) @@ -60,20 +60,43 @@ def proc(X): X_np = np.array(X) - # Run PCA using sklearn + # Run fit_transform using sklearn sklearn_pca = SklearnPCA(n_components=2) X_transformed_sklearn = sklearn_pca.fit_transform(X_np) - # Compare the results - # Note: the signs of the components can be different between different PCA implementations, - # so we need to compare the absolute values + # Compare the transform results print("X_transformed_sklearn: ", X_transformed_sklearn) print("X_transformed_jax", result[0]) - # compare variance + # Compare the variance results print("X_transformed_sklearn.explained_variance_: ", sklearn_pca.explained_variance_) print("X_transformed_jax.explained_variance_: ", result[1]) + # Test inverse_transform + def proc_reconstruct(X): + model = SimplePCA( + method='full', + n_components=2, + ) + + model.fit(X) + X_reconstructed = model.inverse_transform(model.transform(X)) + + return X_reconstructed + + # Run the simulation + result = spsim.sim_jax(sim, proc_reconstruct)(X) + + # Run inverse_transform using sklearn + X_reconstructed_sklearn = sklearn_pca.inverse_transform(X_transformed_sklearn) + + # Compare the results + self.assertTrue(np.allclose(X_reconstructed_sklearn, result, atol=1e-4)) + + + + + if __name__ == "__main__": From 9eeed172b0f471d1aa857c3304ca848ef44683fa Mon Sep 17 00:00:00 2001 From: hacker-jerry <57279550+hacker-jerry@users.noreply.github.com> Date: Fri, 14 Jul 2023 14:51:31 +0800 Subject: [PATCH 3/4] [sml] add PCA in jax --- sml/pca/simple_pca_emul.py | 6 +++--- sml/pca/simple_pca_test.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sml/pca/simple_pca_emul.py b/sml/pca/simple_pca_emul.py index 81dc9d9e..0deee872 100644 --- a/sml/pca/simple_pca_emul.py +++ b/sml/pca/simple_pca_emul.py @@ -61,14 +61,14 @@ def proc_reconstruct(X): ) emulator.up() # Create a simple dataset - X = random.normal(random.PRNGKey(0), (15, 5)) + X = random.normal(random.PRNGKey(0), (15, 100)) result = emulator.run(proc)(X) print("X_transformed_jax: ", result[0]) print("X_transformed_jax: ", result[1]) # The transformed data should have 2 dimensions assert result[0].shape[1] == 2 # The mean of the transformed data should be approximately 0 - assert jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-4) + assert jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-3) # Compare with sklearn model = SklearnPCA(n_components=2) @@ -90,7 +90,7 @@ def proc_reconstruct(X): print("X_reconstructed_sklearn: ", X_reconstructed) - assert np.allclose(X_reconstructed, result, atol=1e-4) + assert np.allclose(X_reconstructed, result, atol=1e-3) finally: emulator.down() diff --git a/sml/pca/simple_pca_test.py b/sml/pca/simple_pca_test.py index 1c879755..8191b900 100644 --- a/sml/pca/simple_pca_test.py +++ b/sml/pca/simple_pca_test.py @@ -47,7 +47,7 @@ def proc_transform(X): return X_transformed, X_variances # Create a simple dataset - X = random.normal(random.PRNGKey(0), (15, 5)) + X = random.normal(random.PRNGKey(0), (15, 100)) # Run the simulation result = spsim.sim_jax(sim, proc_transform)(X) @@ -56,7 +56,7 @@ def proc_transform(X): self.assertEqual(result[0].shape[1], 2) # The mean of the transformed data should be approximately 0 - self.assertTrue(jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-4)) + self.assertTrue(jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-3)) X_np = np.array(X) @@ -91,7 +91,7 @@ def proc_reconstruct(X): X_reconstructed_sklearn = sklearn_pca.inverse_transform(X_transformed_sklearn) # Compare the results - self.assertTrue(np.allclose(X_reconstructed_sklearn, result, atol=1e-4)) + self.assertTrue(np.allclose(X_reconstructed_sklearn, result, atol=1e-3)) From 00db30ca91d3bc20f385781a923e1700ba238d44 Mon Sep 17 00:00:00 2001 From: hacker-jerry <57279550+hacker-jerry@users.noreply.github.com> Date: Fri, 14 Jul 2023 16:10:51 +0800 Subject: [PATCH 4/4] [sml] add PCA in jax --- sml/pca/simple_pca.py | 44 +++++++++++++++++++++++++++++++++++--- sml/pca/simple_pca_emul.py | 4 ++-- sml/pca/simple_pca_test.py | 4 ++-- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/sml/pca/simple_pca.py b/sml/pca/simple_pca.py index e1a41ad3..dd226d91 100644 --- a/sml/pca/simple_pca.py +++ b/sml/pca/simple_pca.py @@ -17,7 +17,7 @@ from enum import Enum class Method(Enum): - PCA = 'full' + PCA = 'power_iteration' class SimplePCA: @@ -25,7 +25,26 @@ def __init__( self, method: str, n_components: int, + max_iter: int = 100, ): + """A PCA estimator implemented with Power Iteration. + + Parameters + ---------- + method : str + The method to compute the principal components. + 'power_iteration' uses Power Iteration to compute the eigenvalues and eigenvectors. + + n_components : int + Number of components to keep. + + max_iter : int, default=100 + Maximum number of iterations for Power Iteration. + + References + ---------- + Power Iteration: https://en.wikipedia.org/wiki/Power_iteration + """ # parameter check. assert n_components > 0, f"n_components should >0" assert method in [ @@ -33,13 +52,32 @@ def __init__( ], f"method should in {[e.value for e in Method]}, but got {method}" self._n_components = n_components + self._max_iter = max_iter self._mean = None self._components = None self._variances = None self._method = Method(method) def fit(self, X): - """Fit the estimator to the data.""" + """Fit the estimator to the data. + + In the 'power_iteration' method, we use the Power Iteration algorithm to compute the eigenvalues and eigenvectors. + The Power Iteration algorithm works by repeatedly multiplying a vector by the matrix to inflate the largest eigenvalue, + and then normalizing to keep numerical stability. + After finding the largest eigenvalue and eigenvector, we deflate the matrix by subtracting the outer product of the + eigenvector and itself, scaled by the eigenvalue. This leaves a matrix with the same eigenvectors, but the largest + eigenvalue is replaced by zero. + + Parameters + ---------- + X : {array-like}, shape (n_samples, n_features) + Training data. + + Returns + ------- + self : object + Returns an instance of self. + """ assert len(X.shape) == 2, f"Expected X to be 2 dimensional array, got {X.shape}" self._mean = jnp.mean(X, axis=0) @@ -56,7 +94,7 @@ def fit(self, X): # Initialize a random vector vec = jnp.ones((X_centered.shape[1],)) - for _ in range(100): # Max iterations + for _ in range(self._max_iter): # Max iterations # Power iteration vec = jnp.dot(cov_matrix, vec) vec /= jnp.linalg.norm(vec) diff --git a/sml/pca/simple_pca_emul.py b/sml/pca/simple_pca_emul.py index 0deee872..ff6e29a0 100644 --- a/sml/pca/simple_pca_emul.py +++ b/sml/pca/simple_pca_emul.py @@ -33,7 +33,7 @@ def emul_SimplePCA(mode: emulation.Mode.MULTIPROCESS): def proc(X): model = SimplePCA( - method='full', + method='power_iteration', n_components=2, ) @@ -45,7 +45,7 @@ def proc(X): def proc_reconstruct(X): model = SimplePCA( - method='full', + method='power_iteration', n_components=2, ) diff --git a/sml/pca/simple_pca_test.py b/sml/pca/simple_pca_test.py index 8191b900..66d119e4 100644 --- a/sml/pca/simple_pca_test.py +++ b/sml/pca/simple_pca_test.py @@ -36,7 +36,7 @@ def test_simple(self): # Test fit_transform def proc_transform(X): model = SimplePCA( - method='full', + method='power_iteration', n_components=2, ) @@ -75,7 +75,7 @@ def proc_transform(X): # Test inverse_transform def proc_reconstruct(X): model = SimplePCA( - method='full', + method='power_iteration', n_components=2, )