diff --git a/sml/kmeans/BUILD.bazel b/sml/kmeans/BUILD.bazel new file mode 100644 index 00000000..b6c92646 --- /dev/null +++ b/sml/kmeans/BUILD.bazel @@ -0,0 +1,37 @@ +load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "kmeans", + srcs = ["kmeans.py"], + deps = [ + "//sml/utils:fxp_approx", + ], +) + +py_binary( + name = "kmeans_emul", + srcs = ["kmeans_emul.py"], + deps = [ + ":kmeans", + "//examples/python/utils:dataset_utils", # FIXME: remove examples dependency + "//sml/utils:emulation" + ], +) + + + +py_test( + name = "kmeans_test", + srcs = ["kmeans_test.py"], + data = [ + "//examples/python/conf", # FIXME: remove examples dependency + ], + deps = [ + ":kmeans", + "//examples/python/utils:dataset_utils", # FIXME: remove examples dependency + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/kmeans/kmeans.py b/sml/kmeans/kmeans.py new file mode 100644 index 00000000..0145e25e --- /dev/null +++ b/sml/kmeans/kmeans.py @@ -0,0 +1,86 @@ +import jax +import jax.numpy as jnp + +class KMEANS: + """ + Parameters + ---------- + n_clusters : int + The number of clusters to form as well as the number of + centroids to generate. + + n_samples : int + The number of samples. + + max_iter : int, default=300 + Maximum number of iterations of the k-means algorithm for a + single run. + + tol : float, default=1e-4 + Acceptable error to consider the two to be equal. + """ + def __init__(self, n_clusters, n_samples, max_iter=300, tol=1e-4): + self.n_clusters = n_clusters + self.max_iter = max_iter + self.tol = tol + self.init_params = jax.random.randint(jax.random.PRNGKey(1),shape=[self.n_clusters],minval=0,maxval=n_samples) + self._centers = jnp.zeros(()) + + def fit(self, x): + """Fit KMEANS. + + Firstly, randomly select the initial centers. Then calculate the distance between each sample and each center, + and assign each sample to the nearest center. Use an `aligned_array` to indicate the samples in a cluster, + where unrelated samples will be set to 0. Once all samples are assigned, the center of each cluster will + be updated to the average. The average could be got by `sum(data * aligned_array) / sum(aligned_array)`. + Different clusters could use broadcast for better performance. + + Parameters + ---------- + x : {array-like}, shape (n_samples, n_features) + Input data. + + Returns + ------- + self : object + Returns an instance of self. + """ + + centers = jnp.array([x[i] for i in self.init_params]) + for _ in range(self.max_iter): + C = x.reshape((1, x.shape[0], x.shape[1])) - centers.reshape((centers.shape[0], 1, centers.shape[1])) + C = jnp.argmin(jnp.sum(jnp.square(C), axis=2), axis=0) + + S = jnp.tile(C,(self.n_clusters,1)) + ks = jnp.arange(self.n_clusters) + aligned_array_raw = (S.T - ks).T + aligned_array = jnp.equal(aligned_array_raw, 0) + + centers_raw = x.reshape((1, x.shape[0], x.shape[1])) * aligned_array.reshape((aligned_array.shape[0],aligned_array.shape[1],1)) + equals_sum = jnp.sum(aligned_array,axis=1) + centers_sum = jnp.sum(centers_raw, axis=1) + centers = jnp.divide(centers_sum.T, equals_sum).T + + self._centers = centers + return self + + def predict(self, x): + """Result estimates. + + Calculate the distance between each sample and each center, + and assign each sample to the nearest center. + + Parameters + ---------- + x : {array-like}, shape (n_samples, n_features) + Input data for prediction. + + Returns + ------- + ndarray of shape (n_samples) + Returns the result of the sample for each class in the model. + """ + centers = self._centers + y = x.reshape((1, x.shape[0], x.shape[1])) - centers.reshape((centers.shape[0], 1, centers.shape[1])) + y = jnp.argmin(jnp.sum(jnp.square(y), axis=2), axis=0) + return y \ No newline at end of file diff --git a/sml/kmeans/kmeans_emul.py b/sml/kmeans/kmeans_emul.py new file mode 100644 index 00000000..5b3289f5 --- /dev/null +++ b/sml/kmeans/kmeans_emul.py @@ -0,0 +1,54 @@ +import jax.numpy as jnp + +# from sklearn.metrics import roc_auc_score, explained_variance_score +import sml.utils.emulation as emulation + +from sml.kmeans.kmeans import KMEANS +from sklearn.datasets import make_blobs + +# TODO: design the enumation framework, just like py.unittest +# all emulation action should begin with `emul_` (for reflection) +def emul_KMEANS(mode: emulation.Mode.MULTIPROCESS): + def proc(x1, x2): + x = jnp.concatenate((x1, x2), axis=1) + model = KMEANS( + n_clusters=2, + n_samples=x.shape[0], + max_iter=10 + ) + + return model.fit(x).predict(x) + + def load_data(): + n_samples = 1000 + n_features = 100 + X, _ = make_blobs(n_samples=n_samples,n_features=n_features,centers=2) + split_index = n_features//2 + return X[:, :split_index], X[:, split_index:] + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + "examples/python/conf/3pc.json", mode, bandwidth=300, latency=20 + ) + emulator.up() + + # load mock data + x1, x2 = load_data() + X = jnp.concatenate((x1, x2), axis=1) + + # mark these data to be protected in SPU + x1, x2 = emulator.seal(x1, x2) + result = emulator.run(proc)(x1, x2) + print("result\n",result) + + # Compare with sklearn + from sklearn.cluster import KMeans + model = KMeans(n_clusters=2) + print("sklearn:\n",model.fit(X).predict(X)) + finally: + emulator.down() + + +if __name__ == "__main__": + emul_KMEANS(emulation.Mode.MULTIPROCESS) diff --git a/sml/kmeans/kmeans_test.py b/sml/kmeans/kmeans_test.py new file mode 100644 index 00000000..b455615e --- /dev/null +++ b/sml/kmeans/kmeans_test.py @@ -0,0 +1,50 @@ +import unittest +import json +import jax.numpy as jnp + +# from sklearn.metrics import roc_auc_score, explained_variance_score +import spu.utils.simulation as spsim +import spu.spu_pb2 as spu_pb2 # type: ignore + +# TODO: unify this. +import examples.python.utils.dataset_utils as dsutil + +from sml.kmeans.kmeans import KMEANS +from sklearn.datasets import make_blobs + +class UnitTests(unittest.TestCase): + def test_kmeans(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def proc(x1, x2): + x = jnp.concatenate((x1, x2), axis=1) + model = KMEANS( + n_clusters=2, + n_samples=x.shape[0], + max_iter=10 + ) + + return model.fit(x).predict(x) + + def load_data(): + n_samples = 1000 + n_features = 100 + X, _ = make_blobs(n_samples=n_samples,n_features=n_features,centers=2) + split_index = n_features//2 + return X[:, :split_index], X[:, split_index:] + + x1, x2 = load_data() + X = jnp.concatenate((x1, x2), axis=1) + result = spsim.sim_jax(sim, proc)(x1, x2) + print("result\n",result) + + # Compare with sklearn + from sklearn.cluster import KMeans + model = KMeans(n_clusters=2) + print("sklearn:\n",model.fit(X).predict(X)) + + +if __name__ == "__main__": + unittest.main()