From bc7bc3178a3b182b1643e14d5784823dd87693d5 Mon Sep 17 00:00:00 2001 From: Graham Hukill Date: Fri, 30 Aug 2024 11:52:42 -0400 Subject: [PATCH] First pass of adding type hints Why these changes are being introduced: This repository pre-dated a convention to add type hints and apply mypy linting. How this addresses that need: * First pass adds type hints were easy * Focus on function argument and return types * This does NOT address all typing linting errors, preferring to untangle some custom exception linting errors in a future commit Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/IN-1059 --- Pipfile | 2 + Pipfile.lock | 96 ++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 3 +- submitter/cli.py | 12 +++--- submitter/config.py | 30 +++++++------ submitter/errors.py | 2 +- submitter/message.py | 17 +++++--- submitter/sqs.py | 23 +++++++--- submitter/submission.py | 60 ++++++++++++++++---------- 9 files changed, 187 insertions(+), 58 deletions(-) diff --git a/Pipfile b/Pipfile index 0612e18..d9d8372 100644 --- a/Pipfile +++ b/Pipfile @@ -24,6 +24,8 @@ ruff = "*" safety = "*" pre-commit = "*" mypy = "*" +types-requests = "*" +boto3-stubs = {extras = ["essential"], version = "*"} [requires] python_version = "3.12" diff --git a/Pipfile.lock b/Pipfile.lock index 79f52e4..9a3d198 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "45ee5b77d825bd3326278f7f6a82a76b33d18dd2732d9ac0619d6e77ed0ace7b" + "sha256": "17022a654f8221c2957356352640ac68931329e76d4743409fcf9307d2d036a3" }, "pipfile-spec": 6, "requires": { @@ -321,6 +321,17 @@ "markers": "python_version >= '3.8'", "version": "==1.35.9" }, + "boto3-stubs": { + "extras": [ + "essential" + ], + "hashes": [ + "sha256:4aeffc47379b78237c1bbec0d6928ddabca3434f6fd065148905c47fe0e9b2f6", + "sha256:77754883031f49f61e2f2bd6df189dade8d3138f9fd07a7eb49f4a7e37156c7d" + ], + "markers": "python_version >= '3.8'", + "version": "==1.35.9" + }, "botocore": { "hashes": [ "sha256:92962460e4f35d139a23bca28149722030143257ee2916de442243c2464a7434", @@ -329,6 +340,14 @@ "markers": "python_version >= '3.8'", "version": "==1.35.9" }, + "botocore-stubs": { + "hashes": [ + "sha256:59ffbfcac1833990e6c586bff4fcf481ef01b4b88e6d65073d842f6aff84b92e", + "sha256:98c1d7ca0147cd88cc09f08425a2d58bd004232d211c4a4ca9febda6f3b20a85" + ], + "markers": "python_version >= '3.8'", + "version": "==1.35.9" + }, "certifi": { "hashes": [ "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8", @@ -1016,6 +1035,55 @@ "markers": "python_version >= '3.8'", "version": "==1.11.2" }, + "mypy-boto3-cloudformation": { + "hashes": [ + "sha256:0d037d9d6bdb439a84e2391ba987a4e03fcedfad0e881db1cf0f7861d275907c", + "sha256:5da07e14a206a7f0015434d1730a6a68a33167ea6746343189dd1742cfcfdb7d" + ], + "version": "==1.35.0" + }, + "mypy-boto3-dynamodb": { + "hashes": [ + "sha256:1e503c89a5aa65f2b90fc7c861d3630a21544822f30b38e67e4f52463111abb9", + "sha256:75f224d8b78f6d3126eead645aea6c0a8bc2828614f302c168de1d3dad490d11" + ], + "version": "==1.35.0" + }, + "mypy-boto3-ec2": { + "hashes": [ + "sha256:b3e17ee6082a107d7d6d7ac44062264a9fb711c5d6d9e0ce16837cda26d1be7c", + "sha256:f4cdbe524ff4039668cc168e3c6f9c68048481ab33dfb0f5d892bbf2428d1ef2" + ], + "version": "==1.35.8" + }, + "mypy-boto3-lambda": { + "hashes": [ + "sha256:2e78c12a7ba4d2d9c99b75fad58804fd99820e954ab557f14f099d6c85a882ab", + "sha256:b59e45facfc166eddb1d5c2696aa8127463455f9e439e3438494965bcd97c97d" + ], + "version": "==1.35.3" + }, + "mypy-boto3-rds": { + "hashes": [ + "sha256:8861b551854cabec2efbe40db506297e9526e1496a1e55843136df716a2b7a00", + "sha256:c252857561219ecc0a03b2d3936081d7a54a59d1caa01e69deb8cdea761dab76" + ], + "version": "==1.35.0" + }, + "mypy-boto3-s3": { + "hashes": [ + "sha256:74d8f3492eeff768ff6f69ac6d40bf68b40aa6e54ebe10a8d098fc3d24a54abf", + "sha256:f7300b559dee5435872625448becf159abe36b19cd7006dd78e0d51610312183" + ], + "version": "==1.35.2" + }, + "mypy-boto3-sqs": { + "hashes": [ + "sha256:61752f1c2bf2efa3815f64d43c25b4a39dbdbd9e472ae48aa18d7c6d2a7a6eb8", + "sha256:9fd6e622ed231c06f7542ba6f8f0eea92046cace24defa95d0d0ce04e7caee0c" + ], + "version": "==1.35.0" + }, "mypy-extensions": { "hashes": [ "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", @@ -1358,6 +1426,7 @@ "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12", "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4" ], + "markers": "python_version >= '3.8'", "version": "==6.0.2" }, "referencing": { @@ -1757,6 +1826,31 @@ "markers": "python_version >= '3.7'", "version": "==0.12.5" }, + "types-awscrt": { + "hashes": [ + "sha256:0839fe12f0f914d8f7d63ed777c728cb4eccc2d5d79a26e377d12b0604e7bf0e", + "sha256:84a9f4f422ec525c314fdf54c23a1e73edfbcec968560943ca2d41cfae623b38" + ], + "markers": "python_version >= '3.7' and python_version < '4.0'", + "version": "==0.21.2" + }, + "types-requests": { + "hashes": [ + "sha256:90c079ff05e549f6bf50e02e910210b98b8ff1ebdd18e19c873cd237737c1358", + "sha256:f754283e152c752e46e70942fa2a146b5bc70393522257bb85bd1ef7e019dcc3" + ], + "index": "pypi", + "markers": "python_version >= '3.8'", + "version": "==2.32.0.20240712" + }, + "types-s3transfer": { + "hashes": [ + "sha256:60167a3bfb5c536ec6cdb5818f7f9a28edca9dc3e0b5ff85ae374526fc5e576e", + "sha256:7a3fec8cd632e2b5efb665a355ef93c2a87fdd5a45b74a949f95a9e628a86356" + ], + "markers": "python_version >= '3.8'", + "version": "==0.10.2" + }, "typing-extensions": { "hashes": [ "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", diff --git a/pyproject.toml b/pyproject.toml index e62aba6..53e63ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,8 @@ line-length = 90 [tool.mypy] disallow_untyped_calls = true disallow_untyped_defs = true -exclude = ["tests/"] +exclude = ["tests/", "output/"] +ignore_missing_imports = true [tool.pytest.ini_options] log_level = "INFO" diff --git a/submitter/cli.py b/submitter/cli.py index 595081b..0f014f8 100644 --- a/submitter/cli.py +++ b/submitter/cli.py @@ -15,7 +15,7 @@ @click.group() -def main(): +def main() -> None: pass @@ -24,7 +24,7 @@ def main(): "--queue", default=CONFIG.INPUT_QUEUE, help="Name of queue to process messages from" ) @click.option("--wait", default=20, help="seconds to wait for long polling. max 20") -def start(queue, wait): +def start(queue: str, wait: int) -> None: logger.info("Starting processing messages from queue %s", queue) message_loop(queue, wait) logger.info("Completed processing messages from queue %s", queue) @@ -50,7 +50,7 @@ def start(queue, wait): default="tests/fixtures/completely-fake-data.json", help="Path to json file of sample messages to load", ) -def load_sample_input_data(input_queue, output_queue, filepath): +def load_sample_input_data(input_queue: str, output_queue: str, filepath: str) -> None: logger.info(f"Loading sample data from file '{filepath}' into queue {input_queue}") count = 0 messages = generate_submission_messages_from_file(filepath, output_queue) @@ -73,7 +73,7 @@ def load_sample_input_data(input_queue, output_queue, filepath): default="tests/fixtures/completely-fake-data.json", help="Path to json file of sample messages to load", ) -def load_sample_output_data(output_queue, filepath): +def load_sample_output_data(output_queue: str, filepath: str) -> None: logger.info(f"Loading sample data from file '{filepath}' into queue {output_queue}") count = 0 messages = generate_result_messages_from_file(filepath, output_queue) @@ -85,14 +85,14 @@ def load_sample_output_data(output_queue, filepath): @main.command() @click.argument("name") -def create_queue(name): +def create_queue(name: str) -> None: """Create queue with NAME supplied as argument""" queue = create(name) logger.info(queue.url) @main.command() -def verify_dspace_connection(): +def verify_dspace_connection() -> None: client = DSpaceClient(CONFIG.DSPACE_API_URL, timeout=CONFIG.DSPACE_TIMEOUT) try: client.login(CONFIG.DSPACE_USER, CONFIG.DSPACE_PASSWORD) diff --git a/submitter/config.py b/submitter/config.py index f6e6edd..c2f18ea 100644 --- a/submitter/config.py +++ b/submitter/config.py @@ -5,7 +5,7 @@ class Config: - def __init__(self): + def __init__(self) -> None: try: self.ENV = os.environ["WORKSPACE"] except KeyError as e: @@ -15,7 +15,21 @@ def __init__(self): print(f"Configuring dspace-submission-service for env={self.ENV}") self.load_config_variables(self.ENV) - def load_config_variables(self, env: str): + def load_config_variables(self, env: str) -> None: + # default to using env vars with defaults + self.DSPACE_API_URL = os.getenv("DSPACE_API_URL") + self.DSPACE_USER = os.getenv("DSPACE_USER") + self.DSPACE_PASSWORD = os.getenv("DSPACE_PASSWORD") + self.DSPACE_TIMEOUT = float(os.getenv("DSPACE_TIMEOUT", "120.0")) + self.INPUT_QUEUE = os.getenv("INPUT_QUEUE") + self.LOG_FILTER = os.getenv("LOG_FILTER", "true").lower() + self.LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() + self.SENTRY_DSN = os.getenv("SENTRY_DSN") + self.SKIP_PROCESSING = os.getenv("SKIP_PROCESSING", "false").lower() + self.SQS_ENDPOINT_URL = os.getenv("SQS_ENDPOINT_URL") + self.OUTPUT_QUEUES = os.getenv("OUTPUT_QUEUES", "output").split(",") + + # if testing environment, override if env == "test": self.DSPACE_API_URL = "mock://dspace.edu/rest/" self.DSPACE_USER = "test" @@ -28,15 +42,3 @@ def load_config_variables(self, env: str): self.SKIP_PROCESSING = "false" self.SQS_ENDPOINT_URL = "https://sqs.us-east-1.amazonaws.com/" self.OUTPUT_QUEUES = ["empty_result_queue"] - else: - self.DSPACE_API_URL = os.getenv("DSPACE_API_URL") - self.DSPACE_USER = os.getenv("DSPACE_USER") - self.DSPACE_PASSWORD = os.getenv("DSPACE_PASSWORD") - self.DSPACE_TIMEOUT = float(os.getenv("DSPACE_TIMEOUT", "120.0")) - self.INPUT_QUEUE = os.getenv("INPUT_QUEUE") - self.LOG_FILTER = os.getenv("LOG_FILTER", "true").lower() - self.LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() - self.SENTRY_DSN = os.getenv("SENTRY_DSN") - self.SKIP_PROCESSING = os.getenv("SKIP_PROCESSING", "false").lower() - self.SQS_ENDPOINT_URL = os.getenv("SQS_ENDPOINT_URL") - self.OUTPUT_QUEUES = os.getenv("OUTPUT_QUEUES", "output").split(",") diff --git a/submitter/errors.py b/submitter/errors.py index 259ba51..a76b903 100644 --- a/submitter/errors.py +++ b/submitter/errors.py @@ -57,7 +57,7 @@ class BitstreamAddError(Exception): message (str): Explanation of the error """ - def __init__(self): + def __init__(self) -> None: self.message = ( "Error occurred while parsing bitstream information from files listed in " "submission message." diff --git a/submitter/message.py b/submitter/message.py index 3052ac7..26a4c80 100644 --- a/submitter/message.py +++ b/submitter/message.py @@ -1,7 +1,10 @@ +from collections.abc import Iterator import json -def generate_submission_messages_from_file(filepath, output_queue): +def generate_submission_messages_from_file( + filepath: str, output_queue: str +) -> Iterator[tuple[dict, dict]]: with open(filepath) as file: messages = json.load(file) @@ -11,7 +14,7 @@ def generate_submission_messages_from_file(filepath, output_queue): yield attributes, body -def attributes_from_json(message_json, output_queue): +def attributes_from_json(message_json: dict, output_queue: str) -> dict: attributes = { "PackageID": { "DataType": "String", @@ -29,7 +32,7 @@ def attributes_from_json(message_json, output_queue): return attributes -def body_from_json(message_json): +def body_from_json(message_json: dict) -> dict: body = { "SubmissionSystem": message_json["target system"], "CollectionHandle": message_json["collection handle"], @@ -46,7 +49,9 @@ def body_from_json(message_json): return body -def generate_result_messages_from_file(filepath, output_queue): +def generate_result_messages_from_file( + filepath: str, _output_queue: str +) -> Iterator[tuple[dict, dict]]: with open(filepath) as file: messages = json.load(file) @@ -56,7 +61,7 @@ def generate_result_messages_from_file(filepath, output_queue): yield attributes, body -def result_attributes_from_json(message_json): +def result_attributes_from_json(message_json: dict) -> dict: attributes = { "PackageID": { "DataType": "String", @@ -70,7 +75,7 @@ def result_attributes_from_json(message_json): return attributes -def result_body_from_json(message_json): +def result_body_from_json(message_json: dict) -> dict: body = { "ResultType": message_json["result"], "ItemHandle": message_json["handle"], diff --git a/submitter/sqs.py b/submitter/sqs.py index 02affbd..5580d7e 100644 --- a/submitter/sqs.py +++ b/submitter/sqs.py @@ -1,6 +1,7 @@ import hashlib import json import logging +from typing import TYPE_CHECKING import boto3 from dspace.client import DSpaceClient @@ -8,10 +9,14 @@ from submitter import CONFIG, errors from submitter.submission import Submission +if TYPE_CHECKING: + from mypy_boto3_sqs.service_resource import Message, Queue, SQSServiceResource + from mypy_boto3_sqs.type_defs import SendMessageResultTypeDef + logger = logging.getLogger(__name__) -def sqs_client(): +def sqs_client() -> "SQSServiceResource": sqs = boto3.resource( service_name="sqs", region_name=CONFIG.AWS_REGION_NAME, @@ -21,7 +26,7 @@ def sqs_client(): return sqs -def message_loop(queue, wait, visibility=30): +def message_loop(queue: str, wait: int, visibility: int = 30) -> None: logger.info("Message loop started") msgs = retrieve_messages_from_queue(queue, wait, visibility) @@ -32,7 +37,7 @@ def message_loop(queue, wait, visibility=30): logger.info("No messages available in queue %s", queue) -def process(msgs): +def process(msgs: list["Message"]) -> None: if CONFIG.SKIP_PROCESSING != "true": client = DSpaceClient(CONFIG.DSPACE_API_URL, timeout=CONFIG.DSPACE_TIMEOUT) client.login(CONFIG.DSPACE_USER, CONFIG.DSPACE_PASSWORD) @@ -70,7 +75,11 @@ def process(msgs): logger.info("Deleted message '%s' from input queue", message_id) -def retrieve_messages_from_queue(input_queue, wait, visibility=30): +def retrieve_messages_from_queue( + input_queue: str, + wait: int, + visibility: int = 30, +) -> list["Message"]: sqs = sqs_client() queue = sqs.get_queue_by_name(QueueName=input_queue) @@ -87,7 +96,9 @@ def retrieve_messages_from_queue(input_queue, wait, visibility=30): return msgs -def write_message_to_queue(attributes: dict, body: dict, output_queue: str): +def write_message_to_queue( + attributes: dict, body: dict, output_queue: str +) -> "SendMessageResultTypeDef": sqs = sqs_client() queue = sqs.get_queue_by_name(QueueName=output_queue) response = queue.send_message( @@ -97,7 +108,7 @@ def write_message_to_queue(attributes: dict, body: dict, output_queue: str): return response -def create(name): +def create(name: str) -> "Queue": sqs = sqs_client() queue = sqs.create_queue(QueueName=name) return queue diff --git a/submitter/submission.py b/submitter/submission.py index 1608707..ef519b9 100644 --- a/submitter/submission.py +++ b/submitter/submission.py @@ -1,29 +1,36 @@ +from collections.abc import Iterator import json import logging import sys import traceback from datetime import datetime +from typing import TYPE_CHECKING import dspace +from dspace.client import DSpaceClient import requests import smart_open from submitter import CONFIG, errors +if TYPE_CHECKING: + from mypy_boto3_sqs.service_resource import Message + from mypy_boto3_sqs.type_defs import MessageAttributeValueExtraOutputTypeDef + logger = logging.getLogger(__name__) class Submission: def __init__( self, - attributes, - result_queue, - result_message=None, - destination=None, - collection_handle=None, - metadata_location=None, - files=None, - ): + attributes: dict, + result_queue: str, + result_message: dict | None = None, + destination: str | None = None, + collection_handle: str | None = None, + metadata_location: str | None = None, + files: list[dict] | None = None, + ) -> None: self.destination = destination self.collection_handle = collection_handle self.metadata_location = metadata_location @@ -33,7 +40,7 @@ def __init__( self.result_queue = result_queue @classmethod - def from_message(cls, message): + def from_message(cls, message: "Message") -> "Submission": """ Create a submission with all necessary publishing data from a submit message. @@ -44,11 +51,12 @@ def from_message(cls, message): SubmitMessageInvalidResultQueueError SubmitMessageMissingAttributeError """ - result_message = None - result_queue = message.message_attributes.get("OutputQueue", {}).get( - "StringValue", None + output_queue: MessageAttributeValueExtraOutputTypeDef | dict = ( + message.message_attributes.get("OutputQueue", {}) ) - if result_queue not in CONFIG.OUTPUT_QUEUES: + result_queue = output_queue.get("StringValue") + + if not result_queue or result_queue not in CONFIG.OUTPUT_QUEUES: raise errors.SubmitMessageInvalidResultQueueError( message.message_id, result_queue ) @@ -87,7 +95,7 @@ def from_message(cls, message): result_queue=result_queue, ) - def create_item(self): + def create_item(self) -> dspace.item.Item: """Create item instance with metadata entries from submission message.""" try: logger.debug("Creating local item instance from submission message") @@ -99,13 +107,13 @@ def create_item(self): except KeyError as e: raise errors.ItemCreateError(self.metadata_location) from e - def get_metadata_entries_from_file(self): + def get_metadata_entries_from_file(self) -> Iterator[dict]: with smart_open.open(self.metadata_location) as f: metadata = json.load(f) for entry in metadata["metadata"]: yield entry - def add_bitstreams_to_item(self, item): + def add_bitstreams_to_item(self, item: dspace.item.Item) -> dspace.item.Item: """Add bitstreams to item from files in submission message.""" try: logger.debug( @@ -122,7 +130,9 @@ def add_bitstreams_to_item(self, item): except KeyError as e: raise errors.BitstreamAddError() from e - def result_error_message(self, message, dspace_response=None): + def result_error_message( + self, message: "Message", dspace_response: str | None = None + ) -> None: time = datetime.now() tb = traceback.format_exception(*sys.exc_info()) self.result_message = { @@ -133,7 +143,7 @@ def result_error_message(self, message, dspace_response=None): "ExceptionTraceback": prettify(tb), } - def result_success_message(self, item): + def result_success_message(self, item: dspace.item.Item) -> None: self.result_message = { "ResultType": "success", "ItemHandle": item.handle, @@ -150,7 +160,7 @@ def result_success_message(self, item): } ) - def submit(self, client): + def submit(self, client: DSpaceClient) -> None: """Submit a submission to DSpace as a new item with associated bitstreams. Creates a local item instance from the submission message, adds bitstream @@ -192,7 +202,11 @@ def submit(self, client): raise e -def post_item(client, item, collection_handle): +def post_item( + client: DSpaceClient, + item: dspace.item.Item, + collection_handle: str, +) -> None: """Post item with metadata to DSpace.""" try: entries = [entry.to_dict() for entry in item.metadata] @@ -208,7 +222,7 @@ def post_item(client, item, collection_handle): raise errors.ItemPostError(e, collection_handle) from e -def post_bitstreams(client, item): +def post_bitstreams(client: DSpaceClient, item: dspace.item.Item) -> None: """Post all bitstreams to an existing DSpace item.""" logger.debug( "Posting %d bitstream(s) to item '%s' in DSpace", @@ -230,7 +244,7 @@ def post_bitstreams(client, item): raise errors.BitstreamPostError(e, bitstream.name, item.handle) from e -def clean_up_partial_success(client, item): +def clean_up_partial_success(client: DSpaceClient, item: dspace.item.Item) -> None: logger.info("Item '%s' was partially posted to DSpace, cleaning up", item.handle) handle = item.handle for bitstream in item.bitstreams: @@ -242,7 +256,7 @@ def clean_up_partial_success(client, item): logger.info("Item '%s' deleted from DSpace", handle) -def prettify(traceback: list): +def prettify(traceback: list) -> list[str]: output = [] for item in traceback: lines = item.strip().split("\n")