Skip to content

Commit

Permalink
fix(ci): format with ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
Iamhexi committed Sep 21, 2024
1 parent 0f8e67a commit 5c29dec
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 32 deletions.
15 changes: 0 additions & 15 deletions knowledge_verificator/main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,4 @@
"""Main module with CLI definition."""
<<<<<<< HEAD
||||||| parent of 0627c06 (feat(qg): implement basic QG module)
from knowledge_verificator.nli import infer_relation

if __name__ == '__main__':
premise = input("Premise: ")
hypothesis = input("Hypothesis: ")
print(
infer_relation(
premise=premise,
hypothesis=hypothesis
).value
)
=======

if __name__ == '__main__':
print('Currently nothing happens here.')
>>>>>>> 0627c06 (feat(qg): implement basic QG module)
8 changes: 5 additions & 3 deletions knowledge_verificator/nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from enum import Enum
import warnings
import logging
from transformers import AutoTokenizer, AutoModelForSequenceClassification # type: ignore[import-untyped]
from transformers import AutoTokenizer, AutoModelForSequenceClassification # type: ignore
from transformers import ( # type: ignore[import-untyped]
AutoTokenizer,
AutoModelForSequenceClassification,
)
import torch


Expand Down Expand Up @@ -77,8 +79,8 @@ def infer(
torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0)
)

# remember bart doesn't have 'token_type_ids', remove the line below if you are using bart.
token_type_ids = None
# `bart` model does not have `token_type_ids`.
if self._hg_model_hub_name != self._available_models['bart']:
token_type_ids = (
torch.Tensor(tokenized_input_seq_pair['token_type_ids'])
Expand Down
18 changes: 9 additions & 9 deletions knowledge_verificator/qg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
from transformers import T5Tokenizer, T5ForConditionalGeneration # type: ignore[import-untyped]


class QuestionGeneration:
class QuestionGeneration: # pylint: disable=too-few-public-methods
"""Class for generating question based on supplied context."""

def __init__(self) -> None:
self.trained_model_path = (
"ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation"
'ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation'
)
self.trained_tokenizer_path = (
"ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation"
'ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation'
)

self.model = T5ForConditionalGeneration.from_pretrained(self.trained_model_path)
self.tokenizer = T5Tokenizer.from_pretrained(self.trained_tokenizer_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.model.to(self.device)
self.model.eval()

Expand All @@ -32,14 +32,14 @@ def generate(self, answer: str, context: str) -> dict[str, str]:
Returns:
dict[str, str]: Dictionary with a generated question, and a provided answer and context.
"""
input_text = f"<answer> {answer} <context> {context} "
encoding = self.tokenizer.encode_plus(input_text, return_tensors="pt")
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
input_text = f'<answer> {answer} <context> {context} '
encoding = self.tokenizer.encode_plus(input_text, return_tensors='pt')
input_ids = encoding['input_ids'].to(self.device)
attention_mask = encoding['attention_mask'].to(self.device)
outputs = self.model.generate(
input_ids=input_ids, attention_mask=attention_mask
)
question = self.tokenizer.decode(
outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return {"question": question, "answer": answer, "context": context}
return {'question': question, 'answer': answer, 'context': context}
10 changes: 5 additions & 5 deletions tests/test_qg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ def qg():


@pytest.mark.parametrize(
"question,answer,context",
(("Where is the red apple located?", "Tree", "The red apple is on a tree."),),
'question,answer,context',
(('Where is the red apple located?', 'Tree', 'The red apple is on a tree.'),),
)
def test_basic_question_generation(question: str, answer: str, context: str, qg):
"""Test if generating in very simple case works as expected."""
output = qg.generate(answer=answer, context=context)
expected = {
"question": question,
"answer": answer,
"context": context,
'question': question,
'answer': answer,
'context': context,
}

assert output == expected

0 comments on commit 5c29dec

Please sign in to comment.