diff --git a/finetune/utils/data/data_utils.py b/finetune/utils/data/data_utils.py index a7e19aec..c5f5fdf2 100644 --- a/finetune/utils/data/data_utils.py +++ b/finetune/utils/data/data_utils.py @@ -36,7 +36,7 @@ def get_raw_dataset(dataset_name, output_path, seed, local_rank): and os.path.isfile(chat_path + "/data/eval.json") ): raise RuntimeError( - "Please check both the train.json and eval.json files in your local directory." + f"Please check both the train.json and eval.json files in your local directory." ) return raw_datasets.LocalJsonFileDataset( output_path, seed, local_rank, dataset_name, chat_path @@ -52,7 +52,7 @@ def get_raw_dataset(dataset_name, output_path, seed, local_rank): and os.path.isfile(chat_path + "/data/eval.jsonl") ): raise RuntimeError( - "Please check both the train.json and eval.json files in your local directory." + f"Please check both the train.json and eval.json files in your local directory." ) return raw_datasets.YiDataset( output_path, seed, local_rank, dataset_name, chat_path @@ -130,7 +130,7 @@ def __getitem__(self, idx): return { "input_ids": self.chosen_dataset[idx]["input_ids"], "attention_mask": self.chosen_dataset[idx]["attention_mask"], - "labels": self.chosen_dataset[idx]["input_ids"], + "labels": self.chosen_dataset[idx]["labels"] } @@ -148,10 +148,9 @@ def create_dataset_split( if train_phase == SFT: for i, tmp_data in enumerate(current_dataset): # tokenize the text - chosen_sentence = raw_dataset.get_prompt_and_chosen( - tmp_data - ) # the accept response - if chosen_sentence is not None: + chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data) # the accept response + prompt_sentence = raw_dataset.get_prompt(tmp_data) + if chosen_sentence is not None and prompt_sentence is not None: chosen_sentence += end_of_conversation_token chosen_token = tokenizer( chosen_sentence, @@ -161,9 +160,11 @@ def create_dataset_split( return_tensors="pt", ) chosen_token["input_ids"] = chosen_token["input_ids"].squeeze(0) - chosen_token["attention_mask"] = chosen_token["attention_mask"].squeeze( - 0 - ) + chosen_token["attention_mask"] = chosen_token["attention_mask"].squeeze(0) + prompt_token = tokenizer(prompt_sentence, add_special_tokens=False) + prompt_token_len = min(max_seq_len, len(prompt_token["input_ids"])) + chosen_token["labels"] = chosen_token["input_ids"].clone() + chosen_token["labels"][:prompt_token_len] = -100 chosen_dataset.append(chosen_token) return PromptDataset( @@ -452,7 +453,7 @@ def __init__(self, max_size, small_batch_size): self.max_size = max_size self.small_batch_size = small_batch_size - def separate(self): + def seperate(self): small_dataset = [] for large_batch in self.dataset: if type(large_batch) == list or type(large_batch) == tuple: @@ -483,7 +484,7 @@ def add(self, data): if len(self.dataset) < self.max_size: self.dataset.append(data) if len(self.dataset) == self.max_size: - return self.separate() + return self.seperate() else: return None else: diff --git a/finetune/utils/data/raw_datasets.py b/finetune/utils/data/raw_datasets.py index acb27a45..45ce3ee8 100644 --- a/finetune/utils/data/raw_datasets.py +++ b/finetune/utils/data/raw_datasets.py @@ -134,7 +134,7 @@ def get_prompt_and_rejected(self, sample): class YiDataset(PromptRawDataset): def __init__(self, output_path, seed, local_rank, dataset_name, chat_path): super().__init__(output_path, seed, local_rank, dataset_name) - print("chat path is {}".format(chat_path)) + print("data path is {}".format(chat_path)) self.dataset_name = "yi" self.dataset_name_clean = "yi" self.raw_datasets = load_dataset( @@ -154,6 +154,11 @@ def get_eval_data(self): if self.raw_datasets["eval"] is not None: return self.raw_datasets["eval"] return None + + def get_prompt(self, sample): + if sample["prompt"] is not None: + return " " + sample["prompt"] + return None def get_prompt_and_chosen(self, sample): if sample["prompt"] is not None and sample["chosen"] is not None: