-
Notifications
You must be signed in to change notification settings - Fork 4
/
inference.py
144 lines (108 loc) · 5.6 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from datetime import datetime
import random
import torch
from torch.nn.functional import mse_loss
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, AdamW, RobertaForSequenceClassification, AutoTokenizer, \
AutoModelForSequenceClassification, FlaxLlamaForCausalLM, LlamaForSequenceClassification
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.dataset import random_split
from tqdm import tqdm
import os
import argparse
import json
import torch
from peft import PeftModel,PeftModelForTokenClassification
import os
import torch.nn as nn
from peft import AutoPeftModelForCausalLM,AutoPeftModelForSequenceClassification,AutoPeftModelForTokenClassification
from tools.order_metrics import *
LlamaForSequenceClassification
import transformers.models.qwen2
from sklearn.metrics import ndcg_score
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
os.environ["TOKENIZERS_PARALLELISM"] = "false"
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def NDCG_k(predictions, labels, k=20):
if len(predictions) < k:
return -1 # or handle as preferred
return ndcg_score([labels], [predictions], k=k)
def save_args_to_json(args, file_path):
args_dict = vars(args)
with open(file_path, 'w') as f:
json.dump(args_dict, f, indent=4)
def get_args():
parser = argparse.ArgumentParser(description="Train a transformer model with LoRA adaptation on text classification tasks.")
# Dataset and training configuration
parser.add_argument('--max_length', type=int, default=1024, help='Maximum length of the tokenized input sequences')
parser.add_argument('--batch_size', type=int, default=20, help='Batch size for training and validation')
parser.add_argument('--data_path', type=str, default='ScImpactPredict/data/Data_TNCSI_S_OA_AuthorCite_8242_fix1.csv', help='Path to the dataset CSV file')
parser.add_argument('--checkpoint', type=str, default='llama3_weight', help='Model checkpoint path')
parser.add_argument('--weight_dir', type=str, default='runs/Jul12_14-54-26_gpu22', help='Model checkpoint path')
parser.add_argument('--loss_func', type=str, default='bce',choices=['bce','mse','l1'])
parser.add_argument('--prompt_style', type=int,default=0)
parser.add_argument('--test_ratio', type=float,default=1.0)
# parser.add_argument('--model_save_path', type=str, help='Path to save the trained models')
parser.add_argument('--device', type=str, default='cuda', help='Device to train the model on (cuda or cpu)')
parser.add_argument('--num_labels', type=int, default=1, help='Number of labels for sequence classification')
parser.add_argument('--load_in_8bit', type=bool, default=True, help='Whether to load the model in 8-bit for efficiency')
# LoRA configuration
parser.add_argument('--lora_r', type=int, default=16, help='Rank of LoRA layers')
parser.add_argument('--lora_alpha', type=int, default=32, help='Expansion factor for LoRA layers')
parser.add_argument('--lora_dropout', type=float, default=0.05, help='Dropout rate for LoRA layers')
parser.add_argument('--lora_bias', type=str, default='none', help='Bias mode for LoRA layers')
parser.add_argument('--target_modules', type=str, default='q_proj,v_proj', help='Comma-separated list of transformer modules to apply LoRA')
default_tb_dir = datetime.now().strftime("%m-%d-%H-%M")
parser.add_argument('--runs_dir', type=str, default=os.path.join('ScImpactPredict/inference',default_tb_dir), help='Directory for storing TensorBoard logs')
return parser.parse_args()
from accelerate import Accelerator
accelerator = Accelerator()
args = get_args()
from offcial_train import TextDataset
args.eff_gpus = int(torch.cuda.device_count() * args.batch_size)
writer = SummaryWriter(args.runs_dir)
tokenizer = AutoTokenizer.from_pretrained(args.weight_dir)
tokenizer.pad_token = tokenizer.eos_token
device_map={'':torch.cuda.current_device()}
model = AutoPeftModelForSequenceClassification.from_pretrained(args.weight_dir, num_labels=args.num_labels, load_in_8bit=args.load_in_8bit,device_map=device_map,)
model.config.pad_token_id = model.config.eos_token_id
model.loss_func = args.loss_func
# model.score.load_state_dict(torch.load(os.path.join(args.weight_dir,'score.pt')))
# model = model.merge_and_unload()
print(model.score.weight)
full_data = pd.read_csv(args.data_path)
dataset = TextDataset(full_data, tokenizer, max_length=args.max_length,prompt_style=args.prompt_style)
test_loader = DataLoader(dataset, batch_size=16, shuffle=True)
print(f'Test Dataloader has {len(test_loader)} samples in total')
total_val_mse = 0.0
total_val_mae = 0.0
all_pred = []
all_GT = []
model.eval()
model,test_loader = accelerator.prepare(model,test_loader)
with torch.no_grad():
for batch in tqdm(test_loader):
outputs = model(**batch)
predictions = outputs["logits"]
labels = batch["labels"]
all_GT+=labels
all_pred+= predictions.squeeze(1)
mse = nn.MSELoss()(predictions.squeeze(1), labels).item()
mae = nn.L1Loss()(predictions.squeeze(1), labels).item()
total_val_mse += mse
total_val_mae += mae
avg_mse_loss = total_val_mse / len(test_loader)
avg_mae_loss = total_val_mae / len(test_loader)
all_pred = torch.Tensor(all_pred).squeeze()
all_GT = torch.Tensor(all_GT).squeeze()
print(args.weight_dir)
print(avg_mse_loss,avg_mae_loss)
print(NDCG_k(all_pred,all_GT))
writer.close()