Skip to content

Commit

Permalink
add sml/kmeans (#277)
Browse files Browse the repository at this point in the history
按照最新的版本增加了kmeans (pr #235)
  • Loading branch information
oahcnauygnid authored Aug 1, 2023
1 parent 7f7d7a3 commit aa326c8
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 0 deletions.
37 changes: 37 additions & 0 deletions sml/kmeans/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test")

package(default_visibility = ["//visibility:public"])

py_library(
name = "kmeans",
srcs = ["kmeans.py"],
deps = [
"//sml/utils:fxp_approx",
],
)

py_binary(
name = "kmeans_emul",
srcs = ["kmeans_emul.py"],
deps = [
":kmeans",
"//examples/python/utils:dataset_utils", # FIXME: remove examples dependency
"//sml/utils:emulation"
],
)



py_test(
name = "kmeans_test",
srcs = ["kmeans_test.py"],
data = [
"//examples/python/conf", # FIXME: remove examples dependency
],
deps = [
":kmeans",
"//examples/python/utils:dataset_utils", # FIXME: remove examples dependency
"//spu:init",
"//spu/utils:simulation",
],
)
86 changes: 86 additions & 0 deletions sml/kmeans/kmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import jax
import jax.numpy as jnp

class KMEANS:
"""
Parameters
----------
n_clusters : int
The number of clusters to form as well as the number of
centroids to generate.
n_samples : int
The number of samples.
max_iter : int, default=300
Maximum number of iterations of the k-means algorithm for a
single run.
tol : float, default=1e-4
Acceptable error to consider the two to be equal.
"""
def __init__(self, n_clusters, n_samples, max_iter=300, tol=1e-4):
self.n_clusters = n_clusters
self.max_iter = max_iter
self.tol = tol
self.init_params = jax.random.randint(jax.random.PRNGKey(1),shape=[self.n_clusters],minval=0,maxval=n_samples)
self._centers = jnp.zeros(())

def fit(self, x):
"""Fit KMEANS.
Firstly, randomly select the initial centers. Then calculate the distance between each sample and each center,
and assign each sample to the nearest center. Use an `aligned_array` to indicate the samples in a cluster,
where unrelated samples will be set to 0. Once all samples are assigned, the center of each cluster will
be updated to the average. The average could be got by `sum(data * aligned_array) / sum(aligned_array)`.
Different clusters could use broadcast for better performance.
Parameters
----------
x : {array-like}, shape (n_samples, n_features)
Input data.
Returns
-------
self : object
Returns an instance of self.
"""

centers = jnp.array([x[i] for i in self.init_params])
for _ in range(self.max_iter):
C = x.reshape((1, x.shape[0], x.shape[1])) - centers.reshape((centers.shape[0], 1, centers.shape[1]))
C = jnp.argmin(jnp.sum(jnp.square(C), axis=2), axis=0)

S = jnp.tile(C,(self.n_clusters,1))
ks = jnp.arange(self.n_clusters)
aligned_array_raw = (S.T - ks).T
aligned_array = jnp.equal(aligned_array_raw, 0)

centers_raw = x.reshape((1, x.shape[0], x.shape[1])) * aligned_array.reshape((aligned_array.shape[0],aligned_array.shape[1],1))
equals_sum = jnp.sum(aligned_array,axis=1)
centers_sum = jnp.sum(centers_raw, axis=1)
centers = jnp.divide(centers_sum.T, equals_sum).T

self._centers = centers
return self

def predict(self, x):
"""Result estimates.
Calculate the distance between each sample and each center,
and assign each sample to the nearest center.
Parameters
----------
x : {array-like}, shape (n_samples, n_features)
Input data for prediction.
Returns
-------
ndarray of shape (n_samples)
Returns the result of the sample for each class in the model.
"""
centers = self._centers
y = x.reshape((1, x.shape[0], x.shape[1])) - centers.reshape((centers.shape[0], 1, centers.shape[1]))
y = jnp.argmin(jnp.sum(jnp.square(y), axis=2), axis=0)
return y
54 changes: 54 additions & 0 deletions sml/kmeans/kmeans_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import jax.numpy as jnp

# from sklearn.metrics import roc_auc_score, explained_variance_score
import sml.utils.emulation as emulation

from sml.kmeans.kmeans import KMEANS
from sklearn.datasets import make_blobs

# TODO: design the enumation framework, just like py.unittest
# all emulation action should begin with `emul_` (for reflection)
def emul_KMEANS(mode: emulation.Mode.MULTIPROCESS):
def proc(x1, x2):
x = jnp.concatenate((x1, x2), axis=1)
model = KMEANS(
n_clusters=2,
n_samples=x.shape[0],
max_iter=10
)

return model.fit(x).predict(x)

def load_data():
n_samples = 1000
n_features = 100
X, _ = make_blobs(n_samples=n_samples,n_features=n_features,centers=2)
split_index = n_features//2
return X[:, :split_index], X[:, split_index:]

try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
"examples/python/conf/3pc.json", mode, bandwidth=300, latency=20
)
emulator.up()

# load mock data
x1, x2 = load_data()
X = jnp.concatenate((x1, x2), axis=1)

# mark these data to be protected in SPU
x1, x2 = emulator.seal(x1, x2)
result = emulator.run(proc)(x1, x2)
print("result\n",result)

# Compare with sklearn
from sklearn.cluster import KMeans
model = KMeans(n_clusters=2)
print("sklearn:\n",model.fit(X).predict(X))
finally:
emulator.down()


if __name__ == "__main__":
emul_KMEANS(emulation.Mode.MULTIPROCESS)
50 changes: 50 additions & 0 deletions sml/kmeans/kmeans_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest
import json
import jax.numpy as jnp

# from sklearn.metrics import roc_auc_score, explained_variance_score
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2 # type: ignore

# TODO: unify this.
import examples.python.utils.dataset_utils as dsutil

from sml.kmeans.kmeans import KMEANS
from sklearn.datasets import make_blobs

class UnitTests(unittest.TestCase):
def test_kmeans(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

def proc(x1, x2):
x = jnp.concatenate((x1, x2), axis=1)
model = KMEANS(
n_clusters=2,
n_samples=x.shape[0],
max_iter=10
)

return model.fit(x).predict(x)

def load_data():
n_samples = 1000
n_features = 100
X, _ = make_blobs(n_samples=n_samples,n_features=n_features,centers=2)
split_index = n_features//2
return X[:, :split_index], X[:, split_index:]

x1, x2 = load_data()
X = jnp.concatenate((x1, x2), axis=1)
result = spsim.sim_jax(sim, proc)(x1, x2)
print("result\n",result)

# Compare with sklearn
from sklearn.cluster import KMeans
model = KMeans(n_clusters=2)
print("sklearn:\n",model.fit(X).predict(X))


if __name__ == "__main__":
unittest.main()

0 comments on commit aa326c8

Please sign in to comment.