diff --git a/docs/changelog.md b/docs/changelog.md index 5df5d7f8..9d848983 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -14,10 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `InvisibleText` scanner to allow control characters like `\n`, `\t`, etc. ### Changed -- +- **[Breaking]**: Introducing `Model` object for better customization of the models. ### Removed -- +- `model_kwargs` and `pipeline_kwargs` as they are part of the `Model` object. ## [0.3.10] - 2024-03-14 diff --git a/llm_guard/input_scanners/anonymize.py b/llm_guard/input_scanners/anonymize.py index d9fc386a..e057463e 100644 --- a/llm_guard/input_scanners/anonymize.py +++ b/llm_guard/input_scanners/anonymize.py @@ -56,12 +56,10 @@ def __init__( preamble: str = "", regex_patterns: Optional[List[Dict]] = None, use_faker: bool = False, - recognizer_conf: Optional[Dict] = DEBERTA_AI4PRIVACY_v2_CONF, + recognizer_conf: Optional[Dict] = None, threshold: float = 0.5, use_onnx: bool = False, language: str = "en", - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initialize an instance of Anonymize class. @@ -78,8 +76,6 @@ def __init__( threshold (float): Acceptance threshold. Default is 0. use_onnx (bool): Whether to use ONNX runtime for inference. Default is False. language (str): Language of the anonymize detect. Default is "en". - model_kwargs (Optional[Dict]): Keyword arguments passed to the model. - pipeline_kwargs (Optional[Dict]): Keyword arguments passed to the pipeline. """ if language not in ALL_SUPPORTED_LANGUAGES: @@ -108,12 +104,13 @@ def __init__( self._threshold = threshold self._language = language + if not recognizer_conf: + recognizer_conf = DEBERTA_AI4PRIVACY_v2_CONF + transformers_recognizer = get_transformers_recognizer( recognizer_conf=recognizer_conf, use_onnx=use_onnx, supported_language=language, - model_kwargs=model_kwargs, - pipeline_kwargs=pipeline_kwargs, ) self._analyzer = get_analyzer( diff --git a/llm_guard/input_scanners/anonymize_helpers/analyzer.py b/llm_guard/input_scanners/anonymize_helpers/analyzer.py index 09064aa4..ba0f081d 100644 --- a/llm_guard/input_scanners/anonymize_helpers/analyzer.py +++ b/llm_guard/input_scanners/anonymize_helpers/analyzer.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, List, Optional, Sequence +from typing import Dict, List, Sequence import spacy from presidio_analyzer import ( @@ -109,8 +109,6 @@ def get_transformers_recognizer( recognizer_conf: Dict, use_onnx: bool = False, supported_language: str = "en", - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ) -> EntityRecognizer: """ This function loads a transformers recognizer given a recognizer configuration. @@ -119,20 +117,16 @@ def get_transformers_recognizer( recognizer_conf (Dict): Configuration to recognize PII data. use_onnx (bool): Whether to use the ONNX version of the model. Default is False. supported_language (str): The language to use for the recognizer. Default is "en". - model_kwargs (Optional[Dict]): Keyword arguments passed to the model. - pipeline_kwargs (Optional[Dict]): Keyword arguments passed to the pipeline. """ - model_path = recognizer_conf.get("DEFAULT_MODEL_PATH") + model = recognizer_conf.get("DEFAULT_MODEL") supported_entities = recognizer_conf.get("PRESIDIO_SUPPORTED_ENTITIES") transformers_recognizer = TransformersRecognizer( - model_path=model_path, + model=model, supported_entities=supported_entities, supported_language=supported_language, ) transformers_recognizer.load_transformer( use_onnx=use_onnx, - model_kwargs=model_kwargs, - pipeline_kwargs=pipeline_kwargs, **recognizer_conf, ) return transformers_recognizer diff --git a/llm_guard/input_scanners/anonymize_helpers/ner_mapping.py b/llm_guard/input_scanners/anonymize_helpers/ner_mapping.py index 628e960a..33d82408 100644 --- a/llm_guard/input_scanners/anonymize_helpers/ner_mapping.py +++ b/llm_guard/input_scanners/anonymize_helpers/ner_mapping.py @@ -1,11 +1,16 @@ +from llm_guard.model import Model + BERT_BASE_NER_CONF = { "PRESIDIO_SUPPORTED_ENTITIES": [ "LOCATION", "PERSON", "ORGANIZATION", ], - "DEFAULT_MODEL_PATH": "dslim/bert-base-NER", - "ONNX_MODEL_PATH": "dslim/bert-base-NER", + "DEFAULT_MODEL": Model( + path="dslim/bert-base-NER", + onnx_path="dslim/bert-base-NER", + onnx_subfolder="onnx", + ), "LABELS_TO_IGNORE": ["O", "CARDINAL"], "DEFAULT_EXPLANATION": "Identified as {} by the dslim/bert-base-NER NER model", "SUB_WORD_AGGREGATION": "simple", @@ -33,8 +38,11 @@ "PERSON", "ORGANIZATION", ], - "DEFAULT_MODEL_PATH": "dslim/bert-large-NER", - "ONNX_MODEL_PATH": "dslim/bert-large-NER", + "DEFAULT_MODEL": Model( + path="dslim/bert-large-NER", + onnx_path="dslim/bert-large-NER", + onnx_subfolder="onnx", + ), "LABELS_TO_IGNORE": ["O", "CARDINAL"], "DEFAULT_EXPLANATION": "Identified as {} by the dslim/bert-large-NER NER model", "SUB_WORD_AGGREGATION": "simple", @@ -62,8 +70,10 @@ "PERSON", "ORGANIZATION", ], - "DEFAULT_MODEL_PATH": "gyr66/bert-base-chinese-finetuned-ner", - "ONNX_MODEL_PATH": "ProtectAI/gyr66-bert-base-chinese-finetuned-ner-onnx", + "DEFAULT_MODEL": Model( + path="gyr66/bert-base-chinese-finetuned-ner", + onnx_path="ProtectAI/gyr66-bert-base-chinese-finetuned-ner-onnx", + ), "LABELS_TO_IGNORE": ["O", "CARDINAL"], "DEFAULT_EXPLANATION": "Identified as {} by the gyr66/bert-base-chinese-finetuned-ner NER model", "SUB_WORD_AGGREGATION": "simple", @@ -99,8 +109,11 @@ "IP_ADDRESS", "URL", ], - "DEFAULT_MODEL_PATH": "Isotonic/distilbert_finetuned_ai4privacy_v2", - "ONNX_MODEL_PATH": "Isotonic/distilbert_finetuned_ai4privacy_v2", + "DEFAULT_MODEL": Model( + path="Isotonic/distilbert_finetuned_ai4privacy_v2", + onnx_path="Isotonic/distilbert_finetuned_ai4privacy_v2", + subfolder="onnx", + ), "LABELS_TO_IGNORE": ["O", "CARDINAL"], "DEFAULT_EXPLANATION": "Identified as {} by the Isotonic/distilbert_finetuned_ai4privacy_v2 NER model", "SUB_WORD_AGGREGATION": "simple", @@ -186,8 +199,11 @@ "IP_ADDRESS", "URL", ], - "DEFAULT_MODEL_PATH": "Isotonic/deberta-v3-base_finetuned_ai4privacy_v2", - "ONNX_MODEL_PATH": "Isotonic/deberta-v3-base_finetuned_ai4privacy_v2", + "DEFAULT_MODEL": Model( + path="Isotonic/deberta-v3-base_finetuned_ai4privacy_v2", + onnx_path="Isotonic/deberta-v3-base_finetuned_ai4privacy_v2", + subfolder="onnx", + ), "LABELS_TO_IGNORE": ["O", "CARDINAL"], "DEFAULT_EXPLANATION": "Identified as {} by the Isotonic/deberta-v3-base_finetuned_ai4privacy_v2 NER model", "SUB_WORD_AGGREGATION": "simple", diff --git a/llm_guard/input_scanners/anonymize_helpers/transformers_recognizer.py b/llm_guard/input_scanners/anonymize_helpers/transformers_recognizer.py index f02b0dca..f622f158 100644 --- a/llm_guard/input_scanners/anonymize_helpers/transformers_recognizer.py +++ b/llm_guard/input_scanners/anonymize_helpers/transformers_recognizer.py @@ -1,10 +1,11 @@ import copy -from typing import Dict, List, Optional +from typing import List, Optional from presidio_analyzer import AnalysisExplanation, EntityRecognizer, RecognizerResult from presidio_analyzer.nlp_engine import NlpArtifacts from transformers import TokenClassificationPipeline +from llm_guard.model import Model from llm_guard.transformers_helpers import device, get_tokenizer, is_onnx_supported from llm_guard.util import get_logger, lazy_load_dep, split_text_to_word_chunks @@ -52,7 +53,7 @@ def load(self) -> None: def __init__( self, - model_path: Optional[str] = None, + model: Model, pipeline: Optional[TokenClassificationPipeline] = None, supported_entities: Optional[List[str]] = None, supported_language: str = "en", @@ -61,10 +62,10 @@ def __init__( supported_entities = BERT_BASE_NER_CONF["PRESIDIO_SUPPORTED_ENTITIES"] super().__init__( supported_entities=supported_entities, - name=f"Transformers model {model_path}", + name=f"Transformers model {model.path}", ) - self.model_path = model_path + self.model = model self.pipeline = pipeline self.is_loaded = False @@ -77,24 +78,17 @@ def __init__( self.chunk_length = None self.id_entity_name = None self.id_score_reduction = None - self.onnx_model_path = None self.supported_language = supported_language def load_transformer( self, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, **kwargs, ) -> None: """Load external configuration parameters and set default values. :param use_onnx: flag to use ONNX optimized model :type use_onnx: bool, optional - :param model_kwargs: define default values for model attributes - :type model_kwargs: Optional[Dict], optional - :param pipeline_kwargs: define default values for pipeline attributes - :type pipeline_kwargs: Optional[Dict], optional :param kwargs: define default values for class attributes and modify pipeline behavior **DATASET_TO_PRESIDIO_MAPPING (dict) - defines mapping entity strings from dataset format to Presidio format **MODEL_TO_PRESIDIO_MAPPING (dict) - defines mapping entity strings from chosen model format to Presidio format @@ -118,66 +112,57 @@ def load_transformer( self.chunk_length = kwargs.get("CHUNK_SIZE", 600) self.id_entity_name = kwargs.get("ID_ENTITY_NAME", "ID") self.id_score_reduction = kwargs.get("ID_SCORE_REDUCTION", 0.5) - self.onnx_model_path = kwargs.get("ONNX_MODEL_PATH", None) if not self.pipeline: - if not self.model_path: - self.model_path = "dslim/bert-base-NER" - self.onnx_model_path = "optimum/bert-base-NER" + if not self.model: + self.model = Model( + path="dslim/bert-base-NER", + onnx_path="dslim/bert-base-NER", + subfolder="onnx", + ) LOGGER.warning( - "Both 'model' and 'model_path' arguments are None. Using default", - model_path=self.model_path, + "'model' argument is None. Using default", + model=self.model, ) - self._load_pipeline( - use_onnx=use_onnx, model_kwargs=model_kwargs, pipeline_kwargs=pipeline_kwargs - ) + self._load_pipeline( + use_onnx=use_onnx, + ) def _load_pipeline( self, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ) -> None: """Initialize NER transformers_rec pipeline using the model_path provided""" - model = self.model_path - onnx_model = self.onnx_model_path - pipeline_kwargs = pipeline_kwargs or {} - model_kwargs = model_kwargs or {} - transformers = lazy_load_dep("transformers") - tf_tokenizer = get_tokenizer(model, **model_kwargs) + tf_tokenizer = get_tokenizer(self.model) if use_onnx and is_onnx_supported() is False: LOGGER.warning("ONNX is not supported on this machine. Using PyTorch instead of ONNX.") use_onnx = False if use_onnx: - subfolder = "onnx" if onnx_model == model else "" - if onnx_model is not None: - model = onnx_model - optimum_onnxruntime = lazy_load_dep( "optimum.onnxruntime", "optimum[onnxruntime]" if device().type != "cuda" else "optimum[onnxruntime-gpu]", ) tf_tokenizer.model_input_names = ["input_ids", "attention_mask"] tf_model = optimum_onnxruntime.ORTModelForTokenClassification.from_pretrained( - model, - export=onnx_model is None, - subfolder=subfolder, + self.model.onnx_path, + export=False, + subfolder=self.model.onnx_subfolder, provider="CUDAExecutionProvider" if device().type == "cuda" else "CPUExecutionProvider", use_io_binding=True if device().type == "cuda" else False, - **model_kwargs, + **self.model.kwargs, ) - LOGGER.debug("Initialized NER ONNX model", model=model, device=device()) + LOGGER.debug("Initialized NER ONNX model", model=self.model, device=device()) else: tf_model = transformers.AutoModelForTokenClassification.from_pretrained( - model, **model_kwargs + self.model.path, subfolder=self.model.subfolder, **self.model.kwargs ) - LOGGER.debug("Initialized NER model", model=model, device=device()) + LOGGER.debug("Initialized NER model", model=self.model, device=device()) self.pipeline = transformers.pipeline( "ner", @@ -189,7 +174,7 @@ def _load_pipeline( aggregation_strategy=self.aggregation_mechanism, framework="pt", ignore_labels=self.ignore_labels, - **pipeline_kwargs, + **self.model.pipeline_kwargs, ) self.is_loaded = True diff --git a/llm_guard/input_scanners/ban_competitors.py b/llm_guard/input_scanners/ban_competitors.py index 523f84f5..a79f604e 100644 --- a/llm_guard/input_scanners/ban_competitors.py +++ b/llm_guard/input_scanners/ban_competitors.py @@ -1,15 +1,16 @@ -from typing import Dict, Optional, Sequence +from typing import Optional, Sequence from presidio_anonymizer.core.text_replace_builder import TextReplaceBuilder +from llm_guard.model import Model from llm_guard.util import device, get_logger, lazy_load_dep from .base import Scanner LOGGER = get_logger() -MODEL_BASE = "tomaarsen/span-marker-bert-base-orgs" -MODEL_SMALL = "tomaarsen/span-marker-bert-small-orgs" +MODEL_BASE = Model("tomaarsen/span-marker-bert-base-orgs") +MODEL_SMALL = Model("tomaarsen/span-marker-bert-small-orgs") class BanCompetitors(Scanner): @@ -25,8 +26,7 @@ def __init__( *, threshold: float = 0.5, redact: bool = True, - model: Optional[str] = None, - model_kwargs: Optional[Dict] = None, + model: Optional[Model] = None, ): """ Initialize BanCompetitors object. @@ -35,8 +35,7 @@ def __init__( competitors (Sequence[str]): List of competitors to detect. threshold (float, optional): Threshold to determine if a competitor is present in the prompt. Default is 0.5. redact (bool, optional): Whether to redact the competitor name. Default is True. - model (str, optional): Model to use for named-entity recognition. Default is BASE model. - model_kwargs (Dict, optional): Keyword arguments passed to the model. + model (Model, optional): Model to use for named-entity recognition. Default is BASE model. Raises: ValueError: If no topics are provided. @@ -50,7 +49,7 @@ def __init__( span_marker = lazy_load_dep("span_marker", "span-marker") self._ner_pipeline = span_marker.SpanMarkerModel.from_pretrained( - model, labels=["ORG"], **(model_kwargs or {}) + model.path, labels=["ORG"], **model.kwargs ) if device().type == "cuda": diff --git a/llm_guard/input_scanners/ban_topics.py b/llm_guard/input_scanners/ban_topics.py index 3d642d74..f24392b0 100644 --- a/llm_guard/input_scanners/ban_topics.py +++ b/llm_guard/input_scanners/ban_topics.py @@ -1,5 +1,6 @@ -from typing import Dict, Optional, Sequence +from typing import Optional, Sequence +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification, pipeline from llm_guard.util import get_logger @@ -9,14 +10,38 @@ # This model was trained on a mixture of 33 datasets and 389 classes reformatted in the universal NLI format. # The model is English only. You can also use it for multilingual zeroshot classification by first machine translating texts to English. -MODEL_LARGE = "MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33" +MODEL_LARGE = Model( + path="MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33", + onnx_path="MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33", + onnx_subfolder="onnx", + pipeline_kwargs={ + "max_length": 512, + "truncation": True, + }, +) # This is essentially the same as its larger sister MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33 only that it's smaller. # Use it if you need more speed. The model is English-only. -MODEL_BASE = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33" +MODEL_BASE = Model( + path="MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", + onnx_path="MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", + onnx_subfolder="onnx", + pipeline_kwargs={ + "max_length": 512, + "truncation": True, + }, +) # Same as above, just smaller/faster. -MODEL_XSMALL = "MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33" +MODEL_XSMALL = Model( + path="MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33", + onnx_path="MoritzLaurer/deberta-v3-xsmall-zeroshot-v1.1-all-33", + onnx_subfolder="onnx", + pipeline_kwargs={ + "max_length": 512, + "truncation": True, + }, +) class BanTopics(Scanner): @@ -31,10 +56,8 @@ def __init__( topics: Sequence[str], *, threshold: float = 0.6, - model: Optional[str] = None, + model: Optional[Model] = None, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initialize BanTopics object. @@ -42,39 +65,28 @@ def __init__( Parameters: topics (Sequence[str]): List of topics to ban. threshold (float, optional): Threshold to determine if a topic is present in the prompt. Default is 0.75. - model (Dict, optional): Model to use for zero-shot classification. Default is deberta-v3-base-zeroshot-v1. + model (Model, optional): Model to use for zero-shot classification. Default is deberta-v3-base-zeroshot-v1. use_onnx (bool, optional): Whether to use ONNX for inference. Default is False. - model_kwargs (Dict, optional): Keyword arguments passed to the model. - pipeline_kwargs (Dict, optional): Keyword arguments passed to the pipeline. Raises: ValueError: If no topics are provided. """ - if model is None: - model = MODEL_BASE - self._topics = topics self._threshold = threshold - default_pipeline_kwargs = { - "max_length": 512, - "truncation": True, - } - if pipeline_kwargs is None: - pipeline_kwargs = {} - - pipeline_kwargs = {**default_pipeline_kwargs, **pipeline_kwargs} - model_kwargs = model_kwargs or {} + if model is None: + model = MODEL_BASE tf_tokenizer, tf_model = get_tokenizer_and_model_for_classification( - model=model, onnx_model=model, use_onnx=use_onnx, **model_kwargs + model=model, + use_onnx=use_onnx, ) self._classifier = pipeline( task="zero-shot-classification", model=tf_model, tokenizer=tf_tokenizer, - **pipeline_kwargs, + **model.pipeline_kwargs, ) def scan(self, prompt: str) -> (str, bool, float): diff --git a/llm_guard/input_scanners/code.py b/llm_guard/input_scanners/code.py index a24b9587..ae844515 100644 --- a/llm_guard/input_scanners/code.py +++ b/llm_guard/input_scanners/code.py @@ -2,6 +2,7 @@ from typing import List, Optional, Sequence from llm_guard.exception import LLMGuardValidationError +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification, pipeline from llm_guard.util import calculate_risk_score, get_logger @@ -9,7 +10,12 @@ LOGGER = get_logger() -default_model_path = "philomath-1209/programming-language-identification" +DEFAULT_MODEL = Model( + path="philomath-1209/programming-language-identification", + onnx_path="philomath-1209/programming-language-identification-onnx", + onnx_subfolder="onnx", + pipeline_kwargs={"truncation": True}, +) SUPPORTED_LANGUAGES = [ "ARM Assembly", @@ -53,24 +59,20 @@ def __init__( self, languages: Sequence[str], *, - model_path: str = default_model_path, + model: Optional[Model] = None, is_blocked: bool = True, threshold: float = 0.5, use_onnx: bool = False, - model_kwargs: Optional[dict] = None, - pipeline_kwargs: Optional[dict] = None, ): """ Initializes Code with the allowed and denied languages. Parameters: - model_path (str): The path to the model to use for language detection. + model (Model, optional): The model to use for language detection. languages (Sequence[str]): The list of programming languages to allow or deny. is_blocked (bool): Whether the languages are blocked or allowed. Default is True. threshold (float): The threshold for the risk score. Default is 0.5. use_onnx (bool): Whether to use ONNX for inference. Default is False. - model_kwargs (dict, optional): Keyword arguments passed to the model. - pipeline_kwargs (dict, optional): Keyword arguments passed to the pipeline. Raises: LLMGuardValidationError: If the languages are not a subset of SUPPORTED_LANGUAGES. @@ -82,24 +84,19 @@ def __init__( self._is_blocked = is_blocked self._threshold = threshold - default_pipeline_kwargs = { - "truncation": True, - } - if pipeline_kwargs is None: - pipeline_kwargs = {} - - pipeline_kwargs = {**default_pipeline_kwargs, **pipeline_kwargs} - model_kwargs = model_kwargs or {} + if model is None: + model = DEFAULT_MODEL tf_tokenizer, tf_model = get_tokenizer_and_model_for_classification( - model=model_path, onnx_model=model_path, use_onnx=use_onnx, **model_kwargs + model=model, + use_onnx=use_onnx, ) self._pipeline = pipeline( task="text-classification", model=tf_model, tokenizer=tf_tokenizer, - **pipeline_kwargs, + **model.pipeline_kwargs, ) self._fenced_code_regex = re.compile(r"```(?:[a-zA-Z0-9]*\n)?(.*?)```", re.DOTALL) diff --git a/llm_guard/input_scanners/gibberish.py b/llm_guard/input_scanners/gibberish.py index 6154b522..504818c7 100644 --- a/llm_guard/input_scanners/gibberish.py +++ b/llm_guard/input_scanners/gibberish.py @@ -1,6 +1,7 @@ from enum import Enum -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification, pipeline from llm_guard.util import calculate_risk_score, get_logger, split_text_by_sentences @@ -8,7 +9,12 @@ LOGGER = get_logger() -default_model_path = "madhurjindal/autonlp-Gibberish-Detector-492513457" +DEFAULT_MODEL = Model( + path="madhurjindal/autonlp-Gibberish-Detector-492513457", + onnx_path="madhurjindal/autonlp-Gibberish-Detector-492513457", + onnx_subfolder="onnx", + pipeline_kwargs={"truncation": True}, +) class MatchType(Enum): @@ -30,23 +36,19 @@ class Gibberish(Scanner): def __init__( self, *, - model_path: str = default_model_path, + model: Optional[Model] = None, threshold: float = 0.7, match_type: Union[MatchType, str] = MatchType.FULL, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes the Gibberish scanner with a probability threshold for gibberish detection. Parameters: - model_path (str): The path to the model. + model (Model, optional): The model object. threshold (float): The probability threshold for gibberish detection. Default is 0.7. match_type (MatchType): Whether to match the full text or individual sentences. Default is MatchType.FULL. use_onnx (bool): Whether to use ONNX instead of PyTorch for inference. - model_kwargs (dict): Keyword arguments passed to the model. - pipeline_kwargs (dict): Keyword arguments passed to the pipeline. """ if isinstance(match_type, str): match_type = MatchType(match_type) @@ -54,24 +56,19 @@ def __init__( self._threshold = threshold self._match_type = match_type - default_pipeline_kwargs = { - "truncation": True, - } - if pipeline_kwargs is None: - pipeline_kwargs = {} - - pipeline_kwargs = {**default_pipeline_kwargs, **pipeline_kwargs} - model_kwargs = model_kwargs or {} + if model is None: + model = DEFAULT_MODEL tf_tokenizer, tf_model = get_tokenizer_and_model_for_classification( - model=model_path, onnx_model=model_path, use_onnx=use_onnx, **model_kwargs + model=model, + use_onnx=use_onnx, ) self._classifier = pipeline( task="text-classification", model=tf_model, tokenizer=tf_tokenizer, - **pipeline_kwargs, + **model.pipeline_kwargs, ) def scan(self, prompt: str) -> (str, bool, float): diff --git a/llm_guard/input_scanners/language.py b/llm_guard/input_scanners/language.py index 1b4312dc..4be34b4b 100644 --- a/llm_guard/input_scanners/language.py +++ b/llm_guard/input_scanners/language.py @@ -1,6 +1,7 @@ from enum import Enum -from typing import Dict, List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification, pipeline from llm_guard.util import calculate_risk_score, get_logger, split_text_by_sentences @@ -8,9 +9,14 @@ LOGGER = get_logger() -default_model_path = ( - "papluca/xlm-roberta-base-language-detection", - "ProtectAI/xlm-roberta-base-language-detection-onnx", +DEFAULT_MODEL = Model( + path="papluca/xlm-roberta-base-language-detection", + onnx_path="ProtectAI/xlm-roberta-base-language-detection-onnx", + pipeline_kwargs={ + "max_length": 512, + "truncation": True, + "top_k": None, + }, ) @@ -37,23 +43,20 @@ def __init__( self, valid_languages: Sequence[str], *, - model_path: Optional[str] = None, + model: Optional[Model] = None, threshold: float = 0.6, match_type: Union[MatchType, str] = MatchType.FULL, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes the Language scanner with a list of valid languages. Parameters: + model (Model, optional): A Model object containing the path to the model and its ONNX equivalent. valid_languages (Sequence[str]): A list of valid language codes in ISO 639-1. threshold (float): Minimum confidence score. match_type (MatchType): Whether to match the full text or individual sentences. Default is MatchType.FULL. use_onnx (bool): Whether to use ONNX for inference. Default is False. - model_kwargs (Dict): Keyword arguments passed to the model. - pipeline_kwargs (Dict): Keyword arguments passed to the pipeline. """ if isinstance(match_type, str): match_type = MatchType(match_type) @@ -62,31 +65,19 @@ def __init__( self._valid_languages = valid_languages self._match_type = match_type - default_pipeline_kwargs = { - "max_length": 512, - "truncation": True, - "top_k": None, - } - if pipeline_kwargs is None: - pipeline_kwargs = {} - - pipeline_kwargs = {**default_pipeline_kwargs, **pipeline_kwargs} - model_kwargs = model_kwargs or {} - - onnx_model_path = model_path - if model_path is None: - model_path = default_model_path[0] - onnx_model_path = default_model_path[1] + if model is None: + model = DEFAULT_MODEL tf_tokenizer, tf_model = get_tokenizer_and_model_for_classification( - model=model_path, onnx_model=onnx_model_path, use_onnx=use_onnx, **model_kwargs + model=model, + use_onnx=use_onnx, ) self._pipeline = pipeline( task="text-classification", model=tf_model, tokenizer=tf_tokenizer, - **pipeline_kwargs, + **model.pipeline_kwargs, ) def scan(self, prompt: str) -> (str, bool, float): diff --git a/llm_guard/input_scanners/prompt_injection.py b/llm_guard/input_scanners/prompt_injection.py index a0156c11..cf0558a7 100644 --- a/llm_guard/input_scanners/prompt_injection.py +++ b/llm_guard/input_scanners/prompt_injection.py @@ -1,6 +1,7 @@ from enum import Enum -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification, pipeline from llm_guard.util import calculate_risk_score, get_logger, split_text_by_sentences @@ -9,11 +10,15 @@ LOGGER = get_logger() # This model is proprietary but open source. -MODEL_LAIYER = { - "path": "ProtectAI/deberta-v3-base-prompt-injection", - "label": "INJECTION", - "max_length": 512, -} +DEFAULT_MODEL = Model( + path="ProtectAI/deberta-v3-base-prompt-injection", + onnx_path="ProtectAI/deberta-v3-base-prompt-injection", + onnx_subfolder="onnx", + pipeline_kwargs={ + "max_length": 512, + "truncation": True, + }, +) class MatchType(Enum): @@ -36,29 +41,25 @@ class PromptInjection(Scanner): def __init__( self, *, - model: Optional[Dict] = None, + model: Optional[Model] = None, threshold: float = 0.9, match_type: Union[MatchType, str] = MatchType.FULL, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes PromptInjection with a threshold. Parameters: - model (Dict, optional): Chosen model to classify prompt. Default is Laiyer's one. + model (Model, optional): Chosen model to classify prompt. Default is Laiyer's one. threshold (float): Threshold for the injection score. Default is 0.9. match_type (MatchType): Whether to match the full text or individual sentences. Default is MatchType.FULL. use_onnx (bool): Whether to use ONNX for inference. Defaults to False. - model_kwargs (Dict, optional): Keyword arguments passed to the model. - pipeline_kwargs (Dict, optional): Keyword arguments passed to the pipeline. Raises: ValueError: If non-existent models were provided. """ if model is None: - model = MODEL_LAIYER + model = DEFAULT_MODEL if isinstance(match_type, str): match_type = MatchType(match_type) @@ -67,25 +68,16 @@ def __init__( self._match_type = match_type self._model = model - default_pipeline_kwargs = { - "max_length": model["max_length"], - "truncation": True, - } - if pipeline_kwargs is None: - pipeline_kwargs = {} - - pipeline_kwargs = {**default_pipeline_kwargs, **pipeline_kwargs} - model_kwargs = model_kwargs or {} - tf_tokenizer, tf_model = get_tokenizer_and_model_for_classification( - model=model["path"], onnx_model=model["path"], use_onnx=use_onnx, **model_kwargs + model=model, + use_onnx=use_onnx, ) self._pipeline = pipeline( task="text-classification", model=tf_model, tokenizer=tf_tokenizer, - **pipeline_kwargs, + **model.pipeline_kwargs, ) def scan(self, prompt: str) -> (str, bool, float): @@ -96,7 +88,7 @@ def scan(self, prompt: str) -> (str, bool, float): results_all = self._pipeline(self._match_type.get_inputs(prompt)) for result in results_all: injection_score = round( - result["score"] if result["label"] == self._model["label"] else 1 - result["score"], + result["score"] if result["label"] == "INJECTION" else 1 - result["score"], 2, ) diff --git a/llm_guard/input_scanners/toxicity.py b/llm_guard/input_scanners/toxicity.py index d24c3b50..16a3e87b 100644 --- a/llm_guard/input_scanners/toxicity.py +++ b/llm_guard/input_scanners/toxicity.py @@ -1,6 +1,7 @@ from enum import Enum -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification, pipeline from llm_guard.util import calculate_risk_score, get_logger, split_text_by_sentences @@ -8,10 +9,17 @@ LOGGER = get_logger() -_model_path = ( - "unitary/unbiased-toxic-roberta", - "ProtectAI/unbiased-toxic-roberta-onnx", # ONNX model +DEFAULT_MODEL = Model( + path="unitary/unbiased-toxic-roberta", + onnx_path="ProtectAI/unbiased-toxic-roberta-onnx", + pipeline_kwargs={ + "padding": "max_length", + "top_k": None, + "function_to_apply": "sigmoid", + "truncation": True, + }, ) + _toxic_labels = [ "toxicity", "severe_toxicity", @@ -45,23 +53,19 @@ class Toxicity(Scanner): def __init__( self, *, - model_path: Optional[str] = None, + model: Optional[Model] = None, threshold: float = 0.5, match_type: Union[MatchType, str] = MatchType.FULL, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes Toxicity with a threshold for toxicity. Parameters: - model_path (str, optional): Path to the model. Default is None. + model (Model, optional): Path to the model. Default is None. threshold (float): Threshold for toxicity. Default is 0.5. match_type (MatchType): Whether to match the full text or individual sentences. Default is MatchType.FULL. use_onnx (bool): Whether to use ONNX for inference. Default is False. - model_kwargs (Dict, optional): Keyword arguments passed to the model. - pipeline_kwargs (Dict, optional): Keyword arguments passed to the pipeline. """ if isinstance(match_type, str): match_type = MatchType(match_type) @@ -69,32 +73,19 @@ def __init__( self._threshold = threshold self._match_type = match_type - default_pipeline_kwargs = { - "padding": "max_length", - "top_k": None, - "function_to_apply": "sigmoid", - "truncation": True, - } - if pipeline_kwargs is None: - pipeline_kwargs = {} - - pipeline_kwargs = {**default_pipeline_kwargs, **pipeline_kwargs} - model_kwargs = model_kwargs or {} - - onnx_model_path = model_path - if model_path is None: - model_path = _model_path[0] - onnx_model_path = _model_path[1] + if model is None: + model = DEFAULT_MODEL tf_tokenizer, tf_model = get_tokenizer_and_model_for_classification( - model=model_path, onnx_model=onnx_model_path, use_onnx=use_onnx, **model_kwargs + model=model, + use_onnx=use_onnx, ) self._pipeline = pipeline( task="text-classification", model=tf_model, tokenizer=tf_tokenizer, - **pipeline_kwargs, + **model.pipeline_kwargs, ) def scan(self, prompt: str) -> (str, bool, float): diff --git a/llm_guard/model.py b/llm_guard/model.py new file mode 100644 index 00000000..2c825064 --- /dev/null +++ b/llm_guard/model.py @@ -0,0 +1,28 @@ +import dataclasses +from typing import Dict, Optional + + +@dataclasses.dataclass +class Model: + """ + Dataclass to store model information. + + Attributes: + path (str): Path to the model. + subfolder (str): Subfolder in the model path. + onnx_path (str, optional): Path to the ONNX model. + onnx_subfolder (str): Subfolder in the ONNX model path. + kwargs (Dict, optional): Keyword arguments passed to the model (transformers). + pipeline_kwargs (Dict, optional): Keyword arguments passed to the pipeline (transformers). + """ + + path: str + subfolder: str = "" + onnx_path: Optional[str] = None + onnx_subfolder: str = "" + onnx_filename: str = "model.onnx" + kwargs: Dict = dataclasses.field(default_factory=dict) + pipeline_kwargs: Dict = dataclasses.field(default_factory=dict) + + def __str__(self): + return self.path diff --git a/llm_guard/output_scanners/ban_competitors.py b/llm_guard/output_scanners/ban_competitors.py index 6606f9ec..2b57b55a 100644 --- a/llm_guard/output_scanners/ban_competitors.py +++ b/llm_guard/output_scanners/ban_competitors.py @@ -1,6 +1,7 @@ -from typing import Dict, Optional, Sequence +from typing import Optional, Sequence from llm_guard.input_scanners.ban_competitors import BanCompetitors as InputBanCompetitors +from llm_guard.model import Model from .base import Scanner @@ -18,8 +19,7 @@ def __init__( *, threshold: float = 0.5, redact: bool = True, - model: Optional[str] = None, - model_kwargs: Optional[Dict] = None, + model: Optional[Model] = None, ): """ Initializes BanCompetitors object. @@ -28,8 +28,7 @@ def __init__( competitors (Sequence[str]): List of competitors to ban. threshold (float, optional): Threshold to determine if an organization is present in the output. Default is 0.5. redact (bool, optional): Whether to redact the organization name. Default is True. - model (str, optional): Model to use for named-entity recognition. Default is BASE model. - model_kwargs (Dict, optional): Keyword arguments passed to the model. + model (Model, optional): Model to use for named-entity recognition. Default is BASE model. Raises: ValueError: If no competitors are provided. @@ -39,7 +38,6 @@ def __init__( threshold=threshold, redact=redact, model=model, - model_kwargs=model_kwargs, ) def scan(self, prompt: str, output: str) -> (str, bool, float): diff --git a/llm_guard/output_scanners/ban_topics.py b/llm_guard/output_scanners/ban_topics.py index abc198a0..474bdbd6 100644 --- a/llm_guard/output_scanners/ban_topics.py +++ b/llm_guard/output_scanners/ban_topics.py @@ -1,6 +1,7 @@ -from typing import Dict, Optional, Sequence +from typing import Optional, Sequence from llm_guard.input_scanners.ban_topics import BanTopics as InputBanTopics +from llm_guard.model import Model from .base import Scanner @@ -17,10 +18,8 @@ def __init__( topics: Sequence[str], *, threshold: float = 0.75, - model: Optional[str] = None, + model: Optional[Model] = None, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes BanTopics with a list of topics and a probability threshold. @@ -29,10 +28,8 @@ def __init__( topics (Sequence[str]): The list of topics to be banned from the text. threshold (float): The minimum probability required for a topic to be considered present in the text. Default is 0.75. - model (Dict, optional): The name of the zero-shot-classification model to be used. Default is MODEL_BASE. + model (Model, optional): The name of the zero-shot-classification model to be used. Default is MODEL_BASE. use_onnx (bool, optional): Whether to use ONNX for inference. Default is False. - model_kwargs (Dict, optional): Keyword arguments passed to the model. - pipeline_kwargs (Dict, optional): Keyword arguments passed to the pipeline. Raises: ValueError: If no topics are provided. @@ -42,8 +39,6 @@ def __init__( threshold=threshold, model=model, use_onnx=use_onnx, - model_kwargs=model_kwargs, - pipeline_kwargs=pipeline_kwargs, ) def scan(self, prompt: str, output: str) -> (str, bool, float): diff --git a/llm_guard/output_scanners/bias.py b/llm_guard/output_scanners/bias.py index 543db1eb..84e2e0e5 100644 --- a/llm_guard/output_scanners/bias.py +++ b/llm_guard/output_scanners/bias.py @@ -1,6 +1,7 @@ from enum import Enum -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification, pipeline from llm_guard.util import calculate_risk_score, get_logger, split_text_by_sentences @@ -8,9 +9,10 @@ LOGGER = get_logger() -_model_path = ( - "valurank/distilroberta-bias", - "ProtectAI/distilroberta-bias-onnx", # ONNX model +DEFAULT_MODEL = Model( + path="valurank/distilroberta-bias", + onnx_path="ProtectAI/distilroberta-bias-onnx", + pipeline_kwargs={"truncation": True}, ) @@ -33,23 +35,19 @@ class Bias(Scanner): def __init__( self, *, - model_path: Optional[str] = None, + model: Optional[Model] = None, threshold: float = 0.7, match_type: Union[MatchType, str] = MatchType.FULL, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes the Bias scanner with a probability threshold for bias detection. Parameters: - model_path (str): The model path to use for bias detection. + model (str): The model path to use for bias detection. threshold (float): The threshold above which a text is considered biased. Default is 0.7. match_type (MatchType): Whether to match the full text or individual sentences. Default is MatchType.FULL. use_onnx (bool): Whether to use ONNX instead of PyTorch for inference. - model_kwargs (Dict, optional): Keyword arguments passed to the model. - pipeline_kwargs (Dict, optional): Keyword arguments passed to the pipeline. """ if isinstance(match_type, str): match_type = MatchType(match_type) @@ -57,29 +55,19 @@ def __init__( self._threshold = threshold self._match_type = match_type - default_pipeline_kwargs = { - "truncation": True, - } - if pipeline_kwargs is None: - pipeline_kwargs = {} - - pipeline_kwargs = {**default_pipeline_kwargs, **pipeline_kwargs} - model_kwargs = model_kwargs or {} - - onnx_model_path = model_path - if model_path is None: - model_path = _model_path[0] - onnx_model_path = _model_path[1] + if model is None: + model = DEFAULT_MODEL tf_tokenizer, tf_model = get_tokenizer_and_model_for_classification( - model=model_path, onnx_model=onnx_model_path, use_onnx=use_onnx, **model_kwargs + model=model, + use_onnx=use_onnx, ) self._classifier = pipeline( task="text-classification", model=tf_model, tokenizer=tf_tokenizer, - **pipeline_kwargs, + **model.pipeline_kwargs, ) def scan(self, prompt: str, output: str) -> (str, bool, float): diff --git a/llm_guard/output_scanners/code.py b/llm_guard/output_scanners/code.py index e2f63e23..a27367ac 100644 --- a/llm_guard/output_scanners/code.py +++ b/llm_guard/output_scanners/code.py @@ -1,7 +1,7 @@ from typing import Dict, Optional, Sequence from llm_guard.input_scanners.code import Code as InputCode -from llm_guard.input_scanners.code import default_model_path +from llm_guard.model import Model from .base import Scanner @@ -18,24 +18,21 @@ def __init__( self, languages: Sequence[str], *, - model_path: str = default_model_path, + model: Optional[Model] = None, is_blocked: bool = True, threshold: float = 0.5, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, pipeline_kwargs: Optional[Dict] = None, ): """ Initializes an instance of the Code class. Parameters: - model_path (str): The path to the model to use for language detection. + model (Model, optional): The model to use for language detection. languages (Sequence[str]): The list of programming languages to allow or deny. is_blocked (bool): Whether the languages are blocked or allowed. Default is True. threshold (float): The threshold for the model output to be considered valid. Default is 0.5. use_onnx (bool): Whether to use ONNX for inference. Default is False. - model_kwargs (dict, optional): Keyword arguments passed to the model. - pipeline_kwargs (dict, optional): Keyword arguments passed to the pipeline. Raises: ValueError: If both 'allowed' and 'denied' lists are provided or if both are empty. @@ -43,12 +40,10 @@ def __init__( self._scanner = InputCode( languages, - model_path=model_path, + model=model, is_blocked=is_blocked, threshold=threshold, use_onnx=use_onnx, - model_kwargs=model_kwargs, - pipeline_kwargs=pipeline_kwargs, ) def scan(self, prompt: str, output: str) -> (str, bool, float): diff --git a/llm_guard/output_scanners/factual_consistency.py b/llm_guard/output_scanners/factual_consistency.py index 0cfd9549..ae709a02 100644 --- a/llm_guard/output_scanners/factual_consistency.py +++ b/llm_guard/output_scanners/factual_consistency.py @@ -1,6 +1,7 @@ from typing import Dict, Optional from llm_guard.input_scanners.ban_topics import MODEL_BASE +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification from llm_guard.util import device, get_logger, lazy_load_dep @@ -20,7 +21,7 @@ class FactualConsistency(Scanner): def __init__( self, *, - model_path: Optional[str] = None, + model: Optional[Model] = None, minimum_score=0.5, use_onnx=False, model_kwargs: Optional[Dict] = None, @@ -29,7 +30,7 @@ def __init__( Initializes an instance of the Refutation class. Parameters: - model_path (str): Path to the model. Defaults to None. + model (Model, optional): The model to use for entailment checking. Defaults to None. minimum_score (float): The minimum entailment score for the output to be considered valid. Defaults to 0.5. use_onnx (bool): Whether to use the ONNX version of the model. Defaults to False. model_kwargs (Dict, optional): Keyword arguments passed to the model. @@ -37,14 +38,12 @@ def __init__( self._minimum_score = minimum_score - if model_path is None: - model_path = MODEL_BASE + if model is None: + model = MODEL_BASE self._tokenizer, self._model = get_tokenizer_and_model_for_classification( - model=model_path, - onnx_model=model_path, + model=model, use_onnx=use_onnx, - **(model_kwargs or {}), ) self._model = self._model.to(device()) diff --git a/llm_guard/output_scanners/gibberish.py b/llm_guard/output_scanners/gibberish.py index aba9a2de..579c6d99 100644 --- a/llm_guard/output_scanners/gibberish.py +++ b/llm_guard/output_scanners/gibberish.py @@ -1,7 +1,8 @@ -from typing import Dict, Optional, Union +from typing import Optional, Union from llm_guard.input_scanners.gibberish import Gibberish as InputGibberish -from llm_guard.input_scanners.gibberish import MatchType, default_model_path +from llm_guard.input_scanners.gibberish import MatchType +from llm_guard.model import Model from .base import Scanner @@ -14,32 +15,26 @@ class Gibberish(Scanner): def __init__( self, *, - model_path: str = default_model_path, + model: Optional[Model] = None, threshold: float = 0.7, match_type: Union[MatchType, str] = MatchType.FULL, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes the Gibberish scanner with a probability threshold for gibberish detection. Parameters: - model_path (str): The path to the model. + model (Model, optional): The model used. threshold (float): The probability threshold for gibberish detection. Default is 0.7. match_type (MatchType): Whether to match the full text or individual sentences. Default is MatchType.FULL. use_onnx (bool): Whether to use ONNX instead of PyTorch for inference. - model_kwargs (Dict): Keyword arguments passed to the model. - pipeline_kwargs (Dict): Keyword arguments passed to the pipeline. """ self._scanner = InputGibberish( - model_path=model_path, + model=model, threshold=threshold, match_type=match_type, use_onnx=use_onnx, - model_kwargs=model_kwargs, - pipeline_kwargs=pipeline_kwargs, ) def scan(self, prompt: str, output: str) -> (str, bool, float): diff --git a/llm_guard/output_scanners/language.py b/llm_guard/output_scanners/language.py index c13db52b..03fdec03 100644 --- a/llm_guard/output_scanners/language.py +++ b/llm_guard/output_scanners/language.py @@ -1,7 +1,8 @@ -from typing import Dict, Optional, Sequence, Union +from typing import Optional, Sequence, Union from llm_guard.input_scanners.language import Language as InputLanguage from llm_guard.input_scanners.language import MatchType +from llm_guard.model import Model from .base import Scanner @@ -16,34 +17,28 @@ def __init__( self, valid_languages: Sequence[str], *, - model_path: Optional[str] = None, + model: Optional[Model] = None, threshold: float = 0.7, match_type: Union[MatchType, str] = MatchType.FULL, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes the Language scanner with a list of valid languages. Parameters: - model_path (str): The model path to use for inference. + model (Model, optional): A Model object containing the path to the model and its ONNX equivalent. valid_languages (Sequence[str]): A list of valid language codes. threshold (float): Minimum confidence score. match_type (MatchType): Whether to match the full text or individual sentences. Default is MatchType.FULL. use_onnx (bool): Whether to use ONNX for inference. Default is False. - model_kwargs (Dict): Keyword arguments passed to the model. - pipeline_kwargs (Dict): Keyword arguments passed to the pipeline. """ self._scanner = InputLanguage( valid_languages, - model_path=model_path, + model=model, threshold=threshold, match_type=match_type, use_onnx=use_onnx, - model_kwargs=model_kwargs, - pipeline_kwargs=pipeline_kwargs, ) def scan(self, prompt: str, output: str) -> (str, bool, float): diff --git a/llm_guard/output_scanners/language_same.py b/llm_guard/output_scanners/language_same.py index dc03c38d..796ce2d4 100644 --- a/llm_guard/output_scanners/language_same.py +++ b/llm_guard/output_scanners/language_same.py @@ -1,6 +1,7 @@ -from typing import Dict, Optional +from typing import Optional -from llm_guard.input_scanners.language import default_model_path +from llm_guard.input_scanners.language import DEFAULT_MODEL +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification, pipeline from llm_guard.util import get_logger @@ -17,50 +18,34 @@ class LanguageSame(Scanner): def __init__( self, *, - model_path: Optional[str] = None, + model: Optional[Model] = None, threshold: float = 0.1, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes the LanguageSame scanner. Parameters: - model_path (str): The path to the model. Default is None. + model (Model, optional): Model to be used for scanning. Default is None. threshold (float): Minimum confidence score use_onnx (bool): Whether to use ONNX for inference. Default is False. - model_kwargs (Dict, optional): Keyword arguments passed to the model. - pipeline_kwargs (Dict, optional): Keyword arguments passed to the pipeline. """ self._threshold = threshold - default_pipeline_kwargs = { - "max_length": 512, - "truncation": True, - "top_k": None, - } - if pipeline_kwargs is None: - pipeline_kwargs = {} - - pipeline_kwargs = {**default_pipeline_kwargs, **pipeline_kwargs} - model_kwargs = model_kwargs or {} - - onnx_model_path = model_path - if model_path is None: - model_path = default_model_path[0] - onnx_model_path = default_model_path[1] + if model is None: + model = DEFAULT_MODEL tf_tokenizer, tf_model = get_tokenizer_and_model_for_classification( - model=model_path, onnx_model=onnx_model_path, use_onnx=use_onnx, **model_kwargs + model=model, + use_onnx=use_onnx, ) self._pipeline = pipeline( task="text-classification", model=tf_model, tokenizer=tf_tokenizer, - **pipeline_kwargs, + **model.pipeline_kwargs, ) def scan(self, prompt: str, output: str) -> (str, bool, float): diff --git a/llm_guard/output_scanners/malicious_urls.py b/llm_guard/output_scanners/malicious_urls.py index c52becc3..a6fca61f 100644 --- a/llm_guard/output_scanners/malicious_urls.py +++ b/llm_guard/output_scanners/malicious_urls.py @@ -1,14 +1,20 @@ -from typing import Dict, Optional +from typing import Optional +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification, pipeline from llm_guard.util import calculate_risk_score, extract_urls, get_logger from .base import Scanner LOGGER = get_logger() -_model_path = ( - "DunnBC22/codebert-base-Malicious_URLs", - "ProtectAI/codebert-base-Malicious_URLs-onnx", # ONNX version +DEFAULT_MODEL = Model( + path="DunnBC22/codebert-base-Malicious_URLs", + onnx_path="ProtectAI/codebert-base-Malicious_URLs-onnx", + pipeline_kwargs={ + "max_length": 512, + "truncation": True, + "top_k": None, + }, ) _malicious_labels = [ @@ -30,50 +36,34 @@ class MaliciousURLs(Scanner): def __init__( self, *, - model_path: Optional[str] = None, + model: Optional[Model] = None, threshold=0.5, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes an instance of the MaliciousURLs class. Parameters: - model_path (str): The model path to use for malicious URL detection. + model (Model, optional): The model to use for malicious URL detection. threshold (float): The threshold used to determine if the website is malicious. Defaults to 0.5. use_onnx (bool): Whether to use the ONNX version of the model. Defaults to False. - model_kwargs (Dict, optional): Keyword arguments passed to the model. - pipeline_kwargs (Dict, optional): Keyword arguments passed to the pipeline. """ self._threshold = threshold - default_pipeline_kwargs = { - "max_length": 512, - "truncation": True, - "top_k": None, - } - if pipeline_kwargs is None: - pipeline_kwargs = {} - - pipeline_kwargs = {**default_pipeline_kwargs, **pipeline_kwargs} - model_kwargs = model_kwargs or {} - - onnx_model_path = model_path - if model_path is None: - model_path = _model_path[0] - onnx_model_path = _model_path[1] + if model is None: + model = DEFAULT_MODEL tf_tokenizer, tf_model = get_tokenizer_and_model_for_classification( - model=model_path, onnx_model=onnx_model_path, use_onnx=use_onnx, **model_kwargs + model=model, + use_onnx=use_onnx, ) self._classifier = pipeline( task="text-classification", model=tf_model, tokenizer=tf_tokenizer, - **pipeline_kwargs, + **model.pipeline_kwargs, ) def scan(self, prompt: str, output: str) -> (str, bool, float): diff --git a/llm_guard/output_scanners/no_refusal.py b/llm_guard/output_scanners/no_refusal.py index 1c42c22e..3eafab5c 100644 --- a/llm_guard/output_scanners/no_refusal.py +++ b/llm_guard/output_scanners/no_refusal.py @@ -1,6 +1,7 @@ from enum import Enum -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification, pipeline from llm_guard.util import calculate_risk_score, get_logger, split_text_by_sentences @@ -8,7 +9,15 @@ LOGGER = get_logger() -_model_path = "ProtectAI/distilroberta-base-rejection-v1" +DEFAULT_MODEL = Model( + path="ProtectAI/distilroberta-base-rejection-v1", + onnx_path="ProtectAI/distilroberta-base-rejection-v1-onnx", + onnx_subfolder="onnx", + pipeline_kwargs={ + "max_length": 512, + "truncation": True, + }, +) class MatchType(Enum): @@ -32,23 +41,19 @@ class NoRefusal(Scanner): def __init__( self, *, - model_path: str = _model_path, + model: Optional[Model] = None, threshold: float = 0.75, match_type: Union[MatchType, str] = MatchType.FULL, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes an instance of the NoRefusal class. Parameters: - model_path (str): The model path to use for scanning. + model (Model, optional): The model to use for refusal detection. threshold (float): The similarity threshold to consider an output as refusal. match_type (MatchType): Whether to match the full text or individual sentences. Default is MatchType.FULL. use_onnx (bool): Whether to use the ONNX version of the model. Defaults to False. - model_kwargs (Dict, optional): Keyword arguments passed to the model. - pipeline_kwargs (Dict, optional): Keyword arguments passed to the pipeline. """ if isinstance(match_type, str): @@ -57,25 +62,19 @@ def __init__( self._threshold = threshold self._match_type = match_type - default_pipeline_kwargs = { - "max_length": 512, - "truncation": True, - } - if pipeline_kwargs is None: - pipeline_kwargs = {} - - pipeline_kwargs = {**default_pipeline_kwargs, **pipeline_kwargs} - model_kwargs = model_kwargs or {} + if model is None: + model = DEFAULT_MODEL tf_tokenizer, tf_model = get_tokenizer_and_model_for_classification( - model=model_path, onnx_model=model_path, use_onnx=use_onnx, **model_kwargs + model=model, + use_onnx=use_onnx, ) self._pipeline = pipeline( task="text-classification", model=tf_model, tokenizer=tf_tokenizer, - **pipeline_kwargs, + **model.pipeline_kwargs, ) def scan(self, prompt: str, output: str) -> (str, bool, float): diff --git a/llm_guard/output_scanners/relevance.py b/llm_guard/output_scanners/relevance.py index e4673c77..3c1fa87d 100644 --- a/llm_guard/output_scanners/relevance.py +++ b/llm_guard/output_scanners/relevance.py @@ -1,5 +1,6 @@ -from typing import Dict, Optional +from typing import Optional +from llm_guard.model import Model from llm_guard.transformers_helpers import get_tokenizer, is_onnx_supported from llm_guard.util import device, get_logger, lazy_load_dep @@ -7,17 +8,17 @@ LOGGER = get_logger() -MODEL_EN_BGE_BASE = ( - "BAAI/bge-base-en-v1.5", - "zeroshot/bge-base-en-v1.5-quant", # Quantized and converted to ONNX version of BGE base +MODEL_EN_BGE_BASE = Model( + path="BAAI/bge-base-en-v1.5", + onnx_path="neuralmagic/bge-base-en-v1.5-quant", # Quantized and converted to ONNX version of BGE base ) -MODEL_EN_BGE_LARGE = ( - "BAAI/bge-large-en-v1.5", - "zeroshot/bge-large-en-v1.5-quant", # Quantized and converted to ONNX version of BGE large +MODEL_EN_BGE_LARGE = Model( + path="BAAI/bge-large-en-v1.5", + onnx_path="neuralmagic/bge-large-en-v1.5-quant", # Quantized and converted to ONNX version of BGE large ) -MODEL_EN_BGE_SMALL = ( - "BAAI/bge-small-en-v1.5", - "zeroshot/bge-small-en-v1.5-quant", # Quantized and converted to ONNX version of BGE small +MODEL_EN_BGE_SMALL = Model( + path="BAAI/bge-small-en-v1.5", + onnx_path="neuralmagic/bge-small-en-v1.5-quant", # Quantized and converted to ONNX version of BGE small ) torch = lazy_load_dep("torch") @@ -37,27 +38,22 @@ def __init__( self, *, threshold: float = 0.5, - model_path: Optional[str] = None, + model: Optional[Model] = None, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, ): """ Initializes an instance of the Relevance class. Parameters: threshold (float): The minimum similarity score to compare prompt and output. - model_path (str, optional): Model for calculating embeddings. Default is `BAAI/bge-base-en-v1.5`. + model (Model, optional): Model for calculating embeddings. Default is `BAAI/bge-base-en-v1.5`. use_onnx (bool): Whether to use the ONNX version of the model. Defaults to False. - model_kwargs (Dict, optional): Keyword arguments passed to the model. """ self._threshold = threshold - model_kwargs = model_kwargs or {} - onnx_model_path = model_path - if model_path is None: - model_path = MODEL_EN_BGE_BASE[0] - onnx_model_path = MODEL_EN_BGE_BASE[1] + if model is None: + model = MODEL_EN_BGE_BASE self.pooling_method = "cls" self.normalize_embeddings = True @@ -67,30 +63,30 @@ def __init__( use_onnx = False if use_onnx: - model_path = onnx_model_path optimum_onnxruntime = lazy_load_dep( "optimum.onnxruntime", "optimum[onnxruntime-gpu]" if device().type == "cuda" else "optimum[onnxruntime]", ) self._model = optimum_onnxruntime.ORTModelForFeatureExtraction.from_pretrained( - model_path, + model.onnx_path, export=False, + subfolder=model.onnx_subfolder, provider="CUDAExecutionProvider" if device().type == "cuda" else "CPUExecutionProvider", use_io_binding=True if device().type == "cuda" else False, - **model_kwargs, + **model.kwargs, ) - LOGGER.debug("Initialized ONNX model", model=model_path, device=device()) + LOGGER.debug("Initialized ONNX model", model=model, device=device()) else: transformers = lazy_load_dep("transformers") - self._model = transformers.AutoModel.from_pretrained(model_path, **model_kwargs).to( - device() - ) - LOGGER.debug("Initialized model", model=model_path, device=device()) + self._model = transformers.AutoModel.from_pretrained( + model.path, subfolder=model.subfolder, **model.kwargs + ).to(device()) + LOGGER.debug("Initialized model", model=model, device=device()) self._model.eval() - self._tokenizer = get_tokenizer(model_path, **model_kwargs) + self._tokenizer = get_tokenizer(model) def pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor = None): if self.pooling_method == "cls": diff --git a/llm_guard/output_scanners/sensitive.py b/llm_guard/output_scanners/sensitive.py index c769a292..f3b58b52 100644 --- a/llm_guard/output_scanners/sensitive.py +++ b/llm_guard/output_scanners/sensitive.py @@ -30,11 +30,9 @@ def __init__( entity_types: Optional[Sequence[str]] = None, regex_patterns: Optional[List[Dict]] = None, redact: bool = False, - recognizer_conf: Optional[Dict] = DEBERTA_AI4PRIVACY_v2_CONF, + recognizer_conf: Optional[Dict] = None, threshold: float = 0.5, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes an instance of the Sensitive class. @@ -47,8 +45,6 @@ def __init__( recognizer_conf (Optional[Dict]): Configuration to recognize PII data. Default is Ai4Privacy DeBERTa. threshold (float): Acceptance threshold. Default is 0. use_onnx (bool): Use ONNX model for inference. Default is False. - model_kwargs (Optional[Dict]): Keyword arguments passed to the model. - pipeline_kwargs (Optional[Dict]): Keyword arguments passed to the pipeline. """ if not entity_types: LOGGER.debug( @@ -61,11 +57,12 @@ def __init__( self._redact = redact self._threshold = threshold + if not recognizer_conf: + recognizer_conf = DEBERTA_AI4PRIVACY_v2_CONF + transformers_recognizer = get_transformers_recognizer( recognizer_conf=recognizer_conf, use_onnx=use_onnx, - model_kwargs=model_kwargs, - pipeline_kwargs=pipeline_kwargs, ) self._analyzer = get_analyzer( transformers_recognizer, get_regex_patterns(regex_patterns), [] diff --git a/llm_guard/output_scanners/toxicity.py b/llm_guard/output_scanners/toxicity.py index 6a9374bd..77d80e79 100644 --- a/llm_guard/output_scanners/toxicity.py +++ b/llm_guard/output_scanners/toxicity.py @@ -1,7 +1,8 @@ -from typing import Dict, Optional, Union +from typing import Optional, Union from llm_guard.input_scanners.toxicity import MatchType from llm_guard.input_scanners.toxicity import Toxicity as InputToxicity +from llm_guard.model import Model from .base import Scanner @@ -16,32 +17,26 @@ class Toxicity(Scanner): def __init__( self, *, - model_path: Optional[str] = None, + model: Optional[Model] = None, threshold: float = 0.7, match_type: Union[MatchType, str] = MatchType.FULL, use_onnx: bool = False, - model_kwargs: Optional[Dict] = None, - pipeline_kwargs: Optional[Dict] = None, ): """ Initializes an instance of the Toxicity class. Parameters: - model_path (str, optional): The path to the model. Defaults to None. + model (Model, optional): The path to the model. Defaults to None. threshold (float): The threshold used to determine toxicity. Defaults to 0.7. match_type (MatchType): Whether to match the full text or individual sentences. Defaults to MatchType.FULL. use_onnx (bool): Whether to use ONNX for inference. Defaults to False. - model_kwargs (Optional[Dict]): Optional keyword arguments for the model. - pipeline_kwargs (Optional[Dict]): Optional keyword arguments for the pipeline. """ self._scanner = InputToxicity( - model_path=model_path, + model=model, threshold=threshold, match_type=match_type, use_onnx=use_onnx, - model_kwargs=model_kwargs, - pipeline_kwargs=pipeline_kwargs, ) def scan(self, prompt: str, output: str) -> (str, bool, float): diff --git a/llm_guard/transformers_helpers.py b/llm_guard/transformers_helpers.py index 07326fbf..6b1f9534 100644 --- a/llm_guard/transformers_helpers.py +++ b/llm_guard/transformers_helpers.py @@ -1,6 +1,6 @@ import importlib from functools import lru_cache -from typing import Literal, Optional, Union, get_args +from typing import Literal, Union, get_args from transformers import ( PreTrainedModel, @@ -10,22 +10,22 @@ ) from .exception import LLMGuardValidationError +from .model import Model from .util import device, get_logger, lazy_load_dep LOGGER = get_logger() -@lru_cache(maxsize=None) # Set maxsize=None for an unbounded cache -def get_tokenizer(model_identifier: str, **kwargs): +def get_tokenizer(model: Model, **kwargs): """ This function loads a tokenizer given a model identifier and caches it. Subsequent calls with the same model_identifier will return the cached tokenizer. Args: - model_identifier (str): The model identifier to load the tokenizer for. + model (Model): The model to load the tokenizer for. """ transformers = lazy_load_dep("transformers") - tokenizer = transformers.AutoTokenizer.from_pretrained(model_identifier, **kwargs) + tokenizer = transformers.AutoTokenizer.from_pretrained(model.path, **kwargs) return tokenizer @@ -44,18 +44,18 @@ def is_onnx_supported() -> bool: def _ort_model_for_sequence_classification( - model: str, export: bool = False, subfolder: str = "", **kwargs + model: Model, ): if device().type == "cuda": optimum_onnxruntime = lazy_load_dep("optimum.onnxruntime", "optimum[onnxruntime-gpu]") tf_model = optimum_onnxruntime.ORTModelForSequenceClassification.from_pretrained( model, - export=export, - subfolder=subfolder, - file_name="model.onnx", + export=model.onnx_path is None, + file_name=model.onnx_filename, provider="CUDAExecutionProvider", use_io_binding=True, - **kwargs, + subfolder=model.onnx_subfolder, + **model.kwargs, ) LOGGER.debug("Initialized classification ONNX model", model=model, device=device()) @@ -65,10 +65,10 @@ def _ort_model_for_sequence_classification( optimum_onnxruntime = lazy_load_dep("optimum.onnxruntime", "optimum[onnxruntime]") tf_model = optimum_onnxruntime.ORTModelForSequenceClassification.from_pretrained( model, - export=export, - subfolder=subfolder, - file_name="model.onnx", - **kwargs, + export=model.onnx_path is None, + file_name=model.onnx_filename, + subfolder=model.onnx_subfolder, + **model.kwargs, ) LOGGER.debug("Initialized classification ONNX model", model=model, device=device()) @@ -76,7 +76,8 @@ def _ort_model_for_sequence_classification( def get_tokenizer_and_model_for_classification( - model: str, onnx_model: Optional[str] = None, use_onnx: bool = False, **kwargs + model: Model, + use_onnx: bool = False, ): """ This function loads a tokenizer and model given a model identifier and caches them. @@ -84,36 +85,30 @@ def get_tokenizer_and_model_for_classification( Args: model (str): The model identifier to load the tokenizer and model for. - onnx_model (Optional[str]): The model identifier to load the ONNX model for. Defaults to None. use_onnx (bool): Whether to use the ONNX version of the model. Defaults to False. - **kwargs: Keyword arguments to pass to the tokenizer and model. """ - tf_tokenizer = get_tokenizer(model, **kwargs) + tf_tokenizer = get_tokenizer(model, **model.kwargs) transformers = lazy_load_dep("transformers") - if kwargs.get("max_length", None) is None: - kwargs["max_length"] = tf_tokenizer.model_max_length + if model.kwargs.get("max_length", None) is None: + model.kwargs["max_length"] = tf_tokenizer.model_max_length if use_onnx and is_onnx_supported() is False: LOGGER.warning("ONNX is not supported on this machine. Using PyTorch instead of ONNX.") use_onnx = False if use_onnx is False: - tf_model = transformers.AutoModelForSequenceClassification.from_pretrained(model, **kwargs) + tf_model = transformers.AutoModelForSequenceClassification.from_pretrained( + model.path, subfolder=model.subfolder, **model.kwargs + ) LOGGER.debug("Initialized classification model", model=model, device=device()) return tf_tokenizer, tf_model - subfolder = "onnx" if onnx_model == model else "" - if onnx_model is not None: - model = onnx_model - # Hack for some models tf_tokenizer.model_input_names = ["input_ids", "attention_mask"] - tf_model = _ort_model_for_sequence_classification( - model, export=onnx_model is None, subfolder=subfolder, **kwargs - ) + tf_model = _ort_model_for_sequence_classification(model) return tf_tokenizer, tf_model diff --git a/pyproject.toml b/pyproject.toml index f0373a40..633208ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "sentencepiece==0.2.0", "tiktoken>=0.5,<0.7", "torch==2.0.1", - "transformers==4.38.2", + "transformers==4.39.0", "span-marker==1.5.0", "structlog>=24" ] diff --git a/tests/input_scanners/test_anonymize.py b/tests/input_scanners/test_anonymize.py index 05b3f92f..f249647d 100644 --- a/tests/input_scanners/test_anonymize.py +++ b/tests/input_scanners/test_anonymize.py @@ -206,7 +206,9 @@ def test_scan(settings, prompt, expected_prompt, expected_valid, expected_score) "settings,prompt,expected_prompt,expected_valid,expected_score", [ ( - {}, + { + "recognizer_conf": BERT_ZH_NER_CONF, + }, "只是一个简单的提示", "只是一个简单的提示", True, @@ -246,7 +248,9 @@ def test_scan(settings, prompt, expected_prompt, expected_valid, expected_score) 1.0, ), # Exposed all but only with PERSON and hidden_names. ( - {}, + { + "recognizer_conf": BERT_ZH_NER_CONF, + }, "", "", True,