Skip to content

Commit

Permalink
jailbreak and injection tests supports for text-classification.
Browse files Browse the repository at this point in the history
  • Loading branch information
chakravarthik27 committed Sep 20, 2024
1 parent 90e902f commit 551cc12
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
10 changes: 5 additions & 5 deletions langtest/transform/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def transform(self, count: int = 50, *args, **kwargs) -> List[Sample]:

class InjectionProbalities(BaseSafetyTest):
alias_name = "injection_probalities_score"
supported_tasks = ["question-answering"]
supported_tasks = ["text-classification", "question-answering"]
""" Injection Probabilities Score test.
"""

Expand Down Expand Up @@ -209,10 +209,10 @@ async def run(self, sample_list: List[Sample], *args, **kwargs) -> List[Sample]:
progress = kwargs.get("progress_bar", False)

for sample in sample_list:
if isinstance(sample, samples.QASample):
if isinstance(sample, samples.BaseQASample):
text = sample.get_prompt()
elif isinstance(sample, samples.NERSample):
text = sample + sample.original
elif isinstance(sample, samples.BaseSample):
text = sample.original

result = prompt_guard.get_indirect_injection_score(text)

Expand All @@ -227,7 +227,7 @@ async def run(self, sample_list: List[Sample], *args, **kwargs) -> List[Sample]:

class JailBreakProbalities(BaseSafetyTest):
alias_name = "jailbreak_probalities_score"
supported_tasks = ["question-answering"]
supported_tasks = ["text-classification", "question-answering"]
""" Jailbreak Probabilities test.
"""

Expand Down
4 changes: 4 additions & 0 deletions langtest/utils/custom_types/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def __str__(self) -> str:
"""String representation"""
return f"{self.max_score:.3f}"

def __eq__(self, other: "MaxScoreOutput") -> bool:
"""Greater than comparison method."""
return self.max_score >= other.max_score

def __ge__(self, other: "MaxScoreOutput") -> bool:
"""Greater than comparison method."""
return self.max_score >= other.max_score
Expand Down

0 comments on commit 551cc12

Please sign in to comment.