Skip to content

Commit

Permalink
feat(answer_chooser): implement basic chooser with AnswerChooser
Browse files Browse the repository at this point in the history
  • Loading branch information
Iamhexi committed Sep 21, 2024
1 parent bf7f221 commit cfd37ac
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 28 deletions.
85 changes: 85 additions & 0 deletions knowledge_verificator/answer_chooser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Module with AnswerChooser, which finds a best candidate for an answer in a paragraph."""

import random
import nltk
from nltk.corpus import wordnet

# FIXME: Write docstrings.


class AnswerChooser:
def __init__(self) -> None:
nltk.download('wordnet')
nltk.download('stopwords')
nltk.download('punkt_tab')

def remove_stopwords(self, text: str) -> str:
"""
Remove stopwords from a string using NLTK.
"""
# Get the stopwords from the NLTK stopwords corpus
stopwords = set(nltk.corpus.stopwords.words('english'))

# Tokenize the text into words
words = nltk.word_tokenize(text)

# Remove the stopwords from the words
filtered_words = [
word for word in words if word.lower() not in stopwords
]

# Join the filtered words back into a string
cleaned_text = ' '.join(filtered_words)

# Return the cleaned text
return cleaned_text

def santize(self, word: str) -> str:
"""Convert to lowercase and remove any punctuation mark."""
word = word.strip()
word = word.lower()
to_remove = ('.', ',', '?', '!', '-', '_', '/', '(', ')', "'")
for punctuation_mark in to_remove:
word = word.replace(punctuation_mark, '')
return word

def find_part_of_speech(self, word: str) -> str:
"""
Determine the part of speech of a word using WordNet.
"""
# Look up the word in WordNet
word = self.santize(word=word)
synsets = wordnet.synsets(word)

# If the word is not found, return 'n/a'
if not synsets:
return 'n/a'

# Get the first synset and determine the part of speech
synset = synsets[0]
pos = synset.pos()

# Map WordNet POS tags to more common POS tags
if pos == 'a':
return 'adjective'
elif pos == 'n':
return 'noun'
elif pos == 'r':
return 'adverb'
elif pos == 'v':
return 'verb'
else:
return 'n/a'

def choose_answer(self, paragraph: str) -> str:
paragraph = self.remove_stopwords(paragraph)
words = paragraph.split(' ')
# FIXME: Refactor not to use three times the same function...
words = [
self.santize(word)
for word in words
if self.santize(word)
and self.find_part_of_speech(self.santize(word)) == 'n/a'
]

return random.choice(words)

Check warning on line 85 in knowledge_verificator/answer_chooser.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

knowledge_verificator/answer_chooser.py#L85

Standard pseudo-random generators are not suitable for security/cryptographic purposes.
20 changes: 12 additions & 8 deletions knowledge_verificator/main.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
"""Main module with CLI definition."""

import random
import logging

from logging import Logger

from knowledge_verificator.answer_chooser import AnswerChooser
from knowledge_verificator.nli import NaturalLanguageInference, Relation
from knowledge_verificator.qg import QuestionGeneration


if __name__ == '__main__':
# TODO: Take logger level from CLI or config file.
logger = Logger('main_logger')
# TODO: Take logger level from CLI or config file.
# Set logging to standard output (handle 2.) stream.
logging_handler = logging.StreamHandler()
logging_handler.setLevel(logging.DEBUG)
logger.addHandler(logging_handler)

chooser = AnswerChooser()
qg_module = QuestionGeneration()

paragraph = input('Enter a paragraph you would like to learn: ')
logger.debug('Loaded the following paragraph:\n %s', paragraph)

qg_module = QuestionGeneration()

# Answer is a randomly choosen word.
words = paragraph.split(' ')
chosen_answer = random.choice(words)
chosen_answer = chooser.choose_answer(paragraph=paragraph)
# words = paragraph.split(' ')
# chosen_answer = random.choice(words)
logger.debug(
'The `%s` has been chosen as the answer, based on which the question will be generated.',
chosen_answer,
Expand All @@ -40,7 +42,9 @@
'Question Generation module has supplied the question: %s', question
)

user_answer = input(f'Answer the question. {question}')
user_answer = input(
f'Answer the question with full sentence. {question} \n Your answer.: '
)

nli_module = NaturalLanguageInference()
relation = nli_module.infer_relation(
Expand All @@ -53,6 +57,6 @@
case Relation.CONTRADICTION:
feedback = f'wrong. Correct answer is {chosen_answer}'
case Relation.NEUTRAL:
feedback = 'not directly answer the posed question.'
feedback = 'not directly associated with the posed question.'

print(f'Your answer is {feedback}')
157 changes: 137 additions & 20 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ transformers = "^4.44.2"
torch = "^2.4.1"
pylint-pytest = "^1.1.8"
sentencepiece = "^0.2.0"
nltk = "^3.9.1"
rich = "^13.8.1"

[tool.poetry.group.test]

Expand Down

0 comments on commit cfd37ac

Please sign in to comment.