diff --git a/DiarizationLM/requirements.txt b/DiarizationLM/requirements.txt index c8dcf6f..e3d4815 100644 --- a/DiarizationLM/requirements.txt +++ b/DiarizationLM/requirements.txt @@ -5,3 +5,4 @@ absl-py openai datasets tqdm +colortimelog diff --git a/DiarizationLM/unsloth/2_export.py b/DiarizationLM/unsloth/2_export.py index 4a7601d..c661f79 100644 --- a/DiarizationLM/unsloth/2_export.py +++ b/DiarizationLM/unsloth/2_export.py @@ -3,6 +3,7 @@ import os import config from unsloth import FastLanguageModel +import colortimelog def export_models( @@ -15,28 +16,28 @@ def export_models( checkpoint_path = os.path.join( config.MODEL_ID, f"checkpoint-{config.CHECKPOINT}" ) - print(f"Loading model from {checkpoint_path}...") - model, tokenizer = FastLanguageModel.from_pretrained( - model_name=checkpoint_path, - max_seq_length=config.MAX_SEQ_LENGTH, - dtype=None, - load_in_4bit=True, - ) + with colortimelog.timeblock(f"Loading model from {checkpoint_path}..."): + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=checkpoint_path, + max_seq_length=config.MAX_SEQ_LENGTH, + dtype=None, + load_in_4bit=True, + ) if save_lora: - print("Saving LoRA model...") - model.save_pretrained( - os.path.join(config.MODEL_ID, "lora_model") - ) # Local saving - tokenizer.save_pretrained(os.path.join(config.MODEL_ID, "lora_model")) + with colortimelog.timeblock("Saving LoRA model..."): + model.save_pretrained( + os.path.join(config.MODEL_ID, "lora_model") + ) # Local saving + tokenizer.save_pretrained(os.path.join(config.MODEL_ID, "lora_model")) if save_16bit: - print("Saving 16bit model...") - model.save_pretrained_merged( - os.path.join(config.MODEL_ID, "model"), - tokenizer, - save_method="merged_16bit", - ) + with colortimelog.timeblock("Saving 16bit model...") + model.save_pretrained_merged( + os.path.join(config.MODEL_ID, "model"), + tokenizer, + save_method="merged_16bit", + ) if __name__ == "__main__": diff --git a/DiarizationLM/unsloth/3_batch_inference.py b/DiarizationLM/unsloth/3_batch_inference.py index 89f1c66..f0120c6 100644 --- a/DiarizationLM/unsloth/3_batch_inference.py +++ b/DiarizationLM/unsloth/3_batch_inference.py @@ -4,6 +4,7 @@ import os import config +import colortimelog import tqdm from diarizationlm import utils from unsloth import FastLanguageModel @@ -61,12 +62,12 @@ def run_inference(input_file: str, output_dir: str) -> None: if __name__ == "__main__": for eval_dataset in config.EVAL_INPUTS: - print("Running inference on:", eval_dataset) - eval_input = config.EVAL_INPUTS[eval_dataset] - output_dir = os.path.join( - config.MODEL_ID, - "decoded", - f"checkpoint-{config.CHECKPOINT}", - eval_dataset, - ) - run_inference(input_file=eval_input, output_dir=output_dir) + with colortimelog.timeblock("Running inference on: " + eval_dataset): + eval_input = config.EVAL_INPUTS[eval_dataset] + output_dir = os.path.join( + config.MODEL_ID, + "decoded", + f"checkpoint-{config.CHECKPOINT}", + eval_dataset, + ) + run_inference(input_file=eval_input, output_dir=output_dir) diff --git a/DiarizationLM/unsloth/4_eval.py b/DiarizationLM/unsloth/4_eval.py index e811c5d..e357aa4 100644 --- a/DiarizationLM/unsloth/4_eval.py +++ b/DiarizationLM/unsloth/4_eval.py @@ -5,6 +5,7 @@ import config +import colortimelog from diarizationlm import utils from diarizationlm import metrics @@ -51,14 +52,14 @@ def evaluate(input_file: str, output_file: str) -> None: if __name__ == "__main__": for eval_dataset in config.EVAL_INPUTS: - print("Evaluating:", eval_dataset) - output_dir = os.path.join(config.MODEL_ID, - "decoded", - f"checkpoint-{config.CHECKPOINT}", - eval_dataset) - postprocess( - input_file=os.path.join(output_dir, "final.json"), - output_file=os.path.join(output_dir, "postprocessed.json")) - evaluate( - input_file=os.path.join(output_dir, "postprocessed.json"), - output_file=os.path.join(output_dir, "metrics.json")) + with colortimelog.timeblock("Evaluating: " + eval_dataset): + output_dir = os.path.join(config.MODEL_ID, + "decoded", + f"checkpoint-{config.CHECKPOINT}", + eval_dataset) + postprocess( + input_file=os.path.join(output_dir, "final.json"), + output_file=os.path.join(output_dir, "postprocessed.json")) + evaluate( + input_file=os.path.join(output_dir, "postprocessed.json"), + output_file=os.path.join(output_dir, "metrics.json"))