-
Notifications
You must be signed in to change notification settings - Fork 0
/
clean_pretrain.py
107 lines (91 loc) · 4.06 KB
/
clean_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import os
import random
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
import numpy as np
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from tqdm import tqdm
# 设置随机种子
random_seed = 0
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
batch_size = 32
model_name_or_path = "./bert"
device = "cuda"
num_epochs = 3
lr = 2e-5
if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")):
padding_side = "left"
else:
padding_side = "right"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)
if getattr(tokenizer, "pad_token_id") is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
def collate_fn(examples):
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
def tokenize_function(examples):
outputs = tokenizer(examples["sentence"], truncation=True, max_length=None)
return outputs
train_dataset = load_dataset('json', data_files='./data/imdb/train.json')['train']
train_dataset = train_dataset.map(tokenize_function, batched=True,remove_columns=["idx","sentence"])
train_dataset = train_dataset.rename_column("label", "labels")
train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=batch_size)
val_dataset = load_dataset('json', data_files='./data/imdb/dev.json')['train']
val_dataset = val_dataset.map(tokenize_function, batched=True,remove_columns=["idx","sentence"])
val_dataset = val_dataset.rename_column("label", "labels")
eval_dataloader = DataLoader(val_dataset, shuffle=False, collate_fn=collate_fn, batch_size=batch_size)
test_dataset = load_dataset('json', data_files='./data/imdb/test.json')['train']
test_dataset = test_dataset.map(tokenize_function, batched=True,remove_columns=["idx","sentence"])
test_dataset = test_dataset.rename_column("label", "labels")
test_dataloader = DataLoader(test_dataset, shuffle=False, collate_fn=collate_fn, batch_size=batch_size)
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)
optimizer = AdamW(params=model.parameters(), lr=lr)
# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs), num_training_steps=(len(train_dataloader) * num_epochs))
model.to(device)
best_dev_acc = -1
for epoch in range(num_epochs):
model.train()
for step, batch in enumerate(tqdm(train_dataloader)):
batch.to(device)
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
model.eval()
total_number = 0
total_correct = 0
for step, batch in enumerate(tqdm(eval_dataloader)):
batch.to(device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = predictions, batch["labels"]
correct = (predictions == references).sum().item()
total_correct += correct
total_number += references.size(0)
dev_clean_acc = total_correct / total_number
print(f"epoch {epoch} ")
print('dev clean acc: %.4f'% dev_clean_acc)
if dev_clean_acc > best_dev_acc:
best_dev_acc = dev_clean_acc
torch.save(model.state_dict(), os.path.join('bert2', f"pytorch_model.bin"))
model.eval()
total_number = 0
total_correct = 0
for step, batch in enumerate(tqdm(test_dataloader)):
batch.to(device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = predictions, batch["labels"]
correct = (predictions == references).sum().item()
total_correct += correct
total_number += references.size(0)
print(total_correct / total_number)