diff --git a/.github/workflows/python-quality.yml b/.github/workflows/python-quality.yml index 039c146ed1..cb66dcea17 100644 --- a/.github/workflows/python-quality.yml +++ b/.github/workflows/python-quality.yml @@ -42,7 +42,7 @@ jobs: - run: .venv/bin/python utils/check_static_imports.py - run: .venv/bin/python utils/generate_async_inference_client.py - run: .venv/bin/python utils/generate_inference_types.py - - run: .venv/bin/python utils/generate_task_parameters.py + - run: .venv/bin/python utils/check_task_parameters.py # Run type checking at least on huggingface_hub root file to check all modules # that can be lazy-loaded actually exist. diff --git a/Makefile b/Makefile index c00fc30d4e..ee9bb21a66 100644 --- a/Makefile +++ b/Makefile @@ -23,11 +23,11 @@ style: inference_check: python utils/generate_inference_types.py - python utils/generate_task_parameters.py + python utils/check_task_parameters.py inference_update: python utils/generate_inference_types.py --update - python utils/generate_task_parameters.py --update + python utils/check_task_parameters.py --update repocard: diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index d6f719a997..a97263a7f9 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -292,8 +292,10 @@ "ChatCompletionInputFunctionDefinition", "ChatCompletionInputFunctionName", "ChatCompletionInputGrammarType", + "ChatCompletionInputGrammarTypeType", "ChatCompletionInputMessage", "ChatCompletionInputMessageChunk", + "ChatCompletionInputMessageChunkType", "ChatCompletionInputStreamOptions", "ChatCompletionInputToolType", "ChatCompletionInputURL", @@ -322,6 +324,7 @@ "DocumentQuestionAnsweringOutputElement", "DocumentQuestionAnsweringParameters", "FeatureExtractionInput", + "FeatureExtractionInputTruncationDirection", "FillMaskInput", "FillMaskOutputElement", "FillMaskParameters", @@ -332,6 +335,7 @@ "ImageSegmentationInput", "ImageSegmentationOutputElement", "ImageSegmentationParameters", + "ImageSegmentationSubtask", "ImageToImageInput", "ImageToImageOutput", "ImageToImageParameters", @@ -354,12 +358,14 @@ "SummarizationInput", "SummarizationOutput", "SummarizationParameters", + "SummarizationTruncationStrategy", "TableQuestionAnsweringInput", "TableQuestionAnsweringInputData", "TableQuestionAnsweringOutputElement", "Text2TextGenerationInput", "Text2TextGenerationOutput", "Text2TextGenerationParameters", + "Text2TextGenerationTruncationStrategy", "TextClassificationInput", "TextClassificationOutputElement", "TextClassificationOutputTransform", @@ -370,6 +376,7 @@ "TextGenerationOutput", "TextGenerationOutputBestOfSequence", "TextGenerationOutputDetails", + "TextGenerationOutputFinishReason", "TextGenerationOutputPrefillToken", "TextGenerationOutputToken", "TextGenerationStreamOutput", @@ -389,6 +396,7 @@ "TextToSpeechInput", "TextToSpeechOutput", "TextToSpeechParameters", + "TokenClassificationAggregationStrategy", "TokenClassificationInput", "TokenClassificationOutputElement", "TokenClassificationParameters", @@ -396,6 +404,8 @@ "TranslationInput", "TranslationOutput", "TranslationParameters", + "TranslationTruncationStrategy", + "TypeEnum", "VideoClassificationInput", "VideoClassificationOutputElement", "VideoClassificationOutputTransform", @@ -812,8 +822,10 @@ def __dir__(): ChatCompletionInputFunctionDefinition, # noqa: F401 ChatCompletionInputFunctionName, # noqa: F401 ChatCompletionInputGrammarType, # noqa: F401 + ChatCompletionInputGrammarTypeType, # noqa: F401 ChatCompletionInputMessage, # noqa: F401 ChatCompletionInputMessageChunk, # noqa: F401 + ChatCompletionInputMessageChunkType, # noqa: F401 ChatCompletionInputStreamOptions, # noqa: F401 ChatCompletionInputToolType, # noqa: F401 ChatCompletionInputURL, # noqa: F401 @@ -842,6 +854,7 @@ def __dir__(): DocumentQuestionAnsweringOutputElement, # noqa: F401 DocumentQuestionAnsweringParameters, # noqa: F401 FeatureExtractionInput, # noqa: F401 + FeatureExtractionInputTruncationDirection, # noqa: F401 FillMaskInput, # noqa: F401 FillMaskOutputElement, # noqa: F401 FillMaskParameters, # noqa: F401 @@ -852,6 +865,7 @@ def __dir__(): ImageSegmentationInput, # noqa: F401 ImageSegmentationOutputElement, # noqa: F401 ImageSegmentationParameters, # noqa: F401 + ImageSegmentationSubtask, # noqa: F401 ImageToImageInput, # noqa: F401 ImageToImageOutput, # noqa: F401 ImageToImageParameters, # noqa: F401 @@ -874,12 +888,14 @@ def __dir__(): SummarizationInput, # noqa: F401 SummarizationOutput, # noqa: F401 SummarizationParameters, # noqa: F401 + SummarizationTruncationStrategy, # noqa: F401 TableQuestionAnsweringInput, # noqa: F401 TableQuestionAnsweringInputData, # noqa: F401 TableQuestionAnsweringOutputElement, # noqa: F401 Text2TextGenerationInput, # noqa: F401 Text2TextGenerationOutput, # noqa: F401 Text2TextGenerationParameters, # noqa: F401 + Text2TextGenerationTruncationStrategy, # noqa: F401 TextClassificationInput, # noqa: F401 TextClassificationOutputElement, # noqa: F401 TextClassificationOutputTransform, # noqa: F401 @@ -890,6 +906,7 @@ def __dir__(): TextGenerationOutput, # noqa: F401 TextGenerationOutputBestOfSequence, # noqa: F401 TextGenerationOutputDetails, # noqa: F401 + TextGenerationOutputFinishReason, # noqa: F401 TextGenerationOutputPrefillToken, # noqa: F401 TextGenerationOutputToken, # noqa: F401 TextGenerationStreamOutput, # noqa: F401 @@ -909,6 +926,7 @@ def __dir__(): TextToSpeechInput, # noqa: F401 TextToSpeechOutput, # noqa: F401 TextToSpeechParameters, # noqa: F401 + TokenClassificationAggregationStrategy, # noqa: F401 TokenClassificationInput, # noqa: F401 TokenClassificationOutputElement, # noqa: F401 TokenClassificationParameters, # noqa: F401 @@ -916,6 +934,8 @@ def __dir__(): TranslationInput, # noqa: F401 TranslationOutput, # noqa: F401 TranslationParameters, # noqa: F401 + TranslationTruncationStrategy, # noqa: F401 + TypeEnum, # noqa: F401 VideoClassificationInput, # noqa: F401 VideoClassificationOutputElement, # noqa: F401 VideoClassificationOutputTransform, # noqa: F401 diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 89beb847e1..d8581b0ca8 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -76,11 +76,14 @@ DocumentQuestionAnsweringOutputElement, FillMaskOutputElement, ImageClassificationOutputElement, + ImageClassificationOutputTransform, ImageSegmentationOutputElement, + ImageSegmentationSubtask, ImageToTextOutput, ObjectDetectionOutputElement, QuestionAnsweringOutputElement, SummarizationOutput, + SummarizationTruncationStrategy, TableQuestionAnsweringOutputElement, TextClassificationOutputElement, TextClassificationOutputTransform, @@ -89,9 +92,11 @@ TextGenerationStreamOutput, TextToImageTargetSize, TextToSpeechEarlyStoppingEnum, + TokenClassificationAggregationStrategy, TokenClassificationOutputElement, ToolElement, TranslationOutput, + TranslationTruncationStrategy, VisualQuestionAnsweringOutputElement, ZeroShotClassificationOutputElement, ZeroShotImageClassificationOutputElement, @@ -941,28 +946,25 @@ def document_question_answering( a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used. Defaults to None. doc_stride (`int`, *optional*): - If the words in the document are too long to fit with the question for the model, it will - be split in several chunks with some overlap. This argument controls the size of that - overlap. + If the words in the document are too long to fit with the question for the model, it will be split in + several chunks with some overlap. This argument controls the size of that overlap. handle_impossible_answer (`bool`, *optional*): - Whether to accept impossible as an answer. + Whether to accept impossible as an answer lang (`str`, *optional*): - Language to use while running OCR. + Language to use while running OCR. Defaults to english. max_answer_len (`int`, *optional*): - The maximum length of predicted answers (e.g., only answers with a shorter length are - considered). + The maximum length of predicted answers (e.g., only answers with a shorter length are considered). max_question_len (`int`, *optional*): The maximum length of the question after tokenization. It will be truncated if needed. max_seq_len (`int`, *optional*): - The maximum length of the total sentence (context + question) in tokens of each chunk - passed to the model. The context will be split in several chunks (using doc_stride as - overlap) if needed. + The maximum length of the total sentence (context + question) in tokens of each chunk passed to the + model. The context will be split in several chunks (using doc_stride as overlap) if needed. top_k (`int`, *optional*): - The number of answers to return (will be chosen by order of likelihood). Can return less - than top_k answers if there are not enough options available within the context. - word_boxes (`List[Union[List[float], str]]`, *optional*): - A list of words and bounding boxes (normalized 0->1000). If provided, the inference will - skip the OCR step and use the provided bounding boxes instead. + The number of answers to return (will be chosen by order of likelihood). Can return less than top_k + answers if there are not enough options available within the context. + word_boxes (`List[Union[List[float], str`, *optional*): + A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR + step and use the provided bounding boxes instead. Returns: `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. @@ -1079,11 +1081,10 @@ def fill_mask( model (`str`, *optional*): The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. - targets (`List[str]`, *optional*): - When passed, the model will limit the scores to the passed targets instead of looking up - in the whole vocabulary. If the provided targets are not in the model vocab, they will be - tokenized and the first resulting token will be used (with a warning, and that might be - slower). + targets (`List[str`, *optional*): + When passed, the model will limit the scores to the passed targets instead of looking up in the whole + vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first + resulting token will be used (with a warning, and that might be slower). top_k (`int`, *optional*): When passed, overrides the number of predictions to return. Returns: @@ -1117,7 +1118,7 @@ def image_classification( image: ContentT, *, model: Optional[str] = None, - function_to_apply: Optional[Literal["sigmoid", "softmax", "none"]] = None, + function_to_apply: Optional["ImageClassificationOutputTransform"] = None, top_k: Optional[int] = None, ) -> List[ImageClassificationOutputElement]: """ @@ -1129,8 +1130,8 @@ def image_classification( model (`str`, *optional*): The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used. - function_to_apply (`Literal["sigmoid", "softmax", "none"]`, *optional*): - The function to apply to the output scores. + function_to_apply (`"ImageClassificationOutputTransform"`, *optional*): + The function to apply to the output. top_k (`int`, *optional*): When specified, limits the output to the top K most probable classes. Returns: @@ -1162,7 +1163,7 @@ def image_segmentation( model: Optional[str] = None, mask_threshold: Optional[float] = None, overlap_mask_area_threshold: Optional[float] = None, - subtask: Optional[Literal["instance", "panoptic", "semantic"]] = None, + subtask: Optional["ImageSegmentationSubtask"] = None, threshold: Optional[float] = None, ) -> List[ImageSegmentationOutputElement]: """ @@ -1184,7 +1185,7 @@ def image_segmentation( Threshold to use when turning the predicted masks into binary values. overlap_mask_area_threshold (`float`, *optional*): Mask overlap threshold to eliminate small, disconnected segments. - subtask (`Literal["instance", "panoptic", "semantic"]`, *optional*): + subtask (`"ImageSegmentationSubtask"`, *optional*): Segmentation task to be performed, depending on model capabilities. threshold (`float`, *optional*): Probability threshold to filter out predicted masks. @@ -1483,26 +1484,24 @@ def question_answering( The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. align_to_words (`bool`, *optional*): - Attempts to align the answer to real words. Improves quality on space separated - languages. Might hurt on non-space-separated languages (like Japanese or Chinese). + Attempts to align the answer to real words. Improves quality on space separated languages. Might hurt + on non-space-separated languages (like Japanese or Chinese) doc_stride (`int`, *optional*): - If the context is too long to fit with the question for the model, it will be split in - several chunks with some overlap. This argument controls the size of that overlap. + If the context is too long to fit with the question for the model, it will be split in several chunks + with some overlap. This argument controls the size of that overlap. handle_impossible_answer (`bool`, *optional*): Whether to accept impossible as an answer. max_answer_len (`int`, *optional*): - The maximum length of predicted answers (e.g., only answers with a shorter length are - considered). + The maximum length of predicted answers (e.g., only answers with a shorter length are considered). max_question_len (`int`, *optional*): The maximum length of the question after tokenization. It will be truncated if needed. max_seq_len (`int`, *optional*): - The maximum length of the total sentence (context + question) in tokens of each chunk - passed to the model. The context will be split in several chunks (using docStride as - overlap) if needed. + The maximum length of the total sentence (context + question) in tokens of each chunk passed to the + model. The context will be split in several chunks (using docStride as overlap) if needed. top_k (`int`, *optional*): - The number of answers to return (will be chosen by order of likelihood). Note that we - return less than topk answers if there are not enough options available within the - context. + The number of answers to return (will be chosen by order of likelihood). Note that we return less than + topk answers if there are not enough options available within the context. + Returns: Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]: When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`. @@ -1604,7 +1603,7 @@ def summarization( model: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, generate_parameters: Optional[Dict[str, Any]] = None, - truncation: Optional[Literal["do_not_truncate", "longest_first", "only_first", "only_second"]] = None, + truncation: Optional["SummarizationTruncationStrategy"] = None, ) -> SummarizationOutput: """ Generate a summary of a given text using a specified model. @@ -1622,7 +1621,7 @@ def summarization( Whether to clean up the potential extra spaces in the text output. generate_parameters (`Dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. - truncation (`Literal["do_not_truncate", "longest_first", "only_first", "only_second"]`, *optional*): + truncation (`"SummarizationTruncationStrategy"`, *optional*): The truncation strategy to use. Returns: [`SummarizationOutput`]: The generated summary text. @@ -2349,10 +2348,10 @@ def text_to_image( self, prompt: str, *, - negative_prompt: Optional[str] = None, + negative_prompt: Optional[List[str]] = None, height: Optional[float] = None, width: Optional[float] = None, - num_inference_steps: Optional[float] = None, + num_inference_steps: Optional[int] = None, guidance_scale: Optional[float] = None, model: Optional[str] = None, scheduler: Optional[str] = None, @@ -2372,8 +2371,8 @@ def text_to_image( Args: prompt (`str`): The prompt to generate an image from. - negative_prompt (`str`, *optional*): - An optional negative prompt for the image generation. + negative_prompt (`List[str`, *optional*): + One or several prompt to guide what NOT to include in image generation. height (`float`, *optional*): The height in pixels of the image to generate. width (`float`, *optional*): @@ -2382,8 +2381,8 @@ def text_to_image( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*): - Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + A higher guidance scale value encourages the model to generate images closely linked to the text + prompt, but values too high may cause saturation and other artifacts. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text-to-image model will be used. @@ -2473,18 +2472,25 @@ def text_to_speech( early_stopping (`Union[bool, "TextToSpeechEarlyStoppingEnum"`, *optional*): Controls the stopping condition for beam-based methods. epsilon_cutoff (`float`, *optional*): - If set to float strictly between 0 and 1, only tokens with a conditional probability - greater than epsilon_cutoff will be sampled. In the paper, suggested values range from - 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language - Model Desmoothing](https://hf.co/papers/2210.15191) for more details. + If set to float strictly between 0 and 1, only tokens with a conditional probability greater than + epsilon_cutoff will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on + the size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://hf.co/papers/2210.15191) for more details. eta_cutoff (`float`, *optional*): - Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to + Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly + between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) + * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token + probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, + depending on the size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://hf.co/papers/2210.15191) for more details. float strictly between 0 and 1, a token is only considered if it is greater than either - eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter - term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In - the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. - See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) - for more details. + eta_cutoff (`float`, *optional*): + Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly + between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) + * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token + probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, + depending on the size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://hf.co/papers/2210.15191) for more details. max_length (`int`, *optional*): The maximum length (in tokens) of the generated text, including the input. max_new_tokens (`int`, *optional*): @@ -2494,20 +2500,19 @@ def text_to_speech( min_new_tokens (`int`, *optional*): The minimum number of tokens to generate. Takes precedence over maxLength. num_beam_groups (`int`, *optional*): - Number of groups to divide num_beams into in order to ensure diversity among different - groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. + Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. + See [this paper](https://hf.co/papers/1610.02424) for more details. num_beams (`int`, *optional*): Number of beams to use for beam search. penalty_alpha (`float`, *optional*): - The value balances the model confidence and the degeneration penalty in contrastive - search decoding. + The value balances the model confidence and the degeneration penalty in contrastive search decoding. temperature (`float`, *optional*): The value used to modulate the next token probabilities. top_k (`int`, *optional*): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional*): - If set to float < 1, only the smallest set of most probable tokens with probabilities - that add up to top_p or higher are kept for generation. + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to + top_p or higher are kept for generation. typical_p (`float`, *optional*): Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text @@ -2563,7 +2568,7 @@ def token_classification( text: str, *, model: Optional[str] = None, - aggregation_strategy: Optional[Literal["none", "simple", "first", "average", "max"]] = None, + aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None, ignore_labels: Optional[List[str]] = None, stride: Optional[int] = None, ) -> List[TokenClassificationOutputElement]: @@ -2578,10 +2583,10 @@ def token_classification( The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used. Defaults to None. - aggregation_strategy (`Literal["none", "simple", "first", "average", "max"]`, *optional*): - The strategy used to fuse tokens based on model predictions. - ignore_labels (`List[str]`, *optional*): - A list of labels to ignore. + aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*): + The strategy used to fuse tokens based on model predictions + ignore_labels (`List[str`, *optional*): + A list of labels to ignore stride (`int`, *optional*): The number of overlapping tokens between chunks when splitting the input text. @@ -2639,7 +2644,7 @@ def translation( src_lang: Optional[str] = None, tgt_lang: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, - truncation: Optional[Literal["do_not_truncate", "longest_first", "only_first", "only_second"]] = None, + truncation: Optional["TranslationTruncationStrategy"] = None, generate_parameters: Optional[Dict[str, Any]] = None, ) -> TranslationOutput: """ @@ -2663,7 +2668,7 @@ def translation( Target language to translate to. Required for models that can translate to multiple languages. clean_up_tokenization_spaces (`bool`, *optional*): Whether to clean up the potential extra spaces in the text output. - truncation (`Literal["do_not_truncate", "longest_first", "only_first", "only_second"]`, *optional*): + truncation (`"TranslationTruncationStrategy"`, *optional*): The truncation strategy to use. generate_parameters (`Dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. @@ -2686,13 +2691,13 @@ def translation( >>> client.translation("My name is Wolfgang and I live in Berlin") 'Mein Name ist Wolfgang und ich lebe in Berlin.' >>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr") - TranslationOutput(translation_text='Je m\'appelle Wolfgang et je vis Ć  Berlin.') + TranslationOutput(translation_text='Je m'appelle Wolfgang et je vis Ć  Berlin.') ``` Specifying languages: ```py >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX") - "Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica" + "Mon nom est Sarah Jessica Parker mais vous pouvez m'appeler Jessica" ``` """ # Throw error if only one of `src_lang` and `tgt_lang` was given @@ -2733,9 +2738,8 @@ def visual_question_answering( a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used. Defaults to None. top_k (`int`, *optional*): - The number of answers to return (will be chosen by order of likelihood). Note that we - return less than topk answers if there are not enough options available within the - context. + The number of answers to return (will be chosen by order of likelihood). Note that we return less than + topk answers if there are not enough options available within the context. Returns: `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. @@ -2770,7 +2774,7 @@ def zero_shot_classification( text: str, labels: List[str], *, - multi_label: bool = False, + multi_label: Optional[bool] = False, hypothesis_template: Optional[str] = None, model: Optional[str] = None, ) -> List[ZeroShotClassificationOutputElement]: @@ -2782,14 +2786,13 @@ def zero_shot_classification( The input text to classify. labels (`List[str]`): List of strings. Each string is the verbalization of a possible label for the input text. - multi_label (`bool`): - Boolean. If True, the probability for each label is evaluated independently and multiple labels can have a probability close to 1 simultaneously or all probabilities can be close to 0. - If False, the labels are considered mutually exclusive and the probability over all labels always sums to 1. Defaults to False. + multi_label (`bool`, *optional*): + Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of + the label likelihoods for each sequence is 1. If true, the labels are considered independent and + probabilities are normalized for each candidate. hypothesis_template (`str`, *optional*): - A template sentence string with curly brackets to which the label strings are added. The label strings are added at the position of the curly brackets "{}". - Zero-shot classifiers are based on NLI models, which evaluate if a hypothesis is entailed in another text or not. - For example, with hypothesis_template="This text is about {}." and labels=["economics", "politics"], the system internally creates the two hypotheses "This text is about economics." and "This text is about politics.". - The model then evaluates for both hypotheses if they are entailed in the provided `text` or not. + The sentence used in conjunction with candidateLabels to attempt the text classification by replacing + the placeholder with the candidate labels. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot classification model will be used. @@ -2887,8 +2890,8 @@ def zero_shot_image_classification( The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot image classification model will be used. hypothesis_template (`str`, *optional*): - The sentence used in conjunction with `labels` to attempt the text classification by replacing the - placeholder with the candidate labels. + The sentence used in conjunction with candidateLabels to attempt the text classification by replacing + the placeholder with the candidate labels. Returns: `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 7d3ec48d61..0b9dec1489 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -62,11 +62,14 @@ DocumentQuestionAnsweringOutputElement, FillMaskOutputElement, ImageClassificationOutputElement, + ImageClassificationOutputTransform, ImageSegmentationOutputElement, + ImageSegmentationSubtask, ImageToTextOutput, ObjectDetectionOutputElement, QuestionAnsweringOutputElement, SummarizationOutput, + SummarizationTruncationStrategy, TableQuestionAnsweringOutputElement, TextClassificationOutputElement, TextClassificationOutputTransform, @@ -75,9 +78,11 @@ TextGenerationStreamOutput, TextToImageTargetSize, TextToSpeechEarlyStoppingEnum, + TokenClassificationAggregationStrategy, TokenClassificationOutputElement, ToolElement, TranslationOutput, + TranslationTruncationStrategy, VisualQuestionAnsweringOutputElement, ZeroShotClassificationOutputElement, ZeroShotImageClassificationOutputElement, @@ -983,28 +988,25 @@ async def document_question_answering( a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used. Defaults to None. doc_stride (`int`, *optional*): - If the words in the document are too long to fit with the question for the model, it will - be split in several chunks with some overlap. This argument controls the size of that - overlap. + If the words in the document are too long to fit with the question for the model, it will be split in + several chunks with some overlap. This argument controls the size of that overlap. handle_impossible_answer (`bool`, *optional*): - Whether to accept impossible as an answer. + Whether to accept impossible as an answer lang (`str`, *optional*): - Language to use while running OCR. + Language to use while running OCR. Defaults to english. max_answer_len (`int`, *optional*): - The maximum length of predicted answers (e.g., only answers with a shorter length are - considered). + The maximum length of predicted answers (e.g., only answers with a shorter length are considered). max_question_len (`int`, *optional*): The maximum length of the question after tokenization. It will be truncated if needed. max_seq_len (`int`, *optional*): - The maximum length of the total sentence (context + question) in tokens of each chunk - passed to the model. The context will be split in several chunks (using doc_stride as - overlap) if needed. + The maximum length of the total sentence (context + question) in tokens of each chunk passed to the + model. The context will be split in several chunks (using doc_stride as overlap) if needed. top_k (`int`, *optional*): - The number of answers to return (will be chosen by order of likelihood). Can return less - than top_k answers if there are not enough options available within the context. - word_boxes (`List[Union[List[float], str]]`, *optional*): - A list of words and bounding boxes (normalized 0->1000). If provided, the inference will - skip the OCR step and use the provided bounding boxes instead. + The number of answers to return (will be chosen by order of likelihood). Can return less than top_k + answers if there are not enough options available within the context. + word_boxes (`List[Union[List[float], str`, *optional*): + A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR + step and use the provided bounding boxes instead. Returns: `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. @@ -1123,11 +1125,10 @@ async def fill_mask( model (`str`, *optional*): The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. - targets (`List[str]`, *optional*): - When passed, the model will limit the scores to the passed targets instead of looking up - in the whole vocabulary. If the provided targets are not in the model vocab, they will be - tokenized and the first resulting token will be used (with a warning, and that might be - slower). + targets (`List[str`, *optional*): + When passed, the model will limit the scores to the passed targets instead of looking up in the whole + vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first + resulting token will be used (with a warning, and that might be slower). top_k (`int`, *optional*): When passed, overrides the number of predictions to return. Returns: @@ -1162,7 +1163,7 @@ async def image_classification( image: ContentT, *, model: Optional[str] = None, - function_to_apply: Optional[Literal["sigmoid", "softmax", "none"]] = None, + function_to_apply: Optional["ImageClassificationOutputTransform"] = None, top_k: Optional[int] = None, ) -> List[ImageClassificationOutputElement]: """ @@ -1174,8 +1175,8 @@ async def image_classification( model (`str`, *optional*): The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used. - function_to_apply (`Literal["sigmoid", "softmax", "none"]`, *optional*): - The function to apply to the output scores. + function_to_apply (`"ImageClassificationOutputTransform"`, *optional*): + The function to apply to the output. top_k (`int`, *optional*): When specified, limits the output to the top K most probable classes. Returns: @@ -1208,7 +1209,7 @@ async def image_segmentation( model: Optional[str] = None, mask_threshold: Optional[float] = None, overlap_mask_area_threshold: Optional[float] = None, - subtask: Optional[Literal["instance", "panoptic", "semantic"]] = None, + subtask: Optional["ImageSegmentationSubtask"] = None, threshold: Optional[float] = None, ) -> List[ImageSegmentationOutputElement]: """ @@ -1230,7 +1231,7 @@ async def image_segmentation( Threshold to use when turning the predicted masks into binary values. overlap_mask_area_threshold (`float`, *optional*): Mask overlap threshold to eliminate small, disconnected segments. - subtask (`Literal["instance", "panoptic", "semantic"]`, *optional*): + subtask (`"ImageSegmentationSubtask"`, *optional*): Segmentation task to be performed, depending on model capabilities. threshold (`float`, *optional*): Probability threshold to filter out predicted masks. @@ -1539,26 +1540,24 @@ async def question_answering( The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. align_to_words (`bool`, *optional*): - Attempts to align the answer to real words. Improves quality on space separated - languages. Might hurt on non-space-separated languages (like Japanese or Chinese). + Attempts to align the answer to real words. Improves quality on space separated languages. Might hurt + on non-space-separated languages (like Japanese or Chinese) doc_stride (`int`, *optional*): - If the context is too long to fit with the question for the model, it will be split in - several chunks with some overlap. This argument controls the size of that overlap. + If the context is too long to fit with the question for the model, it will be split in several chunks + with some overlap. This argument controls the size of that overlap. handle_impossible_answer (`bool`, *optional*): Whether to accept impossible as an answer. max_answer_len (`int`, *optional*): - The maximum length of predicted answers (e.g., only answers with a shorter length are - considered). + The maximum length of predicted answers (e.g., only answers with a shorter length are considered). max_question_len (`int`, *optional*): The maximum length of the question after tokenization. It will be truncated if needed. max_seq_len (`int`, *optional*): - The maximum length of the total sentence (context + question) in tokens of each chunk - passed to the model. The context will be split in several chunks (using docStride as - overlap) if needed. + The maximum length of the total sentence (context + question) in tokens of each chunk passed to the + model. The context will be split in several chunks (using docStride as overlap) if needed. top_k (`int`, *optional*): - The number of answers to return (will be chosen by order of likelihood). Note that we - return less than topk answers if there are not enough options available within the - context. + The number of answers to return (will be chosen by order of likelihood). Note that we return less than + topk answers if there are not enough options available within the context. + Returns: Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]: When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`. @@ -1662,7 +1661,7 @@ async def summarization( model: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, generate_parameters: Optional[Dict[str, Any]] = None, - truncation: Optional[Literal["do_not_truncate", "longest_first", "only_first", "only_second"]] = None, + truncation: Optional["SummarizationTruncationStrategy"] = None, ) -> SummarizationOutput: """ Generate a summary of a given text using a specified model. @@ -1680,7 +1679,7 @@ async def summarization( Whether to clean up the potential extra spaces in the text output. generate_parameters (`Dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. - truncation (`Literal["do_not_truncate", "longest_first", "only_first", "only_second"]`, *optional*): + truncation (`"SummarizationTruncationStrategy"`, *optional*): The truncation strategy to use. Returns: [`SummarizationOutput`]: The generated summary text. @@ -2413,10 +2412,10 @@ async def text_to_image( self, prompt: str, *, - negative_prompt: Optional[str] = None, + negative_prompt: Optional[List[str]] = None, height: Optional[float] = None, width: Optional[float] = None, - num_inference_steps: Optional[float] = None, + num_inference_steps: Optional[int] = None, guidance_scale: Optional[float] = None, model: Optional[str] = None, scheduler: Optional[str] = None, @@ -2436,8 +2435,8 @@ async def text_to_image( Args: prompt (`str`): The prompt to generate an image from. - negative_prompt (`str`, *optional*): - An optional negative prompt for the image generation. + negative_prompt (`List[str`, *optional*): + One or several prompt to guide what NOT to include in image generation. height (`float`, *optional*): The height in pixels of the image to generate. width (`float`, *optional*): @@ -2446,8 +2445,8 @@ async def text_to_image( The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*): - Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + A higher guidance scale value encourages the model to generate images closely linked to the text + prompt, but values too high may cause saturation and other artifacts. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text-to-image model will be used. @@ -2538,18 +2537,25 @@ async def text_to_speech( early_stopping (`Union[bool, "TextToSpeechEarlyStoppingEnum"`, *optional*): Controls the stopping condition for beam-based methods. epsilon_cutoff (`float`, *optional*): - If set to float strictly between 0 and 1, only tokens with a conditional probability - greater than epsilon_cutoff will be sampled. In the paper, suggested values range from - 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language - Model Desmoothing](https://hf.co/papers/2210.15191) for more details. + If set to float strictly between 0 and 1, only tokens with a conditional probability greater than + epsilon_cutoff will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on + the size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://hf.co/papers/2210.15191) for more details. eta_cutoff (`float`, *optional*): - Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to + Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly + between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) + * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token + probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, + depending on the size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://hf.co/papers/2210.15191) for more details. float strictly between 0 and 1, a token is only considered if it is greater than either - eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter - term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In - the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. - See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) - for more details. + eta_cutoff (`float`, *optional*): + Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly + between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) + * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token + probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, + depending on the size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://hf.co/papers/2210.15191) for more details. max_length (`int`, *optional*): The maximum length (in tokens) of the generated text, including the input. max_new_tokens (`int`, *optional*): @@ -2559,20 +2565,19 @@ async def text_to_speech( min_new_tokens (`int`, *optional*): The minimum number of tokens to generate. Takes precedence over maxLength. num_beam_groups (`int`, *optional*): - Number of groups to divide num_beams into in order to ensure diversity among different - groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. + Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. + See [this paper](https://hf.co/papers/1610.02424) for more details. num_beams (`int`, *optional*): Number of beams to use for beam search. penalty_alpha (`float`, *optional*): - The value balances the model confidence and the degeneration penalty in contrastive - search decoding. + The value balances the model confidence and the degeneration penalty in contrastive search decoding. temperature (`float`, *optional*): The value used to modulate the next token probabilities. top_k (`int`, *optional*): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional*): - If set to float < 1, only the smallest set of most probable tokens with probabilities - that add up to top_p or higher are kept for generation. + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to + top_p or higher are kept for generation. typical_p (`float`, *optional*): Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text @@ -2629,7 +2634,7 @@ async def token_classification( text: str, *, model: Optional[str] = None, - aggregation_strategy: Optional[Literal["none", "simple", "first", "average", "max"]] = None, + aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None, ignore_labels: Optional[List[str]] = None, stride: Optional[int] = None, ) -> List[TokenClassificationOutputElement]: @@ -2644,10 +2649,10 @@ async def token_classification( The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used. Defaults to None. - aggregation_strategy (`Literal["none", "simple", "first", "average", "max"]`, *optional*): - The strategy used to fuse tokens based on model predictions. - ignore_labels (`List[str]`, *optional*): - A list of labels to ignore. + aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*): + The strategy used to fuse tokens based on model predictions + ignore_labels (`List[str`, *optional*): + A list of labels to ignore stride (`int`, *optional*): The number of overlapping tokens between chunks when splitting the input text. @@ -2706,7 +2711,7 @@ async def translation( src_lang: Optional[str] = None, tgt_lang: Optional[str] = None, clean_up_tokenization_spaces: Optional[bool] = None, - truncation: Optional[Literal["do_not_truncate", "longest_first", "only_first", "only_second"]] = None, + truncation: Optional["TranslationTruncationStrategy"] = None, generate_parameters: Optional[Dict[str, Any]] = None, ) -> TranslationOutput: """ @@ -2730,7 +2735,7 @@ async def translation( Target language to translate to. Required for models that can translate to multiple languages. clean_up_tokenization_spaces (`bool`, *optional*): Whether to clean up the potential extra spaces in the text output. - truncation (`Literal["do_not_truncate", "longest_first", "only_first", "only_second"]`, *optional*): + truncation (`"TranslationTruncationStrategy"`, *optional*): The truncation strategy to use. generate_parameters (`Dict[str, Any]`, *optional*): Additional parametrization of the text generation algorithm. @@ -2754,13 +2759,13 @@ async def translation( >>> await client.translation("My name is Wolfgang and I live in Berlin") 'Mein Name ist Wolfgang und ich lebe in Berlin.' >>> await client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr") - TranslationOutput(translation_text='Je m\'appelle Wolfgang et je vis Ć  Berlin.') + TranslationOutput(translation_text='Je m'appelle Wolfgang et je vis Ć  Berlin.') ``` Specifying languages: ```py >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX") - "Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica" + "Mon nom est Sarah Jessica Parker mais vous pouvez m'appeler Jessica" ``` """ # Throw error if only one of `src_lang` and `tgt_lang` was given @@ -2801,9 +2806,8 @@ async def visual_question_answering( a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used. Defaults to None. top_k (`int`, *optional*): - The number of answers to return (will be chosen by order of likelihood). Note that we - return less than topk answers if there are not enough options available within the - context. + The number of answers to return (will be chosen by order of likelihood). Note that we return less than + topk answers if there are not enough options available within the context. Returns: `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. @@ -2839,7 +2843,7 @@ async def zero_shot_classification( text: str, labels: List[str], *, - multi_label: bool = False, + multi_label: Optional[bool] = False, hypothesis_template: Optional[str] = None, model: Optional[str] = None, ) -> List[ZeroShotClassificationOutputElement]: @@ -2851,14 +2855,13 @@ async def zero_shot_classification( The input text to classify. labels (`List[str]`): List of strings. Each string is the verbalization of a possible label for the input text. - multi_label (`bool`): - Boolean. If True, the probability for each label is evaluated independently and multiple labels can have a probability close to 1 simultaneously or all probabilities can be close to 0. - If False, the labels are considered mutually exclusive and the probability over all labels always sums to 1. Defaults to False. + multi_label (`bool`, *optional*): + Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of + the label likelihoods for each sequence is 1. If true, the labels are considered independent and + probabilities are normalized for each candidate. hypothesis_template (`str`, *optional*): - A template sentence string with curly brackets to which the label strings are added. The label strings are added at the position of the curly brackets "{}". - Zero-shot classifiers are based on NLI models, which evaluate if a hypothesis is entailed in another text or not. - For example, with hypothesis_template="This text is about {}." and labels=["economics", "politics"], the system internally creates the two hypotheses "This text is about economics." and "This text is about politics.". - The model then evaluates for both hypotheses if they are entailed in the provided `text` or not. + The sentence used in conjunction with candidateLabels to attempt the text classification by replacing + the placeholder with the candidate labels. model (`str`, *optional*): The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot classification model will be used. @@ -2958,8 +2961,8 @@ async def zero_shot_image_classification( The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot image classification model will be used. hypothesis_template (`str`, *optional*): - The sentence used in conjunction with `labels` to attempt the text classification by replacing the - placeholder with the candidate labels. + The sentence used in conjunction with candidateLabels to attempt the text classification by replacing + the placeholder with the candidate labels. Returns: `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. diff --git a/src/huggingface_hub/inference/_generated/types/__init__.py b/src/huggingface_hub/inference/_generated/types/__init__.py index caa46d05fc..d59bae0ba3 100644 --- a/src/huggingface_hub/inference/_generated/types/__init__.py +++ b/src/huggingface_hub/inference/_generated/types/__init__.py @@ -24,8 +24,10 @@ ChatCompletionInputFunctionDefinition, ChatCompletionInputFunctionName, ChatCompletionInputGrammarType, + ChatCompletionInputGrammarTypeType, ChatCompletionInputMessage, ChatCompletionInputMessageChunk, + ChatCompletionInputMessageChunkType, ChatCompletionInputStreamOptions, ChatCompletionInputToolType, ChatCompletionInputURL, @@ -56,7 +58,7 @@ DocumentQuestionAnsweringOutputElement, DocumentQuestionAnsweringParameters, ) -from .feature_extraction import FeatureExtractionInput +from .feature_extraction import FeatureExtractionInput, FeatureExtractionInputTruncationDirection from .fill_mask import FillMaskInput, FillMaskOutputElement, FillMaskParameters from .image_classification import ( ImageClassificationInput, @@ -64,7 +66,12 @@ ImageClassificationOutputTransform, ImageClassificationParameters, ) -from .image_segmentation import ImageSegmentationInput, ImageSegmentationOutputElement, ImageSegmentationParameters +from .image_segmentation import ( + ImageSegmentationInput, + ImageSegmentationOutputElement, + ImageSegmentationParameters, + ImageSegmentationSubtask, +) from .image_to_image import ImageToImageInput, ImageToImageOutput, ImageToImageParameters, ImageToImageTargetSize from .image_to_text import ( ImageToTextEarlyStoppingEnum, @@ -86,13 +93,23 @@ QuestionAnsweringParameters, ) from .sentence_similarity import SentenceSimilarityInput, SentenceSimilarityInputData -from .summarization import SummarizationInput, SummarizationOutput, SummarizationParameters +from .summarization import ( + SummarizationInput, + SummarizationOutput, + SummarizationParameters, + SummarizationTruncationStrategy, +) from .table_question_answering import ( TableQuestionAnsweringInput, TableQuestionAnsweringInputData, TableQuestionAnsweringOutputElement, ) -from .text2text_generation import Text2TextGenerationInput, Text2TextGenerationOutput, Text2TextGenerationParameters +from .text2text_generation import ( + Text2TextGenerationInput, + Text2TextGenerationOutput, + Text2TextGenerationParameters, + Text2TextGenerationTruncationStrategy, +) from .text_classification import ( TextClassificationInput, TextClassificationOutputElement, @@ -106,11 +123,13 @@ TextGenerationOutput, TextGenerationOutputBestOfSequence, TextGenerationOutputDetails, + TextGenerationOutputFinishReason, TextGenerationOutputPrefillToken, TextGenerationOutputToken, TextGenerationStreamOutput, TextGenerationStreamOutputStreamDetails, TextGenerationStreamOutputToken, + TypeEnum, ) from .text_to_audio import ( TextToAudioEarlyStoppingEnum, @@ -128,11 +147,12 @@ TextToSpeechParameters, ) from .token_classification import ( + TokenClassificationAggregationStrategy, TokenClassificationInput, TokenClassificationOutputElement, TokenClassificationParameters, ) -from .translation import TranslationInput, TranslationOutput, TranslationParameters +from .translation import TranslationInput, TranslationOutput, TranslationParameters, TranslationTruncationStrategy from .video_classification import ( VideoClassificationInput, VideoClassificationOutputElement, diff --git a/utils/check_task_parameters.py b/utils/check_task_parameters.py new file mode 100644 index 0000000000..80e2586a55 --- /dev/null +++ b/utils/check_task_parameters.py @@ -0,0 +1,795 @@ +# coding=utf-8 +# Copyright 2024-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Utility script to check and update the InferenceClient task methods arguments and docstrings +based on the tasks input parameters. + +What this script does: +- [x] detect missing parameters in method signature +- [x] add missing parameters to methods signature +- [x] detect missing parameters in method docstrings +- [x] add missing parameters to methods docstrings +- [x] detect outdated parameters in method signature +- [x] update outdated parameters in method signature +- [x] detect outdated parameters in method docstrings +- [x] update outdated parameters in method docstrings +- [ ] detect when parameter not used in method implementation +- [ ] update method implementation when parameter not used +Related resources: +- https://github.com/huggingface/huggingface_hub/issues/2063 +- https://github.com/huggingface/huggingface_hub/issues/2557 +- https://github.com/huggingface/huggingface_hub/pull/2561 +""" + +import argparse +import builtins +import inspect +import re +import textwrap +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, NoReturn, Optional, Set + +import libcst as cst +from helpers import format_source_code +from libcst.codemod import CodemodContext +from libcst.codemod.visitors import GatherImportsVisitor + +from huggingface_hub import InferenceClient + + +# Paths to project files +BASE_DIR = Path(__file__).parents[1] / "src" / "huggingface_hub" +INFERENCE_TYPES_PATH = BASE_DIR / "inference" / "_generated" / "types" +INFERENCE_CLIENT_FILE = BASE_DIR / "inference" / "_client.py" + +DEFAULT_MODULE = "huggingface_hub.inference._generated.types" + + +# Temporary solution to skip tasks where there is no Parameters dataclass or the schema needs to be updated +TASKS_TO_SKIP = [ + "chat_completion", + "text_generation", + "depth_estimation", + "audio_to_audio", + "feature_extraction", + "sentence_similarity", + "table_question_answering", + "automatic_speech_recognition", + "image_to_text", + "image_to_image", +] + +PARAMETERS_DATACLASS_REGEX = re.compile( + r""" + ^@dataclass + \nclass\s(\w+Parameters)\(BaseInferenceType\): + """, + re.VERBOSE | re.MULTILINE, +) +CORE_PARAMETERS = { + "model", # Model identifier + "text", # Text input + "image", # Image input + "audio", # Audio input + "inputs", # Generic inputs + "input", # Generic input + "prompt", # For generation tasks + "question", # For QA tasks + "context", # For QA tasks + "labels", # For classification tasks +} + +#### NODE VISITORS (READING THE CODE) + + +class DataclassFieldCollector(cst.CSTVisitor): + """A visitor that collects fields (parameters) from a dataclass.""" + + def __init__(self, dataclass_name: str): + self.dataclass_name = dataclass_name + self.parameters: Dict[str, Dict[str, str]] = {} + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + """Visit class definitions to find the target dataclass.""" + + if node.name.value == self.dataclass_name: + body_statements = node.body.body + for index, field in enumerate(body_statements): + # Check if the statement is a simple statement (like a variable declaration) + if isinstance(field, cst.SimpleStatementLine): + for stmt in field.body: + # Check if it's an annotated assignment (typical for dataclass fields) + if isinstance(stmt, cst.AnnAssign) and isinstance(stmt.target, cst.Name): + param_name = stmt.target.value + param_type = cst.Module([]).code_for_node(stmt.annotation.annotation) + docstring = self._extract_docstring(body_statements, index) + self.parameters[param_name] = { + "type": param_type, + "docstring": docstring, + } + + @staticmethod + def _extract_docstring( + body_statements: List[cst.CSTNode], + field_index: int, + ) -> str: + """Extract the docstring following a field definition.""" + if field_index + 1 < len(body_statements): + # Check if the next statement is a simple statement (like a string) + next_stmt = body_statements[field_index + 1] + if isinstance(next_stmt, cst.SimpleStatementLine): + for stmt in next_stmt.body: + # Check if the statement is a string expression (potential docstring) + if isinstance(stmt, cst.Expr) and isinstance(stmt.value, cst.SimpleString): + return stmt.value.evaluated_value.strip() + # No docstring found or there's no statement after the field + return "" + + +class ModulesCollector(cst.CSTVisitor): + """Visitor that maps type names to their defining modules.""" + + def __init__(self): + self.type_to_module = {} + + def visit_ClassDef(self, node: cst.ClassDef): + """Map class definitions to the current module.""" + self.type_to_module[node.name.value] = DEFAULT_MODULE + + def visit_ImportFrom(self, node: cst.ImportFrom): + """Map imported types to their modules.""" + if node.module: + module_name = node.module.value + for alias in node.names: + self.type_to_module[alias.name.value] = module_name + + +class MethodArgumentsCollector(cst.CSTVisitor): + """Collects parameter types and docstrings from a method.""" + + def __init__(self, method_name: str): + self.method_name = method_name + self.parameters: Dict[str, Dict[str, str]] = {} + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + if node.name.value != self.method_name: + return + # Extract docstring + docstring = self._extract_docstring(node) + param_docs = self._parse_docstring_params(docstring) + # Collect parameters + for param in node.params.params + node.params.kwonly_params: + if param.name.value == "self" or param.name.value in CORE_PARAMETERS: + continue + param_type = cst.Module([]).code_for_node(param.annotation.annotation) if param.annotation else "Any" + self.parameters[param.name.value] = {"type": param_type, "docstring": param_docs.get(param.name.value, "")} + + def _extract_docstring(self, node: cst.FunctionDef) -> str: + """Extract docstring from function node.""" + if ( + isinstance(node.body.body[0], cst.SimpleStatementLine) + and isinstance(node.body.body[0].body[0], cst.Expr) + and isinstance(node.body.body[0].body[0].value, cst.SimpleString) + ): + return node.body.body[0].body[0].value.evaluated_value + return "" + + def _parse_docstring_params(self, docstring: str) -> Dict[str, str]: + """Parse parameter descriptions from docstring.""" + param_docs = {} + lines = docstring.split("\n") + + # Find Args section + args_idx = next((i for i, line in enumerate(lines) if line.strip().lower() == "args:"), None) + if args_idx is None: + return param_docs + # Parse parameter descriptions + current_param = None + current_desc = [] + for line in lines[args_idx + 1 :]: + stripped_line = line.strip() + if not stripped_line or stripped_line.lower() in ("returns:", "raises:", "example:", "examples:"): + break + + if stripped_line.endswith(":"): # Parameter line + if current_param: + param_docs[current_param] = " ".join(current_desc) + current_desc = [] + # Extract only the parameter name before the first space or parenthesis + current_param = re.split(r"\s|\(", stripped_line[:-1], 1)[0].strip() + else: # Description line + current_desc.append(stripped_line) + if current_param: # Save last parameter + param_docs[current_param] = " ".join(current_desc) + return param_docs + + +#### TREE TRANSFORMERS (UPDATING THE CODE) + + +class AddImports(cst.CSTTransformer): + """Transformer that adds import statements to the module.""" + + def __init__(self, imports_to_add: List[cst.BaseStatement]): + self.imports_to_add = imports_to_add + self.added = False + + def leave_Module( + self, + original_node: cst.Module, + updated_node: cst.Module, + ) -> cst.Module: + """Insert the import statements into the module.""" + # If imports were already added, don't add them again + if self.added: + return updated_node + insertion_index = 0 + # Find the index where to insert the imports: make sure the imports are inserted before any code and after all imports (not necessary, we can remove/simplify this part) + for idx, stmt in enumerate(updated_node.body): + if not isinstance(stmt, cst.SimpleStatementLine): + insertion_index = idx + break + elif not isinstance(stmt.body[0], (cst.Import, cst.ImportFrom)): + insertion_index = idx + break + # Insert the imports + new_body = ( + list(updated_node.body[:insertion_index]) + + list(self.imports_to_add) + + list(updated_node.body[insertion_index:]) + ) + self.added = True + return updated_node.with_changes(body=new_body) + + +class UpdateParameters(cst.CSTTransformer): + """Updates a method's parameters, types, and docstrings.""" + + def __init__(self, method_name: str, param_updates: Dict[str, Dict[str, str]]): + self.method_name = method_name + self.param_updates = param_updates + self.found_method = False # Flag to check if the method is found + + def leave_FunctionDef( + self, + original_node: cst.FunctionDef, + updated_node: cst.FunctionDef, + ) -> cst.FunctionDef: + # Only proceed if the current function is the target method + if original_node.name.value != self.method_name: + return updated_node + self.found_method = True # Set the flag as the method is found + # Update the parameters and docstring of the method + new_params = self._update_parameters(updated_node.params) + updated_body = self._update_docstring(updated_node.body) + # Return the updated function definition + return updated_node.with_changes(params=new_params, body=updated_body) + + def _update_parameters(self, params: cst.Parameters) -> cst.Parameters: + """Update parameter types and add new parameters.""" + new_params = list(params.params) # Copy regular parameters (e.g., 'self') + new_kwonly_params = [] + # Collect existing parameter names to avoid duplicates + existing_params = {p.name.value for p in params.params + params.kwonly_params} + # Update existing keyword-only parameters + for param in params.kwonly_params: + param_name = param.name.value + if param_name in self.param_updates: + # Update the type annotation for the parameter + new_annotation = cst.Annotation( + annotation=cst.parse_expression(self.param_updates[param_name]["type"]) + ) + new_kwonly_params.append(param.with_changes(annotation=new_annotation)) + else: + # Keep the parameter as is if no update is needed + new_kwonly_params.append(param) + # Add new parameters that are not already present + for param_name, param_info in self.param_updates.items(): + if param_name not in existing_params: + # Create a new parameter with the provided type and a default value of None + annotation = cst.Annotation(annotation=cst.parse_expression(param_info["type"])) + new_param = cst.Param( + name=cst.Name(param_name), + annotation=annotation, + default=cst.Name("None"), + ) + new_kwonly_params.append(new_param) + # Return the updated parameters object with new and updated parameters + return params.with_changes(params=new_params, kwonly_params=new_kwonly_params) + + def _update_docstring(self, body: cst.IndentedBlock) -> cst.IndentedBlock: + """Update parameter descriptions in the docstring.""" + # Check if the first statement is a docstring + if not ( + isinstance(body.body[0], cst.SimpleStatementLine) + and isinstance(body.body[0].body[0], cst.Expr) + and isinstance(body.body[0].body[0].value, cst.SimpleString) + ): + # Return the body unchanged if no docstring is found + return body + + docstring_expr = body.body[0].body[0] + docstring = docstring_expr.value.evaluated_value # Get the docstring content + # Update the docstring content with new and updated parameters + updated_docstring = self._update_docstring_content(docstring) + new_docstring = cst.SimpleString(f'"""{updated_docstring}"""') + # Replace the old docstring with the updated one + new_body = [body.body[0].with_changes(body=[docstring_expr.with_changes(value=new_docstring)])] + list( + body.body[1:] + ) + # Return the updated function body + return body.with_changes(body=new_body) + + def _update_docstring_content(self, docstring: str) -> str: + """Update parameter descriptions in the docstring content.""" + # Split parameters into new and updated ones based on their status + new_params = {name: info for name, info in self.param_updates.items() if info["status"] == "new"} + update_params = { + name: info for name, info in self.param_updates.items() if info["status"] in ("update_type", "update_doc") + } + # Split the docstring into lines for processing + docstring_lines = docstring.split("\n") + # Find or create the "Args:" section and compute indentation levels + args_index = next((i for i, line in enumerate(docstring_lines) if line.strip().lower() == "args:"), None) + if args_index is None: + # If 'Args:' section is not found, insert it before 'Returns:' or at the end + insertion_index = next( + ( + i + for i, line in enumerate(docstring_lines) + if line.strip().lower() in ("returns:", "raises:", "examples:", "example:") + ), + len(docstring_lines), + ) + docstring_lines.insert(insertion_index, "Args:") + args_index = insertion_index # Update the args_index with the new section + base_indent = docstring_lines[args_index][: -len(docstring_lines[args_index].lstrip())] + param_indent = base_indent + " " # Indentation for parameter lines + desc_indent = param_indent + " " # Indentation for description lines + # Update existing parameters in the docstring + if update_params: + docstring_lines = self._process_existing_params( + docstring_lines, update_params, args_index, param_indent, desc_indent + ) + # Add new parameters to the docstring + if new_params: + docstring_lines = self._add_new_params(docstring_lines, new_params, args_index, param_indent, desc_indent) + # Join the docstring lines back into a single string + return "\n".join(docstring_lines) + + def _format_param_docstring( + self, + param_name: str, + param_info: Dict[str, str], + param_indent: str, + desc_indent: str, + ) -> List[str]: + """Format the docstring lines for a single parameter.""" + # Extract and format the parameter type + param_type = param_info["type"].replace("Optional[", "").rstrip("]") + optional_str = "*optional*" if "Optional[" in param_info["type"] else "" + # Create the parameter line with type and optionality + param_line = f"{param_indent}{param_name} (`{param_type}`, {optional_str}):" + # Get and clean up the parameter description + param_desc = (param_info.get("docstring") or "").strip() + param_desc = " ".join(param_desc.split()) + if param_desc: + # Wrap the description text to maintain line width and indentation + wrapped_desc = textwrap.fill( + param_desc, + width=119, + initial_indent=desc_indent, + subsequent_indent=desc_indent, + ) + return [param_line, wrapped_desc] + else: + # Return only the parameter line if there's no description + return [param_line] + + def _process_existing_params( + self, + docstring_lines: List[str], + params_to_update: Dict[str, Dict[str, str]], + args_index: int, + param_indent: str, + desc_indent: str, + ) -> List[str]: + """Update existing parameters in the docstring.""" + i = args_index + 1 # Start after the 'Args:' section + while i < len(docstring_lines): + line = docstring_lines[i] + stripped_line = line.strip() + if not stripped_line: + # Skip empty lines + i += 1 + continue + if stripped_line.lower() in ("returns:", "raises:", "example:", "examples:"): + # Stop processing if another section starts + break + if stripped_line.endswith(":"): + # Check if the line is a parameter line + param_line = stripped_line + param_name = param_line.strip().split()[0] # Extract parameter name + if param_name in params_to_update: + # Get the updated parameter info + param_info = params_to_update[param_name] + # Format the new parameter docstring + param_doc_lines = self._format_param_docstring(param_name, param_info, param_indent, desc_indent) + # Find the end of the current parameter's description + start_idx = i + end_idx = i + 1 + while end_idx < len(docstring_lines): + next_line = docstring_lines[end_idx] + # Next parameter or section starts or another section starts or empty line + if ( + (next_line.strip().endswith(":") and not next_line.startswith(desc_indent)) + or next_line.lower() in ("returns:", "raises:", "example:", "examples:") + or not next_line + ): + break + end_idx += 1 + # Insert new param docs and preserve the rest of the docstring + docstring_lines = ( + docstring_lines[:start_idx] # Keep everything before + + param_doc_lines # Insert new parameter docs + + docstring_lines[end_idx:] # Keep everything after + ) + i = start_idx + len(param_doc_lines) # Update index to after inserted lines + i += 1 + else: + i += 1 # Move to the next line if not a parameter line + return docstring_lines + + def _add_new_params( + self, + docstring_lines: List[str], + new_params: Dict[str, Dict[str, str]], + args_index: int, + param_indent: str, + desc_indent: str, + ) -> List[str]: + """Add new parameters to the docstring.""" + # Find the insertion point after existing parameters + insertion_index = args_index + 1 + empty_line_index = None + while insertion_index < len(docstring_lines): + line = docstring_lines[insertion_index] + stripped_line = line.strip() + # Track empty line at the end of Args section + if not stripped_line: + if empty_line_index is None: # Remember first empty line + empty_line_index = insertion_index + insertion_index += 1 + continue + if stripped_line.lower() in ("returns:", "raises:", "example:", "examples:"): + break + empty_line_index = None # Reset if we find more content + if stripped_line.endswith(":") and not line.startswith(desc_indent.strip()): + insertion_index += 1 + else: + insertion_index += 1 + + # If we found an empty line at the end of the Args section, insert before it + if empty_line_index is not None: + insertion_index = empty_line_index + # Prepare the new parameter documentation lines + param_docs = [] + for param_name, param_info in new_params.items(): + param_doc_lines = self._format_param_docstring(param_name, param_info, param_indent, desc_indent) + param_docs.extend(param_doc_lines) + # Insert the new parameters into the docstring + docstring_lines[insertion_index:insertion_index] = param_docs + return docstring_lines + + +#### UTILS + + +def _check_parameters( + inference_client_module: cst.Module, + parameters_module: cst.Module, + method_name: str, + parameter_type_name: str, +) -> Dict[str, Dict[str, Any]]: + """ + Check for missing parameters and outdated types/docstrings. + + Args: + inference_client_module: Module containing the InferenceClient + parameters_module: Module containing the parameters dataclass + method_name: Name of the method to check + parameter_type_name: Name of the parameters dataclass + + Returns: + Dict mapping parameter names to their updates: + {param_name: { + "type": str, # Type annotation + "docstring": str, # Parameter documentation + "status": "new"|"update_type"|"update_doc" # Whether parameter is new or needs update + }} + """ + # Get parameters from the dataclass + params_collector = DataclassFieldCollector(parameter_type_name) + parameters_module.visit(params_collector) + dataclass_params = params_collector.parameters + # Get existing parameters from the method + method_collector = MethodArgumentsCollector(method_name) + inference_client_module.visit(method_collector) + existing_params = method_collector.parameters + + updates = {} + # Check for new and updated parameters + for param_name, param_info in dataclass_params.items(): + if param_name not in existing_params: + # New parameter + updates[param_name] = {**param_info, "status": "new"} + else: + # Check for type/docstring changes + current = existing_params[param_name] + normalized_current_doc = _normalize_docstring(current["docstring"]) + normalized_new_doc = _normalize_docstring(param_info["docstring"]) + if current["type"] != param_info["type"]: + updates[param_name] = {**param_info, "status": "update_type"} + if normalized_current_doc != normalized_new_doc: + updates[param_name] = {**param_info, "status": "update_doc"} + return updates + + +def _update_parameters( + module: cst.Module, + method_name: str, + param_updates: Dict[str, Dict[str, str]], +) -> cst.Module: + """ + Update method parameters, types and docstrings. + + Args: + module: The module to update + method_name: Name of the method to update + param_updates: Dictionary of parameter updates with their type and docstring + Format: {param_name: {"type": str, "docstring": str, "status": "new"|"update_type"|"update_doc"}} + + Returns: + Updated module + """ + transformer = UpdateParameters(method_name, param_updates) + return module.visit(transformer) + + +def _get_imports_to_add( + parameters: Dict[str, Dict[str, str]], + parameters_module: cst.Module, + inference_client_module: cst.Module, +) -> Dict[str, List[str]]: + """ + Get the needed imports for missing parameters. + + Args: + parameters (Dict[str, Dict[str, str]]): Dictionary of parameters with their type and docstring. + eg: {"function_to_apply": {"type": "ClassificationOutputTransform", "docstring": "Function to apply to the input."}} + parameters_module (cst.Module): The module where the parameters are defined. + inference_client_module (cst.Module): The module of the inference client. + + Returns: + Dict[str, List[str]]: A dictionary mapping modules to list of types to import. + eg: {"huggingface_hub.inference._generated.types": ["ClassificationOutputTransform"]} + """ + # Collect all type names from parameter annotations + types_to_import = set() + for param_info in parameters.values(): + types_to_import.update(_collect_type_hints_from_annotation(param_info["type"])) + # Gather existing imports in the inference client module + context = CodemodContext() + gather_visitor = GatherImportsVisitor(context) + inference_client_module.visit(gather_visitor) + # Map types to their defining modules in the parameters module + module_collector = ModulesCollector() + parameters_module.visit(module_collector) + # Determine which imports are needed + + needed_imports = {} + for type_name in types_to_import: + types_to_modules = module_collector.type_to_module + module = types_to_modules.get(type_name, DEFAULT_MODULE) + # Maybe no need to check that since the code formatter will handle duplicate imports? + if module not in gather_visitor.object_mapping or type_name not in gather_visitor.object_mapping[module]: + needed_imports.setdefault(module, []).append(type_name) + return needed_imports + + +def _generate_import_statements(import_dict: Dict[str, List[str]]) -> str: + """ + Generate import statements from a dictionary of needed imports. + + Args: + import_dict (Dict[str, List[str]]): Dictionary mapping modules to list of types to import. + eg: {"typing": ["List", "Dict"], "huggingface_hub.inference._generated.types": ["ClassificationOutputTransform"]} + + Returns: + str: The import statements as a string. + """ + import_statements = [] + for module, imports in import_dict.items(): + if imports: + import_list = ", ".join(imports) + import_statements.append(f"from {module} import {import_list}") + else: + import_statements.append(f"import {module}") + return "\n".join(import_statements) + + +def _normalize_docstring(docstring: str) -> str: + """Normalize a docstring by removing extra whitespace, newlines and indentation.""" + # Split into lines, strip whitespace from each line, and join back + return " ".join(line.strip() for line in docstring.split("\n")).strip() + + +# TODO: Needs to be improved, maybe using `typing.get_type_hints` instead (we gonna need to access the method though)? +def _collect_type_hints_from_annotation(annotation_str: str) -> Set[str]: + """ + Collect type hints from an annotation string. + + Args: + annotation_str (str): The annotation string. + + Returns: + Set[str]: A set of type hints. + """ + type_string = annotation_str.replace(" ", "") + builtin_types = {d for d in dir(builtins) if isinstance(getattr(builtins, d), type)} + types = re.findall(r"\w+|'[^']+'|\"[^\"]+\"", type_string) + extracted_types = {t.strip("\"'") for t in types if t.strip("\"'") not in builtin_types} + return extracted_types + + +def _get_parameter_type_name(method_name: str) -> Optional[str]: + file_path = INFERENCE_TYPES_PATH / f"{method_name}.py" + if not file_path.is_file(): + print(f"File not found: {file_path}") + return None + + content = file_path.read_text(encoding="utf-8") + match = PARAMETERS_DATACLASS_REGEX.search(content) + + return match.group(1) if match else None + + +def _parse_module_from_file(filepath: Path) -> Optional[cst.Module]: + try: + code = filepath.read_text(encoding="utf-8") + return cst.parse_module(code) + except FileNotFoundError: + print(f"File not found: {filepath}") + except cst.ParserSyntaxError as e: + print(f"Syntax error while parsing {filepath}: {e}") + return None + + +def _check_and_update_parameters( + method_params: Dict[str, str], + update: bool, +) -> NoReturn: + """ + Check if task methods have missing parameters and update the InferenceClient source code if needed. + """ + merged_imports = defaultdict(set) + logs = [] + inference_client_filename = INFERENCE_CLIENT_FILE + # Read and parse the inference client module + inference_client_module = _parse_module_from_file(inference_client_filename) + modified_module = inference_client_module + has_changes = False + + for method_name, parameter_type_name in method_params.items(): + parameters_filename = INFERENCE_TYPES_PATH / f"{method_name}.py" + parameters_module = _parse_module_from_file(parameters_filename) + + # Check for missing parameters + updates = _check_parameters( + modified_module, + parameters_module, + method_name, + parameter_type_name, + ) + + if not updates: + continue + + if update: + ## Get missing imports to add + needed_imports = _get_imports_to_add(updates, parameters_module, modified_module) + for module, imports_to_add in needed_imports.items(): + merged_imports[module].update(imports_to_add) + modified_module = _update_parameters(modified_module, method_name, updates) + has_changes = True + else: + logs.append(f"\nšŸ”§ Updates needed in method `{method_name}`:") + new_params = [p for p, i in updates.items() if i["status"] == "new"] + updated_params = { + p: "type" if i["status"] == "update_type" else "docstring" + for p, i in updates.items() + if i["status"] in ("update_type", "update_doc") + } + if new_params: + for param in sorted(new_params): + logs.append(f" ā€¢ {param} (missing)") + + if updated_params: + for param, update_type in sorted(updated_params.items()): + logs.append(f" ā€¢ {param} (outdated {update_type})") + + if has_changes: + if merged_imports: + import_statements = _generate_import_statements(merged_imports) + imports_to_add = cst.parse_module(import_statements).body + # Update inference client module with the missing imports + modified_module = modified_module.visit(AddImports(imports_to_add)) + # Format the updated source code + formatted_source_code = format_source_code(modified_module.code) + INFERENCE_CLIENT_FILE.write_text(formatted_source_code) + + if len(logs) > 0: + for log in logs: + print(log) + print( + "āŒ Mismatch between between parameters defined in tasks methods signature in " + "`./src/huggingface_hub/inference/_client.py` and parameters defined in " + "`./src/huggingface_hub/inference/_generated/types.py \n" + "Please run `make inference_update` or `python utils/generate_task_parameters.py --update" + ) + exit(1) + else: + if update: + print( + "āœ… InferenceClient source code has been updated in" + " `./src/huggingface_hub/inference/_client.py`.\n Please make sure the changes are" + " accurate and commit them." + ) + else: + print("āœ… All good!") + exit(0) + + +def update_inference_client(update: bool): + print(f"šŸ™ˆ Skipping the following tasks: {TASKS_TO_SKIP}") + # Get all tasks from the ./src/huggingface_hub/inference/_generated/types/ + tasks = set() + for file in INFERENCE_TYPES_PATH.glob("*.py"): + if file.stem not in TASKS_TO_SKIP: + tasks.add(file.stem) + + # Construct a mapping between method names and their parameters dataclass names + method_params = {} + for method_name, _ in inspect.getmembers(InferenceClient, predicate=inspect.isfunction): + if method_name.startswith("_") or method_name not in tasks: + continue + parameter_type_name = _get_parameter_type_name(method_name) + if parameter_type_name is not None: + method_params[method_name] = parameter_type_name + _check_and_update_parameters(method_params, update=update) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--update", + action="store_true", + help=("Whether to update `./src/huggingface_hub/inference/_client.py` if parameters are missing."), + ) + args = parser.parse_args() + update_inference_client(update=args.update) diff --git a/utils/generate_inference_types.py b/utils/generate_inference_types.py index 5f9675d60a..944cb16d71 100644 --- a/utils/generate_inference_types.py +++ b/utils/generate_inference_types.py @@ -222,13 +222,9 @@ def _list_dataclasses(content: str) -> List[str]: return INHERITED_DATACLASS_REGEX.findall(content) -def _list_shared_aliases(content: str) -> List[str]: - """List all shared class aliases defined in the module.""" - all_aliases = TYPE_ALIAS_REGEX.findall(content) - shared_class_pattern = r"(\w+(?:" + "|".join(re.escape(cls) for cls in SHARED_CLASSES) + r"))$" - shared_class_regex = re.compile(shared_class_pattern) - aliases = [alias_class for alias_class, _ in all_aliases if shared_class_regex.search(alias_class)] - return aliases +def _list_type_aliases(content: str) -> List[str]: + """List all type aliases defined in the module.""" + return [alias_class for alias_class, _ in TYPE_ALIAS_REGEX.findall(content)] def fix_inference_classes(content: str, module_name: str) -> str: @@ -293,7 +289,7 @@ def check_inference_types(update: bool) -> NoReturn: fixed_content = fix_inference_classes(content, module_name=file.stem) formatted_content = format_source_code(fixed_content) dataclasses[file.stem] = _list_dataclasses(formatted_content) - aliases[file.stem] = _list_shared_aliases(formatted_content) + aliases[file.stem] = _list_type_aliases(formatted_content) check_and_update_file_content(file, formatted_content, update) all_classes = {module: dataclasses[module] + aliases[module] for module in dataclasses.keys()} diff --git a/utils/generate_task_parameters.py b/utils/generate_task_parameters.py deleted file mode 100644 index fa80c2f7a0..0000000000 --- a/utils/generate_task_parameters.py +++ /dev/null @@ -1,548 +0,0 @@ -# coding=utf-8 -# Copyright 2024-present, the HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Utility script to check and update the InferenceClient task methods arguments and docstrings -based on the tasks input parameters. - -What this script does: -- [x] detect missing parameters in method signature -- [x] add missing parameters to methods signature -- [ ] detect outdated parameters in method signature -- [ ] update outdated parameters in method signature - -- [x] detect missing parameters in method docstrings -- [x] add missing parameters to methods docstrings -- [ ] detect outdated parameters in method docstrings -- [ ] update outdated parameters in method docstrings - -- [ ] detect when parameter not used in method implementation -- [ ] update method implementation when parameter not used -Related resources: -- https://github.com/huggingface/huggingface_hub/issues/2063 -- https://github.com/huggingface/huggingface_hub/issues/2557 -- https://github.com/huggingface/huggingface_hub/pull/2561 -""" - -import argparse -import builtins -import inspect -import re -import textwrap -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, NoReturn, Optional, Set - -import libcst as cst -from helpers import format_source_code -from libcst.codemod import CodemodContext -from libcst.codemod.visitors import GatherImportsVisitor - -from huggingface_hub.inference._client import InferenceClient - - -# Paths to project files -BASE_DIR = Path(__file__).parents[1] / "src" / "huggingface_hub" -INFERENCE_TYPES_PATH = BASE_DIR / "inference" / "_generated" / "types" -INFERENCE_CLIENT_FILE = BASE_DIR / "inference" / "_client.py" - -DEFAULT_MODULE = "huggingface_hub.inference._generated.types" - - -# Temporary solution to skip tasks where there is no Parameters dataclass or the schema needs to be updated -TASKS_TO_SKIP = [ - "chat_completion", - "depth_estimation", - "audio_to_audio", - "feature_extraction", - "sentence_similarity", - "table_question_answering", - "automatic_speech_recognition", - "image_to_text", - "image_to_image", -] - -PARAMETERS_DATACLASS_REGEX = re.compile( - r""" - ^@dataclass - \nclass\s(\w+Parameters)\(BaseInferenceType\): - """, - re.VERBOSE | re.MULTILINE, -) - -#### NODE VISITORS - - -class DataclassFieldCollector(cst.CSTVisitor): - """A visitor that collects fields (parameters) from a dataclass.""" - - def __init__(self, dataclass_name: str): - self.dataclass_name = dataclass_name - self.parameters: Dict[str, Dict[str, str]] = {} - - def visit_ClassDef(self, node: cst.ClassDef) -> None: - """Visit class definitions to find the target dataclass.""" - - if node.name.value == self.dataclass_name: - body_statements = node.body.body - for index, field in enumerate(body_statements): - # Check if the statement is a simple statement (like a variable declaration) - if isinstance(field, cst.SimpleStatementLine): - for stmt in field.body: - # Check if it's an annotated assignment (typical for dataclass fields) - if isinstance(stmt, cst.AnnAssign) and isinstance(stmt.target, cst.Name): - param_name = stmt.target.value - param_type = cst.Module([]).code_for_node(stmt.annotation.annotation) - docstring = self._extract_docstring(body_statements, index) - self.parameters[param_name] = { - "type": param_type, - "docstring": docstring, - } - - @staticmethod - def _extract_docstring(body_statements: List[cst.CSTNode], field_index: int) -> str: - """Extract the docstring following a field definition.""" - if field_index + 1 < len(body_statements): - # Check if the next statement is a simple statement (like a string) - next_stmt = body_statements[field_index + 1] - if isinstance(next_stmt, cst.SimpleStatementLine): - for stmt in next_stmt.body: - # Check if the statement is a string expression (potential docstring) - if isinstance(stmt, cst.Expr) and isinstance(stmt.value, cst.SimpleString): - return stmt.value.evaluated_value.strip() - # No docstring found or there's no statement after the field - return "" - - -class ModulesCollector(cst.CSTVisitor): - """Visitor that maps type names to their defining modules.""" - - def __init__(self): - self.type_to_module = {} - - def visit_ClassDef(self, node: cst.ClassDef): - """Map class definitions to the current module.""" - self.type_to_module[node.name.value] = DEFAULT_MODULE - - def visit_ImportFrom(self, node: cst.ImportFrom): - """Map imported types to their modules.""" - if node.module: - module_name = node.module.value - for alias in node.names: - self.type_to_module[alias.name.value] = module_name - - -class ArgumentsCollector(cst.CSTVisitor): - """Collects existing argument names from a method.""" - - def __init__(self, method_name: str): - self.method_name = method_name - self.existing_args: Set[str] = set() - - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - if node.name.value == self.method_name: - self.existing_args.update( - param.name.value - for param in node.params.params + node.params.kwonly_params - if param.name.value != "self" - ) - - -#### TREE TRANSFORMERS - - -class AddParameters(cst.CSTTransformer): - """Updates a method by adding missing parameters and updating the docstring.""" - - def __init__(self, method_name: str, missing_params: Dict[str, Dict[str, str]]): - self.method_name = method_name - self.missing_params = missing_params - self.found_method = False - - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: - if original_node.name.value == self.method_name: - self.found_method = True - new_params = self._update_parameters(updated_node.params) - updated_body = self._update_docstring(updated_node.body) - return updated_node.with_changes(params=new_params, body=updated_body) - return updated_node - - def _update_parameters(self, params: cst.Parameters) -> cst.Parameters: - new_kwonly_params = list(params.kwonly_params) - existing_args = {param.name.value for param in params.params + params.kwonly_params} - - for param_name, param_info in self.missing_params.items(): - if param_name not in existing_args: - annotation = cst.Annotation(annotation=cst.parse_expression(param_info["type"])) - new_param = cst.Param( - name=cst.Name(param_name), - annotation=annotation, - default=cst.Name("None"), - ) - new_kwonly_params.append(new_param) - - return params.with_changes(kwonly_params=new_kwonly_params) - - def _update_docstring(self, body: cst.IndentedBlock) -> cst.IndentedBlock: - if not isinstance(body.body[0], cst.SimpleStatementLine) or not isinstance(body.body[0].body[0], cst.Expr): - return body - - docstring_expr = body.body[0].body[0] - if not isinstance(docstring_expr.value, cst.SimpleString): - return body - - docstring = docstring_expr.value.evaluated_value - updated_docstring = self._update_docstring_content(docstring) - new_docstring = cst.SimpleString(f'"""{updated_docstring}"""') - new_body = [body.body[0].with_changes(body=[docstring_expr.with_changes(value=new_docstring)])] + list( - body.body[1:] - ) - return body.with_changes(body=new_body) - - def _update_docstring_content(self, docstring: str) -> str: - docstring_lines = docstring.split("\n") - - # Step 1: find the right insertion index - args_index = next((i for i, line in enumerate(docstring_lines) if line.strip().lower() == "args:"), None) - # If there is no "Args:" section, insert it after the first section that is not empty and not a sub-section - if args_index is None: - insertion_index = next( - ( - i - for i, line in enumerate(docstring_lines) - if line.strip().lower() in ("returns:", "raises:", "examples:", "example:") - ), - len(docstring_lines), - ) - docstring_lines.insert(insertion_index, "Args:") - args_index = insertion_index - insertion_index += 1 - else: - # Find the next section (in this order: Returns, Raises, Example(s)) - next_section_index = next( - ( - i - for i, line in enumerate(docstring_lines) - if line.strip().lower() in ("returns:", "raises:", "example:", "examples:") - ), - None, - ) - if next_section_index is not None: - # If there's a blank line before "Returns:", insert before that blank line - if next_section_index > 0 and docstring_lines[next_section_index - 1].strip() == "": - insertion_index = next_section_index - 1 - else: - # If there's no blank line, insert at the "Returns:" line and add a blank line after insertion - insertion_index = next_section_index - docstring_lines.insert(insertion_index, "") - else: - # If there's no next section, insert at the end - insertion_index = len(docstring_lines) - - # Step 2: format the parameter docstring - # Calculate the base indentation - base_indentation = docstring_lines[args_index][ - : len(docstring_lines[args_index]) - len(docstring_lines[args_index].lstrip()) - ] - param_indentation = base_indentation + " " # Indent parameters under "Args:" - description_indentation = param_indentation + " " # Indent descriptions under parameter names - - param_docs = [] - for param_name, param_info in self.missing_params.items(): - param_type_str = param_info["type"].replace("Optional[", "").rstrip("]") - optional_str = "*optional*" if "Optional[" in param_info["type"] else "" - param_docstring = (param_info.get("docstring") or "").strip() - - # Clean up the docstring to remove extra spaces - param_docstring = " ".join(param_docstring.split()) - - # Prepare the parameter line - param_line = f"{param_indentation}{param_name} (`{param_type_str}`, {optional_str}):" - - # Wrap the parameter docstring - wrapped_description = textwrap.fill( - param_docstring, - width=119, - initial_indent=description_indentation, - subsequent_indent=description_indentation, - ) - - # Combine parameter line and description - if param_docstring: - param_doc = f"{param_line}\n{wrapped_description}" - else: - param_doc = param_line - - param_docs.append(param_doc) - - # Step 3: insert the new parameter docs into the docstring - docstring_lines[insertion_index:insertion_index] = param_docs - return "\n".join(docstring_lines) - - -class AddImports(cst.CSTTransformer): - """Transformer that adds import statements to the module.""" - - def __init__(self, imports_to_add: List[cst.BaseStatement]): - self.imports_to_add = imports_to_add - self.added = False - - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: - """Insert the import statements into the module.""" - # If imports were already added, don't add them again - if self.added: - return updated_node - insertion_index = 0 - # Find the index where to insert the imports: make sure the imports are inserted before any code and after all imports (not necessary, we can remove/simplify this part) - for idx, stmt in enumerate(updated_node.body): - if not isinstance(stmt, cst.SimpleStatementLine): - insertion_index = idx - break - elif not isinstance(stmt.body[0], (cst.Import, cst.ImportFrom)): - insertion_index = idx - break - # Insert the imports - new_body = ( - list(updated_node.body[:insertion_index]) - + list(self.imports_to_add) - + list(updated_node.body[insertion_index:]) - ) - self.added = True - return updated_node.with_changes(body=new_body) - - -#### UTILS - - -def check_missing_parameters( - inference_client_module: cst.Module, - parameters_module: cst.Module, - method_name: str, - parameter_type_name: str, -) -> Dict[str, Dict[str, str]]: - # Get parameters from the parameters module - params_collector = DataclassFieldCollector(parameter_type_name) - parameters_module.visit(params_collector) - parameters = params_collector.parameters - - # Get existing arguments from the method - method_argument_collector = ArgumentsCollector(method_name) - inference_client_module.visit(method_argument_collector) - existing_args = method_argument_collector.existing_args - missing_params = {k: v for k, v in parameters.items() if k not in existing_args} - return missing_params - - -def get_imports_to_add( - parameters: Dict[str, Dict[str, str]], - parameters_module: cst.Module, - inference_client_module: cst.Module, -) -> Dict[str, List[str]]: - """ - Get the needed imports for missing parameters. - - Args: - parameters (Dict[str, Dict[str, str]]): Dictionary of parameters with their type and docstring. - eg: {"function_to_apply": {"type": "ClassificationOutputTransform", "docstring": "Function to apply to the input."}} - parameters_module (cst.Module): The module where the parameters are defined. - inference_client_module (cst.Module): The module of the inference client. - - Returns: - Dict[str, List[str]]: A dictionary mapping modules to list of types to import. - eg: {"huggingface_hub.inference._generated.types": ["ClassificationOutputTransform"]} - """ - # Collect all type names from parameter annotations - types_to_import = set() - for param_info in parameters.values(): - types_to_import.update(_collect_type_hints_from_annotation(param_info["type"])) - - # Gather existing imports in the inference client module - context = CodemodContext() - gather_visitor = GatherImportsVisitor(context) - inference_client_module.visit(gather_visitor) - - # Map types to their defining modules in the parameters module - module_collector = ModulesCollector() - parameters_module.visit(module_collector) - - # Determine which imports are needed - needed_imports = {} - for type_name in types_to_import: - types_to_modules = module_collector.type_to_module - module = types_to_modules.get(type_name, DEFAULT_MODULE) - # Maybe no need to check that since the code formatter will handle duplicate imports? - if module not in gather_visitor.object_mapping or type_name not in gather_visitor.object_mapping[module]: - needed_imports.setdefault(module, []).append(type_name) - return needed_imports - - -def _generate_import_statements(import_dict: Dict[str, List[str]]) -> str: - """ - Generate import statements from a dictionary of needed imports. - - Args: - import_dict (Dict[str, List[str]]): Dictionary mapping modules to list of types to import. - eg: {"typing": ["List", "Dict"], "huggingface_hub.inference._generated.types": ["ClassificationOutputTransform"]} - - Returns: - str: The import statements as a string. - """ - import_statements = [] - for module, imports in import_dict.items(): - if imports: - import_list = ", ".join(imports) - import_statements.append(f"from {module} import {import_list}") - else: - import_statements.append(f"import {module}") - return "\n".join(import_statements) - - -# TODO: Needs to be improved, maybe using `typing.get_type_hints` instead (we gonna need to access the method though)? -def _collect_type_hints_from_annotation(annotation_str: str) -> Set[str]: - """ - Collect type hints from an annotation string. - - Args: - annotation_str (str): The annotation string. - - Returns: - Set[str]: A set of type hints. - """ - type_string = annotation_str.replace(" ", "") - builtin_types = {d for d in dir(builtins) if isinstance(getattr(builtins, d), type)} - types = re.findall(r"\w+|'[^']+'|\"[^\"]+\"", type_string) - extracted_types = {t.strip("\"'") for t in types if t.strip("\"'") not in builtin_types} - return extracted_types - - -def _get_parameter_type_name(method_name: str) -> Optional[str]: - file_path = INFERENCE_TYPES_PATH / f"{method_name}.py" - if not file_path.is_file(): - print(f"File not found: {file_path}") - return None - - content = file_path.read_text(encoding="utf-8") - match = PARAMETERS_DATACLASS_REGEX.search(content) - - return match.group(1) if match else None - - -def _parse_module_from_file(filepath: Path) -> Optional[cst.Module]: - try: - code = filepath.read_text(encoding="utf-8") - return cst.parse_module(code) - except FileNotFoundError: - print(f"File not found: {filepath}") - except cst.ParserSyntaxError as e: - print(f"Syntax error while parsing {filepath}: {e}") - return None - - -def _check_parameters(method_params: Dict[str, str], update: bool) -> NoReturn: - """ - Check if task methods have missing parameters and update the InferenceClient source code if needed. - - Args: - method_params (Dict[str, str]): Dictionary mapping method names to their parameters dataclass names. - update (bool): Whether to update the InferenceClient source code if missing parameters are found. - """ - merged_imports = defaultdict(set) - logs = [] - inference_client_filename = INFERENCE_CLIENT_FILE - # Read and parse the inference client module - inference_client_module = _parse_module_from_file(inference_client_filename) - modified_module = inference_client_module - has_changes = False - for method_name, parameter_type_name in method_params.items(): - parameters_filename = INFERENCE_TYPES_PATH / f"{method_name}.py" - - # Read and parse the parameters module - parameters_module = _parse_module_from_file(parameters_filename) - - # Check if the method has missing parameters - missing_params = check_missing_parameters(modified_module, parameters_module, method_name, parameter_type_name) - if not missing_params: - continue - if update: - ## Get missing imports to add - needed_imports = get_imports_to_add(missing_params, parameters_module, modified_module) - for module, imports_to_add in needed_imports.items(): - merged_imports[module].update(imports_to_add) - # Update method parameters and docstring - modified_module = modified_module.visit(AddParameters(method_name, missing_params)) - has_changes = True - else: - logs.append(f"āŒ Missing parameters found in `{method_name}`.") - - if has_changes: - if merged_imports: - import_statements = _generate_import_statements(merged_imports) - imports_to_add = cst.parse_module(import_statements).body - # Update inference client module with the missing imports - modified_module = modified_module.visit(AddImports(imports_to_add)) - # Format the updated source code - formatted_source_code = format_source_code(modified_module.code) - INFERENCE_CLIENT_FILE.write_text(formatted_source_code) - - if len(logs) > 0: - for log in logs: - print(log) - print( - "āŒ Mismatch between between parameters defined in tasks methods signature in " - "`./src/huggingface_hub/inference/_client.py` and parameters defined in " - "`./src/huggingface_hub/inference/_generated/types.py \n" - "Please run `make inference_update` or `python utils/generate_task_parameters.py --update" - ) - exit(1) - else: - if update: - print( - "āœ… InferenceClient source code has been updated in" - " `./src/huggingface_hub/inference/_client.py`.\n Please make sure the changes are" - " accurate and commit them." - ) - else: - print("āœ… All good!") - exit(0) - - -def update_inference_client(update: bool): - print(f"šŸ™ˆ Skipping the following tasks: {TASKS_TO_SKIP}") - # Get all tasks from the ./src/huggingface_hub/inference/_generated/types/ - tasks = set() - for file in INFERENCE_TYPES_PATH.glob("*.py"): - if file.stem not in TASKS_TO_SKIP: - tasks.add(file.stem) - - # Construct a mapping between method names and their parameters dataclass names - method_params = {} - for method_name, _ in inspect.getmembers(InferenceClient, predicate=inspect.isfunction): - if method_name.startswith("_") or method_name not in tasks: - continue - parameter_type_name = _get_parameter_type_name(method_name) - if parameter_type_name is not None: - method_params[method_name] = parameter_type_name - _check_parameters(method_params, update=update) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--update", - action="store_true", - help=("Whether to update `./src/huggingface_hub/inference/_client.py` if parameters are missing."), - ) - args = parser.parse_args() - update_inference_client(update=args.update)