From dc408e88a174e0126fa1a08be04a82ac4f516cb7 Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Wed, 9 Oct 2024 14:50:42 +0530 Subject: [PATCH] updated the docs. --- langtest/transform/accuracy.py | 11 +++++++++-- langtest/utils/util_metrics.py | 16 +++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/langtest/transform/accuracy.py b/langtest/transform/accuracy.py index b456b40f3..7563d3999 100644 --- a/langtest/transform/accuracy.py +++ b/langtest/transform/accuracy.py @@ -2,7 +2,7 @@ from collections import defaultdict import pandas as pd from abc import ABC, abstractmethod -from typing import Any, DefaultDict, Dict, List, Type +from typing import Any, DefaultDict, Dict, List, Type, Union from langtest.modelhandler.modelhandler import ModelAPI from langtest.transform.base import ITests @@ -1213,9 +1213,16 @@ async def run( return [] @staticmethod - def preprocess(y_true, y_pred): + def preprocess(y_true: List, y_pred: List): """ Preprocesses the input data for the degradation analysis. + + Args: + y_true (List): The true labels. + y_pred (List): The predicted labels. + + Returns: + Tuple[pd.Series, pd.Series]: The preprocessed true and predicted labels. """ if isinstance(y_true, list): diff --git a/langtest/utils/util_metrics.py b/langtest/utils/util_metrics.py index e55d038ab..0dfb3ef9a 100644 --- a/langtest/utils/util_metrics.py +++ b/langtest/utils/util_metrics.py @@ -407,9 +407,19 @@ def calculate_f1_score_multi_label( def combine_labels(labels: List[str]) -> List[str]: """ - Combines labels for degradation analysis. - input labels: ["B-ORG", "I-ORG", "B-PER", "I-PER"] - output labels: ["ORG", "PER"] + Combine labels by removing the BIO tags and keeping only the entity type. + + Args: + labels (List[str]): List of strings or a string. Labels can be in the format of BIO tags. + Example: ["B-ORG", "I-ORG", "B-PER", "I-PER"] + + Raises: + ValueError: If the input is not a list or a string. + + + Returns: + labels (List[str]): List of entity types without the BIO tags. + """ try: output_list = []