From 48acc2228ef2b4b74989fd918304a31de31019c2 Mon Sep 17 00:00:00 2001 From: jiangcheng01ai Date: Thu, 9 Nov 2023 14:11:15 +0800 Subject: [PATCH 1/3] fix sft loss promlem --- finetune/utils/data/data_utils.py | 25 +++++++++++++------------ finetune/utils/data/raw_datasets.py | 7 ++++++- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/finetune/utils/data/data_utils.py b/finetune/utils/data/data_utils.py index a7e19aec..c5f5fdf2 100644 --- a/finetune/utils/data/data_utils.py +++ b/finetune/utils/data/data_utils.py @@ -36,7 +36,7 @@ def get_raw_dataset(dataset_name, output_path, seed, local_rank): and os.path.isfile(chat_path + "/data/eval.json") ): raise RuntimeError( - "Please check both the train.json and eval.json files in your local directory." + f"Please check both the train.json and eval.json files in your local directory." ) return raw_datasets.LocalJsonFileDataset( output_path, seed, local_rank, dataset_name, chat_path @@ -52,7 +52,7 @@ def get_raw_dataset(dataset_name, output_path, seed, local_rank): and os.path.isfile(chat_path + "/data/eval.jsonl") ): raise RuntimeError( - "Please check both the train.json and eval.json files in your local directory." + f"Please check both the train.json and eval.json files in your local directory." ) return raw_datasets.YiDataset( output_path, seed, local_rank, dataset_name, chat_path @@ -130,7 +130,7 @@ def __getitem__(self, idx): return { "input_ids": self.chosen_dataset[idx]["input_ids"], "attention_mask": self.chosen_dataset[idx]["attention_mask"], - "labels": self.chosen_dataset[idx]["input_ids"], + "labels": self.chosen_dataset[idx]["labels"] } @@ -148,10 +148,9 @@ def create_dataset_split( if train_phase == SFT: for i, tmp_data in enumerate(current_dataset): # tokenize the text - chosen_sentence = raw_dataset.get_prompt_and_chosen( - tmp_data - ) # the accept response - if chosen_sentence is not None: + chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data) # the accept response + prompt_sentence = raw_dataset.get_prompt(tmp_data) + if chosen_sentence is not None and prompt_sentence is not None: chosen_sentence += end_of_conversation_token chosen_token = tokenizer( chosen_sentence, @@ -161,9 +160,11 @@ def create_dataset_split( return_tensors="pt", ) chosen_token["input_ids"] = chosen_token["input_ids"].squeeze(0) - chosen_token["attention_mask"] = chosen_token["attention_mask"].squeeze( - 0 - ) + chosen_token["attention_mask"] = chosen_token["attention_mask"].squeeze(0) + prompt_token = tokenizer(prompt_sentence, add_special_tokens=False) + prompt_token_len = min(max_seq_len, len(prompt_token["input_ids"])) + chosen_token["labels"] = chosen_token["input_ids"].clone() + chosen_token["labels"][:prompt_token_len] = -100 chosen_dataset.append(chosen_token) return PromptDataset( @@ -452,7 +453,7 @@ def __init__(self, max_size, small_batch_size): self.max_size = max_size self.small_batch_size = small_batch_size - def separate(self): + def seperate(self): small_dataset = [] for large_batch in self.dataset: if type(large_batch) == list or type(large_batch) == tuple: @@ -483,7 +484,7 @@ def add(self, data): if len(self.dataset) < self.max_size: self.dataset.append(data) if len(self.dataset) == self.max_size: - return self.separate() + return self.seperate() else: return None else: diff --git a/finetune/utils/data/raw_datasets.py b/finetune/utils/data/raw_datasets.py index acb27a45..45ce3ee8 100644 --- a/finetune/utils/data/raw_datasets.py +++ b/finetune/utils/data/raw_datasets.py @@ -134,7 +134,7 @@ def get_prompt_and_rejected(self, sample): class YiDataset(PromptRawDataset): def __init__(self, output_path, seed, local_rank, dataset_name, chat_path): super().__init__(output_path, seed, local_rank, dataset_name) - print("chat path is {}".format(chat_path)) + print("data path is {}".format(chat_path)) self.dataset_name = "yi" self.dataset_name_clean = "yi" self.raw_datasets = load_dataset( @@ -154,6 +154,11 @@ def get_eval_data(self): if self.raw_datasets["eval"] is not None: return self.raw_datasets["eval"] return None + + def get_prompt(self, sample): + if sample["prompt"] is not None: + return " " + sample["prompt"] + return None def get_prompt_and_chosen(self, sample): if sample["prompt"] is not None and sample["chosen"] is not None: From 84cd8fe495fec1da6ff348e0c4d98c9fb9df89d7 Mon Sep 17 00:00:00 2001 From: jiangcheng01ai Date: Thu, 9 Nov 2023 14:25:13 +0800 Subject: [PATCH 2/3] fix sft loss promlem --- finetune/utils/data/data_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/finetune/utils/data/data_utils.py b/finetune/utils/data/data_utils.py index c5f5fdf2..a09e4e2e 100644 --- a/finetune/utils/data/data_utils.py +++ b/finetune/utils/data/data_utils.py @@ -36,7 +36,7 @@ def get_raw_dataset(dataset_name, output_path, seed, local_rank): and os.path.isfile(chat_path + "/data/eval.json") ): raise RuntimeError( - f"Please check both the train.json and eval.json files in your local directory." + "Please check both the train.json and eval.json files in your local directory." ) return raw_datasets.LocalJsonFileDataset( output_path, seed, local_rank, dataset_name, chat_path @@ -52,14 +52,14 @@ def get_raw_dataset(dataset_name, output_path, seed, local_rank): and os.path.isfile(chat_path + "/data/eval.jsonl") ): raise RuntimeError( - f"Please check both the train.json and eval.json files in your local directory." + "Please check both the train.json and eval.json files in your local directory." ) return raw_datasets.YiDataset( output_path, seed, local_rank, dataset_name, chat_path ) else: raise RuntimeError( - f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py." + "We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py." ) From a165a412e3ad61c67f7a294a10070697a12e1026 Mon Sep 17 00:00:00 2001 From: jiangcheng01ai Date: Thu, 9 Nov 2023 14:31:43 +0800 Subject: [PATCH 3/3] fix sft loss promlem --- finetune/utils/data/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetune/utils/data/data_utils.py b/finetune/utils/data/data_utils.py index a09e4e2e..641c1935 100644 --- a/finetune/utils/data/data_utils.py +++ b/finetune/utils/data/data_utils.py @@ -59,7 +59,7 @@ def get_raw_dataset(dataset_name, output_path, seed, local_rank): ) else: raise RuntimeError( - "We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py." + f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py." )