-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
4 changed files
with
227 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") | ||
|
||
package(default_visibility = ["//visibility:public"]) | ||
|
||
py_library( | ||
name = "kmeans", | ||
srcs = ["kmeans.py"], | ||
deps = [ | ||
"//sml/utils:fxp_approx", | ||
], | ||
) | ||
|
||
py_binary( | ||
name = "kmeans_emul", | ||
srcs = ["kmeans_emul.py"], | ||
deps = [ | ||
":kmeans", | ||
"//examples/python/utils:dataset_utils", # FIXME: remove examples dependency | ||
"//sml/utils:emulation" | ||
], | ||
) | ||
|
||
|
||
|
||
py_test( | ||
name = "kmeans_test", | ||
srcs = ["kmeans_test.py"], | ||
data = [ | ||
"//examples/python/conf", # FIXME: remove examples dependency | ||
], | ||
deps = [ | ||
":kmeans", | ||
"//examples/python/utils:dataset_utils", # FIXME: remove examples dependency | ||
"//spu:init", | ||
"//spu/utils:simulation", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
class KMEANS: | ||
""" | ||
Parameters | ||
---------- | ||
n_clusters : int | ||
The number of clusters to form as well as the number of | ||
centroids to generate. | ||
n_samples : int | ||
The number of samples. | ||
max_iter : int, default=300 | ||
Maximum number of iterations of the k-means algorithm for a | ||
single run. | ||
tol : float, default=1e-4 | ||
Acceptable error to consider the two to be equal. | ||
""" | ||
def __init__(self, n_clusters, n_samples, max_iter=300, tol=1e-4): | ||
self.n_clusters = n_clusters | ||
self.max_iter = max_iter | ||
self.tol = tol | ||
self.init_params = jax.random.randint(jax.random.PRNGKey(1),shape=[self.n_clusters],minval=0,maxval=n_samples) | ||
self._centers = jnp.zeros(()) | ||
|
||
def fit(self, x): | ||
"""Fit KMEANS. | ||
Firstly, randomly select the initial centers. Then calculate the distance between each sample and each center, | ||
and assign each sample to the nearest center. Use an `aligned_array` to indicate the samples in a cluster, | ||
where unrelated samples will be set to 0. Once all samples are assigned, the center of each cluster will | ||
be updated to the average. The average could be got by `sum(data * aligned_array) / sum(aligned_array)`. | ||
Different clusters could use broadcast for better performance. | ||
Parameters | ||
---------- | ||
x : {array-like}, shape (n_samples, n_features) | ||
Input data. | ||
Returns | ||
------- | ||
self : object | ||
Returns an instance of self. | ||
""" | ||
|
||
centers = jnp.array([x[i] for i in self.init_params]) | ||
for _ in range(self.max_iter): | ||
C = x.reshape((1, x.shape[0], x.shape[1])) - centers.reshape((centers.shape[0], 1, centers.shape[1])) | ||
C = jnp.argmin(jnp.sum(jnp.square(C), axis=2), axis=0) | ||
|
||
S = jnp.tile(C,(self.n_clusters,1)) | ||
ks = jnp.arange(self.n_clusters) | ||
aligned_array_raw = (S.T - ks).T | ||
aligned_array = jnp.equal(aligned_array_raw, 0) | ||
|
||
centers_raw = x.reshape((1, x.shape[0], x.shape[1])) * aligned_array.reshape((aligned_array.shape[0],aligned_array.shape[1],1)) | ||
equals_sum = jnp.sum(aligned_array,axis=1) | ||
centers_sum = jnp.sum(centers_raw, axis=1) | ||
centers = jnp.divide(centers_sum.T, equals_sum).T | ||
|
||
self._centers = centers | ||
return self | ||
|
||
def predict(self, x): | ||
"""Result estimates. | ||
Calculate the distance between each sample and each center, | ||
and assign each sample to the nearest center. | ||
Parameters | ||
---------- | ||
x : {array-like}, shape (n_samples, n_features) | ||
Input data for prediction. | ||
Returns | ||
------- | ||
ndarray of shape (n_samples) | ||
Returns the result of the sample for each class in the model. | ||
""" | ||
centers = self._centers | ||
y = x.reshape((1, x.shape[0], x.shape[1])) - centers.reshape((centers.shape[0], 1, centers.shape[1])) | ||
y = jnp.argmin(jnp.sum(jnp.square(y), axis=2), axis=0) | ||
return y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import jax.numpy as jnp | ||
|
||
# from sklearn.metrics import roc_auc_score, explained_variance_score | ||
import sml.utils.emulation as emulation | ||
|
||
from sml.kmeans.kmeans import KMEANS | ||
from sklearn.datasets import make_blobs | ||
|
||
# TODO: design the enumation framework, just like py.unittest | ||
# all emulation action should begin with `emul_` (for reflection) | ||
def emul_KMEANS(mode: emulation.Mode.MULTIPROCESS): | ||
def proc(x1, x2): | ||
x = jnp.concatenate((x1, x2), axis=1) | ||
model = KMEANS( | ||
n_clusters=2, | ||
n_samples=x.shape[0], | ||
max_iter=10 | ||
) | ||
|
||
return model.fit(x).predict(x) | ||
|
||
def load_data(): | ||
n_samples = 1000 | ||
n_features = 100 | ||
X, _ = make_blobs(n_samples=n_samples,n_features=n_features,centers=2) | ||
split_index = n_features//2 | ||
return X[:, :split_index], X[:, split_index:] | ||
|
||
try: | ||
# bandwidth and latency only work for docker mode | ||
emulator = emulation.Emulator( | ||
"examples/python/conf/3pc.json", mode, bandwidth=300, latency=20 | ||
) | ||
emulator.up() | ||
|
||
# load mock data | ||
x1, x2 = load_data() | ||
X = jnp.concatenate((x1, x2), axis=1) | ||
|
||
# mark these data to be protected in SPU | ||
x1, x2 = emulator.seal(x1, x2) | ||
result = emulator.run(proc)(x1, x2) | ||
print("result\n",result) | ||
|
||
# Compare with sklearn | ||
from sklearn.cluster import KMeans | ||
model = KMeans(n_clusters=2) | ||
print("sklearn:\n",model.fit(X).predict(X)) | ||
finally: | ||
emulator.down() | ||
|
||
|
||
if __name__ == "__main__": | ||
emul_KMEANS(emulation.Mode.MULTIPROCESS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import unittest | ||
import json | ||
import jax.numpy as jnp | ||
|
||
# from sklearn.metrics import roc_auc_score, explained_variance_score | ||
import spu.utils.simulation as spsim | ||
import spu.spu_pb2 as spu_pb2 # type: ignore | ||
|
||
# TODO: unify this. | ||
import examples.python.utils.dataset_utils as dsutil | ||
|
||
from sml.kmeans.kmeans import KMEANS | ||
from sklearn.datasets import make_blobs | ||
|
||
class UnitTests(unittest.TestCase): | ||
def test_kmeans(self): | ||
sim = spsim.Simulator.simple( | ||
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 | ||
) | ||
|
||
def proc(x1, x2): | ||
x = jnp.concatenate((x1, x2), axis=1) | ||
model = KMEANS( | ||
n_clusters=2, | ||
n_samples=x.shape[0], | ||
max_iter=10 | ||
) | ||
|
||
return model.fit(x).predict(x) | ||
|
||
def load_data(): | ||
n_samples = 1000 | ||
n_features = 100 | ||
X, _ = make_blobs(n_samples=n_samples,n_features=n_features,centers=2) | ||
split_index = n_features//2 | ||
return X[:, :split_index], X[:, split_index:] | ||
|
||
x1, x2 = load_data() | ||
X = jnp.concatenate((x1, x2), axis=1) | ||
result = spsim.sim_jax(sim, proc)(x1, x2) | ||
print("result\n",result) | ||
|
||
# Compare with sklearn | ||
from sklearn.cluster import KMeans | ||
model = KMeans(n_clusters=2) | ||
print("sklearn:\n",model.fit(X).predict(X)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |