From 89db9e284bf0a48016a8bafe104e289063a05c54 Mon Sep 17 00:00:00 2001 From: Jerseyshin <50386604+Jerseyshin@users.noreply.github.com> Date: Mon, 8 Apr 2024 15:08:20 +0800 Subject: [PATCH 1/6] Update marvell.py: fix ordering --- fedlearner/privacy/splitnn/marvell.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fedlearner/privacy/splitnn/marvell.py b/fedlearner/privacy/splitnn/marvell.py index 2810f15b8..e40c53ac1 100644 --- a/fedlearner/privacy/splitnn/marvell.py +++ b/fedlearner/privacy/splitnn/marvell.py @@ -546,6 +546,7 @@ def solve_small_pos(u, v, d, g_norm_square, p, P, lam10=None, lam11=None, lam21= elif lam11: ordering = [1, 0, 2] else: + ordering = [1, 2, 0] while True: if i % 3 == ordering[0]: # fix lam21 D = np.float32(P - p * (d - np.float32(1.0)) * lam21) From 8d73573e04a66e11aea5e8eaf5c2e84f3e777403 Mon Sep 17 00:00:00 2001 From: Jerseyshin <50386604+Jerseyshin@users.noreply.github.com> Date: Mon, 24 Jun 2024 16:40:12 +0800 Subject: [PATCH 2/6] Create readme.md --- fedlearner/privacy/splitnn/readme.md | 83 ++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 fedlearner/privacy/splitnn/readme.md diff --git a/fedlearner/privacy/splitnn/readme.md b/fedlearner/privacy/splitnn/readme.md new file mode 100644 index 000000000..fe4ed91e5 --- /dev/null +++ b/fedlearner/privacy/splitnn/readme.md @@ -0,0 +1,83 @@ +# Fedlearner标签保护参数说明 + +## embedding保护 + +--using_embedding_protection : bool型,是否开启embedding保护(discorloss),True为开启 + +--discorloss_weight : float型,若开启embedding保护,设置embedding保护大小,值越大embedding保护效果越强,相应的对准确率影响越大,推荐设置设置范围在[0.001, 0.05] + +样例: + +```python +from fedlearner.privacy.splitnn.discorloss import DisCorLoss + +if args.using_embedding_protection: + + discorloss = DisCorLoss().tf_distance_cor(act1_f, y, False) + + #act1_f为另一方的前传激活值,y为标签,False表示不输出debug信息 + + discorloss = tf.math.reduce_mean(discorloss) + + loss += float(args.discorloss_weight) * discorloss + + #在原来的loss上添加discorloss +``` + +## gradient保护 + +--using_marvell_protection : bool型,是否开启gradient保护(Marvell),True为开启 + +--sumkl_threshold : float型,若开启gradient保护,设置gradient保护大小,值越小保护效果越强,相应的对准确率影响越大,推荐设置范围在[0.1, 4.0] + +样例: +```python +train_op = model.minimize(optimizer, loss, global_step=global_step, \ + + marvell_protection=args.using_marvell_protection, \ + + marvell_threshold=float(args.sumkl_threshold), labels=y) + +#model.minimize中使用参数marvell_protection和marvell_threshold并传入labels +``` +## fedpass保护 + +--using_fedpass: bool型,是否开启FedPass,True为开启 + +--fedpass_mean: fedpass的密钥的均值,默认值为50.0 + +--fedpass_scale: fedpass的密钥的方差,默认值为5.0 + +样例: +```python +dense_logits = fedpass(32, dense_activations, mean=float(args.fedpass_mean), scale=float(args.fedpass_scale)) +``` +## embedding攻击 + +--using_emb_attack : bool型,是否开启embedding攻击,True为开启 + +样例: +```python +from fedlearner.privacy.splitnn.emb_attack import emb_attack_auc + +if args.using_emb_attack: + + #传入另一方的前传激活值act1_f和标签y + + emb_auc = emb_attack_auc(act1_f, y) +``` + +## gradient攻击 + +--using_norm_attack : bool型,是否开启norm攻击,True为开启 + +样例: +```python +from fedlearner.privacy.splitnn.norm_attack import norm_attack_auc + +if args.using_norm_attack: + + #传入loss,另一方的前传激活值act1_f,model.minimize使用的参数gate_gradients以及标签y以及marvell参数 + + norm_auc = norm_attack_auc(loss=loss, var_list=[act1_f], gate_gradients=tf.train.Optimizer.GATE_OP, y=y, marvell_protection=args.marvell_protection, sumkl_threshold=args.sumkl_threshold) +``` From ac821d3e870c49fe5eeeeba7380797447bae7733 Mon Sep 17 00:00:00 2001 From: Jerseyshin <50386604+Jerseyshin@users.noreply.github.com> Date: Mon, 24 Jun 2024 16:40:52 +0800 Subject: [PATCH 3/6] Update norm_attack.py --- fedlearner/privacy/splitnn/norm_attack.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/fedlearner/privacy/splitnn/norm_attack.py b/fedlearner/privacy/splitnn/norm_attack.py index 2739704aa..9bc0c6495 100644 --- a/fedlearner/privacy/splitnn/norm_attack.py +++ b/fedlearner/privacy/splitnn/norm_attack.py @@ -1,17 +1,22 @@ import tensorflow.compat.v1 as tf +from fedlearner.privacy.splitnn.marvell import KL_gradient_perturb # Norm Attack见论文:https://arxiv.org/pdf/2102.08504.pdf -def get_norm_pred(loss, var_list, gate_gradients): +def get_norm_pred(loss, var_list, gate_gradients, marvell_protection, sumkl_threshold): # 获取gradient g = tf.gradients(loss, var_list, gate_gradients=gate_gradients)[0] + if marvell_protection: + g = KL_gradient_perturb(g, y, float(sumkl_threshold)) # 计算gradient二范数,label=0和label=1的gradient二范数会存在差异 norm_pred = tf.math.sigmoid(tf.norm(g, ord=2, axis=1)) return norm_pred -def norm_attack_auc(loss, var_list, gate_gradients, y): - norm_pred = get_norm_pred(loss, var_list, gate_gradients) +def norm_attack_auc(loss, var_list, gate_gradients, y, marvell_protection, sumkl_threshold): + norm_pred = get_norm_pred(loss, var_list, gate_gradients, marvell_protection, sumkl_threshold) norm_pred = tf.reshape(norm_pred, y.shape) + sum_pred = tf.reduce_sum(norm_pred) + norm_pred = norm_pred / sum_pred # 计算norm attack auc _, norm_auc = tf.metrics.auc(y, norm_pred) return norm_auc From 4787ef2fe27c0c8495fbd55ba2e5a5bdc28ed6ee Mon Sep 17 00:00:00 2001 From: Jerseyshin <50386604+Jerseyshin@users.noreply.github.com> Date: Mon, 24 Jun 2024 16:41:18 +0800 Subject: [PATCH 4/6] Create fedpass.py --- fedlearner/privacy/splitnn/fedpass.py | 37 +++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 fedlearner/privacy/splitnn/fedpass.py diff --git a/fedlearner/privacy/splitnn/fedpass.py b/fedlearner/privacy/splitnn/fedpass.py new file mode 100644 index 000000000..5171ee455 --- /dev/null +++ b/fedlearner/privacy/splitnn/fedpass.py @@ -0,0 +1,37 @@ +import tensorflow.compat.v1 as tf + + +def scale_transform(s_scalekey): + """ 对密钥应用变换并计算缩放因子 """ + _, s_c = tf.shape(s_scalekey)[0], tf.shape(s_scalekey)[1] + s_scale = tf.reduce_mean(s_scalekey, axis=0) + s_scale = tf.reshape(s_scale, [1, s_c]) + return s_scale + +def fedpass(hidden_feature, x, mean, scale): + # hidden_feature: 中间层维度 + # x: 输入数据 + # mean, scale: 随机密钥的均值和方差 + + # 定义层 + dense = tf.keras.layers.Dense(hidden_feature, use_bias=False, activation=None) + encode = tf.keras.layers.Dense(hidden_feature // 4, use_bias=False, activation=None) + decode = tf.keras.layers.Dense(hidden_feature, use_bias=False, activation=None) + + # 初始化随机变量 + newshape = tf.shape(x) + skey = tf.random.normal(newshape, mean=mean, stddev=scale, dtype=x.dtype) + bkey = tf.random.normal(newshape, mean=mean, stddev=scale, dtype=x.dtype) + # 应用层和计算缩放因子 + s_scalekey = dense(skey) + b_scalekey = dense(bkey) + + + s_scale = scale_transform(s_scalekey) + b_scale = scale_transform(b_scalekey) + + s_scale = tf.reshape(decode(tf.nn.leaky_relu(encode(s_scale))), [1, hidden_feature]) + b_scale = tf.reshape(decode(tf.nn.leaky_relu(encode(b_scale))), [1, hidden_feature]) + x = dense(x) + x = tf.tanh(s_scale) * x + tf.tanh(b_scale) + return x From 27f9f785bdaf6d5f9e811191cfbd1b027b7bcf11 Mon Sep 17 00:00:00 2001 From: Jerseyshin <50386604+Jerseyshin@users.noreply.github.com> Date: Mon, 24 Jun 2024 16:42:27 +0800 Subject: [PATCH 5/6] Update trainer_worker.py --- fedlearner/trainer/trainer_worker.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/fedlearner/trainer/trainer_worker.py b/fedlearner/trainer/trainer_worker.py index caa5bf292..1fc8c51c0 100644 --- a/fedlearner/trainer/trainer_worker.py +++ b/fedlearner/trainer/trainer_worker.py @@ -202,6 +202,18 @@ def create_argument_parser(): type=str, default='0.25', help='Marvell sumKL threshold.') + parser.add_argument('--using_fedpass', + type=str_as_bool, + default='False', + help='Whether use fedpass protection.') + parser.add_argument('--fedpass_mean', + type=str, + default='50.0', + help='FedPass secretkey mean.') + parser.add_argument('--fedpass_scale', + type=str, + default='5.0', + help='FedPass secretkey scale.') parser.add_argument('--using_emb_attack', type=str_as_bool, default='False', From 24a7e5414a753a99b999a150db0f7b56412a5c72 Mon Sep 17 00:00:00 2001 From: Jerseyshin <50386604+Jerseyshin@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:54:59 +0800 Subject: [PATCH 6/6] Update norm_attack.py remove sigmoid in norm attack --- fedlearner/privacy/splitnn/norm_attack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fedlearner/privacy/splitnn/norm_attack.py b/fedlearner/privacy/splitnn/norm_attack.py index 9bc0c6495..d1b208338 100644 --- a/fedlearner/privacy/splitnn/norm_attack.py +++ b/fedlearner/privacy/splitnn/norm_attack.py @@ -9,7 +9,7 @@ def get_norm_pred(loss, var_list, gate_gradients, marvell_protection, sumkl_thre if marvell_protection: g = KL_gradient_perturb(g, y, float(sumkl_threshold)) # 计算gradient二范数,label=0和label=1的gradient二范数会存在差异 - norm_pred = tf.math.sigmoid(tf.norm(g, ord=2, axis=1)) + norm_pred = tf.norm(g, ord=2, axis=1) return norm_pred def norm_attack_auc(loss, var_list, gate_gradients, y, marvell_protection, sumkl_threshold):