From 070569b6aebdcf65b869dbe3a71567a25846146e Mon Sep 17 00:00:00 2001 From: asofter Date: Sat, 12 Aug 2023 13:24:28 +0200 Subject: [PATCH] * refutation scanner for the output --- CHANGELOG.md | 8 +-- README.md | 1 + docs/output_scanners/refutation.md | 31 +++++++++ llm_guard/output_scanners/__init__.py | 2 + llm_guard/output_scanners/refutation.py | 80 ++++++++++++++++++++++++ requirements.txt | 1 + tests/output_scanners/test_refutation.py | 35 +++++++++++ 7 files changed, 154 insertions(+), 4 deletions(-) create mode 100644 docs/output_scanners/refutation.md create mode 100644 llm_guard/output_scanners/refutation.py create mode 100644 tests/output_scanners/test_refutation.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c0b18f5..d5464264 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,15 +8,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added -- +- [Refutation output scanner](./docs/output_scanners/refutation.md) ### Fixed - ### Changed -- [Anonymize prompt scanner] Using the transformer based Spacy model `en_core_web_trf` ([reference](https://microsoft.github.io/presidio/analyzer/nlp_engines/spacy_stanza/)) -- [Anonymize prompt scanner] Supporting faker for applicable entities instead of placeholder (`use_faker` parameter) -- [Jailbreak prompt scanner] Updated dataset with more examples, removed duplicates +- **Anonymize prompt scanner**: Using the transformer based Spacy model `en_core_web_trf` ([reference](https://microsoft.github.io/presidio/analyzer/nlp_engines/spacy_stanza/)) +- **Anonymize prompt scanner**: Supporting faker for applicable entities instead of placeholder (`use_faker` parameter) +- **Jailbreak prompt scanner**: Updated dataset with more examples, removed duplicates ### Removed - diff --git a/README.md b/README.md index 4ad10479..bbd39aae 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ python -m spacy download en_core_web_trf - [Code](docs/output_scanners/code.md) - [Deanonymize](docs/output_scanners/deanonymize.md) - [NoRefusal](docs/output_scanners/no_refusal.md) +- [Refutation](docs/output_scanners/refutation.md) - [Regex](docs/output_scanners/regex.md) - [Relevance](docs/output_scanners/relevance.md) - [Sensitive](docs/output_scanners/sensitive.md) diff --git a/docs/output_scanners/refutation.md b/docs/output_scanners/refutation.md new file mode 100644 index 00000000..0702b1ba --- /dev/null +++ b/docs/output_scanners/refutation.md @@ -0,0 +1,31 @@ +# Refutation Scanner + +This scanner is designed to assess if the given content contradicts or refutes a certain statement or prompt. It acts as +a tool for ensuring the consistency and correctness of language model outputs, especially in contexts where logical +contradictions can be problematic. + +## Attack + +When interacting with users or processing information, it's important for a language model to not provide outputs that +directly contradict the given inputs or established facts. Such contradictions can lead to confusion or misinformation. +The scanner aims to highlight such inconsistencies in the output. + +## How it works + +The scanner leverages pretrained natural language inference (NLI) models from HuggingFace, such +as [ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli](https://huggingface.co/ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli) ( +and other variants), to determine the relationship between a given prompt and the generated output. + +A high contradiction score indicates that the output refutes the prompt. + +This calculated refutation score is then compared to a pre-set threshold. Outputs that cross this threshold are flagged +as contradictory. + +## Usage + +```python +from llm_guard.output_scanners import Refutation + +scanner = Refutation(threshold=0.7) +sanitized_output, is_valid = scanner.scan(prompt, model_output) +``` diff --git a/llm_guard/output_scanners/__init__.py b/llm_guard/output_scanners/__init__.py index 9b4dfd30..573ccb1c 100644 --- a/llm_guard/output_scanners/__init__.py +++ b/llm_guard/output_scanners/__init__.py @@ -4,6 +4,7 @@ from .code import Code from .deanonymize import Deanonymize from .no_refusal import NoRefusal +from .refutation import Refutation from .regex import Regex from .relevance import Relevance from .sensitive import Sensitive @@ -15,6 +16,7 @@ "Code", "Deanonymize", "NoRefusal", + "Refutation", "Regex", "Relevance", "Sensitive", diff --git a/llm_guard/output_scanners/refutation.py b/llm_guard/output_scanners/refutation.py new file mode 100644 index 00000000..a3713ffd --- /dev/null +++ b/llm_guard/output_scanners/refutation.py @@ -0,0 +1,80 @@ +import logging + +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from transformers import logging as transformers_logging + +from .base import Scanner + +_model_path = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli" +# _model_path = "ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli" +# _model_path = "ynie/bart-large-snli_mnli_fever_anli_R1_R2_R3-nli" +# _model_path = "ynie/electra-large-discriminator-snli_mnli_fever_anli_R1_R2_R3-nli" +# _model_path = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli" + +transformers_logging.set_verbosity_error() +log = logging.getLogger(__name__) + +MAX_LENGTH = 256 + + +class Refutation(Scanner): + """ + Refutation Class: + + This class checks for refutation between a given prompt and output using a pretrained NLI model. + """ + + def __init__(self, threshold=0.5): + """ + Initializes an instance of the Refutation class. + + Parameters: + threshold (float): The threshold used to determine refutation. Defaults to 0. + """ + + self._model = AutoModelForSequenceClassification.from_pretrained(_model_path) + self._model.eval() + self._tokenizer = AutoTokenizer.from_pretrained(_model_path) + self._threshold = threshold + + def scan(self, prompt: str, output: str) -> (str, bool): + if prompt.strip() == "": + return output, True + + tokenized_input_seq_pair = self._tokenizer.encode_plus( + prompt, output, max_length=MAX_LENGTH, return_token_type_ids=True, truncation=True + ) + + input_ids = torch.Tensor(tokenized_input_seq_pair["input_ids"]).long().unsqueeze(0) + # remember bart doesn't have 'token_type_ids', remove the line below if you are using bart. + token_type_ids = ( + torch.Tensor(tokenized_input_seq_pair["token_type_ids"]).long().unsqueeze(0) + ) + attention_mask = ( + torch.Tensor(tokenized_input_seq_pair["attention_mask"]).long().unsqueeze(0) + ) + + outputs = self._model( + input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=None + ) + + predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one + + log.debug( + f"Detected entailment in the output with score: {predicted_probability[0]}, neutral: {predicted_probability[1]}, contradiction: {predicted_probability[2]}" + ) + + contradiction_score = predicted_probability[2] + if contradiction_score > self._threshold: + log.warning( + f"Detected refutation in the output with score: {contradiction_score}, threshold: {self._threshold}" + ) + + return output, False + + log.debug( + f"Not refutation in the output. Max score: {contradiction_score}, threshold: {self._threshold}" + ) + + return output, True diff --git a/requirements.txt b/requirements.txt index 0d33ac44..d48940c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,6 @@ presidio-anonymizer==2.2.33 sentence-transformers==2.2.2 spacy==3.6.1 tiktoken==0.4.0 +torch==2.0.1 transformers==4.30.0 xformers==0.0.20 diff --git a/tests/output_scanners/test_refutation.py b/tests/output_scanners/test_refutation.py new file mode 100644 index 00000000..7a00aec1 --- /dev/null +++ b/tests/output_scanners/test_refutation.py @@ -0,0 +1,35 @@ +import pytest + +from llm_guard.output_scanners.refutation import Refutation + + +@pytest.mark.parametrize( + "prompt,output,expected_output,expected_valid", + [ + ( + "All humans are vegetarians.", + "Many people around the world eat meat and fish as part of their diet. Therefore, the statement that all " + "humans are vegetarians is incorrect.", + "Many people around the world eat meat and fish as part of their diet. Therefore, the statement that all " + "humans are vegetarians is incorrect.", + False, + ), # Refutation + ( + "Birds can fly", + "Birds cannot fly.", + "Birds cannot fly.", + False, + ), # Refutation + ( + "The sun rises in the east.", + "Every morning, people see the sun appear on the eastern horizon.", + "Every morning, people see the sun appear on the eastern horizon.", + True, + ), # No Refutation + ], +) +def test_scan(prompt, output, expected_output, expected_valid): + scanner = Refutation() + sanitized_output, valid = scanner.scan(prompt, output) + assert sanitized_output == expected_output + assert valid == expected_valid