Skip to content

Commit

Permalink
[sml] add PCA in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
hacker-jerry committed Jul 14, 2023
1 parent 9eeed17 commit 00db30c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
44 changes: 41 additions & 3 deletions sml/pca/simple_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,67 @@
from enum import Enum

class Method(Enum):
PCA = 'full'
PCA = 'power_iteration'


class SimplePCA:
def __init__(
self,
method: str,
n_components: int,
max_iter: int = 100,
):
"""A PCA estimator implemented with Power Iteration.
Parameters
----------
method : str
The method to compute the principal components.
'power_iteration' uses Power Iteration to compute the eigenvalues and eigenvectors.
n_components : int
Number of components to keep.
max_iter : int, default=100
Maximum number of iterations for Power Iteration.
References
----------
Power Iteration: https://en.wikipedia.org/wiki/Power_iteration
"""
# parameter check.
assert n_components > 0, f"n_components should >0"
assert method in [
e.value for e in Method
], f"method should in {[e.value for e in Method]}, but got {method}"

self._n_components = n_components
self._max_iter = max_iter
self._mean = None
self._components = None
self._variances = None
self._method = Method(method)

def fit(self, X):
"""Fit the estimator to the data."""
"""Fit the estimator to the data.
In the 'power_iteration' method, we use the Power Iteration algorithm to compute the eigenvalues and eigenvectors.
The Power Iteration algorithm works by repeatedly multiplying a vector by the matrix to inflate the largest eigenvalue,
and then normalizing to keep numerical stability.
After finding the largest eigenvalue and eigenvector, we deflate the matrix by subtracting the outer product of the
eigenvector and itself, scaled by the eigenvalue. This leaves a matrix with the same eigenvectors, but the largest
eigenvalue is replaced by zero.
Parameters
----------
X : {array-like}, shape (n_samples, n_features)
Training data.
Returns
-------
self : object
Returns an instance of self.
"""
assert len(X.shape) == 2, f"Expected X to be 2 dimensional array, got {X.shape}"

self._mean = jnp.mean(X, axis=0)
Expand All @@ -56,7 +94,7 @@ def fit(self, X):
# Initialize a random vector
vec = jnp.ones((X_centered.shape[1],))

for _ in range(100): # Max iterations
for _ in range(self._max_iter): # Max iterations
# Power iteration
vec = jnp.dot(cov_matrix, vec)
vec /= jnp.linalg.norm(vec)
Expand Down
4 changes: 2 additions & 2 deletions sml/pca/simple_pca_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
def emul_SimplePCA(mode: emulation.Mode.MULTIPROCESS):
def proc(X):
model = SimplePCA(
method='full',
method='power_iteration',
n_components=2,
)

Expand All @@ -45,7 +45,7 @@ def proc(X):

def proc_reconstruct(X):
model = SimplePCA(
method='full',
method='power_iteration',
n_components=2,
)

Expand Down
4 changes: 2 additions & 2 deletions sml/pca/simple_pca_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_simple(self):
# Test fit_transform
def proc_transform(X):
model = SimplePCA(
method='full',
method='power_iteration',
n_components=2,
)

Expand Down Expand Up @@ -75,7 +75,7 @@ def proc_transform(X):
# Test inverse_transform
def proc_reconstruct(X):
model = SimplePCA(
method='full',
method='power_iteration',
n_components=2,
)

Expand Down

0 comments on commit 00db30c

Please sign in to comment.