Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
asofter committed Jan 31, 2024
2 parents a3f85ce + 253975d commit a83f6cd
Show file tree
Hide file tree
Showing 18 changed files with 420 additions and 59 deletions.
15 changes: 15 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
root = true

[*]
indent_style = space
indent_size = 4
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
max_line_length = 120

[*.md]
trim_trailing_whitespace = false

[*.{yml,yaml,json}]
indent_size = 2
3 changes: 2 additions & 1 deletion docs/input_scanners/anonymize.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ from llm_guard.input_scanners import Anonymize
from llm_guard.input_scanners.anonymize_helpers import BERT_LARGE_NER_CONF

scanner = Anonymize(vault, preamble="Insert before prompt", allowed_names=["John Doe"], hidden_names=["Test LLC"],
recognizer_conf=BERT_LARGE_NER_CONF)
recognizer_conf=BERT_LARGE_NER_CONF, language="en")
sanitized_prompt, is_valid, risk_score = scanner.scan(prompt)
```

Expand All @@ -97,6 +97,7 @@ sanitized_prompt, is_valid, risk_score = scanner.scan(prompt)
- `use_faker`: Substitutes eligible entities with fabricated data.
- `recognizer_conf`: Configures recognizer for the PII data detection.
- `threshold`: Sets the acceptance threshold (Default: `0`).
- `language`: Language of the anonymize detect. Default is "en".

Retrieving Original Data: To revert to the initial data, utilize the [Deanonymize](../output_scanners/deanonymize.md)
scanner.
Expand Down
70 changes: 55 additions & 15 deletions llm_guard/input_scanners/anonymize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from presidio_anonymizer.core.text_replace_builder import TextReplaceBuilder
from presidio_anonymizer.entities import PIIEntity, RecognizerResult

from llm_guard.exception import LLMGuardValidationError
from llm_guard.util import logger
from llm_guard.vault import Vault

Expand All @@ -23,6 +24,7 @@
"resources",
"sensisitive_patterns.json",
)

default_entity_types = [
"CREDIT_CARD",
"CRYPTO",
Expand All @@ -39,6 +41,8 @@
"US_SSN_RE",
]

ALL_SUPPORTED_LANGUAGES = ["en", "zh"]


class Anonymize(Scanner):
"""
Expand All @@ -61,6 +65,7 @@ def __init__(
recognizer_conf: Optional[Dict] = BERT_BASE_NER_CONF,
threshold: float = 0,
use_onnx: bool = False,
language: str = "en",
):
"""
Initialize an instance of Anonymize class.
Expand All @@ -76,13 +81,20 @@ def __init__(
recognizer_conf (Optional[Dict]): Configuration to recognize PII data. Default is dslim/bert-base-NER.
threshold (float): Acceptance threshold. Default is 0.
use_onnx (bool): Whether to use ONNX runtime for inference. Default is False.
language (str): Language of the anonymize detect. Default is "en".
"""

if language not in ALL_SUPPORTED_LANGUAGES:
raise LLMGuardValidationError(
f"Language must be in the list of allowed: {ALL_SUPPORTED_LANGUAGES}"
)

os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disables huggingface/tokenizers warning

if not entity_types:
logger.debug(f"No entity types provided, using default: {default_entity_types}")
entity_types = default_entity_types.copy()

entity_types.append("CUSTOM")

if not hidden_names:
Expand All @@ -94,12 +106,19 @@ def __init__(
self._preamble = preamble
self._use_faker = use_faker
self._threshold = threshold
self._language = language

transformers_recognizer = get_transformers_recognizer(
recognizer_conf=recognizer_conf,
use_onnx=use_onnx,
supported_language=language,
)

transformers_recognizer = get_transformers_recognizer(recognizer_conf, use_onnx)
self._analyzer = get_analyzer(
transformers_recognizer,
Anonymize.get_regex_patterns(regex_pattern_groups_path),
hidden_names,
recognizer=transformers_recognizer,
regex_groups=Anonymize.get_regex_patterns(regex_pattern_groups_path),
custom_names=hidden_names,
supported_languages=list(set(["en", language])),
)

@staticmethod
Expand All @@ -122,9 +141,11 @@ def get_regex_patterns(json_path: str) -> List[dict]:
regex_groups.append(
{
"name": group["name"].upper(),
"expressions": group["expressions"],
"context": group["context"],
"score": group["score"],
"expressions": group.get("expressions", []),
"context": group.get("context", []),
"score": group.get("score", 0.75),
"languages": group.get("languages", ["en"]),
"reuse": group.get("reuse", False),
}
)
logger.debug(f"Loaded regex pattern for {group['name']}")
Expand Down Expand Up @@ -220,14 +241,15 @@ def _get_entity_placeholder(entity_type: str, index: int, use_faker: bool) -> st

@staticmethod
def _anonymize(
prompt: str, pii_entities: List[PIIEntity], use_faker: bool
prompt: str, pii_entities: List[PIIEntity], vault: Vault, use_faker: bool
) -> (str, List[tuple]):
"""
Replace detected entities in the prompt with anonymized placeholders.
Parameters:
prompt (str): Original text prompt.
pii_entities (List[PIIEntity]): List of entities detected in the prompt.
vault (Vault): A vault instance with the anonymized data stored.
use_faker (bool): Whether to use faker to generate fake data.
Returns:
Expand All @@ -236,7 +258,7 @@ def _anonymize(
"""
text_replace_builder = TextReplaceBuilder(original_text=prompt)

entity_type_counter = {}
entity_type_counter, new_entity_counter = {}, {}
for pii_entity in pii_entities:
entity_type = pii_entity.entity_type
entity_value = text_replace_builder.get_text_in_position(
Expand All @@ -247,9 +269,25 @@ def _anonymize(
entity_type_counter[entity_type] = {}

if entity_value not in entity_type_counter[entity_type]:
entity_type_counter[entity_type][entity_value] = (
len(entity_type_counter[entity_type]) + 1
)
vault_entities = [
(entity_placeholder, entity_vault_value)
for entity_placeholder, entity_vault_value in vault.get()
if entity_type in entity_placeholder
]
entity_placeholder = [
entity_placeholder
for entity_placeholder, entity_vault_value in vault_entities
if entity_vault_value == entity_value
]
if len(entity_placeholder) > 0:
entity_type_counter[entity_type][entity_value] = int(
entity_placeholder[0].split("_")[-1][:-1]
)
else:
entity_type_counter[entity_type][entity_value] = (
len(vault_entities) + new_entity_counter.get(entity_type, 0) + 1
)
new_entity_counter[entity_type] = new_entity_counter.get(entity_type, 0) + 1

results = []
sorted_pii_entities = sorted(pii_entities, reverse=True)
Expand Down Expand Up @@ -282,7 +320,7 @@ def scan(self, prompt: str) -> (str, bool, float):

analyzer_results = self._analyzer.analyze(
text=Anonymize.remove_single_quotes(prompt),
language="en",
language=self._language,
entities=self._entity_types,
allow_list=self._allowed_names,
score_threshold=self._threshold,
Expand All @@ -298,14 +336,16 @@ def scan(self, prompt: str) -> (str, bool, float):
merged_results = self._merge_entities_with_whitespace_between(prompt, analyzer_results)

sanitized_prompt, anonymized_results = self._anonymize(
prompt, merged_results, self._use_faker
prompt, merged_results, self._vault, self._use_faker
)

if prompt != sanitized_prompt:
logger.warning(
f"Found sensitive data in the prompt and replaced it: {merged_results}, risk score: {risk_score}"
)
self._vault.extend(anonymized_results)
for entity_placeholder, entity_value in anonymized_results:
if not self._vault.placeholder_exists(entity_placeholder):
self._vault.append((entity_placeholder, entity_value))
return self._preamble + sanitized_prompt, False, risk_score

logger.debug(f"Prompt does not have sensitive data to replace. Risk score is {risk_score}")
Expand Down
109 changes: 80 additions & 29 deletions llm_guard/input_scanners/anonymize_helpers/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Sequence
import copy
from typing import Dict, List, Sequence

import spacy
from presidio_analyzer import (
Expand All @@ -13,11 +14,16 @@
from llm_guard.exception import LLMGuardValidationError

from .ner_mapping import ALL_RECOGNIZER_CONF
from .predefined_recognizers import _get_predefined_recognizers
from .predefined_recognizers.zh import CustomPatternRecognizer
from .transformers_recognizer import TransformersRecognizer


def _add_recognizers(
registry: RecognizerRegistry, regex_groups, custom_names
registry: RecognizerRegistry,
regex_groups,
custom_names,
supported_languages: List[str] = ["en"],
) -> RecognizerRegistry:
"""
Create a RecognizerRegistry and populate it with regex patterns and custom names.
Expand All @@ -30,42 +36,79 @@ def _add_recognizers(
RecognizerRegistry: A RecognizerRegistry object loaded with regex and custom name recognizers.
"""

if len(custom_names) > 0:
registry.add_recognizer(
PatternRecognizer(supported_entity="CUSTOM", deny_list=custom_names)
)
for language in supported_languages:
# custom recognizer per language
if len(custom_names) > 0:
custom_recognier = PatternRecognizer

if language == "zh":
custom_recognier = CustomPatternRecognizer

registry.add_recognizer(
custom_recognier(
supported_entity="CUSTOM",
supported_language=language,
deny_list=custom_names,
)
)

# predefined recognizers per language
for _Recognizer in _get_predefined_recognizers(language):
registry.add_recognizer(_Recognizer(supported_language=language))

for pattern_data in regex_groups:
languages = pattern_data["languages"] or ["en"]

label = pattern_data["name"]
compiled_patterns = pattern_data["expressions"]
patterns = []
for pattern in compiled_patterns:
patterns.append(Pattern(name=label, regex=pattern, score=pattern_data["score"]))
registry.add_recognizer(
PatternRecognizer(
supported_entity=label,
patterns=patterns,
context=pattern_data["context"],
)
reuse = pattern_data.get("reuse", False)

patterns = map(
lambda exp: Pattern(name=label, regex=exp, score=pattern_data["score"]),
pattern_data.get("expressions", []) or [],
)

for language in languages:
if language not in supported_languages:
continue

if reuse:
new_recognizer = copy.deepcopy(
registry.get_recognizers(language=reuse["language"], entities=[reuse["name"]])[
0
]
)
new_recognizer.supported_language = language
registry.add_recognizer(new_recognizer)
else:
registry.add_recognizer(
PatternRecognizer(
supported_entity=label,
supported_language=language,
patterns=patterns,
context=pattern_data["context"],
)
)

return registry


def _get_nlp_engine() -> NlpEngine:
# Use small spacy model, for faster inference.
if not spacy.util.is_package("en_core_web_sm"):
spacy.cli.download("en_core_web_sm")
def _get_nlp_engine(languages: List[str] = ["en"]) -> NlpEngine:
models = []

configuration = {
"nlp_engine_name": "spacy",
"models": [{"lang_code": "en", "model_name": "en_core_web_sm"}],
}
for language in languages:
if not spacy.util.is_package(f"{language}_core_web_sm"):
# Use small spacy model, for faster inference.
spacy.cli.download(f"{language}_core_web_sm")
models.append({"lang_code": language, "model_name": f"{language}_core_web_sm"})

configuration = {"nlp_engine_name": "spacy", "models": models}

return NlpEngineProvider(nlp_configuration=configuration).create_engine()


def get_transformers_recognizer(recognizer_conf: Dict, use_onnx: bool = False) -> EntityRecognizer:
def get_transformers_recognizer(
recognizer_conf: Dict, use_onnx: bool = False, supported_language: str = "en"
) -> EntityRecognizer:
if recognizer_conf not in ALL_RECOGNIZER_CONF:
raise LLMGuardValidationError(
f"Recognizer must be in the list of allowed: {ALL_RECOGNIZER_CONF}"
Expand All @@ -76,20 +119,28 @@ def get_transformers_recognizer(recognizer_conf: Dict, use_onnx: bool = False) -
transformers_recognizer = TransformersRecognizer(
model_path=model_path,
supported_entities=supported_entities,
supported_language=supported_language,
)
transformers_recognizer.load_transformer(use_onnx=use_onnx, **recognizer_conf)
return transformers_recognizer


def get_analyzer(
recognizer: EntityRecognizer, regex_groups, custom_names: Sequence[str]
recognizer: EntityRecognizer,
regex_groups,
custom_names: Sequence[str],
supported_languages: List[str] = ["en"],
) -> AnalyzerEngine:
nlp_engine = _get_nlp_engine()
nlp_engine = _get_nlp_engine(languages=supported_languages)

registry = RecognizerRegistry()
registry.load_predefined_recognizers(nlp_engine=nlp_engine)
registry = _add_recognizers(registry, regex_groups, custom_names)
registry = _add_recognizers(registry, regex_groups, custom_names, supported_languages)
registry.add_recognizer(recognizer)
registry.remove_recognizer("SpacyRecognizer")

return AnalyzerEngine(nlp_engine=nlp_engine, registry=registry, supported_languages=["en"])
return AnalyzerEngine(
nlp_engine=nlp_engine,
registry=registry,
supported_languages=supported_languages,
)
Loading

0 comments on commit a83f6cd

Please sign in to comment.