Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sml/kmeans #235

Closed
wants to merge 0 commits into from
Closed

add sml/kmeans #235

wants to merge 0 commits into from

Conversation

oahcnauygnid
Copy link
Contributor

add sml/kmeans

@github-actions
Copy link

github-actions bot commented Jul 10, 2023

CLA Assistant Lite bot All contributors have signed the CLA ✍️ ✅

@oahcnauygnid
Copy link
Contributor Author

I have read the CLA Document and I hereby sign the CLA

@deadlywing
Copy link
Contributor

I have read the CLA Document and I hereby sign the CLA

I have read the CLA Document and I hereby sign the CLA

Thanks for contributing KMeans algorithm. The correctness of this algo seems ok to me, here are some optimization tips and recommendations:

  1. you can use the feature of broadcast in jax.numpy to avoid for-loop in fit and predict method
  2. you can avoid some redundant .T op in fit method.
  3. I find you use something like float compare trick(< eps) rather equal op directly. Maybe it's a good idea for ABY3 protocol, because mul is cheap, but if you use CHEETAH, maybe the total cost of your implementation will be larger?(I think you can do some more experiments under more config, e.g. CHEETAH, FM128, etc.)
  4. maybe you can add some doc and comment about the details of your implementation?

you can refer to my #todo note for my recommendations.

import jax
import jax.numpy as jnp

class KMEANS:
    def __init__(self, k, data_len, max_iter=300, tol=1e-4):
        # todo: some doc about your hyper-param
        self.k = k
        self.max_iter = max_iter
        self.tol = tol
        self.init_params = jax.random.randint(jax.random.PRNGKey(1),shape=[self.k],minval=0,maxval=data_len)
        self._centers = jnp.zeros(())

    def fit(self, x):
        """Fit KMEANS.
        # todo: some doc about you implementation

        Parameters
        ----------
        x : {array-like}, shape (n_samples, n_features)
            Input data.

        Returns
        -------
        self : object
            Returns an instance of self.
        """

        centers = jnp.array([x[i] for i in self.init_params]) # k,p
        for i in range(self.max_iter):
            # print(f"iter: {i+1:3d}/{self.max_iter}\r")
            # C = []
            # for j in range(self.k):
            #     C.append(x - centers[j])
            # C = jnp.array(C)
            # todo: use broadcast
            C = x.reshape((1, x.shape[0], x.shape[1])) - centers.reshape((centers.shape[0], 1, centers.shape[1]))

            C = jnp.argmin(jnp.sum(jnp.square(C), axis=2), axis=0)

            S = jnp.tile(C,(self.k,1))

            # todo: simple simplify
            # ks = jnp.array([i for i in range(self.k)])
            ks = jnp.arange(self.k)
            
            # todo: simple simplify
            # equals_raw = (S.T - ks.T).T 
            equals_raw = (S.T - ks).T  

            # todo: maybe more experiments about two implementation ?
            equals = jnp.less(jnp.square(equals_raw), self.tol) 
            # equals = jnp.equal(equals_raw, 0)


            # centers_raw = []
            # for j in range(self.k):
            #     centers_raw.append(jnp.multiply(x.T,equals[j]).T)  
            #     # centers_raw.append(jnp.multiply(x.T,equals[j].T).T) 
            # centers_raw = jnp.array(centers_raw)
            # todo: use broadcast
            centers_raw = x.reshape((1, x.shape[0], x.shape[1])) * equals.reshape((equals.shape[0],equals.shape[1],1))

            equals_sum = jnp.sum(equals,axis=1) 
            centers_sum = jnp.sum(centers_raw, axis=1)
            # todo: simple simplify
            centers = jnp.divide(centers_sum.T, equals_sum).T 
            # centers = jnp.divide(centers_sum.T, equals_sum.T).T

        self._centers = centers
        return self

    def predict(self, x):
        """result estimates.

        Parameters
        ----------
        x : {array-like}, shape (n_samples, n_features)
            Input data for prediction.

        Returns
        -------
        ndarray of shape (n_samples)
            Returns the result of the sample for each class in the model.
        """
        centers = self._centers
        # todo: can use broadcast, leave this for you
        y = []
        for j in range(self.k):
            y.append(x - centers[j])
        y = jnp.array(y)
        y = jnp.argmin(jnp.sum(jnp.square(y), axis=2), axis=0)
        return y 

@deadlywing
Copy link
Contributor

I have read the CLA Document and I hereby sign the CLA

Besides, the param k can change to n_clusters to align the param in sklearn.

@anakinxc anakinxc requested a review from deadlywing July 10, 2023 19:46

return model.fit(x).predict(x)

DATASET_CONFIG_FILE = "examples/python/conf/ds_breast_cancer_basic.json"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some common used test dataset for KMeans? And also verify the result to meet our expectation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different tasks need different datasets. For clustering task, sklearn.datasets.make_blobs can generate some good dummy dataset which can be used for verification.

@oahcnauygnid
Copy link
Contributor Author

I make revisions based on the suggestions. The comparison with eps is also used to avoid the impact of possible small errors in some MPC schemes, and a strict equality comparisons may get things wrong in some case, I think.

@deadlywing
Copy link
Contributor

I make revisions based on the suggestions. The comparison with eps is also used to avoid the impact of possible small errors in some MPC schemes, and a strict equality comparisons may get things wrong in some case, I think.

Thanks for your revisions, the algorithm part seems good to me.
However here still exists some minor errosr: for emul and test file, your params do not change to the newest version, please must make sure your program can runs correctly before doing commits.

Besides, I highly recommend that you compare your output with the plaintext sklearn output just like #240 do in both emul and test file.(You can use sklearn.datasets.make_blobs to generate very distinct datasets).

Finally, you say "The comparison with eps is also used to avoid the impact of possible small errors in some MPC schemes", but I think it's not a problem at all. You can find the dtype of the matrix S and ks are both int, so SPU can do it lossless. The integer equal will call a2b which costs a lot, but for fxp less, msb is sufficient(of course, one more mul).So I think for aby3 protocol, comparison with eps will do better, but if you runs in cheetah, I am not sure which way can do better(for mul is much more expensive for cheetah, maybe you can do some more experiments about this).

@oahcnauygnid
Copy link
Contributor Author

oahcnauygnid commented Jul 23, 2023

Thanks for your suggestions and I add sklearn.datasets.make_blobs.
I do experiments on the using of equal and compare with eps. The results (in seconds) are shown in the table.
It seems the results of the two methods are similar for smaller sample sets, . When the data volume is large, it appears that the advantage of compare with eps in ABY3 is smaller than that of equal in CHEETAH, so maybe using equal is better.

  aby3   cheetah  
n_samples equal compare with eps equal compare with eps
100 0.12 0.12 0.12 0.12
1000 0.44 0.42 0.44 0.41
10000 2.52 2.45 2.54 2.47
100000 27.81 27.93 26.01 28.05

@deadlywing
Copy link
Contributor

Thanks for your suggestions and I add sklearn.datasets.make_blobs. I do experiments on the using of equal and compare with eps. The results (in seconds) are shown in the table. It seems the results of the two methods are similar for smaller sample sets, . When the data volume is large, it appears that the advantage of compare with eps in ABY3 is smaller than that of equal in CHEETAH, so maybe using equal is better.

  aby3   cheetah  
n_samples equal compare with eps equal compare with eps
100 0.12 0.12 0.12 0.12
1000 0.44 0.42 0.44 0.41
10000 2.52 2.45 2.54 2.47
100000 27.81 27.93 26.01 28.05

Thanks for doing elaborated experiments! All other things are good to me!

But I am confused how you run your experiments?( In my thoughts, cheetah is much much slower than aby3, but the output seems similar. )

emulator.up()
n_samples = 1000
X, _ = make_blobs(n_samples=n_samples,n_features=100,centers=2)
x1,x2 = X[:n_samples//2],X[n_samples//2:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要先把数据load到PYU,否则会把x1,x2当成明文。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议您可以参考emulator.prepare_dataset封装一个spu版本的make_blobs

Thanks!

@deadlywing
Copy link
Contributor

Thanks for your suggestions and I add sklearn.datasets.make_blobs. I do experiments on the using of equal and compare with eps. The results (in seconds) are shown in the table. It seems the results of the two methods are similar for smaller sample sets, . When the data volume is large, it appears that the advantage of compare with eps in ABY3 is smaller than that of equal in CHEETAH, so maybe using equal is better.

  aby3   cheetah  
n_samples equal compare with eps equal compare with eps
100 0.12 0.12 0.12 0.12
1000 0.44 0.42 0.44 0.41
10000 2.52 2.45 2.54 2.47
100000 27.81 27.93 26.01 28.05

您好,我最近发现之前emul部分的代码有点问题,即在运行emul程序之前,是需要先将数据load到PYU上的,但之前emulation并没有提供相应的api;所以如果使用了自定义数据,实际上emul部分都是在明文下运行的,具体可以参考:(其实主要就是调用一下emulator.seal哈)
https://github.com/secretflow/spu/blob/main/sml/linear_model/simple_sgd_emul.py

我理解这个修改以后,cheetah和aby3就会有比较明显的性能区别了~

Thanks

@oahcnauygnid
Copy link
Contributor Author

我修改了代码,这是新的结果,看起来两种方法仍然是接近的,并且aby3比cheetah快很多,使用equal是没问题的。

  aby3   cheetah  
n_samples equal compare with eps equal compare with eps
100 1.10 1.10 11.26 11.30
200 1.60 1.48 15.34 15.81
1000 2.02 2.01 56.49 56.65
2000 2.94 2.93 87.91 88.89

)
emulator.up()

(x1, x2), _ = emulator.prepare_dataset("examples/python/conf/ds_mock_clustering_basic.json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们最新的commit中移除了这个api,选择暴露 seal 接口,方便后续开发者使用自定义的数据集~

麻烦您参照最新的api修改一下,谢谢🙏


return model.fit(x).predict(x)

DATASET_CONFIG_FILE = "examples/python/conf/ds_mock_clustering_basic.json"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我本地运行的时候会报找不到这个文件的错,应该是您conf那个文件夹的BUILD没有增加这个新增的json文件哈~

@github-actions github-actions bot locked and limited conversation to collaborators Jul 31, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants