From 8183f8138ccac89e24c553469bfad78032a9dd5a Mon Sep 17 00:00:00 2001 From: tarantula-leo <54618933+tarantula-leo@users.noreply.github.com> Date: Wed, 19 Jul 2023 20:11:35 +0800 Subject: [PATCH] [sml] add LR in jax (#246) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [Solve] 使用 SPU 实现LR基础功能: 1. 完成了LR的基础功能(L2 正则、SGD); 2. 根据sklearn预留各类接口供后续开发使用; 3. 通过了在 spsim 和 emulator 的 unittest。 --- sml/lr/BUILD.bazel | 49 +++++++ sml/lr/simple_lr.py | 296 +++++++++++++++++++++++++++++++++++++++ sml/lr/simple_lr_emul.py | 68 +++++++++ sml/lr/simple_lr_test.py | 66 +++++++++ 4 files changed, 479 insertions(+) create mode 100644 sml/lr/BUILD.bazel create mode 100644 sml/lr/simple_lr.py create mode 100644 sml/lr/simple_lr_emul.py create mode 100644 sml/lr/simple_lr_test.py diff --git a/sml/lr/BUILD.bazel b/sml/lr/BUILD.bazel new file mode 100644 index 00000000..0d290afb --- /dev/null +++ b/sml/lr/BUILD.bazel @@ -0,0 +1,49 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "simple_lr", + srcs = ["simple_lr.py"], + deps = [ + "//sml/utils:fxp_approx", + ], +) + +py_binary( + name = "simple_lr_emul", + srcs = ["simple_lr_emul.py"], + deps = [ + ":simple_lr", + "//examples/python/utils:dataset_utils", # FIXME: remove examples dependency + "//sml/utils:emulation", + ], +) + +py_test( + name = "simple_lr_test", + srcs = ["simple_lr_test.py"], + data = [ + "//examples/python/conf", # FIXME: remove examples dependency + ], + deps = [ + ":simple_lr", + "//examples/python/utils:dataset_utils", # FIXME: remove examples dependency + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/lr/simple_lr.py b/sml/lr/simple_lr.py new file mode 100644 index 00000000..b063f669 --- /dev/null +++ b/sml/lr/simple_lr.py @@ -0,0 +1,296 @@ +import numpy as np +import jax.numpy as jnp +from enum import Enum + +def t1_sig(x, limit: bool = True): + ''' + taylor series referenced from: + https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/ + ''' + T0 = 1.0 / 2 + T1 = 1.0 / 4 + ret = T0 + x * T1 + if limit: + return jnp.select([ret < 0, ret > 1], [0, 1], ret) + else: + return ret + + +def t3_sig(x, limit: bool = True): + ''' + taylor series referenced from: + https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/ + ''' + T3 = -1.0 / 48 + ret = t1_sig(x, False) + jnp.power(x, 3) * T3 + if limit: + return jnp.select([x < -2, x > 2], [0, 1], ret) + else: + return ret + + +def t5_sig(x, limit: bool = True): + ''' + taylor series referenced from: + https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/ + ''' + T5 = 1.0 / 480 + ret = t3_sig(x, False) + jnp.power(x, 5) * T5 + if limit: + return jnp.select([ret < 0, ret > 1], [0, 1], ret) + else: + return ret + + +def seg3_sig(x): + ''' + f(x) = 0.5 + 0.125x if -4 <= x <= 4 + 1 if x > 4 + 0 if -4 > x + ''' + return jnp.select([x < -4, x > 4], [0, 1], 0.5 + x * 0.125) + + +def df_sig(x): + ''' + https://dergipark.org.tr/en/download/article-file/54559 + Dataflow implementation of sigmoid function: + F(x) = 0.5 * ( x / ( 1 + |x| ) ) + 0.5 + df_sig has higher precision than sr_sig if x in [-2, 2] + ''' + return 0.5 * (x / (1 + jnp.abs(x))) + 0.5 + + +def sr_sig(x): + ''' + https://en.wikipedia.org/wiki/Sigmoid_function#Examples + Square Root approximation functions: + F(x) = 0.5 * ( x / ( 1 + x^2 )^0.5 ) + 0.5 + sr_sig almost perfect fit to sigmoid if x out of range [-3,3] + ''' + return 0.5 * (x / jnp.sqrt(1 + jnp.square(x))) + 0.5 + + +def ls7_sig(x): + '''Polynomial fitting''' + return ( + 5.00052959e-01 + + 2.35176260e-01 * x + - 3.97212202e-05 * jnp.power(x, 2) + - 1.23407424e-02 * jnp.power(x, 3) + + 4.04588962e-06 * jnp.power(x, 4) + + 3.94330487e-04 * jnp.power(x, 5) + - 9.74060972e-08 * jnp.power(x, 6) + - 4.74674505e-06 * jnp.power(x, 7) + ) + + +def mix_sig(x): + ''' + mix ls7 & sr sig, use ls7 if |x| < 4 , else use sr. + has higher precision in all input range. + NOTICE: this method is very expensive, only use for hessian matrix. + ''' + ls7 = ls7_sig(x) + sr = sr_sig(x) + return jnp.select([x < -4, x > 4], [sr, sr], ls7) + + +def real_sig(x): + return 1 / (1 + jnp.exp(-x)) + +def sigmoid(x, sig_type): + if sig_type is SigType.REAL: + return real_sig(x) + elif sig_type is SigType.T1: + return t1_sig(x) + elif sig_type is SigType.T3: + return t3_sig(x) + elif sig_type is SigType.T5: + return t5_sig(x) + elif sig_type is SigType.DF: + return df_sig(x) + elif sig_type is SigType.SR: + return sr_sig(x) + elif sig_type is SigType.MIX: + return mix_sig(x) + +class SigType(Enum): + REAL = 'real' + T1 = 't1' + T3 = 't3' + T5 = 't5' + DF = 'df' + SR = 'sr' + # DO NOT use this except in hessian case. + MIX = 'mix' + +class Penalty(Enum): + NONE = 'None' + L1 = 'l1' # not supported + L2 = 'l2' + Elastic = 'elasticnet' # not supported + +class MultiClass(Enum): + Ovr = 'ovr' # binary problem + Multy = 'multinomial' # multi_class problem not supported + + +class SGDClassifier: + def __init__( + self, + epochs: int, + learning_rate: float, + batch_size: int, + penalty: str, + sig_type: str, + l2_norm: float, + class_weight: None, + multi_class: str, + ): + # parameter check. + assert epochs > 0, f"epochs should >0" + assert learning_rate > 0, f"learning_rate should >0" + assert batch_size > 0, f"batch_size should >0" + assert penalty == 'l2', "only support L2 penalty for now" + if penalty == Penalty.L2: + assert l2_norm > 0, f"l2_norm should >0 if use L2 penalty" + assert penalty in [ + e.value for e in Penalty + ], f"penalty should in {[e.value for e in Penalty]}, but got {penalty}" + assert sig_type in [ + e.value for e in SigType + ], f"sig_type should in {[e.value for e in SigType]}, but got {sig_type}" + assert class_weight == None, f"not support class_weight for now" + assert multi_class == 'ovr', f"only support binary problem for now" + + self._epochs = epochs + self._learning_rate = learning_rate + self._batch_size = batch_size + self._l2_norm = l2_norm + self._penalty = Penalty(penalty) + self._sig_type = SigType(sig_type) + self._class_weight = class_weight + self._multi_class = MultiClass(multi_class) + + self._weights = jnp.zeros(()) + + def _update_weights( + self, + x, # array-like + y, # array-like + w, # array-like + total_batch: int, + batch_size: int, + ) -> np.ndarray: + assert x.shape[0] >= total_batch * batch_size, "total batch is too large" + num_feat = x.shape[1] + assert w.shape[0] == num_feat + 1, "w shape is mismatch to x" + assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array" + w = w.reshape((w.shape[0], 1)) + + for idx in range(total_batch): + begin = idx * batch_size + end = (idx + 1) * batch_size + # padding one col for bias in w + x_slice = jnp.concatenate( + (x[begin:end, :], jnp.ones((batch_size, 1))), axis=1 + ) + y_slice = y[begin:end, :] + + pred = jnp.matmul(x_slice, w) + pred = sigmoid(pred, self._sig_type) + + err = pred - y_slice + grad = jnp.matmul(jnp.transpose(x_slice), err) + + if self._penalty == Penalty.L2: + w_with_zero_bias = jnp.resize(w, (num_feat, 1)) + w_with_zero_bias = jnp.concatenate( + (w_with_zero_bias, jnp.zeros((1, 1))), + axis=0, + ) + grad = grad + w_with_zero_bias * self._l2_norm + elif self._penalty == Penalty.L1: + pass + elif self._penalty == Penalty.Elastic: + pass + + step = (self._learning_rate * grad) / batch_size + + w = w - step + + return w + + def fit(self, x, y): + """Fit linear model with Stochastic Gradient Descent. + + Parameters + ---------- + X : {array-like}, shape (n_samples, n_features) + Training data. + + y : ndarray of shape (n_samples,) + Target values. + + Returns + ------- + self : object + Returns an instance of self. + """ + assert len(x.shape) == 2, f"expect x to be 2 dimension array, got {x.shape}" + + num_sample = x.shape[0] + num_feat = x.shape[1] + batch_size = min(self._batch_size, num_sample) + total_batch = int(num_sample / batch_size) + weights = jnp.zeros((num_feat + 1, 1)) + + # not support class_weight for now + if isinstance(self._class_weight, dict): + pass + elif self._class_weight == 'balanced': + pass + + # do train + for _ in range(self._epochs): + weights = self._update_weights( + x, + y, + weights, + total_batch, + batch_size, + ) + + self._weights = weights + return self + + def predict_proba(self, x): + """Probability estimates. + + Parameters + ---------- + X : {array-like}, shape (n_samples, n_features) + Input data for prediction. + + Returns + ------- + ndarray of shape (n_samples, n_classes) + Returns the probability of the sample for each class in the model, + where classes are ordered as they are in `self.classes_`. + """ + if self._multi_class == MultiClass.Ovr: + num_feat = x.shape[1] + w = self._weights + assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}" + assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array" + w.reshape((w.shape[0], 1)) + + bias = w[-1, 0] + w = jnp.resize(w, (num_feat, 1)) + pred = jnp.matmul(x, w) + bias + pred = sigmoid(pred, self._sig_type) + return pred + elif self._multi_class == MultiClass.Multy: + # not support multi_class problem for now + pass diff --git a/sml/lr/simple_lr_emul.py b/sml/lr/simple_lr_emul.py new file mode 100644 index 00000000..8bde20b0 --- /dev/null +++ b/sml/lr/simple_lr_emul.py @@ -0,0 +1,68 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp +import pandas as pd +import sys +import os +from sklearn.metrics import roc_auc_score +from sklearn.datasets import load_breast_cancer +from sklearn.preprocessing import MinMaxScaler + +# Add the library directory to the path +sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) +from sml.lr.simple_lr import SGDClassifier +import sml.utils.emulation as emulation + +# TODO: design the enumation framework, just like py.unittest +# all emulation action should begin with `emul_` (for reflection) +def emul_SGDClassifier(mode: emulation.Mode.MULTIPROCESS): + def proc(x, y): + model = SGDClassifier( + epochs=3, + learning_rate=0.1, + batch_size=8, + penalty='l2', + sig_type='sr', + l2_norm=1.0, + class_weight=None, + multi_class='ovr' + ) + return model.fit(x, y).predict_proba(x) + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 + ) + emulator.up() + + # Create dataset + X, y = load_breast_cancer(return_X_y=True, as_frame=True) + scalar = MinMaxScaler(feature_range=(-2, 2)) + cols = X.columns + X = scalar.fit_transform(X) + X = pd.DataFrame(X, columns=cols) + + # Run + result = emulator.run(proc)(X.values, y.values.reshape(-1, 1)) # X, y should be two-dimension array + print(result) + print("Predict result: ", result) + print("ROC Score: ", roc_auc_score(y.values, result)) + + finally: + emulator.down() + +if __name__ == "__main__": + emul_SGDClassifier(emulation.Mode.MULTIPROCESS) \ No newline at end of file diff --git a/sml/lr/simple_lr_test.py b/sml/lr/simple_lr_test.py new file mode 100644 index 00000000..38cb959b --- /dev/null +++ b/sml/lr/simple_lr_test.py @@ -0,0 +1,66 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import json +import jax.numpy as jnp +import pandas as pd +import sys +import os +import spu.utils.simulation as spsim +import spu.spu_pb2 as spu_pb2 +from sklearn.metrics import roc_auc_score +from sklearn.datasets import load_breast_cancer +from sklearn.preprocessing import MinMaxScaler + +# Add the library directory to the path +sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) +from sml.lr.simple_lr import SGDClassifier + +class UnitTests(unittest.TestCase): + def test_simple(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + # Test SGDClassifier + def proc(x, y): + model = SGDClassifier( + epochs=3, + learning_rate=0.1, + batch_size=8, + penalty='l2', + sig_type='sr', + l2_norm=1.0, + class_weight=None, + multi_class='ovr' + ) + return model.fit(x, y).predict_proba(x) + + # Create dataset + X, y = load_breast_cancer(return_X_y=True, as_frame=True) + scalar = MinMaxScaler(feature_range=(-2, 2)) + cols = X.columns + X = scalar.fit_transform(X) + X = pd.DataFrame(X, columns=cols) + + # Run + result = spsim.sim_jax(sim, proc)(X.values, y.values.reshape(-1, 1)) # X, y should be two-dimension array + print("Predict result: ", result) + print("ROC Score: ", roc_auc_score(y.values, result)) + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file