Skip to content

Commit

Permalink
- Simplify error checking in astro_db.py
Browse files Browse the repository at this point in the history
    - should not be verifying correct youtube data
- YouTubeDataAPI now takes a URL instead of the video id
    - no more static methods for validating video ids/urls
- Tests modified to account for new error checking
- Removed unnecessary test in test_astro_db.py
  • Loading branch information
AustinCullar committed Oct 26, 2024
1 parent 5241c0a commit 2d37241
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 112 deletions.
17 changes: 1 addition & 16 deletions src/astro.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,6 @@
from rich_argparse import ArgumentDefaultsRichHelpFormatter


def extract_video_id_from_url(url: str) -> str:
"""
Grab the video ID from the provided URL. The ID will come after
the substring 'v=' in the URL, so I just split the string on that
substring and return the latter half.
"""

video_id = url.split('v=')[1]
if not YouTubeDataAPI.valid_video_id(video_id):
raise ValueError('Invalid video URL provided')

return video_id


def parse_args(astro_theme):
"""
Argument parsing logic. Returns the arguments parsed from the CLI
Expand Down Expand Up @@ -57,7 +43,6 @@ def main():

# parse arguments
args = parse_args(astro_theme)
video_id = extract_video_id_from_url(args.youtube_url)

# load environment variables
load_dotenv()
Expand All @@ -76,7 +61,7 @@ def main():

# collect metadata for provided video
youtube = YouTubeDataAPI(logger, api_key, log_json)
video_data = youtube.get_video_metadata(video_id)
video_data = youtube.get_video_metadata(args.youtube_url)

logger.print_video_data(video_data)

Expand Down
33 changes: 13 additions & 20 deletions src/astro_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import sqlite3

import pandas as pd
from src.data_collection.yt_data_api import YouTubeDataAPI
from src.data_collection.data_structures import VideoData


Expand All @@ -28,7 +27,9 @@ def __merge_comment_data(self, comment_table: str, new_dataframe: pd.DataFrame):
append any new comments to the comment table.
"""
# pull comments from local database
db_dataframe = pd.read_sql(f"SELECT * FROM '{comment_table}'", self.conn)
db_dataframe = pd.read_sql(f'SELECT * FROM {comment_table}', self.conn)
if db_dataframe is None:
raise LookupError(f'Failed to pull data from comment table: {comment_table}')

# check for comments made nonvisible since our last check
nonvisible_comments = self.__get_nonvisible_comments(old=db_dataframe, new=new_dataframe)
Expand Down Expand Up @@ -113,10 +114,9 @@ def __create_comment_table_for_video(self, video_data) -> str:
"""
self.logger.debug('Creating comment table for new video...')

if not video_data:
raise ValueError('NULL video data')

if not video_data.channel_id or not YouTubeDataAPI.valid_video_id(video_data.video_id):
if not video_data or \
not video_data.channel_id or \
not video_data.video_id:
raise ValueError('Invalid video data')

if not video_data.channel_title:
Expand Down Expand Up @@ -164,14 +164,7 @@ def __get_comment_table_for(self, video_id: str) -> str:
"""
self.logger.debug(f'Searching for comment table for video ID: {video_id}')

if not YouTubeDataAPI.valid_video_id(video_id): # don't waste time querying database
return ''

get_comment_table_for_video_id = \
f"SELECT comment_table FROM Videos WHERE video_id='{video_id}'"

self.cursor.execute(get_comment_table_for_video_id)

self.cursor.execute(f"SELECT comment_table FROM Videos WHERE video_id='{video_id}'")
table = self.cursor.fetchone()

if table:
Expand Down Expand Up @@ -224,21 +217,21 @@ def insert_comment_dataframe(self, video_data, dataframe: pd.DataFrame):
"""
self.logger.debug('Inserting new comment dataframe...')

if not video_data:
raise ValueError('NULL video data')
if not video_data or not video_data.video_id:
raise ValueError('Invalid video data')

if not YouTubeDataAPI.valid_video_id(video_data.video_id):
raise ValueError('Invalid video id')
if dataframe is None:
raise ValueError('Cannot insert NULL dataframe')

comment_table = self.__get_comment_table_for(video_data.video_id)
if comment_table:
self.logger.debug('Compare & merge local comment data with new data')
self.logger.debug('Merging new comment data with local database...')
return self.__merge_comment_data(comment_table, dataframe)
else:
self.logger.debug(f'Comment table for video id {video_data.video_id} did not exist - creating it now')
comment_table = self.__create_comment_table_for_video(video_data)

dataframe.to_sql(comment_table, self.conn, index=False, if_exists='append')
dataframe.to_sql(comment_table, self.conn, index=False, if_exists='replace')

self.conn.commit()

Expand Down
43 changes: 23 additions & 20 deletions src/data_collection/yt_data_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,7 @@ def __init__(self, logger, api_key, log_json=False):
self.log_json = log_json
self.youtube = build('youtube', 'v3', developerKey=self.api_key)

@staticmethod
def valid_video_id(video_id: str) -> bool:
valid_tokens = (string.ascii_uppercase +
string.ascii_lowercase +
string.digits + '-' + '_')

if video_id:
for token in video_id:
if token not in valid_tokens:
return False

# all tokens are valid
return True

# null video_id
return False

def parse_comment_api_response(self, response, comment_dataframe) -> pd.DataFrame:
def __parse_comment_api_response(self, response, comment_dataframe) -> pd.DataFrame:
"""
Parse API response for comment query. This will grab all comments and their replies,
storing the resulting data in a dataframe.
Expand Down Expand Up @@ -85,6 +68,25 @@ def parse_comment_api_response(self, response, comment_dataframe) -> pd.DataFram

return df, comment_count

def __extract_video_id_from_url(self, url: str) -> str:
"""
Grab the video ID from the provided URL. The ID will come after
the substring 'v=' in the URL, so I just split the string on that
substring and return the latter half.
"""
video_id = url.split('v=')[1]

# validate extracted video id
valid_tokens = (string.ascii_uppercase +
string.ascii_lowercase +
string.digits + '-' + '_')

for token in video_id:
if token not in valid_tokens:
raise ValueError('Invalid video URL provided')

return video_id

def get_comments(self, video_data) -> pd.DataFrame:
"""
Collect and store comment information in a dataframe. Collected
Expand Down Expand Up @@ -120,7 +122,7 @@ def get_comments(self, video_data) -> pd.DataFrame:
with self.logger.log_file_only():
self.logger.info(json.dumps(response, indent=4))

comment_dataframe, comments_added = self.parse_comment_api_response(response, comment_dataframe)
comment_dataframe, comments_added = self.__parse_comment_api_response(response, comment_dataframe)
if 'nextPageToken' in response: # there are more comments to fetch
page_token = response['nextPageToken']
else:
Expand All @@ -141,13 +143,14 @@ def get_comments(self, video_data) -> pd.DataFrame:

return comment_dataframe

def get_video_metadata(self, video_id: str) -> VideoData:
def get_video_metadata(self, url: str) -> VideoData:
"""
Collect video information provided a video ID.
Return all data in a VideoData class for easy access.
"""
self.logger.debug('Collecting video metadata...')

video_id = self.__extract_video_id_from_url(url)
return_data = VideoData()

request = self.youtube.videos().list(
Expand Down
74 changes: 19 additions & 55 deletions src/tests/test_astro_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# Astro modules
from src.astro_db import AstroDB
from src.tests.test_objects import test_video_data
from src.data_collection.yt_data_api import YouTubeDataAPI
from src.tests.astro_mocks import MockSqlite3Connection


Expand Down Expand Up @@ -41,14 +40,8 @@ def database_fault(mock_sqlite3_connect, logger):
"""
Force the database queries to return None
"""
valid_video_id_orig = YouTubeDataAPI.valid_video_id

YouTubeDataAPI.valid_video_id = MagicMock(return_value=True)
mock_sqlite3_connect.set_return_value(None)

yield AstroDB(logger, 'test2.db')

YouTubeDataAPI.valid_video_id = valid_video_id_orig
return AstroDB(logger, 'test2.db')


class TestAstroDB:
Expand All @@ -71,19 +64,13 @@ def __get_table_row_count(self, conn, table_name):
return row_count

def __insert_dataframe_exception(self, astro_db, comment_dataframe, video_data) -> bool:
bad_input = not video_data or \
not YouTubeDataAPI.valid_video_id(video_data.video_id)

if not bad_input:
if video_data and video_data.video_id:
return False

# expect an exception
with pytest.raises(ValueError) as exception:
astro_db.insert_comment_dataframe(video_data, comment_dataframe)
if not video_data:
assert str(exception.value) == 'NULL video data'
elif not YouTubeDataAPI.valid_video_id(video_data.video_id):
assert str(exception.value) == 'Invalid video id'
assert str(exception.value) == 'Invalid video id'

return True

Expand All @@ -105,15 +92,12 @@ def test_create_comment_table_for_video(self, astro_db, video_data):

bad_input = not video_data or \
not video_data.channel_id or \
not YouTubeDataAPI.valid_video_id(video_data.video_id)
not video_data.video_id

if bad_input: # expect an exception
with pytest.raises(ValueError) as exception:
comment_table_name = astro_db._AstroDB__create_comment_table_for_video(video_data)
if not video_data:
assert str(exception.value) == 'NULL video data'
elif not video_data.channel_id or not video_data.video_id:
assert str(exception.value) == 'Invalid video data'
assert str(exception.value) == 'Invalid video data'
else:
# create entry in Videos table along with a new comment table for that video
comment_table_name = astro_db._AstroDB__create_comment_table_for_video(video_data)
Expand Down Expand Up @@ -153,32 +137,6 @@ def test_create_unique_table_name(self, logger, mock_sqlite3_connect, table_name
name = astro_db._AstroDB__create_unique_table_name()
assert name == table_names[1]

@pytest.mark.parametrize('fail_database_query', [True, False])
@pytest.mark.parametrize('video_id', [video_data.video_id for video_data in test_video_data if video_data])
def test_get_comment_table_for(self, request, astro_db, fail_database_query, video_id):
if fail_database_query:
# force database to return None in order to test lookup failure path
astro_db = request.getfixturevalue('database_fault')

# consider this a normal run if we have a valid video_id and no expected database failure
normal_run = YouTubeDataAPI.valid_video_id(video_id) and not fail_database_query

# verify that AstroDB finds the comment table
table_name = astro_db._AstroDB__get_comment_table_for(video_id)

assert table_name if normal_run else not table_name

# verify that the database agrees with AstroDB
conn = astro_db.get_db_conn()
cursor = conn.cursor()

cursor.execute(f"SELECT comment_table FROM Videos WHERE video_id='{video_id}'")
database_table = cursor.fetchone()

assert database_table if normal_run else not database_table
if normal_run:
assert database_table[0] == table_name

@pytest.mark.parametrize('video_data', test_video_data)
def test_insert_comment_dataframe(self, astro_db, video_data, comment_dataframe):
if not self.__insert_dataframe_exception(astro_db, comment_dataframe, video_data):
Expand Down Expand Up @@ -219,18 +177,24 @@ def test_get_video_data(self, astro_db, video_data):
conn = astro_db.get_db_conn()
cursor = conn.cursor()

if YouTubeDataAPI.valid_video_id(video_data.video_id):
if not video_data.video_id:
with pytest.raises(ValueError) as exception:
db_video_data = astro_db.get_video_data(video_data.video_id)
assert str(exception.value) == 'Invalid video id'
else:
db_video_data = astro_db.get_video_data(video_data.video_id)

cursor.execute(f"SELECT * from Videos WHERE video_id='{video_data.video_id}'")
db_entry = cursor.fetchone()

assert db_entry
assert db_entry[1] == video_data.channel_title
assert db_entry[2] == video_data.channel_id
assert db_entry[3] == video_data.video_id
assert db_entry[4] == video_data.view_count
assert db_entry[5] == video_data.like_count
assert db_entry[6] == video_data.comment_count
assert db_entry[7] == video_data.filtered_comment_count
assert db_entry[1] == db_video_data.channel_title
assert db_entry[2] == db_video_data.channel_id
assert db_entry[3] == db_video_data.video_id
assert db_entry[4] == db_video_data.view_count
assert db_entry[5] == db_video_data.like_count
assert db_entry[6] == db_video_data.comment_count
assert db_entry[7] == db_video_data.filtered_comment_count

@pytest.mark.parametrize('video_data', [test_video_data[0]])
def test_new_comment_detection(self, astro_db, comment_dataframe, video_data):
Expand Down
2 changes: 1 addition & 1 deletion src/tests/test_yt_data_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_get_video_metadata(
viewCount=viewCount,
commentCount=commentCount)

video_data = youtube.get_video_metadata('video_id')
video_data = youtube.get_video_metadata('youtube.com/test/v=videoid')

assert video_data.channel_id == channelId
assert video_data.channel_title == channelTitle
Expand Down

0 comments on commit 2d37241

Please sign in to comment.