From 5c29dec4ddf3902fca68c8deb06913197f094c35 Mon Sep 17 00:00:00 2001 From: Iamhexi Date: Sat, 21 Sep 2024 16:34:41 +0000 Subject: [PATCH] fix(ci): format with ruff --- knowledge_verificator/main.py | 15 --------------- knowledge_verificator/nli.py | 8 +++++--- knowledge_verificator/qg.py | 18 +++++++++--------- tests/test_qg.py | 10 +++++----- 4 files changed, 19 insertions(+), 32 deletions(-) diff --git a/knowledge_verificator/main.py b/knowledge_verificator/main.py index 6ebd989..88a0507 100755 --- a/knowledge_verificator/main.py +++ b/knowledge_verificator/main.py @@ -1,19 +1,4 @@ """Main module with CLI definition.""" -<<<<<<< HEAD -||||||| parent of 0627c06 (feat(qg): implement basic QG module) -from knowledge_verificator.nli import infer_relation - -if __name__ == '__main__': - premise = input("Premise: ") - hypothesis = input("Hypothesis: ") - print( - infer_relation( - premise=premise, - hypothesis=hypothesis - ).value - ) -======= if __name__ == '__main__': print('Currently nothing happens here.') ->>>>>>> 0627c06 (feat(qg): implement basic QG module) diff --git a/knowledge_verificator/nli.py b/knowledge_verificator/nli.py index a47ebb1..713dbde 100644 --- a/knowledge_verificator/nli.py +++ b/knowledge_verificator/nli.py @@ -3,8 +3,10 @@ from enum import Enum import warnings import logging -from transformers import AutoTokenizer, AutoModelForSequenceClassification # type: ignore[import-untyped] -from transformers import AutoTokenizer, AutoModelForSequenceClassification # type: ignore +from transformers import ( # type: ignore[import-untyped] + AutoTokenizer, + AutoModelForSequenceClassification, +) import torch @@ -77,8 +79,8 @@ def infer( torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0) ) - # remember bart doesn't have 'token_type_ids', remove the line below if you are using bart. token_type_ids = None + # `bart` model does not have `token_type_ids`. if self._hg_model_hub_name != self._available_models['bart']: token_type_ids = ( torch.Tensor(tokenized_input_seq_pair['token_type_ids']) diff --git a/knowledge_verificator/qg.py b/knowledge_verificator/qg.py index aed6ee4..5fce4ca 100644 --- a/knowledge_verificator/qg.py +++ b/knowledge_verificator/qg.py @@ -4,20 +4,20 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration # type: ignore[import-untyped] -class QuestionGeneration: +class QuestionGeneration: # pylint: disable=too-few-public-methods """Class for generating question based on supplied context.""" def __init__(self) -> None: self.trained_model_path = ( - "ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation" + 'ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation' ) self.trained_tokenizer_path = ( - "ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation" + 'ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation' ) self.model = T5ForConditionalGeneration.from_pretrained(self.trained_model_path) self.tokenizer = T5Tokenizer.from_pretrained(self.trained_tokenizer_path) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = self.model.to(self.device) self.model.eval() @@ -32,14 +32,14 @@ def generate(self, answer: str, context: str) -> dict[str, str]: Returns: dict[str, str]: Dictionary with a generated question, and a provided answer and context. """ - input_text = f" {answer} {context} " - encoding = self.tokenizer.encode_plus(input_text, return_tensors="pt") - input_ids = encoding["input_ids"].to(self.device) - attention_mask = encoding["attention_mask"].to(self.device) + input_text = f' {answer} {context} ' + encoding = self.tokenizer.encode_plus(input_text, return_tensors='pt') + input_ids = encoding['input_ids'].to(self.device) + attention_mask = encoding['attention_mask'].to(self.device) outputs = self.model.generate( input_ids=input_ids, attention_mask=attention_mask ) question = self.tokenizer.decode( outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True ) - return {"question": question, "answer": answer, "context": context} + return {'question': question, 'answer': answer, 'context': context} diff --git a/tests/test_qg.py b/tests/test_qg.py index fe5c55b..acff7a3 100644 --- a/tests/test_qg.py +++ b/tests/test_qg.py @@ -15,16 +15,16 @@ def qg(): @pytest.mark.parametrize( - "question,answer,context", - (("Where is the red apple located?", "Tree", "The red apple is on a tree."),), + 'question,answer,context', + (('Where is the red apple located?', 'Tree', 'The red apple is on a tree.'),), ) def test_basic_question_generation(question: str, answer: str, context: str, qg): """Test if generating in very simple case works as expected.""" output = qg.generate(answer=answer, context=context) expected = { - "question": question, - "answer": answer, - "context": context, + 'question': question, + 'answer': answer, + 'context': context, } assert output == expected