Skip to content

Commit

Permalink
Return evaluation results to callers (#71)
Browse files Browse the repository at this point in the history
* Add EvalResults struct and return it

* Sync pydantic version w platform
  • Loading branch information
tleyden authored Oct 26, 2023
1 parent f07a80d commit e6c3d29
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 11 deletions.
13 changes: 9 additions & 4 deletions dalm/eval/eval_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
PreTrainedTokenizer,
)

from dalm.eval.eval_results import EvalResults
from dalm.eval.utils import (
calc_eval_results,
construct_search_index,
evaluate_retriever_on_batch,
get_passage_embeddings,
Expand Down Expand Up @@ -180,7 +182,7 @@ def evaluate_rag(
top_k: int = 10,
evaluate_generator: bool = True,
retriever_is_autoregressive: bool = False,
) -> None:
) -> EvalResults:
"""Runs rag evaluation. See `dalm eval-rag --help for details on params"""
test_dataset = load_dataset(dataset_or_path)
selected_torch_dtype: Final[torch.dtype] = torch.float16 if torch_dtype == "float16" else torch.bfloat16
Expand Down Expand Up @@ -254,8 +256,9 @@ def evaluate_rag(
generated_answers_for_eval.extend(batch_answers)

if not evaluate_generator:
print_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
return
eval_results = calc_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
print_eval_results(eval_results)
return eval_results

# TODO: imperative style code, refactor in future but works for now
# If there are any leftover batches to query
Expand All @@ -275,9 +278,11 @@ def evaluate_rag(
if generated_answer_string == answer:
total_em_hit += 1

print_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
eval_results = calc_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
print_eval_results(eval_results)
print("Generator evaluation:")
print("Exact match:", total_em_hit / len(processed_datasets))
return eval_results


def main() -> None:
Expand Down
8 changes: 8 additions & 0 deletions dalm/eval/eval_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pydantic import BaseModel


class EvalResults(BaseModel):
total_examples: int
recall: float
precision: float
hit_rate: float
9 changes: 7 additions & 2 deletions dalm/eval/eval_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from datasets import Dataset
from torch.utils.data import DataLoader
from dalm.eval.eval_results import EvalResults


from dalm.eval.utils import (
construct_search_index,
Expand All @@ -18,6 +20,7 @@
get_passage_embeddings,
evaluate_retriever_on_batch,
print_eval_results,
calc_eval_results,
)
from dalm.models.retriever_only_base_model import AutoModelForSentenceEmbedding
from dalm.utils import load_dataset
Expand Down Expand Up @@ -112,7 +115,7 @@ def evaluate_retriever(
torch_dtype: Literal["float16", "bfloat16"] = "float16",
top_k: int = 10,
is_autoregressive: bool = False,
) -> None:
) -> EvalResults:
"""Runs rag evaluation. See `dalm eval-retriever --help for details on params"""
test_dataset = load_dataset(dataset_or_path)
selected_torch_dtype: Final[torch.dtype] = torch.float16 if torch_dtype == "float16" else torch.bfloat16
Expand Down Expand Up @@ -170,7 +173,9 @@ def evaluate_retriever(
batch_recall.extend(_batch_recall)
total_hit += _total_hit

print_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
eval_results = calc_eval_results(len(processed_datasets), batch_precision, batch_recall, total_hit)
print_eval_results(eval_results)
return eval_results


def main() -> None:
Expand Down
18 changes: 13 additions & 5 deletions dalm/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer, default_data_collator

from dalm.eval.eval_results import EvalResults

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -270,18 +272,24 @@ def evaluate_retriever_on_batch(
return batch_precision, batch_recall, total_hit, top_passages


def print_eval_results(
def calc_eval_results(
total_examples: int,
precisions: list[float],
recalls: list[float],
total_hit: int,
) -> None:
) -> EvalResults:
precision = sum(precisions) / total_examples
recall = sum(recalls) / total_examples
hit_rate = total_hit / float(total_examples)

return EvalResults(total_examples=total_examples, recall=recall, precision=precision, hit_rate=hit_rate)


def print_eval_results(
eval_results: EvalResults,
) -> None:
logger.info("Retriever results:")
logger.info(f"Recall: {recall}")
logger.info(f"Precision: {precision}")
logger.info(f"Hit Rate: {hit_rate}")
logger.info(f"Recall: {eval_results.recall}")
logger.info(f"Precision: {eval_results.precision}")
logger.info(f"Hit Rate: {eval_results.hit_rate}")
logger.info("*************")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"diffusers",
"bitsandbytes",
"typer>=0.9.0,<1.0",
"pydantic==1.10.9", # Sync w/ other platform components
]

[project.scripts]
Expand Down

0 comments on commit e6c3d29

Please sign in to comment.