Skip to content

Commit

Permalink
Change structure of corpus module
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Oct 6, 2018
1 parent 486b25f commit f4019b4
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 128 deletions.
5 changes: 0 additions & 5 deletions chatterbot_corpus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
"""
A machine readable multilingual dialog corpus.
"""
from .corpus import Corpus

__version__ = '1.1.4'
__author__ = 'Gunther Cox'
__email__ = 'gunthercx@gmail.com'
__url__ = 'https://github.com/gunthercox/chatterbot-corpus'

__all__ = (
'Corpus',
)
107 changes: 52 additions & 55 deletions chatterbot_corpus/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@


DIALOG_MAXIMUM_CHARACTER_LENGTH = 400
CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
DATA_DIRECTORY = os.path.join(CURRENT_DIRECTORY, 'data')


class CorpusObject(list):
Expand All @@ -26,76 +28,71 @@ def __init__(self, *args, **kwargs):
self.categories = []


class Corpus(object):
def get_file_path(dotted_path, extension='json'):
"""
Reads a dotted file path and returns the file path.
"""
# If the operating system's file path seperator character is in the string
if os.sep in dotted_path or '/' in dotted_path:
# Assume the path is a valid file path
return dotted_path

def __init__(self):
current_directory = os.path.dirname(os.path.abspath(__file__))
self.data_directory = os.path.join(current_directory, 'data')
parts = dotted_path.split('.')
if parts[0] == 'chatterbot':
parts.pop(0)
parts[0] = DATA_DIRECTORY

def get_file_path(self, dotted_path, extension='json'):
"""
Reads a dotted file path and returns the file path.
"""
corpus_path = os.path.join(*parts)

# If the operating system's file path seperator character is in the string
if os.sep in dotted_path or '/' in dotted_path:
# Assume the path is a valid file path
return dotted_path
if os.path.exists(corpus_path + '.{}'.format(extension)):
corpus_path += '.{}'.format(extension)

parts = dotted_path.split('.')
if parts[0] == 'chatterbot':
parts.pop(0)
parts[0] = self.data_directory
return corpus_path

corpus_path = os.path.join(*parts)

if os.path.exists(corpus_path + '.{}'.format(extension)):
corpus_path += '.{}'.format(extension)
def read_corpus(file_name):
"""
Read and return the data from a corpus json file.
"""
with io.open(file_name, encoding='utf-8') as data_file:
return yaml.load(data_file)

return corpus_path

def read_corpus(self, file_name):
"""
Read and return the data from a corpus json file.
"""
with io.open(file_name, encoding='utf-8') as data_file:
data = yaml.load(data_file)
return data
def list_corpus_files(dotted_path):
"""
Return a list of file paths to each data file in
the specified corpus.
"""
CORPUS_EXTENSION = 'yml'

def list_corpus_files(self, dotted_path):
"""
Return a list of file paths to each data file in
the specified corpus.
"""
CORPUS_EXTENSION = 'yml'
corpus_path = get_file_path(dotted_path, extension=CORPUS_EXTENSION)
paths = []

corpus_path = self.get_file_path(dotted_path, extension=CORPUS_EXTENSION)
paths = []
if os.path.isdir(corpus_path):
paths = glob.glob(corpus_path + '/**/*.' + CORPUS_EXTENSION, recursive=True)
else:
paths.append(corpus_path)

if os.path.isdir(corpus_path):
paths = glob.glob(corpus_path + '/**/*.' + CORPUS_EXTENSION, recursive=True)
else:
paths.append(corpus_path)
paths.sort()
return paths

paths.sort()
return paths

def load_corpus(self, dotted_path):
"""
Return the data contained within a specified corpus.
"""
data_file_paths = self.list_corpus_files(dotted_path)
def load_corpus(dotted_path):
"""
Return the data contained within a specified corpus.
"""
data_file_paths = list_corpus_files(dotted_path)

corpora = []
corpora = []

for file_path in data_file_paths:
corpus = CorpusObject()
corpus_data = self.read_corpus(file_path)
for file_path in data_file_paths:
corpus = CorpusObject()
corpus_data = read_corpus(file_path)

conversations = corpus_data.get('conversations', [])
corpus.categories = corpus_data.get('categories', [])
corpus.extend(conversations)
conversations = corpus_data.get('conversations', [])
corpus.categories = corpus_data.get('categories', [])
corpus.extend(conversations)

corpora.append(corpus)
corpora.append(corpus)

return corpora
return corpora
113 changes: 52 additions & 61 deletions tests/test_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,137 +4,128 @@
import os
import io
from unittest import TestCase
from chatterbot_corpus.corpus import Corpus
from chatterbot_corpus import corpus


class CorpusUtilsTestCase(TestCase):

def setUp(self):
self.corpus = Corpus()

def test_get_file_path(self):
"""
Test that a dotted path is properly converted to a file address.
"""
path = self.corpus.get_file_path('chatterbot.corpus.english')
path = corpus.get_file_path('chatterbot.corpus.english')
self.assertIn(
os.path.join('chatterbot_corpus', 'data', 'english'),
path
)

def test_read_english_corpus(self):
corpus_path = os.path.join(
self.corpus.data_directory,
corpus.DATA_DIRECTORY,
'english', 'conversations.yml'
)
data = self.corpus.read_corpus(corpus_path)
data = corpus.read_corpus(corpus_path)
self.assertIn('conversations', data)

def test_list_english_corpus_files(self):
data_files = self.corpus.list_corpus_files('chatterbot.corpus.english')
data_files = corpus.list_corpus_files('chatterbot.corpus.english')

self.assertIn('.yml', data_files[0])

def test_load_english_corpus(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.english.greetings')
corpus_data = corpus.load_corpus('chatterbot.corpus.english.greetings')

self.assertEqual(len(corpus), 1)
self.assertIn(['Hi', 'Hello'], corpus[0])
self.assertEqual(len(corpus_data), 1)
self.assertIn(['Hi', 'Hello'], corpus_data[0])

def test_load_english_corpus_categories(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.english.greetings')
corpus_data = corpus.load_corpus('chatterbot.corpus.english.greetings')

self.assertEqual(len(corpus), 1)
self.assertEqual(len(corpus_data), 1)

# Test that each conversation gets labeled with the correct category
for conversation in corpus:
for conversation in corpus_data:
self.assertIn('greetings', conversation.categories)


class CorpusLoadingTestCase(TestCase):

def setUp(self):
self.corpus = Corpus()

def test_load_corpus_chinese(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.chinese')
corpus_data = corpus.load_corpus('chatterbot.corpus.chinese')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_traditional_chinese(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.tchinese')
corpus_data = corpus.load_corpus('chatterbot.corpus.tchinese')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_english(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.english')
corpus_data = corpus.load_corpus('chatterbot.corpus.english')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_french(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.french')
corpus_data = corpus.load_corpus('chatterbot.corpus.french')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_german(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.german')
corpus_data = corpus.load_corpus('chatterbot.corpus.german')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_hindi(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.hindi')
corpus_data = corpus.load_corpus('chatterbot.corpus.hindi')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_indonesia(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.indonesia')
corpus_data = corpus.load_corpus('chatterbot.corpus.indonesia')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_italian(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.italian')
corpus_data = corpus.load_corpus('chatterbot.corpus.italian')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_marathi(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.marathi')
corpus_data = corpus.load_corpus('chatterbot.corpus.marathi')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_portuguese(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.portuguese')
corpus_data = corpus.load_corpus('chatterbot.corpus.portuguese')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_russian(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.russian')
corpus_data = corpus.load_corpus('chatterbot.corpus.russian')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_spanish(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.spanish')
corpus_data = corpus.load_corpus('chatterbot.corpus.spanish')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus_telugu(self):
corpus = self.corpus.load_corpus('chatterbot.corpus.telugu')
corpus_data = corpus.load_corpus('chatterbot.corpus.telugu')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))

def test_load_corpus(self):
"""
Test loading the entire corpus of languages.
"""
corpus = self.corpus.load_corpus('chatterbot.corpus')
corpus_data = corpus.load_corpus('chatterbot.corpus')

self.assertTrue(len(corpus))
self.assertTrue(len(corpus_data))


class CorpusFilePathTestCase(TestCase):

def setUp(self):
self.corpus = Corpus()

def test_load_corpus_file(self):
"""
Test that a file path can be specified for a corpus.
Expand All @@ -149,14 +140,14 @@ def test_load_corpus_file(self):
test_corpus.write(yml_data)

# Load the content from the corpus
corpus = self.corpus.load_corpus(file_path)
corpus_data = corpus.load_corpus(file_path)

# Remove the test file
if os.path.exists(file_path):
os.remove(file_path)

self.assertEqual(len(corpus), 1)
self.assertEqual(len(corpus[0]), 2)
self.assertEqual(len(corpus_data), 1)
self.assertEqual(len(corpus_data[0]), 2)

def test_load_corpus_file_non_existent(self):
"""
Expand All @@ -166,25 +157,25 @@ def test_load_corpus_file_non_existent(self):

self.assertFalse(os.path.exists(file_path))
with self.assertRaises(IOError):
corpus = self.corpus.load_corpus(file_path)
corpus.load_corpus(file_path)

def test_load_corpus_english_greetings(self):
file_path = os.path.join(self.corpus.data_directory, 'english', 'greetings.yml')
file_path = os.path.join(corpus.DATA_DIRECTORY, 'english', 'greetings.yml')

corpus = self.corpus.load_corpus(file_path)
corpus_data = corpus.load_corpus(file_path)

self.assertEqual(len(corpus), 1)
self.assertEqual(len(corpus_data), 1)

def test_load_corpus_english(self):
file_path = os.path.join(self.corpus.data_directory, 'english')
file_path = os.path.join(corpus.DATA_DIRECTORY, 'english')

corpus = self.corpus.load_corpus(file_path)
corpus_data = corpus.load_corpus(file_path)

self.assertGreater(len(corpus), 1)
self.assertGreater(len(corpus_data), 1)

def test_load_corpus_english_trailing_slash(self):
file_path = os.path.join(self.corpus.data_directory, 'english') + '/'
file_path = os.path.join(corpus.DATA_DIRECTORY, 'english') + '/'

corpus = self.corpus.load_corpus(file_path)
corpus_data = corpus.load_corpus(file_path)

self.assertGreater(len(corpus), 1)
self.assertGreater(len(corpus_data), 1)
Loading

0 comments on commit f4019b4

Please sign in to comment.