-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
154 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Refutation Scanner | ||
|
||
This scanner is designed to assess if the given content contradicts or refutes a certain statement or prompt. It acts as | ||
a tool for ensuring the consistency and correctness of language model outputs, especially in contexts where logical | ||
contradictions can be problematic. | ||
|
||
## Attack | ||
|
||
When interacting with users or processing information, it's important for a language model to not provide outputs that | ||
directly contradict the given inputs or established facts. Such contradictions can lead to confusion or misinformation. | ||
The scanner aims to highlight such inconsistencies in the output. | ||
|
||
## How it works | ||
|
||
The scanner leverages pretrained natural language inference (NLI) models from HuggingFace, such | ||
as [ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli](https://huggingface.co/ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli) ( | ||
and other variants), to determine the relationship between a given prompt and the generated output. | ||
|
||
A high contradiction score indicates that the output refutes the prompt. | ||
|
||
This calculated refutation score is then compared to a pre-set threshold. Outputs that cross this threshold are flagged | ||
as contradictory. | ||
|
||
## Usage | ||
|
||
```python | ||
from llm_guard.output_scanners import Refutation | ||
|
||
scanner = Refutation(threshold=0.7) | ||
sanitized_output, is_valid = scanner.scan(prompt, model_output) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import logging | ||
|
||
import torch | ||
from transformers import AutoModelForSequenceClassification, AutoTokenizer | ||
from transformers import logging as transformers_logging | ||
|
||
from .base import Scanner | ||
|
||
_model_path = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli" | ||
# _model_path = "ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli" | ||
# _model_path = "ynie/bart-large-snli_mnli_fever_anli_R1_R2_R3-nli" | ||
# _model_path = "ynie/electra-large-discriminator-snli_mnli_fever_anli_R1_R2_R3-nli" | ||
# _model_path = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli" | ||
|
||
transformers_logging.set_verbosity_error() | ||
log = logging.getLogger(__name__) | ||
|
||
MAX_LENGTH = 256 | ||
|
||
|
||
class Refutation(Scanner): | ||
""" | ||
Refutation Class: | ||
This class checks for refutation between a given prompt and output using a pretrained NLI model. | ||
""" | ||
|
||
def __init__(self, threshold=0.5): | ||
""" | ||
Initializes an instance of the Refutation class. | ||
Parameters: | ||
threshold (float): The threshold used to determine refutation. Defaults to 0. | ||
""" | ||
|
||
self._model = AutoModelForSequenceClassification.from_pretrained(_model_path) | ||
self._model.eval() | ||
self._tokenizer = AutoTokenizer.from_pretrained(_model_path) | ||
self._threshold = threshold | ||
|
||
def scan(self, prompt: str, output: str) -> (str, bool): | ||
if prompt.strip() == "": | ||
return output, True | ||
|
||
tokenized_input_seq_pair = self._tokenizer.encode_plus( | ||
prompt, output, max_length=MAX_LENGTH, return_token_type_ids=True, truncation=True | ||
) | ||
|
||
input_ids = torch.Tensor(tokenized_input_seq_pair["input_ids"]).long().unsqueeze(0) | ||
# remember bart doesn't have 'token_type_ids', remove the line below if you are using bart. | ||
token_type_ids = ( | ||
torch.Tensor(tokenized_input_seq_pair["token_type_ids"]).long().unsqueeze(0) | ||
) | ||
attention_mask = ( | ||
torch.Tensor(tokenized_input_seq_pair["attention_mask"]).long().unsqueeze(0) | ||
) | ||
|
||
outputs = self._model( | ||
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=None | ||
) | ||
|
||
predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one | ||
|
||
log.debug( | ||
f"Detected entailment in the output with score: {predicted_probability[0]}, neutral: {predicted_probability[1]}, contradiction: {predicted_probability[2]}" | ||
) | ||
|
||
contradiction_score = predicted_probability[2] | ||
if contradiction_score > self._threshold: | ||
log.warning( | ||
f"Detected refutation in the output with score: {contradiction_score}, threshold: {self._threshold}" | ||
) | ||
|
||
return output, False | ||
|
||
log.debug( | ||
f"Not refutation in the output. Max score: {contradiction_score}, threshold: {self._threshold}" | ||
) | ||
|
||
return output, True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import pytest | ||
|
||
from llm_guard.output_scanners.refutation import Refutation | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"prompt,output,expected_output,expected_valid", | ||
[ | ||
( | ||
"All humans are vegetarians.", | ||
"Many people around the world eat meat and fish as part of their diet. Therefore, the statement that all " | ||
"humans are vegetarians is incorrect.", | ||
"Many people around the world eat meat and fish as part of their diet. Therefore, the statement that all " | ||
"humans are vegetarians is incorrect.", | ||
False, | ||
), # Refutation | ||
( | ||
"Birds can fly", | ||
"Birds cannot fly.", | ||
"Birds cannot fly.", | ||
False, | ||
), # Refutation | ||
( | ||
"The sun rises in the east.", | ||
"Every morning, people see the sun appear on the eastern horizon.", | ||
"Every morning, people see the sun appear on the eastern horizon.", | ||
True, | ||
), # No Refutation | ||
], | ||
) | ||
def test_scan(prompt, output, expected_output, expected_valid): | ||
scanner = Refutation() | ||
sanitized_output, valid = scanner.scan(prompt, output) | ||
assert sanitized_output == expected_output | ||
assert valid == expected_valid |