-
Notifications
You must be signed in to change notification settings - Fork 1
/
mantle_beam_explanations.py
78 lines (62 loc) · 3.56 KB
/
mantle_beam_explanations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import json
import torch
import argparse
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from utils.tokenize_txt import tokenize_t5_explanation_txt
from utils.mantle_utils import get_features, parse
global device; device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main():
parser = argparse.ArgumentParser(description="MaNtLE Beam Search Explanations Generation", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset-name', required=True, help='Dataset name')
parser.add_argument('--model-name', choices=['lr', 'dt', 'nn', 'xgb', ''], required=True, help='Model name')
parser.add_argument('--num-subsets', type=int, default=100, required=False, help='Number of subsets')
args = parser.parse_args()
DATASET_PATH = f'data/{args.dataset_name}/mantle_subsets/{args.model_name}'
tokenizer = T5Tokenizer.from_pretrained('t5-large', model_max_length=1024)
t5_config = T5Config.from_pretrained('pretrained_mantle')
model = T5ForConditionalGeneration.from_pretrained('pretrained_mantle', config=t5_config).to(device)
mantle_config = json.load(open('pretrained_mantle/mantle_config.json', 'r'))
model.eval()
output_txts = []
for sub_idx in range(args.num_subsets):
data_path = os.path.join(DATASET_PATH, f'{sub_idx+1}', 'data.jsonl')
data = []
with open(data_path, 'r') as f:
lines = f.readlines()
example = json.loads(lines[0][:-1])['samples']
for idx, ex in enumerate(example):
if args.dataset_name == 'adult':
income = ex['Income'] if ex['Income'] == '>50K' else 'not >50K'
example[idx]['Income'] = income
elif args.dataset_name == 'recidivism':
recidivism = ex['Recidivism']
example[idx]['Recidivism'] = recidivism
elif args.dataset_name == 'travel_insurance':
insurance = ex['Travel Insurance']
example[idx]['Travel Insurance'] = insurance
else:
raise NotImplementedError('Please implement the equivalent fn for this dataset')
data.extend(example)
txt = get_features(data)
input_ids = tokenize_t5_explanation_txt(tokenizer, mantle_config['max_text_length'], \
txt, prompt=mantle_config['prompt_exp'], lm_adapt=False).to(device)
start_token = tokenizer.additional_special_tokens_ids[0]
with torch.no_grad():
outputs = model.generate(input_ids,
decoder_start_token_id=start_token,
max_length=mantle_config['exp_max_text_length'],
num_beams=20,
num_return_sequences=20)
eos_token_id = tokenizer.eos_token_id
sequences = [seq.cpu().numpy().tolist() + [eos_token_id] for seq in outputs]
sequences = [seq[:seq.index(eos_token_id)] for seq in sequences] # read until first eos token
output_txt = [tokenizer.decode(seq[1:]) for seq in sequences] # skip the start_token
best_exp = output_txt[parse(output_txt, data, args.dataset_name)]
output_txts.append(best_exp)
# Print out the different texts from the model
with open(f'data/{args.dataset_name}/mantle_subsets/{args.model_name}/mantle_beam_explanations.txt', 'w') as f:
for exp in output_txts:
f.write(exp + '\n')
if __name__=='__main__':
main()