Skip to content

Commit

Permalink
Merge pull request #1096 from JohnSnowLabs/feature/add-support-for-th…
Browse files Browse the repository at this point in the history
…e-multi-label-classification-model

Feature/add support for the multi label classification model
  • Loading branch information
chakravarthik27 authored Sep 3, 2024
2 parents cc821c9 + 258a0f7 commit 23eb0c3
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 15 deletions.
2 changes: 2 additions & 0 deletions langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,8 @@ def _import_data(self, file_name, **kwargs) -> List[Sample]:
import ast

i["transformations"] = ast.literal_eval(temp)
elif "transformations" in i:
i.pop("transformations")
sample = self.task.get_sample_class(**i)
samples.append(sample)

Expand Down
15 changes: 12 additions & 3 deletions langtest/datahandler/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,18 @@ def to_csv(sample: SequenceClassificationSample) -> Tuple[str, str]:
Tuple[str, str]:
Row formatted as a list of strings.
"""
if sample.test_case:
return [sample.test_case, sample.expected_results.predictions[0].label]
return [sample.original, sample.expected_results.predictions[0].label]
predictions = sample.expected_results.predictions
multi_label = sample.expected_results.multi_label

if multi_label:
return [
sample.test_case or sample.original,
[elt.label for elt in predictions] if predictions else [],
]
else:
if sample.test_case:
return [sample.test_case, sample.expected_results.predictions[0].label]
return [sample.original, sample.expected_results.predictions[0].label]


class NEROutputFormatter(BaseFormatter):
Expand Down
28 changes: 24 additions & 4 deletions langtest/modelhandler/jsl_modelhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
XlmRoBertaForSequenceClassification,
XlnetForSequenceClassification,
MarianTransformer,
MultiClassifierDLModel,
)
from sparknlp.base import LightPipeline
from sparknlp.pretrained import PretrainedPipeline
Expand All @@ -63,6 +64,7 @@

SUPPORTED_SPARKNLP_CLASSIFERS.extend(
[
MultiClassifierDLModel,
ClassifierDLModel,
SentimentDLModel,
AlbertForSequenceClassification,
Expand Down Expand Up @@ -409,6 +411,7 @@ def __init__(
super().__init__(model)

_classifier = None
self.multi_label_classifier = False
for annotator in self.model.stages:
if self.is_classifier(annotator):
_classifier = annotator
Expand All @@ -417,6 +420,10 @@ def __init__(
if _classifier is None:
raise ValueError(Errors.E040(var="classifier"))

if isinstance(_classifier, MultiClassifierDLModel):
self.multi_label_classifier = True
self.threshold = _classifier.getThreshold()

self.output_col = _classifier.getOutputCol()
self.classes = _classifier.getClasses()
self.model = LightPipeline(self.model)
Expand All @@ -442,10 +449,23 @@ def predict(
Returns:
SequenceClassificationOutput: Classification output from SparkNLP LightPipeline.
"""
prediction_metadata = self.model.fullAnnotate(text)[0][self.output_col][
0
].metadata
prediction = [{"label": x, "score": y} for x, y in prediction_metadata.items()]
prediction_metadata = self.model.fullAnnotate(text)[0][self.output_col]
prediction = []

if len(prediction_metadata) > 0:
prediction_metadata = prediction_metadata[0].metadata
prediction = [
{"label": x, "score": y} for x, y in prediction_metadata.items()
]

if self.multi_label_classifier:
prediction = [x for x in prediction if float(x["score"]) > self.threshold]

return SequenceClassificationOutput(
text=text,
predictions=prediction,
multi_label=self.multi_label_classifier,
)

if not return_all_scores:
prediction = [max(prediction, key=lambda x: x["score"])]
Expand Down
18 changes: 16 additions & 2 deletions langtest/tasks/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import re
from abc import ABC, abstractmethod
from typing import Union
Expand Down Expand Up @@ -267,17 +268,28 @@ def create_sample(
row_data: dict,
feature_column="text",
target_column: Union[samples.SequenceLabel, str] = "label",
multi_label: bool = False,
*args,
**kwargs,
) -> samples.SequenceClassificationSample:
"""Create a sample."""
keys = list(row_data.keys())
# auto-detect the default column names from the row_data
column_mapper = cls.column_mapping(keys, [feature_column, target_column])

# is multi-label classification
# if "multi_label" in kwargs:
# multi_label = kwargs.get("multi_label", False)
# kwargs.pop("multi_label")

labels = row_data.get(column_mapper[target_column])

if isinstance(labels, samples.SequenceLabel):
labels = [labels]
elif isinstance(labels, list):
elif isinstance(labels, list) or isinstance(labels, str):
labels = ast.literal_eval(labels)
if not isinstance(labels, list):
labels = [labels]
labels = [
samples.SequenceLabel(label=label, score=1.0)
if isinstance(label, str)
Expand All @@ -289,7 +301,9 @@ def create_sample(

return samples.SequenceClassificationSample(
original=row_data[column_mapper[feature_column]],
expected_results=samples.SequenceClassificationOutput(predictions=labels),
expected_results=samples.SequenceClassificationOutput(
predictions=labels, multi_label=multi_label
),
)


Expand Down
22 changes: 16 additions & 6 deletions langtest/utils/custom_types/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,35 @@ class SequenceClassificationOutput(BaseModel):
"""Output model for text classification tasks."""

predictions: List[SequenceLabel]
multi_label: bool = False

def to_str_list(self) -> str:
"""Convert the output into list of strings.
Returns:
List[str]: predictions in form of a list of strings.
"""
return ",".join([x.label for x in self.predictions])
return ", ".join([x.label for x in self.predictions])

def __str__(self):
def __str__(self) -> str:
"""String representation"""
labels = {elt.label: elt.score for elt in self.predictions}
return f"SequenceClassificationOutput(predictions={labels})"

def __eq__(self, other):
def __eq__(self, other: "SequenceClassificationOutput") -> bool:
"""Equality comparison method."""
top_class = max(self.predictions, key=lambda x: x.score).label
other_top_class = max(other.predictions, key=lambda x: x.score).label
return top_class == other_top_class

if self.multi_label:
# get all labels
self_labels = {elt.label for elt in self.predictions}
other_labels = {elt.label for elt in other.predictions}
return set(self_labels) == set(other_labels)
elif len(self.predictions) == 0 and len(other.predictions) == 0:
return True
else:
top_class = max(self.predictions, key=lambda x: x.score).label
other_top_class = max(other.predictions, key=lambda x: x.score).label
return top_class == other_top_class


class MinScoreOutput(BaseModel):
Expand Down

0 comments on commit 23eb0c3

Please sign in to comment.