Skip to content

Commit

Permalink
Use colortimelog in DiarizationLM
Browse files Browse the repository at this point in the history
  • Loading branch information
wq2012 committed Sep 20, 2024
1 parent 82b9bdf commit 9259a56
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 38 deletions.
1 change: 1 addition & 0 deletions DiarizationLM/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ absl-py
openai
datasets
tqdm
colortimelog
37 changes: 19 additions & 18 deletions DiarizationLM/unsloth/2_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import config
from unsloth import FastLanguageModel
import colortimelog


def export_models(
Expand All @@ -15,28 +16,28 @@ def export_models(
checkpoint_path = os.path.join(
config.MODEL_ID, f"checkpoint-{config.CHECKPOINT}"
)
print(f"Loading model from {checkpoint_path}...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=checkpoint_path,
max_seq_length=config.MAX_SEQ_LENGTH,
dtype=None,
load_in_4bit=True,
)
with colortimelog.timeblock(f"Loading model from {checkpoint_path}..."):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=checkpoint_path,
max_seq_length=config.MAX_SEQ_LENGTH,
dtype=None,
load_in_4bit=True,
)

if save_lora:
print("Saving LoRA model...")
model.save_pretrained(
os.path.join(config.MODEL_ID, "lora_model")
) # Local saving
tokenizer.save_pretrained(os.path.join(config.MODEL_ID, "lora_model"))
with colortimelog.timeblock("Saving LoRA model..."):
model.save_pretrained(
os.path.join(config.MODEL_ID, "lora_model")
) # Local saving
tokenizer.save_pretrained(os.path.join(config.MODEL_ID, "lora_model"))

if save_16bit:
print("Saving 16bit model...")
model.save_pretrained_merged(
os.path.join(config.MODEL_ID, "model"),
tokenizer,
save_method="merged_16bit",
)
with colortimelog.timeblock("Saving 16bit model...")
model.save_pretrained_merged(
os.path.join(config.MODEL_ID, "model"),
tokenizer,
save_method="merged_16bit",
)


if __name__ == "__main__":
Expand Down
19 changes: 10 additions & 9 deletions DiarizationLM/unsloth/3_batch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os

import config
import colortimelog
import tqdm
from diarizationlm import utils
from unsloth import FastLanguageModel
Expand Down Expand Up @@ -61,12 +62,12 @@ def run_inference(input_file: str, output_dir: str) -> None:

if __name__ == "__main__":
for eval_dataset in config.EVAL_INPUTS:
print("Running inference on:", eval_dataset)
eval_input = config.EVAL_INPUTS[eval_dataset]
output_dir = os.path.join(
config.MODEL_ID,
"decoded",
f"checkpoint-{config.CHECKPOINT}",
eval_dataset,
)
run_inference(input_file=eval_input, output_dir=output_dir)
with colortimelog.timeblock("Running inference on: " + eval_dataset):
eval_input = config.EVAL_INPUTS[eval_dataset]
output_dir = os.path.join(
config.MODEL_ID,
"decoded",
f"checkpoint-{config.CHECKPOINT}",
eval_dataset,
)
run_inference(input_file=eval_input, output_dir=output_dir)
23 changes: 12 additions & 11 deletions DiarizationLM/unsloth/4_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import config

import colortimelog
from diarizationlm import utils
from diarizationlm import metrics

Expand Down Expand Up @@ -51,14 +52,14 @@ def evaluate(input_file: str, output_file: str) -> None:

if __name__ == "__main__":
for eval_dataset in config.EVAL_INPUTS:
print("Evaluating:", eval_dataset)
output_dir = os.path.join(config.MODEL_ID,
"decoded",
f"checkpoint-{config.CHECKPOINT}",
eval_dataset)
postprocess(
input_file=os.path.join(output_dir, "final.json"),
output_file=os.path.join(output_dir, "postprocessed.json"))
evaluate(
input_file=os.path.join(output_dir, "postprocessed.json"),
output_file=os.path.join(output_dir, "metrics.json"))
with colortimelog.timeblock("Evaluating: " + eval_dataset):
output_dir = os.path.join(config.MODEL_ID,
"decoded",
f"checkpoint-{config.CHECKPOINT}",
eval_dataset)
postprocess(
input_file=os.path.join(output_dir, "final.json"),
output_file=os.path.join(output_dir, "postprocessed.json"))
evaluate(
input_file=os.path.join(output_dir, "postprocessed.json"),
output_file=os.path.join(output_dir, "metrics.json"))

0 comments on commit 9259a56

Please sign in to comment.