From d45d246006fe65ba77bc566882f09cd545d3443a Mon Sep 17 00:00:00 2001 From: Austin Cullar Date: Tue, 1 Oct 2024 13:22:23 -0600 Subject: [PATCH 1/4] - Remove randomness AstroDB.create_unique_table_name - Comment table names will now 'roll forward' as new ones are created - Restricted names to 3 characters, allowing for 26^3 (17576) possible names - Removed unused AstroDB.comment_table_exists method and associated test - Modified the existing test for AstroDB.create_unique_table_name to account for new functionality - Added 2 mock classes in the new tests/astro_mocks.py file - MockSqlite3Connection - MockSqlite3Cursor - Added a fixture mock_sqlite3_connect to utilize the above mock classes to mock the database - I should now have full control over the database with this fixture --- src/astro_db.py | 71 +++++++++++++++++++++++----------- src/tests/astro_mocks.py | 32 +++++++++++++++ src/tests/test_astro_db.py | 79 ++++++++++++++++++-------------------- 3 files changed, 118 insertions(+), 64 deletions(-) create mode 100644 src/tests/astro_mocks.py diff --git a/src/astro_db.py b/src/astro_db.py index 8be017d..e57ab33 100644 --- a/src/astro_db.py +++ b/src/astro_db.py @@ -2,8 +2,6 @@ Class for managing comment/video database. """ import sqlite3 -import string -import random import pandas as pd @@ -22,36 +20,65 @@ def __init__(self, logger, db_file: str): def get_db_conn(self): return self.conn - def comment_table_exists(self, table_name: str) -> bool: + def get_next_table_name(self, last_table_name: str) -> str: """ - Check the 'Videos' table for an entry containing the provided - table name in the 'comment_table' column. + Roll the provided string forward by 'incrementing' the + letters. See below for some example transitions: + + AAA -> BAA + BAA -> CAA + ... + ZAA -> ABA + ABA -> BBA """ + # 'roll' the name forward + def next_char(c: chr) -> chr: + return chr(ord(c) + 1) + + new_name = '' + rolled = 0 + + for i in range(len(last_table_name)): + if last_table_name[i] != 'Z': + new_name += next_char(last_table_name[i]) + new_name += last_table_name[i+1:len(last_table_name)] + break + else: + rolled += 1 + new_name += 'A' + if rolled == len(last_table_name): + raise StopIteration("Limit exceeded for number of comment tables in database") - query = f"SELECT * FROM Videos WHERE comment_table='{table_name}'" - self.cursor.execute(query) - table_exists = self.cursor.fetchone() - - return bool(table_exists) + return new_name def create_unique_table_name(self) -> str: """ - Create a random table name from uppercase letters. + This function effectively implements a string odometer. The + name returned will be a 3 character string consisting of capital + letters. Each successive call will 'roll forward' the last-created + comment table name. + + Once the odometer reaches its limit of ZZZ, the next call will + result in an exception. This should only happen after the creation of + 26^3 (17576) tables, which is a limit I'm comfortable with since I doubt + we'll be tracking over one hundred YouTube videos, let alone over 17k. """ - attempts = 3 # 3 attempts to generate unique string - id_string = '' + # Get most recent comment table name by grabbing latest entry in the Videos table + self.cursor.execute("SELECT comment_table FROM Videos ORDER BY id DESC LIMIT 1") - # Generate random names until we get one that doesn't exist - while attempts > 0: - id_string = ''.join(random.choices(string.ascii_uppercase, k=12)) + last_table_name = self.cursor.fetchone() + if not last_table_name: # this is the first comment table we're creating + return 'AAA' + else: + last_table_name = last_table_name[0] - if self.comment_table_exists(id_string): - self.logger.warning('Comment table name collision!') - attempts -= 1 - else: - return id_string + self.logger.debug('last table name: {}'.format(last_table_name)) + + new_name = self.get_next_table_name(last_table_name) + + self.logger.debug('unique table name: {}'.format(new_name)) - return '' + return new_name def create_videos_table(self): """ diff --git a/src/tests/astro_mocks.py b/src/tests/astro_mocks.py new file mode 100644 index 0000000..20ac577 --- /dev/null +++ b/src/tests/astro_mocks.py @@ -0,0 +1,32 @@ +""" +This file will contain all mock classes used for testing. +""" + + +class MockSqlite3Cursor: + return_value = None + + def __init__(self, return_value): + self.return_value = return_value + + def fetchone(self): + if self.return_value: + return (self.return_value,) + else: + return None + + def execute(self, query: str): + return self.return_value + + +class MockSqlite3Connection: + return_value = None + + def set_return_value(self, return_value): + self.return_value = return_value + + def cursor(self): + return MockSqlite3Cursor(self.return_value) + + def commit(self): + return diff --git a/src/tests/test_astro_db.py b/src/tests/test_astro_db.py index a861413..74fb5a4 100644 --- a/src/tests/test_astro_db.py +++ b/src/tests/test_astro_db.py @@ -1,4 +1,5 @@ import pytest +import sqlite3 import os from unittest.mock import MagicMock @@ -6,6 +7,7 @@ # Astro modules from src.astro_db import AstroDB from src.data_collection.data_structures import VideoData +from src.tests.astro_mocks import MockSqlite3Connection test_video_data = [VideoData(video_id='e-qUSPnOlbb', channel_id='itXtJBHdZchKKjlnVrjXeCln', @@ -27,16 +29,6 @@ comment_count=76123)] -@pytest.fixture(scope='function', params=[True, False]) -def mock_comment_table_exists(request): - """ - Mock AstroDB.comment_table_exists to return True on every call. This - is used to stress AstroDB.create_unique_table_name. - """ - mock = AstroDB.comment_table_exists - mock.execute = MagicMock(return_value=request.param) - - @pytest.fixture(scope='class') def astro_db(logger): test_db_file = 'test.db' @@ -45,6 +37,23 @@ def astro_db(logger): os.remove(test_db_file) +@pytest.fixture(scope='function') +def mock_sqlite3_connect(): + # save the original connect function + sqlite3_connect_orig = sqlite3.connect + + # pass instance of our own MockSqlite3Connection class to MagicMock + sqlite_connect_mock = MockSqlite3Connection() + sqlite3.connect = MagicMock(return_value=sqlite_connect_mock) + + # yield the sqlite_connect_mock object so that the test function can + # set the return value + yield sqlite_connect_mock + + # restore the original connect function + sqlite3.connect = sqlite3_connect_orig + + class TestAstroDB: created_comment_tables = [] @@ -79,39 +88,25 @@ def test_create_comment_table_for_video(self, astro_db, video_data): self.created_comment_tables.append(comment_table_name) - def test_comment_table_exists(self, astro_db): - conn = astro_db.get_db_conn() - cursor = conn.cursor() - - for comment_table in self.created_comment_tables: - assert astro_db.comment_table_exists(comment_table) - - # get existing comment table from Videos table - for table_name in self.created_comment_tables: - cursor.execute(f"SELECT * FROM Videos WHERE comment_table='{table_name}'") - row = cursor.fetchone() - - assert row - - comment_table = row[4] # comment_table is the 5th column - assert comment_table in self.created_comment_tables - - @pytest.mark.parametrize('comment_table_exists', [True, False]) - def test_create_unique_table_name(self, astro_db, comment_table_exists): - # this method generates a random 12 character string of capital letters - # the likelihood of collision is extremely low (1 in (26^12)) - # using the mock below to simulate a name collision to exercise error path - mock = astro_db - mock.comment_table_exists = MagicMock(return_value=comment_table_exists) - name = astro_db.create_unique_table_name() - - if not comment_table_exists: # we're expecting a normal run, unlikely to have a name collision - # verify that the returned string is valid for sql table names - assert name - assert '-' not in name - assert '_' not in name + @pytest.mark.parametrize('table_names', [ + ['AAA', 'BAA'], + ['ABB', 'BBB'], + ['ZAA', 'ABA'], + ['ZZZ', ''], + [None, 'AAA'] + ]) + def test_create_unique_table_name(self, logger, mock_sqlite3_connect, table_names): + mock_sqlite3_connect.set_return_value(table_names[0]) + astro_db = AstroDB(logger, 'test2.db') + + if table_names[0] == 'ZZZ': + with pytest.raises(StopIteration) as exception: + name = astro_db.create_unique_table_name() + assert str(exception.value) == "Limit exceeded for number of comment tables in database" else: - assert not name + name = astro_db.create_unique_table_name() + assert name == table_names[1] + @pytest.mark.parametrize('video_data', test_video_data) def test_get_comment_table_for(self, astro_db, video_data): From f6ff862ef27a8e71102fb35210ff0f0f61d15c4b Mon Sep 17 00:00:00 2001 From: Austin Cullar Date: Tue, 1 Oct 2024 13:32:43 -0600 Subject: [PATCH 2/4] - Fix linting error --- src/tests/test_astro_db.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tests/test_astro_db.py b/src/tests/test_astro_db.py index 74fb5a4..3ae4ad3 100644 --- a/src/tests/test_astro_db.py +++ b/src/tests/test_astro_db.py @@ -107,7 +107,6 @@ def test_create_unique_table_name(self, logger, mock_sqlite3_connect, table_name name = astro_db.create_unique_table_name() assert name == table_names[1] - @pytest.mark.parametrize('video_data', test_video_data) def test_get_comment_table_for(self, astro_db, video_data): # verify that AstroDB finds the comment table From 9e6aea2013ed5abb2dba27df5cd512d703ccfa21 Mon Sep 17 00:00:00 2001 From: Austin Cullar Date: Wed, 2 Oct 2024 13:03:22 -0600 Subject: [PATCH 3/4] - Add error handling to astro_db methods - Modify tests to account for new error handling - Add fixture to mock database returning None to simulate record not found - Add static method YouTubeDataAPI.valid_video_id to aid in error checking --- src/astro_db.py | 27 ++++- src/data_collection/yt_data_api.py | 18 ++++ src/tests/astro_mocks.py | 4 +- src/tests/test_astro_db.py | 155 ++++++++++++++++++++--------- 4 files changed, 150 insertions(+), 54 deletions(-) diff --git a/src/astro_db.py b/src/astro_db.py index e57ab33..3874d6c 100644 --- a/src/astro_db.py +++ b/src/astro_db.py @@ -2,9 +2,10 @@ Class for managing comment/video database. """ import sqlite3 +import string import pandas as pd - +from src.data_collection.yt_data_api import YouTubeDataAPI class AstroDB: conn = None @@ -99,6 +100,16 @@ def create_comment_table_for_video(self, video_data) -> str: """ Create a new comment table for a specific video id. """ + if not video_data: + raise ValueError('NULL video data') + + if not video_data.channel_id or not YouTubeDataAPI.valid_video_id(video_data.video_id): + raise ValueError('Invalid video data') + + if not video_data.channel_title: + # Missing the channel title is not critical, but should be investigated + self.logger.warn('Missing channel title') + table_name = self.create_unique_table_name() assert table_name, "Failed to create unique comment table in database" @@ -129,12 +140,16 @@ def get_comment_table_for(self, video_id: str) -> str: """ Given a video id, return the associated comment table, if any. """ - get_comment_table_for_video = \ + if not YouTubeDataAPI.valid_video_id(video_id): # don't waste time querying database + return '' + + get_comment_table_for_video_id = \ f"SELECT comment_table FROM Videos WHERE video_id='{video_id}'" - self.cursor.execute(get_comment_table_for_video) + self.cursor.execute(get_comment_table_for_video_id) table = self.cursor.fetchone() + if table: return table[0] else: @@ -144,6 +159,12 @@ def insert_comment_dataframe(self, video_data, dataframe: pd.DataFrame): """ Given a video ID and a dataframe, commit the dataframe to the database. """ + if not video_data: + raise ValueError('NULL video data') + + if not YouTubeDataAPI.valid_video_id(video_data.video_id): + raise ValueError('Invalid video id') + comment_table = self.get_comment_table_for(video_data.video_id) if not comment_table: self.logger.debug('Comment table for video id {} did not exist'.format(video_data.video_id)) diff --git a/src/data_collection/yt_data_api.py b/src/data_collection/yt_data_api.py index ee59c5f..528a594 100644 --- a/src/data_collection/yt_data_api.py +++ b/src/data_collection/yt_data_api.py @@ -4,6 +4,7 @@ import pandas as pd import traceback +import string from src.data_collection.data_structures import VideoData from googleapiclient.discovery import build @@ -19,6 +20,23 @@ def __init__(self, logger, api_key): self.api_key = api_key self.youtube = build('youtube', 'v3', developerKey=self.api_key) + @staticmethod + def valid_video_id(video_id: str) -> bool: + valid_tokens = (string.ascii_uppercase + + string.ascii_lowercase + + string.digits + '-' + '_') + + if video_id: + for token in video_id: + if token not in valid_tokens: + return False + + # all tokens are valid + return True + + # null video_id + return False + def parse_comment_api_response(self, response) -> pd.DataFrame: """ Parse API response for comment query. This will grab all comments and their replies, diff --git a/src/tests/astro_mocks.py b/src/tests/astro_mocks.py index 20ac577..862f49d 100644 --- a/src/tests/astro_mocks.py +++ b/src/tests/astro_mocks.py @@ -15,8 +15,8 @@ def fetchone(self): else: return None - def execute(self, query: str): - return self.return_value + def execute(self, *args): + return class MockSqlite3Connection: diff --git a/src/tests/test_astro_db.py b/src/tests/test_astro_db.py index 3ae4ad3..176a2c4 100644 --- a/src/tests/test_astro_db.py +++ b/src/tests/test_astro_db.py @@ -7,6 +7,7 @@ # Astro modules from src.astro_db import AstroDB from src.data_collection.data_structures import VideoData +from src.data_collection.yt_data_api import YouTubeDataAPI from src.tests.astro_mocks import MockSqlite3Connection test_video_data = [VideoData(video_id='e-qUSPnOlbb', @@ -26,7 +27,16 @@ channel_title='User_YT99', view_count=12345, like_count=66423, - comment_count=76123)] + comment_count=76123), + # case where an invalid video_id string is provided + VideoData(video_id='bad data', + channel_id='bad data', + channel_title='bad data'), + # empty data set case + VideoData(video_id='', + channel_id='', + channel_title=''), + None] @pytest.fixture(scope='class') @@ -54,6 +64,21 @@ def mock_sqlite3_connect(): sqlite3.connect = sqlite3_connect_orig +@pytest.fixture(scope='function') +def database_fault(mock_sqlite3_connect, logger): + """ + Force the database queries to return None + """ + valid_video_id_orig = YouTubeDataAPI.valid_video_id + + YouTubeDataAPI.valid_video_id = MagicMock(return_value=True) + mock_sqlite3_connect.set_return_value(None) + + yield AstroDB(logger, 'test2.db') + + YouTubeDataAPI.valid_video_id = valid_video_id_orig + + class TestAstroDB: created_comment_tables = [] @@ -73,20 +98,32 @@ def test_create_comment_table_for_video(self, astro_db, video_data): conn = astro_db.get_db_conn() cursor = conn.cursor() - # create entry in Videos table along with a new comment table for that video - comment_table_name = astro_db.create_comment_table_for_video(video_data) + bad_input = not video_data or \ + not video_data.channel_id or \ + not YouTubeDataAPI.valid_video_id(video_data.video_id) + + if bad_input: # expect an exception + with pytest.raises(ValueError) as exception: + comment_table_name = astro_db.create_comment_table_for_video(video_data) + if not video_data: + assert str(exception.value) == 'NULL video data' + elif not video_data.channel_id or not video_data.video_id: + assert str(exception.value) == 'Invalid video data' + else: + # create entry in Videos table along with a new comment table for that video + comment_table_name = astro_db.create_comment_table_for_video(video_data) - # verify creation of entry in Videos and the new comment table - cursor.execute(f"SELECT * FROM Videos WHERE video_id='{video_data.video_id}'") - video_table_data = cursor.fetchone() + # verify creation of entry in Videos and the new comment table + cursor.execute(f"SELECT * FROM Videos WHERE video_id='{video_data.video_id}'") + video_table_data = cursor.fetchone() - assert video_table_data - assert video_table_data[1] == video_data.channel_title - assert video_table_data[2] == video_data.channel_id - assert video_table_data[3] == video_data.video_id - assert video_table_data[4] == comment_table_name + assert video_table_data + assert video_table_data[1] == video_data.channel_title + assert video_table_data[2] == video_data.channel_id + assert video_table_data[3] == video_data.video_id + assert video_table_data[4] == comment_table_name - self.created_comment_tables.append(comment_table_name) + self.created_comment_tables.append(comment_table_name) @pytest.mark.parametrize('table_names', [ ['AAA', 'BAA'], @@ -102,56 +139,76 @@ def test_create_unique_table_name(self, logger, mock_sqlite3_connect, table_name if table_names[0] == 'ZZZ': with pytest.raises(StopIteration) as exception: name = astro_db.create_unique_table_name() - assert str(exception.value) == "Limit exceeded for number of comment tables in database" + assert str(exception.value) == 'Limit exceeded for number of comment tables in database' else: name = astro_db.create_unique_table_name() assert name == table_names[1] - @pytest.mark.parametrize('video_data', test_video_data) - def test_get_comment_table_for(self, astro_db, video_data): + @pytest.mark.parametrize('fail_database_query', [True, False]) + @pytest.mark.parametrize('video_id', [video_data.video_id for video_data in test_video_data if video_data]) + def test_get_comment_table_for(self, request, astro_db, fail_database_query, video_id): + if fail_database_query: + # force database to return None in order to test lookup failure path + astro_db = request.getfixturevalue('database_fault') + + # consider this a normal run if we have a valid video_id and no expected database failure + normal_run = YouTubeDataAPI.valid_video_id(video_id) and not fail_database_query + # verify that AstroDB finds the comment table - table_name = astro_db.get_comment_table_for(video_data.video_id) + table_name = astro_db.get_comment_table_for(video_id) - assert table_name + assert table_name if normal_run else not table_name # verify that the database agrees with AstroDB conn = astro_db.get_db_conn() cursor = conn.cursor() - cursor.execute(f"SELECT comment_table FROM Videos WHERE video_id='{video_data.video_id}'") + cursor.execute(f"SELECT comment_table FROM Videos WHERE video_id='{video_id}'") database_table = cursor.fetchone() - assert database_table - assert database_table[0] == table_name + assert database_table if normal_run else not database_table + if normal_run: + assert database_table[0] == table_name @pytest.mark.parametrize('video_data', test_video_data) def test_insert_comment_dataframe(self, astro_db, video_data, comment_dataframe): - astro_db.insert_comment_dataframe(video_data, comment_dataframe) - - conn = astro_db.get_db_conn() - cursor = conn.cursor() - - # check database for dataframe content - query = f"SELECT comment_table FROM Videos WHERE video_id='{video_data.video_id}'" - cursor.execute(query) - comment_table = cursor.fetchone() - - # there should only be one comment table - assert len(comment_table) == 1 - - comment_table = comment_table[0] - - # grab all rows from coment tables - query = f"SELECT * FROM {comment_table}" - cursor.execute(query) - comment_data = cursor.fetchall() - - # verify that the data in the table matches that in the dataframe - index = 0 - for row in comment_data: - assert row[0] == index - assert row[1] == comment_dataframe.loc[index]['comment'] - assert row[2] == comment_dataframe.loc[index]['user'] - assert row[3] == comment_dataframe.loc[index]['date'] - - index += 1 + bad_input = not video_data or \ + not YouTubeDataAPI.valid_video_id(video_data.video_id) + + if bad_input: # expect an exception + with pytest.raises(ValueError) as exception: + astro_db.insert_comment_dataframe(video_data, comment_dataframe) + if not video_data: + assert str(exception.value) == 'NULL video data' + elif not YouTubeDataAPI.valid_video_id(video_data.video_id): + assert str(exception.value) == 'Invalid video id' + else: # insert dataframe into database, verify table contents + astro_db.insert_comment_dataframe(video_data, comment_dataframe) + + conn = astro_db.get_db_conn() + cursor = conn.cursor() + + # check database for dataframe content + query = f"SELECT comment_table FROM Videos WHERE video_id='{video_data.video_id}'" + cursor.execute(query) + comment_table = cursor.fetchone() + + # there should only be one comment table + assert len(comment_table) == 1 + + comment_table = comment_table[0] + + # grab all rows from coment tables + query = f"SELECT * FROM {comment_table}" + cursor.execute(query) + comment_data = cursor.fetchall() + + # verify that the data in the table matches that in the dataframe + index = 0 + for row in comment_data: + assert row[0] == index + assert row[1] == comment_dataframe.loc[index]['comment'] + assert row[2] == comment_dataframe.loc[index]['user'] + assert row[3] == comment_dataframe.loc[index]['date'] + + index += 1 From f373d3a4848ed04ebf450e0bc9945024501d9b39 Mon Sep 17 00:00:00 2001 From: Austin Cullar Date: Wed, 2 Oct 2024 13:07:47 -0600 Subject: [PATCH 4/4] - Fix linting errors --- src/astro_db.py | 2 +- src/data_collection/yt_data_api.py | 4 ++-- src/tests/test_astro_db.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/astro_db.py b/src/astro_db.py index 3874d6c..4cd98b2 100644 --- a/src/astro_db.py +++ b/src/astro_db.py @@ -2,11 +2,11 @@ Class for managing comment/video database. """ import sqlite3 -import string import pandas as pd from src.data_collection.yt_data_api import YouTubeDataAPI + class AstroDB: conn = None cursor = None diff --git a/src/data_collection/yt_data_api.py b/src/data_collection/yt_data_api.py index 528a594..79db1d3 100644 --- a/src/data_collection/yt_data_api.py +++ b/src/data_collection/yt_data_api.py @@ -23,8 +23,8 @@ def __init__(self, logger, api_key): @staticmethod def valid_video_id(video_id: str) -> bool: valid_tokens = (string.ascii_uppercase + - string.ascii_lowercase + - string.digits + '-' + '_') + string.ascii_lowercase + + string.digits + '-' + '_') if video_id: for token in video_id: diff --git a/src/tests/test_astro_db.py b/src/tests/test_astro_db.py index 176a2c4..b3cc18a 100644 --- a/src/tests/test_astro_db.py +++ b/src/tests/test_astro_db.py @@ -99,8 +99,8 @@ def test_create_comment_table_for_video(self, astro_db, video_data): cursor = conn.cursor() bad_input = not video_data or \ - not video_data.channel_id or \ - not YouTubeDataAPI.valid_video_id(video_data.video_id) + not video_data.channel_id or \ + not YouTubeDataAPI.valid_video_id(video_data.video_id) if bad_input: # expect an exception with pytest.raises(ValueError) as exception: