From 551cc120ca39f3a14d3ef22592d233425c22ec06 Mon Sep 17 00:00:00 2001 From: Kalyan Chakravarthy Date: Fri, 20 Sep 2024 22:18:25 +0530 Subject: [PATCH] jailbreak and injection tests supports for text-classification. --- langtest/transform/safety.py | 10 +++++----- langtest/utils/custom_types/output.py | 4 ++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/langtest/transform/safety.py b/langtest/transform/safety.py index 5eb028b66..08caecaa9 100644 --- a/langtest/transform/safety.py +++ b/langtest/transform/safety.py @@ -178,7 +178,7 @@ def transform(self, count: int = 50, *args, **kwargs) -> List[Sample]: class InjectionProbalities(BaseSafetyTest): alias_name = "injection_probalities_score" - supported_tasks = ["question-answering"] + supported_tasks = ["text-classification", "question-answering"] """ Injection Probabilities Score test. """ @@ -209,10 +209,10 @@ async def run(self, sample_list: List[Sample], *args, **kwargs) -> List[Sample]: progress = kwargs.get("progress_bar", False) for sample in sample_list: - if isinstance(sample, samples.QASample): + if isinstance(sample, samples.BaseQASample): text = sample.get_prompt() - elif isinstance(sample, samples.NERSample): - text = sample + sample.original + elif isinstance(sample, samples.BaseSample): + text = sample.original result = prompt_guard.get_indirect_injection_score(text) @@ -227,7 +227,7 @@ async def run(self, sample_list: List[Sample], *args, **kwargs) -> List[Sample]: class JailBreakProbalities(BaseSafetyTest): alias_name = "jailbreak_probalities_score" - supported_tasks = ["question-answering"] + supported_tasks = ["text-classification", "question-answering"] """ Jailbreak Probabilities test. """ diff --git a/langtest/utils/custom_types/output.py b/langtest/utils/custom_types/output.py index da3c0d5f9..619a71fcb 100644 --- a/langtest/utils/custom_types/output.py +++ b/langtest/utils/custom_types/output.py @@ -80,6 +80,10 @@ def __str__(self) -> str: """String representation""" return f"{self.max_score:.3f}" + def __eq__(self, other: "MaxScoreOutput") -> bool: + """Greater than comparison method.""" + return self.max_score >= other.max_score + def __ge__(self, other: "MaxScoreOutput") -> bool: """Greater than comparison method.""" return self.max_score >= other.max_score