Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix sft loss promlem #86

Merged
merged 4 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions finetune/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_raw_dataset(dataset_name, output_path, seed, local_rank):
and os.path.isfile(chat_path + "/data/eval.json")
):
raise RuntimeError(
"Please check both the train.json and eval.json files in your local directory."
f"Please check both the train.json and eval.json files in your local directory."
loofahcus marked this conversation as resolved.
Show resolved Hide resolved
)
return raw_datasets.LocalJsonFileDataset(
output_path, seed, local_rank, dataset_name, chat_path
Expand All @@ -52,7 +52,7 @@ def get_raw_dataset(dataset_name, output_path, seed, local_rank):
and os.path.isfile(chat_path + "/data/eval.jsonl")
):
raise RuntimeError(
"Please check both the train.json and eval.json files in your local directory."
f"Please check both the train.json and eval.json files in your local directory."
)
return raw_datasets.YiDataset(
output_path, seed, local_rank, dataset_name, chat_path
Expand Down Expand Up @@ -130,7 +130,7 @@ def __getitem__(self, idx):
return {
"input_ids": self.chosen_dataset[idx]["input_ids"],
"attention_mask": self.chosen_dataset[idx]["attention_mask"],
"labels": self.chosen_dataset[idx]["input_ids"],
"labels": self.chosen_dataset[idx]["labels"]
}


Expand All @@ -148,10 +148,9 @@ def create_dataset_split(
if train_phase == SFT:
for i, tmp_data in enumerate(current_dataset):
# tokenize the text
chosen_sentence = raw_dataset.get_prompt_and_chosen(
tmp_data
) # the accept response
if chosen_sentence is not None:
chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data) # the accept response
prompt_sentence = raw_dataset.get_prompt(tmp_data)
if chosen_sentence is not None and prompt_sentence is not None:
chosen_sentence += end_of_conversation_token
chosen_token = tokenizer(
chosen_sentence,
Expand All @@ -161,9 +160,11 @@ def create_dataset_split(
return_tensors="pt",
)
chosen_token["input_ids"] = chosen_token["input_ids"].squeeze(0)
chosen_token["attention_mask"] = chosen_token["attention_mask"].squeeze(
0
)
chosen_token["attention_mask"] = chosen_token["attention_mask"].squeeze(0)
prompt_token = tokenizer(prompt_sentence, add_special_tokens=False)
prompt_token_len = min(max_seq_len, len(prompt_token["input_ids"]))
chosen_token["labels"] = chosen_token["input_ids"].clone()
chosen_token["labels"][:prompt_token_len] = -100
chosen_dataset.append(chosen_token)

return PromptDataset(
Expand Down Expand Up @@ -452,7 +453,7 @@ def __init__(self, max_size, small_batch_size):
self.max_size = max_size
self.small_batch_size = small_batch_size

def separate(self):
def seperate(self):
small_dataset = []
for large_batch in self.dataset:
if type(large_batch) == list or type(large_batch) == tuple:
Expand Down Expand Up @@ -483,7 +484,7 @@ def add(self, data):
if len(self.dataset) < self.max_size:
self.dataset.append(data)
if len(self.dataset) == self.max_size:
return self.separate()
return self.seperate()
else:
return None
else:
Expand Down
7 changes: 6 additions & 1 deletion finetune/utils/data/raw_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_prompt_and_rejected(self, sample):
class YiDataset(PromptRawDataset):
def __init__(self, output_path, seed, local_rank, dataset_name, chat_path):
super().__init__(output_path, seed, local_rank, dataset_name)
print("chat path is {}".format(chat_path))
print("data path is {}".format(chat_path))
self.dataset_name = "yi"
self.dataset_name_clean = "yi"
self.raw_datasets = load_dataset(
Expand All @@ -154,6 +154,11 @@ def get_eval_data(self):
if self.raw_datasets["eval"] is not None:
return self.raw_datasets["eval"]
return None

def get_prompt(self, sample):
if sample["prompt"] is not None:
return " " + sample["prompt"]
loofahcus marked this conversation as resolved.
Show resolved Hide resolved
return None

def get_prompt_and_chosen(self, sample):
if sample["prompt"] is not None and sample["chosen"] is not None:
Expand Down