-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add sml/kmeans #235
add sml/kmeans #235
Conversation
CLA Assistant Lite bot All contributors have signed the CLA ✍️ ✅ |
I have read the CLA Document and I hereby sign the CLA |
Thanks for contributing KMeans algorithm. The correctness of this algo seems ok to me, here are some optimization tips and recommendations:
you can refer to my #todo note for my recommendations. import jax
import jax.numpy as jnp
class KMEANS:
def __init__(self, k, data_len, max_iter=300, tol=1e-4):
# todo: some doc about your hyper-param
self.k = k
self.max_iter = max_iter
self.tol = tol
self.init_params = jax.random.randint(jax.random.PRNGKey(1),shape=[self.k],minval=0,maxval=data_len)
self._centers = jnp.zeros(())
def fit(self, x):
"""Fit KMEANS.
# todo: some doc about you implementation
Parameters
----------
x : {array-like}, shape (n_samples, n_features)
Input data.
Returns
-------
self : object
Returns an instance of self.
"""
centers = jnp.array([x[i] for i in self.init_params]) # k,p
for i in range(self.max_iter):
# print(f"iter: {i+1:3d}/{self.max_iter}\r")
# C = []
# for j in range(self.k):
# C.append(x - centers[j])
# C = jnp.array(C)
# todo: use broadcast
C = x.reshape((1, x.shape[0], x.shape[1])) - centers.reshape((centers.shape[0], 1, centers.shape[1]))
C = jnp.argmin(jnp.sum(jnp.square(C), axis=2), axis=0)
S = jnp.tile(C,(self.k,1))
# todo: simple simplify
# ks = jnp.array([i for i in range(self.k)])
ks = jnp.arange(self.k)
# todo: simple simplify
# equals_raw = (S.T - ks.T).T
equals_raw = (S.T - ks).T
# todo: maybe more experiments about two implementation ?
equals = jnp.less(jnp.square(equals_raw), self.tol)
# equals = jnp.equal(equals_raw, 0)
# centers_raw = []
# for j in range(self.k):
# centers_raw.append(jnp.multiply(x.T,equals[j]).T)
# # centers_raw.append(jnp.multiply(x.T,equals[j].T).T)
# centers_raw = jnp.array(centers_raw)
# todo: use broadcast
centers_raw = x.reshape((1, x.shape[0], x.shape[1])) * equals.reshape((equals.shape[0],equals.shape[1],1))
equals_sum = jnp.sum(equals,axis=1)
centers_sum = jnp.sum(centers_raw, axis=1)
# todo: simple simplify
centers = jnp.divide(centers_sum.T, equals_sum).T
# centers = jnp.divide(centers_sum.T, equals_sum.T).T
self._centers = centers
return self
def predict(self, x):
"""result estimates.
Parameters
----------
x : {array-like}, shape (n_samples, n_features)
Input data for prediction.
Returns
-------
ndarray of shape (n_samples)
Returns the result of the sample for each class in the model.
"""
centers = self._centers
# todo: can use broadcast, leave this for you
y = []
for j in range(self.k):
y.append(x - centers[j])
y = jnp.array(y)
y = jnp.argmin(jnp.sum(jnp.square(y), axis=2), axis=0)
return y |
Besides, the param |
sml/kmeans/kmeans_test.py
Outdated
|
||
return model.fit(x).predict(x) | ||
|
||
DATASET_CONFIG_FILE = "examples/python/conf/ds_breast_cancer_basic.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some common used test dataset for KMeans? And also verify the result to meet our expectation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Different tasks need different datasets. For clustering task, sklearn.datasets.make_blobs
can generate some good dummy dataset which can be used for verification.
I make revisions based on the suggestions. The comparison with eps is also used to avoid the impact of possible small errors in some MPC schemes, and a strict equality comparisons may get things wrong in some case, I think. |
Thanks for your revisions, the algorithm part seems good to me. Besides, I highly recommend that you compare your output with the plaintext sklearn output just like #240 do in both emul and test file.(You can use Finally, you say "The comparison with eps is also used to avoid the impact of possible small errors in some MPC schemes", but I think it's not a problem at all. You can find the dtype of the matrix |
Thanks for your suggestions and I add
|
Thanks for doing elaborated experiments! All other things are good to me! But I am confused how you run your experiments?( In my thoughts, cheetah is much much slower than aby3, but the output seems similar. ) |
sml/kmeans/kmeans_emul.py
Outdated
emulator.up() | ||
n_samples = 1000 | ||
X, _ = make_blobs(n_samples=n_samples,n_features=100,centers=2) | ||
x1,x2 = X[:n_samples//2],X[n_samples//2:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要先把数据load到PYU,否则会把x1,x2当成明文。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议您可以参考emulator.prepare_dataset
封装一个spu版本的make_blobs
Thanks!
您好,我最近发现之前emul部分的代码有点问题,即在运行emul程序之前,是需要先将数据load到PYU上的,但之前emulation并没有提供相应的api;所以如果使用了自定义数据,实际上emul部分都是在明文下运行的,具体可以参考:(其实主要就是调用一下 我理解这个修改以后,cheetah和aby3就会有比较明显的性能区别了~ Thanks |
我修改了代码,这是新的结果,看起来两种方法仍然是接近的,并且aby3比cheetah快很多,使用
|
sml/kmeans/kmeans_emul.py
Outdated
) | ||
emulator.up() | ||
|
||
(x1, x2), _ = emulator.prepare_dataset("examples/python/conf/ds_mock_clustering_basic.json") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我们最新的commit中移除了这个api,选择暴露 seal
接口,方便后续开发者使用自定义的数据集~
麻烦您参照最新的api修改一下,谢谢🙏
sml/kmeans/kmeans_test.py
Outdated
|
||
return model.fit(x).predict(x) | ||
|
||
DATASET_CONFIG_FILE = "examples/python/conf/ds_mock_clustering_basic.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我本地运行的时候会报找不到这个文件的错,应该是您conf那个文件夹的BUILD没有增加这个新增的json文件哈~
38e60eb
to
7f7d7a3
Compare
add sml/kmeans