Skip to content

Commit

Permalink
backup before cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jul 29, 2024
1 parent 03b43ea commit 8aa0ee5
Show file tree
Hide file tree
Showing 9 changed files with 1,104 additions and 529 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,8 @@ examples/infinite-context-length/*.txt
*.arrow
*.txt
*.out
debug/
llama3_8b_generated_from_brrr/
llama3_8b_generated_from_brrr_tp_8/
llama3_8b_generated_from_brrr_with_pp_4/
examples/infinite-context-length/configs/exp59/
461 changes: 461 additions & 0 deletions examples/infinite-context-length/data/exp34/debug_eval_dataset.ipynb

Large diffs are not rendered by default.

471 changes: 425 additions & 46 deletions examples/infinite-context-length/data/exp34/debug_finetune_data.ipynb

Large diffs are not rendered by default.

72 changes: 55 additions & 17 deletions examples/infinite-context-length/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer

PROMPT = "{} {}. \n\n{}"
PROMPT = "{} {}. {}"


def get_keys_in_train_set(dataset):
Expand Down Expand Up @@ -98,12 +98,13 @@ def generate_needle_in_haystack_test(
target_cut_length = context_length - token_length(tokenizer, PROMPT.format(soft_prompt, 1, retrieval_question)) - 1

context = read_context_files(tokenizer, soft_prompt, retrieval_question, target_cut_length)
context = f"\n Passage: {context}"

# Insert the needle into the context at the specified depth percent
context_with_needle = insert_needle_with_depth(needle_prompt, context, depth_percent, target_cut_length, tokenizer)

# Generate the prompt using the context with the needle
prompt = f"{soft_prompt} {context_with_needle}. \n\n{retrieval_question}"
prompt = f"{soft_prompt} {context_with_needle}. {retrieval_question}"

assert str(needle) in context_with_needle, f"depth_percent: {depth_percent}"

Expand Down Expand Up @@ -136,33 +137,54 @@ def get_args():
parser.add_argument("--id", type=int, required=True)
parser.add_argument("--haystack_dir", type=str, default="./haystack_txt/")
parser.add_argument("--is_push_to_hub", type=bool, default=False)
parser.add_argument("--is_exact_context_length", type=int, default=1) # 1 is True, 0 is False
parser.add_argument("--is_padding", type=int, default=1) # 1 is True, 0 is False
parser.add_argument("--is_eval", type=int, default=1) # 1 is True, 0 is False
parser.add_argument("--is_exact_context_length", type=str, required=True) # 1 is True, 0 is False
parser.add_argument("--is_padding", type=str, required=True) # 1 is True, 0 is False
parser.add_argument("--is_eval", type=str, required=True) # 1 is True, 0 is False
parser.add_argument("--check_key_in_dataset", type=str, default=None) # 1 is True, 0 is False
parser.add_argument("--save_path", type=str, required=True) # 1 is True, 0 is False
parser.add_argument("--num_shots", type=int, default=0) # 1 is True, 0 is False
parser.add_argument("--num_digits", type=int, default=4) # 1 is True, 0 is False
return parser.parse_args()


def generate_random_number(num_digits):
import numpy as np

lower_bound = 10 ** (num_digits - 1)
upper_bound = 10**num_digits - 1
return np.random.randint(lower_bound, upper_bound)


if __name__ == "__main__":
PROMPT_SHOT_KEYS = [312, 415, 34]
PROMPT_SHOT_1 = "\n Passage: The gatecode is 312. Remember it. 312 is the gatecode. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. What is the gatecode? The gatecode is 312. \n"
PROMPT_SHOT_2 = "\n Passage: The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The vault key is 415. Remember it. 415 is the vault key. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. What is the gatecode? The vault key is 415. \n"
PROMPT_SHOT_3 = "\n Passage: The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. The encrypted message is 34. Remember it. 34 is the encrypted message. There and back again. What is the encrypted message? The encrypted message is 34. \n"

args = get_args()

for key, value in vars(args).items():
print(f"{key}: {value}")

assert args.is_padding in ["yes", "no"]
assert args.is_eval in ["yes", "no"]
assert args.is_exact_context_length in ["yes", "no"]

context_length = args.context_length
depth_percent = args.depth_percent
tokenizer_path = args.tokenizer_path
num_prompts = args.num_prompts
num_digits = args.num_digits
id = args.id
haystack_dir = args.haystack_dir
is_push_to_hub = args.is_push_to_hub
# is_exact_context_length = args.is_exact_context_length
is_exact_context_length = False if args.is_exact_context_length == 0 else True
is_padding = False if args.is_padding == 0 else True
is_eval = False if args.is_eval == 0 else True
is_exact_context_length = False if args.is_exact_context_length == "no" else True
is_padding = False if args.is_padding == "no" else True
is_eval = False if args.is_eval == "no" else True
check_key_in_dataset = args.check_key_in_dataset
save_path = args.save_path
num_shots = args.num_shots

assert save_path is not None

Expand All @@ -179,36 +201,50 @@ def get_args():
gen_context_length = context_length

# NOTE: depth_percent + 1 to avoid 0
RANGE = 500
start_range = 30 * (depth_percent + 1) * id
end_range = start_range + RANGE
# RANGE = 500
# start_range = 30 * (depth_percent + 1) * id
# end_range = start_range + RANGE

print(
f"Generating prompts for context length: {gen_context_length} (original {context_length}) and depth percent: {depth_percent} and id: {id} \n"
)
print(f"start_range: {start_range}, end_range: {end_range} \n")
# print(f"start_range: {start_range}, end_range: {end_range} \n")

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

def generate_dataset():
soft_prompt = "There is a pass key hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about what is the pass key later on."
soft_prompt = "There is a pass key hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about what is the pass key later on. \n"

if num_shots > 0:
if num_shots >= 1:
soft_prompt += PROMPT_SHOT_1

if num_shots >= 2:
soft_prompt += PROMPT_SHOT_2

if num_shots >= 3:
soft_prompt += PROMPT_SHOT_3

dataset_dict = {
"id": [],
"prompt": [],
"answer": [],
"context_length": [],
"num_tokens": [],
"token_count": [],
"num_shots": [],
"num_digits": [],
"depth_percent": [],
}
generated_ids = set()
generated_pass_keys = set()
generated_pass_keys.update(PROMPT_SHOT_KEYS)

for i in range(num_prompts):
print(f"generating prompt {i} \n")

while True:
pass_key = random.randint(start_range, end_range)
# pass_key = random.randint(start_range, end_range)
pass_key = generate_random_number(num_digits)
if pass_key not in generated_pass_keys:
if check_key_in_dataset is not None:
if str(pass_key) in eval_keys:
Expand Down Expand Up @@ -245,7 +281,9 @@ def generate_dataset():
dataset_dict["prompt"].append(prompt)
dataset_dict["answer"].append(pass_key)
dataset_dict["context_length"].append(context_length)
dataset_dict["num_tokens"].append(token_length(tokenizer, prompt))
dataset_dict["num_shots"].append(num_shots)
dataset_dict["token_count"].append(token_length(tokenizer, prompt))
dataset_dict["num_digits"].append(num_digits)
dataset_dict["depth_percent"].append(depth_percent)

dataset = Dataset.from_dict(dataset_dict)
Expand All @@ -256,7 +294,7 @@ def generate_dataset():

# Save the dataset to disk
dataset.save_to_disk(
f"{save_path}/needle_finetune_data_and_{context_length}_ctx_and_depth_{depth_percent}_and_id_{id}"
f"{save_path}/needle_eval_data_and_{context_length}_ctx_and_depth_{depth_percent}_and_id_{id}_and_num_shots_{num_shots}"
)
# if is_push_to_hub:
# dataset.push_to_hub("nanotron/llama3-16k-passkey-retrieval-eval")
Loading

0 comments on commit 8aa0ee5

Please sign in to comment.