diff --git a/langtest/datahandler/datasource.py b/langtest/datahandler/datasource.py index 868a4152a..1d89303ae 100644 --- a/langtest/datahandler/datasource.py +++ b/langtest/datahandler/datasource.py @@ -957,6 +957,8 @@ def _import_data(self, file_name, **kwargs) -> List[Sample]: import ast i["transformations"] = ast.literal_eval(temp) + elif "transformations" in i: + i.pop("transformations") sample = self.task.get_sample_class(**i) samples.append(sample) diff --git a/langtest/datahandler/format.py b/langtest/datahandler/format.py index 0755108f0..621fe34e0 100644 --- a/langtest/datahandler/format.py +++ b/langtest/datahandler/format.py @@ -108,9 +108,18 @@ def to_csv(sample: SequenceClassificationSample) -> Tuple[str, str]: Tuple[str, str]: Row formatted as a list of strings. """ - if sample.test_case: - return [sample.test_case, sample.expected_results.predictions[0].label] - return [sample.original, sample.expected_results.predictions[0].label] + predictions = sample.expected_results.predictions + multi_label = sample.expected_results.multi_label + + if multi_label: + return [ + sample.test_case or sample.original, + [elt.label for elt in predictions] if predictions else [], + ] + else: + if sample.test_case: + return [sample.test_case, sample.expected_results.predictions[0].label] + return [sample.original, sample.expected_results.predictions[0].label] class NEROutputFormatter(BaseFormatter): diff --git a/langtest/modelhandler/jsl_modelhandler.py b/langtest/modelhandler/jsl_modelhandler.py index f13b18d32..0b703d637 100644 --- a/langtest/modelhandler/jsl_modelhandler.py +++ b/langtest/modelhandler/jsl_modelhandler.py @@ -42,6 +42,7 @@ XlmRoBertaForSequenceClassification, XlnetForSequenceClassification, MarianTransformer, + MultiClassifierDLModel, ) from sparknlp.base import LightPipeline from sparknlp.pretrained import PretrainedPipeline @@ -63,6 +64,7 @@ SUPPORTED_SPARKNLP_CLASSIFERS.extend( [ + MultiClassifierDLModel, ClassifierDLModel, SentimentDLModel, AlbertForSequenceClassification, @@ -409,6 +411,7 @@ def __init__( super().__init__(model) _classifier = None + self.multi_label_classifier = False for annotator in self.model.stages: if self.is_classifier(annotator): _classifier = annotator @@ -417,6 +420,10 @@ def __init__( if _classifier is None: raise ValueError(Errors.E040(var="classifier")) + if isinstance(_classifier, MultiClassifierDLModel): + self.multi_label_classifier = True + self.threshold = _classifier.getThreshold() + self.output_col = _classifier.getOutputCol() self.classes = _classifier.getClasses() self.model = LightPipeline(self.model) @@ -442,10 +449,23 @@ def predict( Returns: SequenceClassificationOutput: Classification output from SparkNLP LightPipeline. """ - prediction_metadata = self.model.fullAnnotate(text)[0][self.output_col][ - 0 - ].metadata - prediction = [{"label": x, "score": y} for x, y in prediction_metadata.items()] + prediction_metadata = self.model.fullAnnotate(text)[0][self.output_col] + prediction = [] + + if len(prediction_metadata) > 0: + prediction_metadata = prediction_metadata[0].metadata + prediction = [ + {"label": x, "score": y} for x, y in prediction_metadata.items() + ] + + if self.multi_label_classifier: + prediction = [x for x in prediction if float(x["score"]) > self.threshold] + + return SequenceClassificationOutput( + text=text, + predictions=prediction, + multi_label=self.multi_label_classifier, + ) if not return_all_scores: prediction = [max(prediction, key=lambda x: x["score"])] diff --git a/langtest/tasks/task.py b/langtest/tasks/task.py index 035725bb8..93af99114 100644 --- a/langtest/tasks/task.py +++ b/langtest/tasks/task.py @@ -1,3 +1,4 @@ +import ast import re from abc import ABC, abstractmethod from typing import Union @@ -267,17 +268,28 @@ def create_sample( row_data: dict, feature_column="text", target_column: Union[samples.SequenceLabel, str] = "label", + multi_label: bool = False, + *args, + **kwargs, ) -> samples.SequenceClassificationSample: """Create a sample.""" keys = list(row_data.keys()) # auto-detect the default column names from the row_data column_mapper = cls.column_mapping(keys, [feature_column, target_column]) + # is multi-label classification + # if "multi_label" in kwargs: + # multi_label = kwargs.get("multi_label", False) + # kwargs.pop("multi_label") + labels = row_data.get(column_mapper[target_column]) if isinstance(labels, samples.SequenceLabel): labels = [labels] - elif isinstance(labels, list): + elif isinstance(labels, list) or isinstance(labels, str): + labels = ast.literal_eval(labels) + if not isinstance(labels, list): + labels = [labels] labels = [ samples.SequenceLabel(label=label, score=1.0) if isinstance(label, str) @@ -289,7 +301,9 @@ def create_sample( return samples.SequenceClassificationSample( original=row_data[column_mapper[feature_column]], - expected_results=samples.SequenceClassificationOutput(predictions=labels), + expected_results=samples.SequenceClassificationOutput( + predictions=labels, multi_label=multi_label + ), ) diff --git a/langtest/utils/custom_types/output.py b/langtest/utils/custom_types/output.py index bcd1e4cf0..6961e4b0f 100644 --- a/langtest/utils/custom_types/output.py +++ b/langtest/utils/custom_types/output.py @@ -8,6 +8,7 @@ class SequenceClassificationOutput(BaseModel): """Output model for text classification tasks.""" predictions: List[SequenceLabel] + multi_label: bool = False def to_str_list(self) -> str: """Convert the output into list of strings. @@ -15,18 +16,27 @@ def to_str_list(self) -> str: Returns: List[str]: predictions in form of a list of strings. """ - return ",".join([x.label for x in self.predictions]) + return ", ".join([x.label for x in self.predictions]) - def __str__(self): + def __str__(self) -> str: """String representation""" labels = {elt.label: elt.score for elt in self.predictions} return f"SequenceClassificationOutput(predictions={labels})" - def __eq__(self, other): + def __eq__(self, other: "SequenceClassificationOutput") -> bool: """Equality comparison method.""" - top_class = max(self.predictions, key=lambda x: x.score).label - other_top_class = max(other.predictions, key=lambda x: x.score).label - return top_class == other_top_class + + if self.multi_label: + # get all labels + self_labels = {elt.label for elt in self.predictions} + other_labels = {elt.label for elt in other.predictions} + return set(self_labels) == set(other_labels) + elif len(self.predictions) == 0 and len(other.predictions) == 0: + return True + else: + top_class = max(self.predictions, key=lambda x: x.score).label + other_top_class = max(other.predictions, key=lambda x: x.score).label + return top_class == other_top_class class MinScoreOutput(BaseModel):