diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index f45f5027..f6c052ce 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -362,53 +362,19 @@ def read_taxonomy(*args, **kwargs): return instructlab.utils.read_taxonomy(*args, **kwargs) -def generate_data( - logger, - api_base, - tls_insecure, - model_family: str, - yaml_rules: Optional[str] = None, - output_dir: Optional[str] = None, - taxonomy: Optional[str] = None, - taxonomy_base: Optional[str] = None, - prompt_file_path: Optional[str] = None, - model_name: Optional[str] = None, - num_cpus: Optional[int] = None, - num_instructions_to_generate: Optional[int] = None, - num_prompt_instructions=2, - request_batch_size=5, - temperature=1.0, - top_p=1.0, - rouge_threshold: Optional[float] = None, - console_output=True, - api_key: Optional[str] = None, - chunk_word_count=None, - server_ctx_size=None, - tls_client_cert: Optional[str] = None, - tls_client_key: Optional[str] = None, - tls_client_passwd: Optional[str] = None, -): - seed_instruction_data = [] - machine_seed_instruction_data = [] - generate_start = time.time() +def unescape(s): + return bytes(s, "utf-8").decode("utf-8") - if not os.path.exists(output_dir): - os.mkdir(output_dir) - - # check taxonomy first then seed_tasks_path - # throw an error if both not found - # pylint: disable=broad-exception-caught,raise-missing-from - if taxonomy and os.path.exists(taxonomy): - seed_instruction_data = read_taxonomy( - logger, taxonomy, taxonomy_base, yaml_rules - ) - else: - raise SystemExit(f"Error: taxonomy ({taxonomy}) does not exist.") - prompt_template = check_prompt_file( - prompt_file_path, get_model_family(model_family, model_name) - ) - max_seed_tokens = max_seed_example_tokens(server_ctx_size, len(prompt_template)) +def _gen_test_data( + logger, + seed_instruction_data, + max_seed_tokens, + taxonomy, + chunk_word_count, + server_ctx_size, + output_file_test, +): max_seed_chars = num_chars_from_tokens(max_seed_tokens) for seed_example in seed_instruction_data: if ( @@ -426,9 +392,6 @@ def generate_data( if not seeds: raise SystemExit("Nothing to generate. Exiting.") - def unescape(s): - return bytes(s, "utf-8").decode("utf-8") - test_data = [] for seed_example in seed_instruction_data: user = seed_example["instruction"] @@ -457,6 +420,80 @@ def unescape(s): fg="red", ) raise click.exceptions.Exit(1) + # utils.jdump(test_data, os.path.join(output_dir, output_file_test)) + with open(output_file_test, "w", encoding="utf-8") as outfile: + for entry in test_data: + json.dump(entry, outfile, ensure_ascii=False) + outfile.write("\n") + + +def _gen_train_data(machine_instruction_data, output_file_train): + train_data = [] + for synth_example in machine_instruction_data: + user = synth_example["instruction"] + if len(synth_example["input"]) > 0: + user += "\n" + synth_example["input"] + train_data.append( + { + "system": utils.get_sysprompt(), + "user": unescape(user), + "assistant": unescape(synth_example["output"]), + } + ) + # utils.jdump(train_data, output_file_train) + with open(output_file_train, "w", encoding="utf-8") as outfile: + for entry in train_data: + json.dump(entry, outfile, ensure_ascii=False) + outfile.write("\n") + + +def generate_data( + logger, + api_base, + tls_insecure, + model_family: str, + yaml_rules: Optional[str] = None, + output_dir: Optional[str] = None, + taxonomy: Optional[str] = None, + taxonomy_base: Optional[str] = None, + prompt_file_path: Optional[str] = None, + model_name: Optional[str] = None, + num_cpus: Optional[int] = None, + num_instructions_to_generate: Optional[int] = None, + num_prompt_instructions=2, + request_batch_size=5, + temperature=1.0, + top_p=1.0, + rouge_threshold: Optional[float] = None, + console_output=True, + api_key: Optional[str] = None, + chunk_word_count=None, + server_ctx_size=None, + tls_client_cert: Optional[str] = None, + tls_client_key: Optional[str] = None, + tls_client_passwd: Optional[str] = None, +): + seed_instruction_data = [] + machine_seed_instruction_data = [] + generate_start = time.time() + + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + # check taxonomy first then seed_tasks_path + # throw an error if both not found + # pylint: disable=broad-exception-caught,raise-missing-from + if taxonomy and os.path.exists(taxonomy): + seed_instruction_data = read_taxonomy( + logger, taxonomy, taxonomy_base, yaml_rules + ) + else: + raise SystemExit(f"Error: taxonomy ({taxonomy}) does not exist.") + + prompt_template = check_prompt_file( + prompt_file_path, get_model_family(model_family, model_name) + ) + max_seed_tokens = max_seed_example_tokens(server_ctx_size, len(prompt_template)) name = Path(model_name).stem # Just in case it is a file path date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_") @@ -466,6 +503,15 @@ def unescape(s): output_file_discarded = os.path.join( output_dir, f"discarded_{name}_{date_suffix}.log" ) + _gen_test_data( + logger, + seed_instruction_data, + max_seed_tokens, + taxonomy, + chunk_word_count, + server_ctx_size, + os.path.join(output_dir, output_file_test), + ) logger.debug(f"Generating to: {os.path.join(output_dir, output_file)}") request_idx = 0 @@ -580,32 +626,9 @@ def unescape(s): f"Generated {total} instructions(discarded {discarded}), rouged {total - keep}, kept {keep} instructions" ) utils.jdump(machine_instruction_data, os.path.join(output_dir, output_file)) - train_data = [] - for synth_example in machine_instruction_data: - user = synth_example["instruction"] - if len(synth_example["input"]) > 0: - user += "\n" + synth_example["input"] - train_data.append( - { - "system": utils.get_sysprompt(), - "user": unescape(user), - "assistant": unescape(synth_example["output"]), - } - ) - # utils.jdump(train_data, os.path.join(output_dir, output_file_train)) - with open( - os.path.join(output_dir, output_file_train), "w", encoding="utf-8" - ) as outfile: - for entry in train_data: - json.dump(entry, outfile, ensure_ascii=False) - outfile.write("\n") - # utils.jdump(test_data, os.path.join(output_dir, output_file_test)) - with open( - os.path.join(output_dir, output_file_test), "w", encoding="utf-8" - ) as outfile: - for entry in test_data: - json.dump(entry, outfile, ensure_ascii=False) - outfile.write("\n") + _gen_train_data( + machine_instruction_data, os.path.join(output_dir, output_file_train) + ) progress_bar.close()