-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
55 lines (43 loc) · 1.72 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from utils import get_dataset, get_model, get_tokenizer
from datasets import load_metric
from tqdm import tqdm
def test_model(model, tokenizer, max_len=512):
test_data = get_dataset(tokenizer, max_len=max_len, Test=True)
predictions = []
references = []
model_inputs = []
for data_point in tqdm(test_data):
input_ids = data_point["input_ids"]
length = len(input_ids)
half_index = int(length / 2)
model_input = tokenizer.decode(input_ids[:half_index])
model_inputs.append(model_input)
target_text = tokenizer.decode(input_ids)
references.append(target_text)
BATCH_SIZE = 50
for i in tqdm(range(0, len(model_inputs), BATCH_SIZE)):
model_input = tokenizer(
model_inputs[i : i + BATCH_SIZE], padding=True, return_tensors="pt"
)
model_input = {k: model_input[k].to("cuda") for k in model_input.keys()}
generated_text = model.generate(
**model_input, max_length=max_len, pad_token_id=tokenizer.eos_token_id
)
for t in range(generated_text.shape[0]):
predictions.append(tokenizer.decode(generated_text[t]))
return compute_rouge(predictions, references)
def compute_rouge(predictions, references):
rouge = load_metric("rouge")
scores = rouge.compute(
predictions=[pred[len(predictions) // 2 :] for pred in predictions],
references=[ref[len(references) // 2 :] for ref in references],
)
scores = {
"rouge1": scores["rouge1"][1],
"rouge2": scores["rouge2"][1],
"rougeL": scores["rougeL"][1],
}
return scores
tokenizer = get_tokenizer()
model = get_model(tokenizer,checkpoint="model9")
print(test_model(model.cuda(), tokenizer))