Skip to content

Commit

Permalink
feat: Implement API endpoint for querying all entities.
Browse files Browse the repository at this point in the history
  • Loading branch information
PromiseFru committed Sep 20, 2024
1 parent 03b6f34 commit 2b179a0
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 2 deletions.
45 changes: 45 additions & 0 deletions migrations/update_entity_created_date.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Script for updating entity 'date_created' for V2 migrated users."""

import logging
from tqdm import tqdm
from src.schemas.usersinfo import UsersInfos
from src.entity import find_entity

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("migration.script")


def fetch_verified_users_infos_data():
"""Fetch verified user information from the UsersInfos table."""
return UsersInfos.select().where(UsersInfos.status == "verified")


def update_created_date(user_info):
"""Update the date_created field for the given user entity."""
phone_number_hash = user_info.full_phone_number
user_created_date = user_info.createdAt

try:
entity = find_entity(phone_number_hash=phone_number_hash)
if entity:
entity.date_created = user_created_date
entity.save()
except Exception as e:
logger.exception("Error updating user: %s - %s", user_info.userId, str(e))


def run():
"""Main function to process all verified users and update the date_created."""
users_infos_data = fetch_verified_users_infos_data()

total = users_infos_data.count()
with tqdm(total=total, desc="Updating", unit="users") as pbar:
for user_info in users_infos_data:
update_created_date(user_info)
pbar.update(1)


if __name__ == "__main__":
run()
9 changes: 7 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from utils.SSL import isSSL

from settings import Configurations

ssl_cert = Configurations.SSL_CERTIFICATE
ssl_port = Configurations.SSL_PORT
ssl_key = Configurations.SSL_KEY
Expand All @@ -20,6 +21,7 @@
from flask_cors import CORS

from src.api_v2 import v2
from src.api_v3 import v3_blueprint

from SwobThirdPartyPlatforms import base_dir

Expand All @@ -32,13 +34,16 @@
)

app.register_blueprint(v2, url_prefix="/v2")
app.register_blueprint(v3_blueprint)


@app.route('/public/<path:path>')
@app.route("/public/<path:path>")
def send_report(path):
platform_name = path.split("-")[0]
logo_path = os.path.join(base_dir, platform_name)
return send_from_directory(logo_path, path)


checkSSL = isSSL(path_crt_file=ssl_cert, path_key_file=ssl_key, path_pem_file=ssl_pem)

if __name__ == "__main__":
Expand All @@ -63,4 +68,4 @@ def send_report(path):
app.run(host=api_host, port=ssl_port, ssl_context=context)
else:
app.logger.info("Running on un-secure port: %s" % api_port)
app.run(host=api_host, port=api_port)
app.run(host=api_host, port=api_port)
189 changes: 189 additions & 0 deletions src/api_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""API V3 Blueprint"""

from datetime import datetime
import calendar

from flask import Blueprint, request, jsonify
from flask_cors import CORS
from werkzeug.exceptions import BadRequest, NotFound

from phonenumbers import geocoder

from src.db import connect
from src.entity import fetch_all_entities
from src.utils import decrypt_and_decode
from base_logger import get_logger

v3_blueprint = Blueprint("v3", __name__, url_prefix="/v3")
CORS(v3_blueprint)

database = connect()

logger = get_logger(__name__)


def set_security_headers(response):
"""Set security headers for each response."""
security_headers = {
"Strict-Transport-Security": "max-age=63072000; includeSubdomains",
"X-Content-Type-Options": "nosniff",
"Content-Security-Policy": "script-src 'self'; object-src 'self'",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Cache-Control": "no-cache",
"Permissions-Policy": (
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), "
"clipboard-read=(), clipboard-write=(), cross-origin-isolated=(), display-capture=(), "
"document-domain=(), encrypted-media=(), execution-while-not-rendered=(), "
"execution-while-out-of-viewport=(), fullscreen=(), gamepad=(), geolocation=(), "
"gyroscope=(), magnetometer=(), microphone=(), midi=(), navigation-override=(), "
"payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), "
"speaker=(), speaker-selection=(), sync-xhr=(), usb=(), web-share=(), "
"xr-spatial-tracking=()"
),
}

for header, value in security_headers.items():
response.headers[header] = value

return response


@v3_blueprint.before_request
def _db_connect():
"""Connect to the database before processing the request."""
database.connect()


@v3_blueprint.teardown_request
def _db_close(response):
"""Close the database connection after processing the request."""
database.close()
return response


@v3_blueprint.after_request
def after_request(response):
"""Set security headers after each request."""
response = set_security_headers(response)
return response


def fetch_entities_by_month(result, start, end):
"""Fetch entities grouped by month."""
new_start = datetime.min
new_end = datetime(
end.year, end.month, calendar.monthrange(end.year, end.month)[1], 23, 59, 59
)

entities = fetch_all_entities(date_range=(new_start, new_end))

for entity in entities:
entity_date_created = entity.date_created
month_name = calendar.month_name[entity_date_created.month]

result = update_result_by_time(result, entity_date_created, month_name)
result = update_countries(result, entity)

result["total_users"] = len(entities)
result["total_countries"] = len(result["countries"])

return result


def fetch_entities_by_day(result, start, end):
"""Fetch entities grouped by day."""
new_start = datetime.min
new_end = datetime(end.year, end.month, end.day, 23, 59, 59)

entities = fetch_all_entities(date_range=(new_start, new_end))

for entity in entities:
entity_date_created = entity.date_created
day_name = entity_date_created.strftime("%c")

result = update_result_by_time(result, entity_date_created, day_name)
result = update_countries(result, entity)

result["total_users"] = len(entities)
result["total_countries"] = len(result["countries"])

return result


def update_result_by_time(result, entity_date_created, time_name):
"""Helper to update the result dictionary with time-based data."""
year = str(entity_date_created.year)

if not result.get(year):
result[year] = []

if any(time_name in x for x in result[year]):
for x in result[year]:
if x[0] == time_name:
x[1] += 1
else:
result[year].append([time_name, 1])

return result


def update_countries(result, entity):
"""Helper to update the result dictionary with country-based data."""
region_code = decrypt_and_decode(entity.country_code)
country_name = geocoder._region_display_name(region_code, "en")

if any(country_name in x for x in result["countries"]):
for x in result["countries"]:
if x[0] == country_name and x[1] == region_code:
x[2] += 1
else:
result["countries"].append([country_name, region_code, 1])

return result


@v3_blueprint.route("/entities", methods=["GET"])
def get_entities_analysis():
"""Retrieve analysis of entities."""
start = request.args.get("start")
end = request.args.get("end")
_format = request.args.get("format", "month")

if not start or not end:
raise BadRequest("Invalid input parameters. Provide 'start', 'end'.")

start = datetime.strptime(start, "%Y-%m-%d").date()
end = datetime.strptime(end, "%Y-%m-%d").date()

if start > end:
raise BadRequest("'start' date cannot be after 'end' date.")

result = {"total_users": 0, "total_countries": 0, "countries": []}

if _format == "month":
result = fetch_entities_by_month(result, start, end)
elif _format == "day":
result = fetch_entities_by_day(result, start, end)
else:
raise BadRequest("Invalid format. Expected 'month' or 'day'.")

logger.info("Successfully fetched entities data.")
return jsonify(result), 200


@v3_blueprint.errorhandler(BadRequest)
@v3_blueprint.errorhandler(NotFound)
def handle_bad_request_error(error):
"""Handle BadRequest errors."""
logger.error(error.description)
return jsonify({"error": error.description}), error.code


@v3_blueprint.errorhandler(Exception)
def handle_generic_error(error):
"""Handle generic errors."""
logger.exception(error)
return (
jsonify({"error": "Oops! Something went wrong. Please try again later."}),
500,
)
59 changes: 59 additions & 0 deletions src/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,62 @@ def find_entity(**search_criteria):
except DoesNotExist:
logger.debug("Entity is not found...")
return None


def fetch_all_entities(
filters=None, date_range=None, truncate_by=None, return_json=False
):
"""
Fetch all entities with optional filters, date range, and date truncation.
Args:
filters (dict, optional): A dictionary where the keys are field names
and the values are the conditions/criteria to filter the entities by.
Defaults to None.
date_range (tuple, optional): A tuple containing (start_date, end_date) to
filter the 'date_created' field by. Dates should be datetime objects.
truncate_by (str, optional): Specify if the 'date_created' should be
truncated by 'day' or 'month'. If provided, date filtering will apply
to truncated dates.
return_json (bool, optional): If True, return the results as a list of dicts.
If False (default), return the results as a list of token objects.
Returns:
list: A list of all entities matching the filter criteria.
"""
filters = filters or {}
logger.debug(
"Fetching all entities with filters: %s, date_range: %s, truncate_by: %s",
filters,
date_range,
truncate_by,
)

with database.connection_context():
query = Entity.select()
conditions = []

for field, value in filters.items():
conditions.append(getattr(Entity, field) == value)

if date_range:
start_date, end_date = date_range
if truncate_by == "day":
conditions.append(
Entity.date_created.truncate("day").between(start_date, end_date)
)
elif truncate_by == "month":
conditions.append(
Entity.date_created.truncate("month").between(start_date, end_date)
)
else:
conditions.append(Entity.date_created.between(start_date, end_date))

if conditions:
query = query.where(*conditions)

total_records = query.count()

entities = list(query.dicts()) if return_json else list(query.execute())
logger.debug("Found %s entities", total_records)
return entities

0 comments on commit 2b179a0

Please sign in to comment.