Skip to content

Commit

Permalink
fix(ci): ignore untyped import from transformers package
Browse files Browse the repository at this point in the history
  • Loading branch information
Iamhexi committed Sep 21, 2024
1 parent 646b1b8 commit 1621604
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 26 deletions.
21 changes: 21 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: 'https://github.com/pre-commit/pre-commit-hooks'
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: 'https://github.com/astral-sh/ruff-pre-commit'
rev: v0.6.6
hooks:
- id: ruff
args:
- '--fix'
- id: ruff-format
- repo: 'https://github.com/pre-commit/mirrors-mypy'
rev: v1.11.2
hooks:
- id: mypy
40 changes: 25 additions & 15 deletions knowledge_verificator/nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from enum import Enum
import warnings
import logging
from transformers import AutoTokenizer, AutoModelForSequenceClassification # type: ignore
from transformers import AutoTokenizer, AutoModelForSequenceClassification # type: ignore[import-untyped]
import torch


class Relation(Enum):
"""Possible relations between premise and hypothesis."""
ENTAILMENT = 'entailment'
NEUTRAL = 'neutral'
CONTRADICTION = 'contradiction'

ENTAILMENT = "entailment"
NEUTRAL = "neutral"
CONTRADICTION = "contradiction"


def infer(
Expand Down Expand Up @@ -48,7 +49,9 @@ def infer(
# hg_model_hub_name = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli"

logging.getLogger("transformers").setLevel(logging.ERROR)
warnings.filterwarnings('ignore', message="`clean_up_tokenization_spaces` was not set.")
warnings.filterwarnings(
"ignore", message="`clean_up_tokenization_spaces` was not set."
)

tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name)
model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name)
Expand All @@ -61,39 +64,46 @@ def infer(
truncation=True,
)

input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0)
input_ids = torch.Tensor(tokenized_input_seq_pair["input_ids"]).long().unsqueeze(0)
# remember bart doesn't have 'token_type_ids', remove the line below if you are using bart.
token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0)
attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0)
token_type_ids = (
torch.Tensor(tokenized_input_seq_pair["token_type_ids"]).long().unsqueeze(0)
)
attention_mask = (
torch.Tensor(tokenized_input_seq_pair["attention_mask"]).long().unsqueeze(0)
)

outputs = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
labels=None
labels=None,
)

predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one
predicted_probability = torch.softmax(outputs[0], dim=1)[
0
].tolist() # batch_size only one

entailment = round(predicted_probability[0], precision)
neutral = round(predicted_probability[1], precision)

return {
Relation.ENTAILMENT : entailment,
Relation.NEUTRAL : neutral,
Relation.CONTRADICTION : round(1. - entailment - neutral, precision)
Relation.ENTAILMENT: entailment,
Relation.NEUTRAL: neutral,
Relation.CONTRADICTION: round(1.0 - entailment - neutral, precision),
}


def infer_relation(
premise: str,
hypothesis: str,
) -> Relation:
"""Infer the most probable type of relationship between `premise` and `hypothesis`."""
inference = infer(premise=premise, hypothesis=hypothesis, precision=10)

max_probability = 0.
max_probability = 0.0
most_probable = Relation.CONTRADICTION
for (relation, probability) in inference.items():
for relation, probability in inference.items():
if probability > max_probability:
max_probability = probability
most_probable = relation
Expand Down
21 changes: 10 additions & 11 deletions tests/test_qg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pytest

from transformers import set_seed
from transformers import set_seed # type: ignore[import-untyped]
from knowledge_verificator.qg import QuestionGeneration


@pytest.fixture
def qg():
"""Provide non-deterministically initialized instance of the `QuestionGeneration` class."""
Expand All @@ -13,19 +14,17 @@ def qg():
return question_generation


@pytest.mark.parametrize('question,answer,context', (
('Where is the red apple located?', 'Tree', 'The red apple is on a tree.'),
))
@pytest.mark.parametrize(
"question,answer,context",
(("Where is the red apple located?", "Tree", "The red apple is on a tree."),),
)
def test_basic_question_generation(question: str, answer: str, context: str, qg):
"""Test if generating in very simple case works as expected."""
output = qg.generate(
answer=answer,
context=context
)
output = qg.generate(answer=answer, context=context)
expected = {
'question': question,
'answer': answer,
'context': context,
"question": question,
"answer": answer,
"context": context,
}

assert output == expected

0 comments on commit 1621604

Please sign in to comment.