Skip to content

Commit

Permalink
* new model for the BanCompetitors scanner
Browse files Browse the repository at this point in the history
  • Loading branch information
asofter committed May 10, 2024
1 parent def7023 commit 3da6a53
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 66 deletions.
2 changes: 2 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def build_input_scanner(scanner_name: str, use_onnx: bool) -> InputScanner:
return input_scanners.BanCompetitors(
competitors=["Google", "Bing", "Yahoo"],
threshold=0.5,
use_onnx=use_onnx,
)

if scanner_name == "BanSubstrings":
Expand Down Expand Up @@ -88,6 +89,7 @@ def build_output_scanner(scanner_name: str, use_onnx: bool) -> OutputScanner:
return output_scanners.BanCompetitors(
competitors=["Google", "Bing", "Yahoo"],
threshold=0.5,
use_onnx=use_onnx,
)

if scanner_name == "BanSubstrings":
Expand Down
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BanCode` scanner was improved to trigger less false-positives.
- Improved logging to support JSON format.
- Optimizations in the `API` to reduce the latency.
- `BanCompetitors` scanner relies on the new model which also supports ONNX inference.

### Removed
-
Expand Down
5 changes: 2 additions & 3 deletions docs/input_scanners/ban_competitors.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ preference.

Models:

- [tomaarsen/span-marker-bert-small-orgs](https://huggingface.co/tomaarsen/span-marker-bert-small-orgs)
- [tomaarsen/span-marker-bert-base-orgs](https://huggingface.co/tomaarsen/span-marker-bert-base-orgs)
- [guishe/nuner-v1_orgs](https://huggingface.co/guishe/nuner-v1_orgs)

## Usage

Expand Down Expand Up @@ -53,7 +52,7 @@ sanitized_prompt, is_valid, risk_score = scanner.scan(prompt)

## Optimization Strategies

ONNX support for this scanner is currently in development ([PR](https://github.com/tomaarsen/SpanMarkerNER/pull/43)).
[Read more](../tutorials/optimization.md)

## Benchmark

Expand Down
5 changes: 2 additions & 3 deletions docs/output_scanners/ban_competitors.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ preference.

Models:

- [tomaarsen/span-marker-bert-small-orgs](https://huggingface.co/tomaarsen/span-marker-bert-small-orgs)
- [tomaarsen/span-marker-bert-base-orgs](https://huggingface.co/tomaarsen/span-marker-bert-base-orgs)
- [guishe/nuner-v1_orgs](https://huggingface.co/guishe/nuner-v1_orgs)

## Usage

Expand Down Expand Up @@ -55,7 +54,7 @@ sanitized_output, is_valid, risk_score = scanner.scan(prompt, output)

## Optimization Strategies

ONNX support for this scanner is currently in development ([PR](https://github.com/tomaarsen/SpanMarkerNER/pull/43)).
[Read more](../tutorials/optimization.md)

## Benchmark

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

import copy
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

from presidio_analyzer import AnalysisExplanation, EntityRecognizer, RecognizerResult
from presidio_analyzer.nlp_engine import NlpArtifacts
from transformers import TokenClassificationPipeline

from llm_guard.model import Model
from llm_guard.transformers_helpers import device, get_tokenizer, is_onnx_supported
from llm_guard.util import get_logger, lazy_load_dep, split_text_to_word_chunks
from llm_guard.transformers_helpers import get_tokenizer_and_model_for_ner
from llm_guard.util import get_logger, split_text_to_word_chunks

from .ner_mapping import BERT_BASE_NER_CONF

Expand Down Expand Up @@ -137,40 +137,7 @@ def _load_pipeline(
use_onnx: bool = False,
) -> None:
"""Initialize NER transformers_rec pipeline using the model_path provided"""
transformers = cast("transformers", lazy_load_dep("transformers"))
tf_tokenizer = get_tokenizer(self.model)

if use_onnx and is_onnx_supported() is False:
LOGGER.warning("ONNX is not supported on this machine. Using PyTorch instead of ONNX.")
use_onnx = False

if use_onnx:
optimum_onnxruntime = lazy_load_dep(
"optimum.onnxruntime",
"optimum[onnxruntime]" if device().type != "cuda" else "optimum[onnxruntime-gpu]",
)

tf_model = optimum_onnxruntime.ORTModelForTokenClassification.from_pretrained(
self.model.onnx_path,
export=False,
subfolder=self.model.onnx_subfolder,
provider=(
"CUDAExecutionProvider" if device().type == "cuda" else "CPUExecutionProvider"
),
revision=self.model.onnx_revision,
file_name=self.model.onnx_filename,
use_io_binding=True if device().type == "cuda" else False,
**self.model.kwargs,
)
LOGGER.debug("Initialized NER ONNX model", model=self.model, device=device())
else:
tf_model = transformers.AutoModelForTokenClassification.from_pretrained(
self.model.path,
subfolder=self.model.subfolder,
revision=self.model.revision,
**self.model.kwargs,
)
LOGGER.debug("Initialized NER model", model=self.model, device=device())
tf_tokenizer, tf_model = get_tokenizer_and_model_for_ner(self.model, use_onnx=use_onnx)

self.model.pipeline_kwargs["ignore_labels"] = self.ignore_labels
self.pipeline = transformers.pipeline(
Expand Down
54 changes: 32 additions & 22 deletions llm_guard/input_scanners/ban_competitors.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from typing import Optional, Sequence
from typing import Optional, Sequence, cast

from presidio_anonymizer.core.text_replace_builder import TextReplaceBuilder

from llm_guard.model import Model
from llm_guard.util import device, get_logger, lazy_load_dep
from llm_guard.transformers_helpers import get_tokenizer_and_model_for_ner
from llm_guard.util import get_logger, lazy_load_dep

from .base import Scanner

LOGGER = get_logger()

MODEL_BASE = Model(
"tomaarsen/span-marker-bert-base-orgs", revision="312bcdb7bc02c85ab9b8b8fe99849ca28714b29d"
)
MODEL_SMALL = Model(
"tomaarsen/span-marker-bert-small-orgs", revision="437bd92fcc2b4236b7d7402113d47920793bab46"
MODEL_V1 = Model(
path="guishe/nuner-v1_orgs",
revision="2e95454e741e5bdcbfabd6eaed5fb03a266cf043",
onnx_path="protectai/guishe-nuner-v1_orgs-onnx",
onnx_revision="20c9739f45f6b4d10ba63c62e6fa92f214a12a52",
onnx_subfolder="",
pipeline_kwargs={
"aggregation_strategy": "simple",
},
)


Expand All @@ -31,6 +36,7 @@ def __init__(
threshold: float = 0.5,
redact: bool = True,
model: Optional[Model] = None,
use_onnx: bool = False,
):
"""
Initialize BanCompetitors object.
Expand All @@ -39,40 +45,44 @@ def __init__(
competitors (Sequence[str]): List of competitors to detect.
threshold (float, optional): Threshold to determine if a competitor is present in the prompt. Default is 0.5.
redact (bool, optional): Whether to redact the competitor name. Default is True.
model (Model, optional): Model to use for named-entity recognition. Default is BASE model.
model (Model, optional): Model to use for named-entity recognition. Default is V1 model.
use_onnx (bool, optional): Whether to use ONNX instead of PyTorch for inference. Default is False.
Raises:
ValueError: If no topics are provided.
"""
if model is None:
model = MODEL_BASE
model = MODEL_V1

self._competitors = competitors
self._threshold = threshold
self._redact = redact

span_marker = lazy_load_dep("span_marker", "span-marker")
self._ner_pipeline = span_marker.SpanMarkerModel.from_pretrained(
model.path, labels=["ORG"], **model.kwargs
tf_tokenizer, tf_model = get_tokenizer_and_model_for_ner(
model=model,
use_onnx=use_onnx,
)

if device().type == "cuda":
self._ner_pipeline = self._ner_pipeline.cuda()
transformers = cast("transformers", lazy_load_dep("transformers"))
self._ner_pipeline = transformers.pipeline(
"ner", model=tf_model, tokenizer=tf_tokenizer, **model.pipeline_kwargs
)

def scan(self, prompt: str) -> (str, bool, float):
is_detected = False
text_replace_builder = TextReplaceBuilder(original_text=prompt)
entities = self._ner_pipeline.predict(prompt)
entities = sorted(entities, key=lambda x: x["char_end_index"], reverse=True)
entities = self._ner_pipeline(prompt)
entities = sorted(entities, key=lambda x: x["end"], reverse=True)
for entity in entities:
if entity["span"] not in self._competitors:
LOGGER.debug("Entity is not a specified competitor", entity=entity["span"])
entity["word"] = entity["word"].strip()
if entity["word"] not in self._competitors:
LOGGER.debug("Entity is not a specified competitor", entity=entity["word"])
continue

if entity["score"] < self._threshold:
LOGGER.debug(
"Competitor detected but the score is below threshold",
entity=entity["span"],
entity=entity["word"],
score=entity["score"],
)
continue
Expand All @@ -82,12 +92,12 @@ def scan(self, prompt: str) -> (str, bool, float):
if self._redact:
text_replace_builder.replace_text_get_insertion_index(
"[REDACTED]",
entity["char_start_index"],
entity["char_end_index"],
entity["start"],
entity["end"],
)

LOGGER.warning(
"Competitor detected with score", entity=entity["span"], score=entity["score"]
"Competitor detected with score", entity=entity["word"], score=entity["score"]
)

if is_detected:
Expand Down
3 changes: 3 additions & 0 deletions llm_guard/output_scanners/ban_competitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
threshold: float = 0.5,
redact: bool = True,
model: Optional[Model] = None,
use_onnx: bool = False,
):
"""
Initializes BanCompetitors object.
Expand All @@ -29,6 +30,7 @@ def __init__(
threshold (float, optional): Threshold to determine if an organization is present in the output. Default is 0.5.
redact (bool, optional): Whether to redact the organization name. Default is True.
model (Model, optional): Model to use for named-entity recognition. Default is BASE model.
use_onnx (bool, optional): Whether to use ONNX instead of PyTorch for inference. Default is False.
Raises:
ValueError: If no competitors are provided.
Expand All @@ -38,6 +40,7 @@ def __init__(
threshold=threshold,
redact=redact,
model=model,
use_onnx=use_onnx,
)

def scan(self, prompt: str, output: str) -> (str, bool, float):
Expand Down
47 changes: 47 additions & 0 deletions llm_guard/transformers_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,53 @@ def get_tokenizer_and_model_for_classification(
return tf_tokenizer, tf_model


def get_tokenizer_and_model_for_ner(
model: Model,
use_onnx: bool = False,
):
"""
This function loads a tokenizer and model given a model identifier and caches them.
Subsequent calls with the same model_identifier will return the cached tokenizer.
Args:
model (str): The model identifier to load the tokenizer and model for.
use_onnx (bool): Whether to use the ONNX version of the model. Defaults to False.
"""
tf_tokenizer = get_tokenizer(model)
transformers = lazy_load_dep("transformers")

if use_onnx and is_onnx_supported() is False:
LOGGER.warning("ONNX is not supported on this machine. Using PyTorch instead of ONNX.")
use_onnx = False

if use_onnx is False:
tf_model = transformers.AutoModelForTokenClassification.from_pretrained(
model.path, subfolder=model.subfolder, revision=model.revision, **model.kwargs
)
LOGGER.debug("Initialized NER model", model=model, device=device())

return tf_tokenizer, tf_model

optimum_onnxruntime = lazy_load_dep(
"optimum.onnxruntime",
"optimum[onnxruntime]" if device().type != "cuda" else "optimum[onnxruntime-gpu]",
)

tf_model = optimum_onnxruntime.ORTModelForTokenClassification.from_pretrained(
model.onnx_path,
export=False,
subfolder=model.onnx_subfolder,
provider=("CUDAExecutionProvider" if device().type == "cuda" else "CPUExecutionProvider"),
revision=model.onnx_revision,
file_name=model.onnx_filename,
use_io_binding=True if device().type == "cuda" else False,
**model.kwargs,
)
LOGGER.debug("Initialized NER ONNX model", model=model, device=device())

return tf_tokenizer, tf_model


ClassificationTask = Literal["text-classification", "zero-shot-classification"]


Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ dependencies = [
"tiktoken>=0.5,<0.7",
"torch==2.2.2",
"transformers==4.39.3",
"span-marker==1.5.0",
"structlog>=24"
]

Expand Down

0 comments on commit 3da6a53

Please sign in to comment.