-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextra_fgm.py
32 lines (26 loc) · 1.07 KB
/
extra_fgm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# -*- coding: utf-8 -*-
# @Time : 2021/7/14 4:42 下午
# @Author : Bubble
# @FileName: extra_fgm.py
import torch
# Fast Gradient Method
class FGM:
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=0.1, emb_name='word_embeddings'):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='word_embeddings'):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}