diff --git a/dalm/eval/eval_rag.py b/dalm/eval/eval_rag.py index 0b9b949..faf8285 100644 --- a/dalm/eval/eval_rag.py +++ b/dalm/eval/eval_rag.py @@ -11,7 +11,9 @@ PreTrainedTokenizer, ) +from dalm.eval.eval_results import EvalResults from dalm.eval.utils import ( + calc_eval_results, construct_search_index, evaluate_retriever_on_batch, get_passage_embeddings, @@ -180,7 +182,7 @@ def evaluate_rag( top_k: int = 10, evaluate_generator: bool = True, retriever_is_autoregressive: bool = False, -) -> None: +) -> EvalResults: """Runs rag evaluation. See `dalm eval-rag --help for details on params""" test_dataset = load_dataset(dataset_or_path) selected_torch_dtype: Final[torch.dtype] = torch.float16 if torch_dtype == "float16" else torch.bfloat16 @@ -254,8 +256,9 @@ def evaluate_rag( generated_answers_for_eval.extend(batch_answers) if not evaluate_generator: - print_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit) - return + eval_results = calc_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit) + print_eval_results(eval_results) + return eval_results # TODO: imperative style code, refactor in future but works for now # If there are any leftover batches to query @@ -275,9 +278,11 @@ def evaluate_rag( if generated_answer_string == answer: total_em_hit += 1 - print_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit) + eval_results = calc_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit) + print_eval_results(eval_results) print("Generator evaluation:") print("Exact match:", total_em_hit / len(processed_datasets)) + return eval_results def main() -> None: diff --git a/dalm/eval/eval_results.py b/dalm/eval/eval_results.py new file mode 100644 index 0000000..05a0691 --- /dev/null +++ b/dalm/eval/eval_results.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class EvalResults(BaseModel): + total_examples: int + recall: float + precision: float + hit_rate: float diff --git a/dalm/eval/eval_retriever_only.py b/dalm/eval/eval_retriever_only.py index 6cbb151..bdd26e4 100644 --- a/dalm/eval/eval_retriever_only.py +++ b/dalm/eval/eval_retriever_only.py @@ -10,6 +10,8 @@ from datasets import Dataset from torch.utils.data import DataLoader +from dalm.eval.eval_results import EvalResults + from dalm.eval.utils import ( construct_search_index, @@ -18,6 +20,7 @@ get_passage_embeddings, evaluate_retriever_on_batch, print_eval_results, + calc_eval_results, ) from dalm.models.retriever_only_base_model import AutoModelForSentenceEmbedding from dalm.utils import load_dataset @@ -112,7 +115,7 @@ def evaluate_retriever( torch_dtype: Literal["float16", "bfloat16"] = "float16", top_k: int = 10, is_autoregressive: bool = False, -) -> None: +) -> EvalResults: """Runs rag evaluation. See `dalm eval-retriever --help for details on params""" test_dataset = load_dataset(dataset_or_path) selected_torch_dtype: Final[torch.dtype] = torch.float16 if torch_dtype == "float16" else torch.bfloat16 @@ -170,7 +173,9 @@ def evaluate_retriever( batch_recall.extend(_batch_recall) total_hit += _total_hit - print_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit) + eval_results = calc_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit) + print_eval_results(eval_results) + return eval_results def main() -> None: diff --git a/dalm/eval/utils.py b/dalm/eval/utils.py index 8a51aa3..3e7eff1 100644 --- a/dalm/eval/utils.py +++ b/dalm/eval/utils.py @@ -10,6 +10,8 @@ from tqdm.auto import tqdm from transformers import PreTrainedTokenizer, default_data_collator +from dalm.eval.eval_results import EvalResults + logger = logging.getLogger(__name__) @@ -270,18 +272,24 @@ def evaluate_retriever_on_batch( return batch_precision, batch_recall, total_hit, top_passages -def print_eval_results( +def calc_eval_results( total_examples: int, precisions: list[float], recalls: list[float], total_hit: int, -) -> None: +) -> EvalResults: precision = sum(precisions) / total_examples recall = sum(recalls) / total_examples hit_rate = total_hit / float(total_examples) + return EvalResults(total_examples=total_examples, recall=recall, precision=precision, hit_rate=hit_rate) + + +def print_eval_results( + eval_results: EvalResults, +) -> None: logger.info("Retriever results:") - logger.info(f"Recall: {recall}") - logger.info(f"Precision: {precision}") - logger.info(f"Hit Rate: {hit_rate}") + logger.info(f"Recall: {eval_results.recall}") + logger.info(f"Precision: {eval_results.precision}") + logger.info(f"Hit Rate: {eval_results.hit_rate}") logger.info("*************") diff --git a/pyproject.toml b/pyproject.toml index 56831bc..0fbf96b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "diffusers", "bitsandbytes", "typer>=0.9.0,<1.0", + "pydantic==1.10.9", # Sync w/ other platform components ] [project.scripts]