Skip to content

Commit

Permalink
fix: 작동되게 수정
Browse files Browse the repository at this point in the history
  • Loading branch information
kooqooo committed Feb 18, 2024
1 parent a77710c commit 0aa3a5a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
9 changes: 5 additions & 4 deletions code/inference_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AutoConfig,
AutoModelForQuestionAnswering,
AutoTokenizer,
BertTokenizerFast,
DataCollatorWithPadding,
EvalPrediction,
HfArgumentParser,
Expand Down Expand Up @@ -94,7 +95,7 @@ def main():
# True일 경우 : run passage retrieval
if data_args.eval_retrieval:
datasets = run_dense_retrieval(
tokenizer.tokenize, datasets, training_args, data_args,
datasets, training_args, data_args,
)

# eval or predict mrc model
Expand All @@ -106,7 +107,7 @@ def run_dense_retrieval(
datasets: DatasetDict,
training_args: TrainingArguments,
data_args: DataTrainingArguments,
tokenizer = AutoTokenizer.from_pretrained('klue/bert-base', padding="max_length", truncation=True, return_tensors="pt"),
tokenizer = BertTokenizerFast.from_pretrained('klue/bert-base', padding="max_length", truncation=True, return_tensors="pt"),
data_path: str = "./data",
context_path: str = "wikipedia_documents.json",
p_encoder = BertEncoder.from_pretrained('klue/bert-base'),
Expand All @@ -122,8 +123,8 @@ def run_dense_retrieval(
output_dir="dense_retireval",
evaluation_strategy="epoch",
learning_rate=1e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=5,
weight_decay=0.01
)
Expand Down
30 changes: 17 additions & 13 deletions code/retrieval_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ def __init__(self, args, num_neg, tokenizer, p_encoder, q_encoder, num_sample: i
'''
self.args = args
self.data_path = data_path
self.dataset = load_from_disk(os.path.join(self.data_path, 'train_dataset'))
self.train_dataset = self.dataset['train'] if num_sample is -1 else self.dataset['train'][:num_sample]

self.dataset = load_from_disk(self.data_path+'/train_dataset')
self.train_dataset = self.dataset['train'] if num_sample == -1 else self.dataset['train'][:num_sample]
self.valid_dataset = self.dataset['validation']
testdata = load_from_disk(os.path.join(self.data_path), 'test_dataset')
testdata = load_from_disk(self.data_path+'/test_dataset')

self.test_dataset = testdata['validation']
del testdata

Expand All @@ -71,8 +73,8 @@ def __init__(self, args, num_neg, tokenizer, p_encoder, q_encoder, num_sample: i
self.q_encoder = q_encoder
self.pwd = os.getcwd()
self.save_path = os.path.join(self.pwd, 'models/dpr')

self.prepare_in_batch_negative(num_neg=num_neg)
print('save_path :', self.save_path)
self.prepare_in_batch_negative(dataset=self.train_dataset, num_neg=num_neg)


def prepare_in_batch_negative(self, dataset=None, num_neg=3, tokenizer=None):
Expand All @@ -81,7 +83,9 @@ def prepare_in_batch_negative(self, dataset=None, num_neg=3, tokenizer=None):
dataset = self.dataset
dataset = concatenate_datasets([dataset["train"].flatten_indices(),
dataset["validation"].flatten_indices()])

# print(dataset)
# print(dataset['context'])
# print(dataset['features'])
if tokenizer is None:
tokenizer = self.tokenizer

Expand All @@ -101,7 +105,7 @@ def prepare_in_batch_negative(self, dataset=None, num_neg=3, tokenizer=None):
p_with_neg.append(c)
p_with_neg.extend(p_neg)
break

# 2. (Question, Passage) 데이터셋 만들어주기
q_seqs = tokenizer(dataset['question'], padding="max_length", truncation=True, return_tensors='pt')
p_seqs = tokenizer(p_with_neg, padding="max_length", truncation=True, return_tensors='pt')
Expand Down Expand Up @@ -206,13 +210,13 @@ def train(self, args=None, override=False, num_pre_batch=0):
with tqdm(self.train_dataloader, unit="batch") as tepoch:
for batch in tepoch:

p_encoder.train()
q_encoder.train()
self.p_encoder.train()
self.q_encoder.train()

targets = torch.zeros(batch_size).long() # positive example은 전부 첫 번째에 위치하므로
targets = targets.to(args.device)

if num_pre_batch is not 0: # Pre-batch
if num_pre_batch != 0: # Pre-batch
p_inputs = {
'input_ids': batch[0].to(args.device),
'attention_mask': batch[1].to(args.device),
Expand All @@ -234,7 +238,7 @@ def train(self, args=None, override=False, num_pre_batch=0):

p_outputs = self.p_encoder(**p_inputs) # (batch_size*(num_neg+1), emb_dim)
q_outputs = self.q_encoder(**q_inputs) # (batch_size*, emb_dim)
if num_pre_batch is not 0: # Pre-batch negative sampling
if num_pre_batch != 0: # Pre-batch negative sampling
temp = p_outputs.clone().detach()
p_outputs = torch.cat((p_outputs, *p_queue), dim=0)
p_queue.append(temp)
Expand Down Expand Up @@ -491,8 +495,8 @@ def forward(
output_dir="dense_retireval",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=5,
weight_decay=0.01
)
Expand Down

0 comments on commit 0aa3a5a

Please sign in to comment.