Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/release/2.5.0' into chore/final_…
Browse files Browse the repository at this point in the history
…website_updates
  • Loading branch information
chakravarthik27 committed Dec 9, 2024
2 parents 5324ccb + 0742714 commit b9fbae4
Show file tree
Hide file tree
Showing 48 changed files with 3,533 additions and 52,169 deletions.
2 changes: 1 addition & 1 deletion langtest/augmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, TypedDict, Union
import os

from pydantic import BaseModel, validator
from pydantic.v1 import BaseModel, validator
from langtest.logger import logger


Expand Down
202 changes: 200 additions & 2 deletions langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def __init_subclass__(cls, **kwargs):
import pandas as pd

dataset_cls = cls.__name__.replace("Dataset", "").lower()

if dataset_cls in ["deltalivetables"]:
dataset_cls = "delta_live_tables"

if dataset_cls == "pandas":
extensions = [
i.replace("read_", "")
Expand Down Expand Up @@ -192,6 +196,7 @@ def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) ->
raise ValueError(Errors.E025())
self._custom_label = file_path.copy()
self._file_path = file_path.get("data_source")
self.file_ext = file_path.get("source", None)
self._size = None

self.datasets_with_jsonl_extension = []
Expand All @@ -209,7 +214,7 @@ def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) ->
if isinstance(self._file_path, str):
_, self.file_ext = os.path.splitext(self._file_path)

if len(self.file_ext) > 0:
if len(self.file_ext) > 0 and "source" not in file_path:
self.file_ext = self.file_ext.replace(".", "")
elif "source" in file_path:
self.file_ext = file_path["source"]
Expand Down Expand Up @@ -255,7 +260,7 @@ def load(self) -> List[Sample]:
list[Sample]: Loaded text data.
"""

if self.file_ext in ("csv", "huggingface"):
if self.file_ext in ("csv", "huggingface", "spark"):
self.init_cls = self.data_sources[self.file_ext.replace(".", "")](
self._custom_label, task=self.task, **self.kwargs
)
Expand Down Expand Up @@ -1890,3 +1895,196 @@ def renamed_extensions(self, inverted: bool = False) -> Dict[str, str]:
"hdf5": "hdf",
}
return ext_map


class SparkDataset(BaseDataset):
"""Class to handle Spark datasets. Subclass of BaseDataset."""

supported_tasks = [
"ner",
"text-classification",
"question-answering",
"summarization",
"toxicity",
"translation",
"security",
"clinical",
"disinformation",
"sensitivity",
"wino-bias",
"legal",
]

def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) -> None:
"""
Initializes a SparkDataset object.
Args:
file_path (str):
The path to the data file.
task (str):
Task to be evaluated on.
**kwargs:
"""
from pyspark.sql import SparkSession
import string
import random

super().__init__()
self._file_path = file_path
self.task = task

if isinstance(file_path, dict):
self.spark_session: SparkSession = file_path.get("spark_session", None)
self.format = kwargs.get("format", "csv")
self.kwargs = kwargs

if self.spark_session is None:
random_str = "langtest_" + "".join(
random.choices(string.ascii_lowercase, k=5)
)
self.spark_session = SparkSession.builder.appName(random_str).getOrCreate()

def load_raw_data(self) -> List[Dict]:
"""
Load data from a file into raw lists of strings
Returns:
List[Dict]:
parsed file into list of dicts
"""
df = self.spark_session.read.csv(self._file_path, header=True, inferSchema=True)
data = df.collect()
return data

def load_data(self) -> List[Sample]:
"""
Load data from a any file and preprocess it based on the specified task.
Returns:
List[Sample]: A list of preprocessed data samples.
"""
from pyspark.sql import DataFrame

if isinstance(self._file_path.get("data_source", None), DataFrame):
df = self._file_path.get("data_source", [])
column_names = self._file_path.get("column_names", {})

elif isinstance(self._file_path, dict):
self.default_params = self._file_path
self.format = self._file_path.get("format", "csv")
self._file_path = self._file_path.get("data_source", self._file_path)

# df = self.spark_session.read.csv(self._file_path, header=True, inferSchema=True)
if hasattr(self.spark_session.read, self.format):
df: DataFrame = getattr(self.spark_session.read, self.format)(
self._file_path, header=True, inferSchema=True
)
else:
raise ValueError(
Errors.E027(format=self.format)
+ f" for {self._file_path} is not supported."
)

column_names = self.default_params

# remove the data_source key from the column_names dict
if isinstance(column_names, dict):
column_names.pop("data_source")
column_names.pop("source")
column_names.pop("spark_session")
else:
column_names = dict()

# generate the sample data
data = []

for idx, row_data in enumerate(df.toPandas().to_dict(orient="records")):
try:
sample = self.task.create_sample(
row_data,
**column_names,
)
data.append(sample)

except Exception as e:
logging.warning(Warnings.W005(idx=idx, row_data=row_data, e=e))
continue

return data

def export_data(self, data: List[Sample], output_path: str):
"""Exports the data to the corresponding format and saves it to 'output_path'."""
raise NotImplementedError()


class DeltaLiveTablesDataset(BaseDataset):
"""A class to handle datasets from the Delta Live Tables(DLT)."""

supported_tasks = [
"ner",
"text-classification",
"question-answering",
"summarization",
"toxicity",
"translation",
"security",
]

def __init__(self, file_path: str, task: TaskManager, **kwargs) -> None:
"""
Initializes a DltDataset object.
Args:
file_path (str):
The path to the data file.
task (str):
Task to be evaluated on.
**kwargs:
"""
super().__init__()
self._file_path = file_path
self.task = task
self.kwargs = kwargs

def load_raw_data(self) -> List[Dict]:
"""
Load data from a file into raw lists of strings
Returns:
List[Dict]:
parsed file into list of dicts
"""
raise NotImplementedError()

def load_data(self) -> List[Sample]:
"""
Load data from a any file or dlt wrapper and preprocess it based on the specified task.
Returns:
List[Sample]: A list of preprocessed data samples.
"""
from pyspark.sql import DataFrame

if not isinstance(self._file_path, DataFrame):
raise ValueError(
"file_path should be a Spark DataFrame representing the DLT table"
)

df: DataFrame = self._file_path
data = []

for idx, row_data in enumerate(df.toPandas().to_dict(orient="records")):
try:
sample = self.task.create_sample(row_data)
data.append(sample)
except Exception as e:
logging.warning(Warnings.W005(idx=idx, row_data=row_data, e=e))
continue

self.dataset_size = len(data)
return data

def export_data(self, data: List[Sample], output_path: str):
"""Exports the data to the corresponding format and saves it to 'output_path'."""
raise NotImplementedError()
44 changes: 34 additions & 10 deletions langtest/langtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from pkg_resources import resource_filename

from langtest.types import DatasetConfig, HarnessConfig, ModelConfig

from .tasks import TaskManager
from .augmentation import AugmentRobustness, TemplaticAugment
from .datahandler.datasource import DataFactory
Expand Down Expand Up @@ -90,20 +92,20 @@ class Harness:
def __init__(
self,
task: Union[str, dict],
model: Optional[Union[list, dict]] = None,
data: Optional[Union[list, dict]] = None,
config: Optional[Union[str, dict]] = None,
model: Optional[Union[List[ModelConfig], ModelConfig]] = None,
data: Optional[Union[List[DatasetConfig], DatasetConfig]] = None,
config: Optional[Union[HarnessConfig, str]] = None,
benchmarking: dict = None,
):
"""Initialize the Harness object.
Args:
task (str, optional): Task for which the model is to be evaluated.
model (list | dict, optional): Specifies the model to be evaluated.
model (ModelConfig, list | dict, optional): Specifies the model to be evaluated.
If provided as a list, each element should be a dictionary with 'model' and 'hub' keys.
If provided as a dictionary, it must contain 'model' and 'hub' keys when specifying a path.
data (dict, optional): The data to be used for evaluation.
config (str | dict, optional): Configuration for the tests to be performed.
data (DatasetConfig, dict, optional): The data to be used for evaluation.
config (str | HarnessConfig , optional): Configuration for the tests to be performed.
Raises:
ValueError: Invalid arguments.
Expand Down Expand Up @@ -156,11 +158,12 @@ def __init__(
raise ValueError(Errors.E003())

if isinstance(model, dict):
hub, model = model["hub"], model["model"]
hub, model, model_type = model["hub"], model["model"], model.get("type")
self.hub = hub
self._actual_model = model
else:
hub = None
model_type = None

# loading task

Expand Down Expand Up @@ -215,14 +218,14 @@ def __init__(
hub = i["hub"]

model_dict[model] = self.task.model(
model, hub, **self._config.get("model_parameters", {})
model, hub, model_type, **self._config.get("model_parameters", {})
)

self.model = model_dict

else:
self.model = self.task.model(
model, hub, **self._config.get("model_parameters", {})
model, hub, model_type, **self._config.get("model_parameters", {})
)
# end model selection
formatted_config = json.dumps(self._config, indent=1)
Expand All @@ -248,7 +251,7 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return object.__repr__(self)

def configure(self, config: Union[str, dict]) -> dict:
def configure(self, config: Union[HarnessConfig, dict, str]) -> HarnessConfig:
"""Configure the Harness with a given configuration.
Args:
Expand Down Expand Up @@ -487,6 +490,15 @@ def report(
if self._generated_results is None:
raise RuntimeError(Errors.E011())

# plot the degradation analysis results
if (
"accuracy" in self._config["tests"]
and "degradation_analysis" in self._config["tests"]["accuracy"]
):
from langtest.transform.accuracy import DegradationAnalysis

DegradationAnalysis.show_results()

if isinstance(self._config, dict):
self.default_min_pass_dict = self._config["tests"]["defaults"].get(
"min_pass_rate", 0.65
Expand Down Expand Up @@ -730,6 +742,12 @@ def generated_results(self) -> Optional[pd.DataFrame]:
columns = [c for c in column_order if c in generated_results_df.columns]
generated_results_df = generated_results_df[columns]

if "degradation_analysis" in generated_results_df["test_type"].unique():
# drop the rows with test_type as 'degradation_analysis'
generated_results_df = generated_results_df[
generated_results_df["test_type"] != "degradation_analysis"
]

return generated_results_df.fillna("-")

def augment(
Expand Down Expand Up @@ -980,6 +998,12 @@ def testcases(self, additional_cols=False) -> pd.DataFrame:
columns = [c for c in column_order if c in testcases_df.columns]
testcases_df = testcases_df[columns]

if "degradation_analysis" in testcases_df["test_type"].unique():
# drop the rows with test_type as 'degradation_analysis'
testcases_df = testcases_df[
testcases_df["test_type"] != "degradation_analysis"
]

return testcases_df.fillna("-")

def save(self, save_dir: str, include_generated_results: bool = False) -> None:
Expand Down
Loading

0 comments on commit b9fbae4

Please sign in to comment.