diff --git a/dsc/utilities/aws/sqs.py b/dsc/utilities/aws/sqs.py index 7b47187..9c5665e 100644 --- a/dsc/utilities/aws/sqs.py +++ b/dsc/utilities/aws/sqs.py @@ -114,8 +114,7 @@ def process_result_message(self, sqs_message: MessageTypeDef) -> tuple[str, str] Args: sqs_message: An SQS result message to be processed. """ - if not self.validate_message(sqs_message): - raise InvalidSQSMessageError + self.validate_message(sqs_message) identifier = sqs_message["MessageAttributes"]["PackageID"]["StringValue"] message_body = json.loads(str(sqs_message["Body"])) self.delete(sqs_message["ReceiptHandle"]) @@ -160,56 +159,45 @@ def send( logger.debug(f"Response from SQS queue: {response}") return response - def validate_message(self, sqs_message: MessageTypeDef) -> bool: + def validate_message(self, sqs_message: MessageTypeDef) -> None: """Validate that an SQS message is formatted as expected. Args: sqs_message: An SQS message to be evaluated. """ - valid = False if not sqs_message.get("ReceiptHandle"): - logger.exception( - f"Failed to retrieve 'ReceiptHandle' from message: {sqs_message}" - ) - elif self.validate_message_attributes( - sqs_message=sqs_message - ) and self.validate_message_body(sqs_message=sqs_message): - valid = True - return valid + message = f"Failed to retrieve 'ReceiptHandle' from message: {sqs_message}" + raise InvalidSQSMessageError(message) + self.validate_message_attributes(sqs_message=sqs_message) + self.validate_message_body(sqs_message=sqs_message) @staticmethod - def validate_message_attributes(sqs_message: MessageTypeDef) -> bool: + def validate_message_attributes(sqs_message: MessageTypeDef) -> None: """Validate that "MessageAttributes" field is formatted as expected. Args: sqs_message: An SQS message to be evaluated. """ - valid = False if ( - "MessageAttributes" in sqs_message - and any( + "MessageAttributes" not in sqs_message + or not any( field for field in sqs_message["MessageAttributes"] if "PackageID" in field ) - and sqs_message["MessageAttributes"]["PackageID"].get("StringValue") + or not sqs_message["MessageAttributes"]["PackageID"].get("StringValue") ): - valid = True - else: - logger.exception(f"Failed to parse SQS message attributes: {sqs_message}") - return valid + message = f"Failed to parse SQS message attributes: {sqs_message}" + raise InvalidSQSMessageError(message) @staticmethod - def validate_message_body(sqs_message: MessageTypeDef) -> bool: + def validate_message_body(sqs_message: MessageTypeDef) -> None: """Validate that "Body" field is formatted as expected. Args: sqs_message: An SQS message to be evaluated. """ - valid = False - if "Body" in sqs_message and json.loads(str(sqs_message["Body"])): - valid = True - else: - logger.exception(f"Failed to parse SQS message body: {sqs_message}") - return valid + if "Body" not in sqs_message or not json.loads(str(sqs_message["Body"])): + message = f"Failed to parse SQS message body: {sqs_message}" + raise InvalidSQSMessageError(message) diff --git a/tests/test_sqs.py b/tests/test_sqs.py index 906907b..f826b7b 100644 --- a/tests/test_sqs.py +++ b/tests/test_sqs.py @@ -133,31 +133,34 @@ def test_sqs_send_success( assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.OK -def test_sqs_validate_message_no_receipthandle_false( +def test_sqs_validate_message_no_receipthandle_invalid( mocked_sqs_input, sqs_client, result_message_valid ): - assert not sqs_client.validate_message(sqs_message={}) + with pytest.raises(InvalidSQSMessageError): + sqs_client.validate_message(sqs_message={}) -def test_sqs_validate_message_true(mocked_sqs_input, sqs_client, result_message_valid): - assert sqs_client.validate_message(sqs_message=result_message_valid) +def test_sqs_validate_message_valid(mocked_sqs_input, sqs_client, result_message_valid): + assert not sqs_client.validate_message(sqs_message=result_message_valid) -def test_sqs_validate_message_attributes_false(mocked_sqs_input, sqs_client): - assert not sqs_client.validate_message_attributes(sqs_message={}) +def test_sqs_validate_message_attributes_invalid(mocked_sqs_input, sqs_client): + with pytest.raises(InvalidSQSMessageError): + sqs_client.validate_message_attributes(sqs_message={}) -def test_sqs_validate_message_attributes_true( +def test_sqs_validate_message_attributes_valid( mocked_sqs_input, sqs_client, result_message_valid ): - assert sqs_client.validate_message_attributes(sqs_message=result_message_valid) + assert not sqs_client.validate_message_attributes(sqs_message=result_message_valid) -def test_sqs_validate_message_body_false(caplog, mocked_sqs_input, sqs_client): - assert not sqs_client.validate_message_body(sqs_message={None}) +def test_sqs_validate_message_body_invalid(caplog, mocked_sqs_input, sqs_client): + with pytest.raises(InvalidSQSMessageError): + sqs_client.validate_message_body(sqs_message={}) -def test_sqs_validate_message_body_true( +def test_sqs_validate_message_body_valid( mocked_sqs_input, sqs_client, result_message_valid ): - assert sqs_client.validate_message_body(sqs_message=result_message_valid) + assert not sqs_client.validate_message_body(sqs_message=result_message_valid)