diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..930df99 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: 'https://github.com/pre-commit/pre-commit-hooks' + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + - repo: 'https://github.com/astral-sh/ruff-pre-commit' + rev: v0.6.6 + hooks: + - id: ruff + args: + - '--fix' + - id: ruff-format + - repo: 'https://github.com/pre-commit/mirrors-mypy' + rev: v1.11.2 + hooks: + - id: mypy diff --git a/knowledge_verificator/nli.py b/knowledge_verificator/nli.py index 28c541f..95abcb1 100644 --- a/knowledge_verificator/nli.py +++ b/knowledge_verificator/nli.py @@ -3,15 +3,16 @@ from enum import Enum import warnings import logging -from transformers import AutoTokenizer, AutoModelForSequenceClassification # type: ignore +from transformers import AutoTokenizer, AutoModelForSequenceClassification # type: ignore[import-untyped] import torch class Relation(Enum): """Possible relations between premise and hypothesis.""" - ENTAILMENT = 'entailment' - NEUTRAL = 'neutral' - CONTRADICTION = 'contradiction' + + ENTAILMENT = "entailment" + NEUTRAL = "neutral" + CONTRADICTION = "contradiction" def infer( @@ -48,7 +49,9 @@ def infer( # hg_model_hub_name = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli" logging.getLogger("transformers").setLevel(logging.ERROR) - warnings.filterwarnings('ignore', message="`clean_up_tokenization_spaces` was not set.") + warnings.filterwarnings( + "ignore", message="`clean_up_tokenization_spaces` was not set." + ) tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name) model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name) @@ -61,29 +64,36 @@ def infer( truncation=True, ) - input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0) + input_ids = torch.Tensor(tokenized_input_seq_pair["input_ids"]).long().unsqueeze(0) # remember bart doesn't have 'token_type_ids', remove the line below if you are using bart. - token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0) - attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0) + token_type_ids = ( + torch.Tensor(tokenized_input_seq_pair["token_type_ids"]).long().unsqueeze(0) + ) + attention_mask = ( + torch.Tensor(tokenized_input_seq_pair["attention_mask"]).long().unsqueeze(0) + ) outputs = model( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, - labels=None + labels=None, ) - predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one + predicted_probability = torch.softmax(outputs[0], dim=1)[ + 0 + ].tolist() # batch_size only one entailment = round(predicted_probability[0], precision) neutral = round(predicted_probability[1], precision) return { - Relation.ENTAILMENT : entailment, - Relation.NEUTRAL : neutral, - Relation.CONTRADICTION : round(1. - entailment - neutral, precision) + Relation.ENTAILMENT: entailment, + Relation.NEUTRAL: neutral, + Relation.CONTRADICTION: round(1.0 - entailment - neutral, precision), } + def infer_relation( premise: str, hypothesis: str, @@ -91,9 +101,9 @@ def infer_relation( """Infer the most probable type of relationship between `premise` and `hypothesis`.""" inference = infer(premise=premise, hypothesis=hypothesis, precision=10) - max_probability = 0. + max_probability = 0.0 most_probable = Relation.CONTRADICTION - for (relation, probability) in inference.items(): + for relation, probability in inference.items(): if probability > max_probability: max_probability = probability most_probable = relation diff --git a/tests/test_qg.py b/tests/test_qg.py index aef44f4..fe5c55b 100644 --- a/tests/test_qg.py +++ b/tests/test_qg.py @@ -2,9 +2,10 @@ import pytest -from transformers import set_seed +from transformers import set_seed # type: ignore[import-untyped] from knowledge_verificator.qg import QuestionGeneration + @pytest.fixture def qg(): """Provide non-deterministically initialized instance of the `QuestionGeneration` class.""" @@ -13,19 +14,17 @@ def qg(): return question_generation -@pytest.mark.parametrize('question,answer,context', ( - ('Where is the red apple located?', 'Tree', 'The red apple is on a tree.'), -)) +@pytest.mark.parametrize( + "question,answer,context", + (("Where is the red apple located?", "Tree", "The red apple is on a tree."),), +) def test_basic_question_generation(question: str, answer: str, context: str, qg): """Test if generating in very simple case works as expected.""" - output = qg.generate( - answer=answer, - context=context - ) + output = qg.generate(answer=answer, context=context) expected = { - 'question': question, - 'answer': answer, - 'context': context, + "question": question, + "answer": answer, + "context": context, } assert output == expected