Skip to content

Commit

Permalink
* refutation scanner for the output
Browse files Browse the repository at this point in the history
  • Loading branch information
asofter committed Aug 12, 2023
1 parent 065d984 commit 070569b
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 4 deletions.
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
-
- [Refutation output scanner](./docs/output_scanners/refutation.md)

### Fixed
-

### Changed
- [Anonymize prompt scanner] Using the transformer based Spacy model `en_core_web_trf` ([reference](https://microsoft.github.io/presidio/analyzer/nlp_engines/spacy_stanza/))
- [Anonymize prompt scanner] Supporting faker for applicable entities instead of placeholder (`use_faker` parameter)
- [Jailbreak prompt scanner] Updated dataset with more examples, removed duplicates
- **Anonymize prompt scanner**: Using the transformer based Spacy model `en_core_web_trf` ([reference](https://microsoft.github.io/presidio/analyzer/nlp_engines/spacy_stanza/))
- **Anonymize prompt scanner**: Supporting faker for applicable entities instead of placeholder (`use_faker` parameter)
- **Jailbreak prompt scanner**: Updated dataset with more examples, removed duplicates

### Removed
-
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ python -m spacy download en_core_web_trf
- [Code](docs/output_scanners/code.md)
- [Deanonymize](docs/output_scanners/deanonymize.md)
- [NoRefusal](docs/output_scanners/no_refusal.md)
- [Refutation](docs/output_scanners/refutation.md)
- [Regex](docs/output_scanners/regex.md)
- [Relevance](docs/output_scanners/relevance.md)
- [Sensitive](docs/output_scanners/sensitive.md)
Expand Down
31 changes: 31 additions & 0 deletions docs/output_scanners/refutation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Refutation Scanner

This scanner is designed to assess if the given content contradicts or refutes a certain statement or prompt. It acts as
a tool for ensuring the consistency and correctness of language model outputs, especially in contexts where logical
contradictions can be problematic.

## Attack

When interacting with users or processing information, it's important for a language model to not provide outputs that
directly contradict the given inputs or established facts. Such contradictions can lead to confusion or misinformation.
The scanner aims to highlight such inconsistencies in the output.

## How it works

The scanner leverages pretrained natural language inference (NLI) models from HuggingFace, such
as [ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli](https://huggingface.co/ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli) (
and other variants), to determine the relationship between a given prompt and the generated output.

A high contradiction score indicates that the output refutes the prompt.

This calculated refutation score is then compared to a pre-set threshold. Outputs that cross this threshold are flagged
as contradictory.

## Usage

```python
from llm_guard.output_scanners import Refutation

scanner = Refutation(threshold=0.7)
sanitized_output, is_valid = scanner.scan(prompt, model_output)
```
2 changes: 2 additions & 0 deletions llm_guard/output_scanners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .code import Code
from .deanonymize import Deanonymize
from .no_refusal import NoRefusal
from .refutation import Refutation
from .regex import Regex
from .relevance import Relevance
from .sensitive import Sensitive
Expand All @@ -15,6 +16,7 @@
"Code",
"Deanonymize",
"NoRefusal",
"Refutation",
"Regex",
"Relevance",
"Sensitive",
Expand Down
80 changes: 80 additions & 0 deletions llm_guard/output_scanners/refutation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import logging

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import logging as transformers_logging

from .base import Scanner

_model_path = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"
# _model_path = "ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli"
# _model_path = "ynie/bart-large-snli_mnli_fever_anli_R1_R2_R3-nli"
# _model_path = "ynie/electra-large-discriminator-snli_mnli_fever_anli_R1_R2_R3-nli"
# _model_path = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli"

transformers_logging.set_verbosity_error()
log = logging.getLogger(__name__)

MAX_LENGTH = 256


class Refutation(Scanner):
"""
Refutation Class:
This class checks for refutation between a given prompt and output using a pretrained NLI model.
"""

def __init__(self, threshold=0.5):
"""
Initializes an instance of the Refutation class.
Parameters:
threshold (float): The threshold used to determine refutation. Defaults to 0.
"""

self._model = AutoModelForSequenceClassification.from_pretrained(_model_path)
self._model.eval()
self._tokenizer = AutoTokenizer.from_pretrained(_model_path)
self._threshold = threshold

def scan(self, prompt: str, output: str) -> (str, bool):
if prompt.strip() == "":
return output, True

tokenized_input_seq_pair = self._tokenizer.encode_plus(
prompt, output, max_length=MAX_LENGTH, return_token_type_ids=True, truncation=True
)

input_ids = torch.Tensor(tokenized_input_seq_pair["input_ids"]).long().unsqueeze(0)
# remember bart doesn't have 'token_type_ids', remove the line below if you are using bart.
token_type_ids = (
torch.Tensor(tokenized_input_seq_pair["token_type_ids"]).long().unsqueeze(0)
)
attention_mask = (
torch.Tensor(tokenized_input_seq_pair["attention_mask"]).long().unsqueeze(0)
)

outputs = self._model(
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=None
)

predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one

log.debug(
f"Detected entailment in the output with score: {predicted_probability[0]}, neutral: {predicted_probability[1]}, contradiction: {predicted_probability[2]}"
)

contradiction_score = predicted_probability[2]
if contradiction_score > self._threshold:
log.warning(
f"Detected refutation in the output with score: {contradiction_score}, threshold: {self._threshold}"
)

return output, False

log.debug(
f"Not refutation in the output. Max score: {contradiction_score}, threshold: {self._threshold}"
)

return output, True
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ presidio-anonymizer==2.2.33
sentence-transformers==2.2.2
spacy==3.6.1
tiktoken==0.4.0
torch==2.0.1
transformers==4.30.0
xformers==0.0.20
35 changes: 35 additions & 0 deletions tests/output_scanners/test_refutation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from llm_guard.output_scanners.refutation import Refutation


@pytest.mark.parametrize(
"prompt,output,expected_output,expected_valid",
[
(
"All humans are vegetarians.",
"Many people around the world eat meat and fish as part of their diet. Therefore, the statement that all "
"humans are vegetarians is incorrect.",
"Many people around the world eat meat and fish as part of their diet. Therefore, the statement that all "
"humans are vegetarians is incorrect.",
False,
), # Refutation
(
"Birds can fly",
"Birds cannot fly.",
"Birds cannot fly.",
False,
), # Refutation
(
"The sun rises in the east.",
"Every morning, people see the sun appear on the eastern horizon.",
"Every morning, people see the sun appear on the eastern horizon.",
True,
), # No Refutation
],
)
def test_scan(prompt, output, expected_output, expected_valid):
scanner = Refutation()
sanitized_output, valid = scanner.scan(prompt, output)
assert sanitized_output == expected_output
assert valid == expected_valid

0 comments on commit 070569b

Please sign in to comment.