-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
109 lines (79 loc) · 4.02 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import os
from glob import glob
import dataset
from torch.utils.data import DataLoader
from transformers import BartConfig, BartForConditionalGeneration
import pandas as pd
import trainstep
import gc
from tqdm import tqdm
import torch
from transformers import get_cosine_schedule_with_warmup
from torch.optim.lr_scheduler import LambdaLR
import torchmetrics
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
# fmt: off
parser = argparse.ArgumentParser(prog="train", description="BART for price prediction")
g = parser.add_argument_group("Common Parameter")
g.add_argument("--method", type=str, choices=["default", "pretrain","finetuning"], default="default", help="training method")
g.add_argument("--pretrained-ckpt-path", type=str, help="pretrained BART model path or name")
g.add_argument("--batch-size", type=int, default=4, help="training batch size")
g.add_argument("--valid-batch-size", type=int, default=256, help="validation batch size")
g.add_argument("--epochs", type=int, default=10, help="the numnber of training epochs")
g.add_argument("--max-learning-rate", type=float, default=2e-4, help="max learning rate")
g.add_argument("--min-learning-rate", type=float, default=1e-5, help="min Learning rate")
g.add_argument("--warmup-rate", type=float, default=0.05, help="warmup step rate")
g.add_argument("--max-seq-len", type=int, default=60, help="dialogue max sequence length")
g.add_argument("--pred-max-seq-len", type=int, default=64, help="summary max sequence length")
g.add_argument("--all-dropout", type=float, help="override all dropout")
g.add_argument("--logging-interval", type=int, default=100, help="logging interval")
g.add_argument("--evaluate-interval", type=int, default=500, help="validation interval")
g.add_argument("--masking-rate", type=float, default=0.3, help="pretrain parameter (only used with `pretrain` method)")
def accuracy_function(real, pred):
accuracies = torch.eq(real, pred)
#print(torch.argmax(pred,dim=1))
mask = torch.logical_not(torch.eq(real, 0))
accuracies = torch.logical_and(mask, accuracies)
accuracies = accuracies.clone().detach()
mask = mask.clone().detach()
return torch.sum(accuracies)/torch.sum(mask)
def main(args: argparse.Namespace):
#os.makedirs(args.output_dir)
#if args.method == "pretrain":
# train_dataset = dataset.PretrainDataset(dataframe = data ,max_seq_len=args.max_seq_len)
data = pd.read_pickle('test2.pkl')
train_dataset = dataset.FinetuningDataset(dataframe = data ,max_seq_len=args.max_seq_len)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
override_args = (
{
"dropout": args.all_dropout,
"attention_dropout": args.all_dropout,
"activation_dropout": args.all_dropout,
"classifier_dropout": args.all_dropout,
}
if args.all_dropout
else {}
)
model = BartForConditionalGeneration(BartConfig.from_pretrained('default.json', **override_args)).to(device)
model.load_state_dict(torch.load('finetuning_30_epoch.ckpt'))
tqdm_dataset = tqdm(train_dataloader)
total_acc = 0
for batch, batch_item in enumerate(tqdm_dataset):
res_ids = model.generate(input_ids = batch_item['input_ids'].to(device),min_length=61, max_length=61,
num_beams=5,eos_token_id=3)
accuracy = accuracy_function(batch_item['labels'][:,:-1].to(device),res_ids)
print(batch_item['input_ids'])
print(res_ids)
print(batch_item['labels'][:,:-1])
print(batch_item['Minimum'])
total_acc += accuracy
tqdm_dataset.set_postfix({
'Total ACC' : '{:06f}'.format(total_acc/(batch+1))
})
#model_dir = os.path.join(args.output_dir, "models")
if __name__ == "__main__":
exit(main(parser.parse_args()))