From 5271e166e0bb09ae30da626aadd8a1790f24793a Mon Sep 17 00:00:00 2001 From: Rio Knightley <128376976+RioKnightleyNHS@users.noreply.github.com> Date: Thu, 31 Oct 2024 13:06:11 +0000 Subject: [PATCH] PRMP 1036 - Fix PDF Intermittence --- .../helpers/requests/getLloydGeorgeRecord.ts | 10 +- lambdas/enums/lambda_error.py | 19 +- lambdas/handlers/edge_presign_handler.py | 48 ++--- lambdas/services/edge_presign_service.py | 61 +++++- lambdas/tests/unit/conftest.py | 5 +- .../unit/enums/test_edge_presign_values.py | 80 ++++---- .../handlers/test_edge_presign_handler.py | 149 ++++++++------ .../services/test_edge_presign_service.py | 182 +++++++++++------- .../decorators/handle_edge_exceptions.py | 26 +-- .../utils/decorators/validate_s3_request.py | 69 +++++++ 10 files changed, 413 insertions(+), 236 deletions(-) create mode 100644 lambdas/utils/decorators/validate_s3_request.py diff --git a/app/src/helpers/requests/getLloydGeorgeRecord.ts b/app/src/helpers/requests/getLloydGeorgeRecord.ts index f422efe96..0fb529ce0 100644 --- a/app/src/helpers/requests/getLloydGeorgeRecord.ts +++ b/app/src/helpers/requests/getLloydGeorgeRecord.ts @@ -86,12 +86,8 @@ export const pollForPresignedUrl = async ({ if (data.jobStatus === JOB_STATUS.COMPLETED && !data.presignedUrl.startsWith('https://')) { return Promise.reject({ response: { status: 500 } }); } - - return { - ...data, - presignedUrl: `${data.presignedUrl}&origin=${ - typeof window !== 'undefined' ? window.location.href : '' - }`, - }; + const result: LloydGeorgeStitchResult = data; + return result; }; + export default getLloydGeorgeRecord; diff --git a/lambdas/enums/lambda_error.py b/lambdas/enums/lambda_error.py index 7b219c10a..2f5e28b39 100644 --- a/lambdas/enums/lambda_error.py +++ b/lambdas/enums/lambda_error.py @@ -396,12 +396,29 @@ def to_str(self) -> str: """ EdgeMalformed = { "err_code": "CE_5001", - "message": "Malformed event structure or missing data", + "message": "Malformed cloudfront request", } + EdgeNoOrigin = { "err_code": "CE_5002", "message": "The request is missing an origin", } + + EdgeNoQuery = { + "err_code": "CE_5003", + "message": "The request is missing a querystring", + } + + EdgeRequiredQuery = { + "err_code": "CE_5004", + "message": "Missing required querystring values", + } + + EdgeRequiredHeaders = { + "err_code": "CE_5005", + "message": "Malformed header structure or missing data", + } + EdgeNoClient = {"err_code": "CE_4001", "message": "Document not found"} """ diff --git a/lambdas/handlers/edge_presign_handler.py b/lambdas/handlers/edge_presign_handler.py index c7bb7dfa2..3863af65b 100644 --- a/lambdas/handlers/edge_presign_handler.py +++ b/lambdas/handlers/edge_presign_handler.py @@ -1,14 +1,10 @@ -import hashlib -import json import logging -from urllib.parse import parse_qs -from enums.lambda_error import LambdaError from services.edge_presign_service import EdgePresignService from utils.decorators.handle_edge_exceptions import handle_edge_exceptions from utils.decorators.override_error_check import override_error_check from utils.decorators.set_audit_arg import set_request_context_for_logging -from utils.lambda_exceptions import CloudFrontEdgeException +from utils.decorators.validate_s3_request import validate_s3_request logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -17,40 +13,18 @@ @set_request_context_for_logging @override_error_check @handle_edge_exceptions +@validate_s3_request def lambda_handler(event, context): - try: - request: dict = event["Records"][0]["cf"]["request"] - logger.info("CloudFront received S3 request", {"Result": {json.dumps(request)}}) - uri: str = request.get("uri", "") - presign_query_string: str = request.get("querystring", "") - - except (KeyError, IndexError) as e: - logger.error( - f"{str(e)}", - {"Result": {LambdaError.EdgeMalformed.to_str()}}, - ) - raise CloudFrontEdgeException(500, LambdaError.EdgeMalformed) - - s3_presign_credentials = parse_qs(presign_query_string) - origin_url = s3_presign_credentials.get("origin", [""])[0] - if not origin_url: - logger.error( - "No Origin", - {"Result": {LambdaError.EdgeNoOrigin.to_str()}}, - ) - raise CloudFrontEdgeException(500, LambdaError.EdgeNoOrigin) - - presign_string = f"{uri}?{presign_query_string}" - encoded_presign_string: str = presign_string.encode("utf-8") - presign_credentials_hash = hashlib.md5(encoded_presign_string).hexdigest() + request: dict = event["Records"][0]["cf"]["request"] + logger.info("Edge received S3 request") edge_presign_service = EdgePresignService() - edge_presign_service.attempt_url_update( - uri_hash=presign_credentials_hash, origin_url=origin_url - ) + request_values: dict = edge_presign_service.filter_request_values(request) + edge_presign_service.use_presign(request_values) - headers: dict = request.get("headers", {}) - if "authorization" in headers: - del headers["authorization"] + forwarded_request: dict = edge_presign_service.update_s3_headers( + request, request_values + ) - return request + logger.info("Edge forwarding S3 request") + return forwarded_request diff --git a/lambdas/services/edge_presign_service.py b/lambdas/services/edge_presign_service.py index f689f374a..cab057619 100644 --- a/lambdas/services/edge_presign_service.py +++ b/lambdas/services/edge_presign_service.py @@ -1,3 +1,4 @@ +import hashlib import re from botocore.exceptions import ClientError @@ -12,23 +13,37 @@ class EdgePresignService: - def __init__(self): self.dynamo_service = DynamoDBService() self.s3_service = S3Service() self.ssm_service = SSMService() self.table_name_ssm_param = "EDGE_REFERENCE_TABLE" - def attempt_url_update(self, uri_hash, origin_url) -> None: + def use_presign(self, request_values: dict): + uri: str = request_values["uri"] + querystring: str = request_values["querystring"] + domain_name: str = request_values["domain_name"] + + presign_string: str = f"{uri}?{querystring}" + encoded_presign_string: str = presign_string.encode("utf-8") + presign_credentials_hash: str = hashlib.md5(encoded_presign_string).hexdigest() + + self.attempt_presign_ingestion( + uri_hash=presign_credentials_hash, + domain_name=domain_name, + ) + + def attempt_presign_ingestion(self, uri_hash: str, domain_name: str) -> None: try: - environment = self.extract_environment_from_url(origin_url) + environment = self.filter_domain_for_env(domain_name) + logger.info(f"Environment found: {environment}") base_table_name: str = self.ssm_service.get_ssm_parameter( self.table_name_ssm_param ) formatted_table_name: str = self.extend_table_name( base_table_name, environment ) - + logger.info(f"Table: {formatted_table_name}") self.dynamo_service.update_item( table_name=formatted_table_name, key=uri_hash, @@ -40,13 +55,43 @@ def attempt_url_update(self, uri_hash, origin_url) -> None: logger.error(f"{str(e)}", {"Result": LambdaError.EdgeNoClient.to_str()}) raise CloudFrontEdgeException(400, LambdaError.EdgeNoClient) - def extract_environment_from_url(self, url: str) -> str: - match = re.search(r"https://([^.]+)\.[^.]+\.[^.]+\.[^.]+", url) + @staticmethod + def update_s3_headers(request: dict, request_values: dict): + domain_name = request_values["domain_name"] + if "authorization" in request["headers"]: + del request["headers"]["authorization"] + request["headers"]["host"] = [{"key": "Host", "value": domain_name}] + + return request + + @staticmethod + def filter_request_values(request: dict) -> dict: + try: + uri: str = request["uri"] + querystring: str = request["querystring"] + headers: dict = request["headers"] + origin: str = request.get("origin", {}) + domain_name: str = origin["s3"]["domainName"] + except KeyError as e: + logger.error(f"Missing request component: {str(e)}") + raise CloudFrontEdgeException(500, LambdaError.EdgeNoOrigin) + + return { + "uri": uri, + "querystring": querystring, + "headers": headers, + "domain_name": domain_name, + } + + @staticmethod + def filter_domain_for_env(domain_name: str) -> str: + match = re.match(r"^[^-]+(?:-[^-]+)?(?=-lloyd)", domain_name) if match: - return match.group(1) + return match.group(0) return "" - def extend_table_name(self, base_table_name, environment) -> str: + @staticmethod + def extend_table_name(base_table_name: str, environment: str) -> str: if environment: return f"{environment}_{base_table_name}" return base_table_name diff --git a/lambdas/tests/unit/conftest.py b/lambdas/tests/unit/conftest.py index 4dc7be607..ce4b25927 100644 --- a/lambdas/tests/unit/conftest.py +++ b/lambdas/tests/unit/conftest.py @@ -14,10 +14,11 @@ REGION_NAME = "eu-west-2" -MOCK_CLOUDFRONT_URL = "test-cloudfront-url.com" MOCK_TABLE_NAME = "test-table" MOCK_BUCKET = "test-s3-bucket" - +MOCK_CLOUDFRONT_URL = "test-cloudfront-url.com" +MOCKED_LG_BUCKET_ENV = "test" +MOCKED_LG_BUCKET_URL = f"{MOCKED_LG_BUCKET_ENV}-lloyd-test-test.com" MOCK_ARF_TABLE_NAME_ENV_NAME = "DOCUMENT_STORE_DYNAMODB_NAME" MOCK_ARF_BUCKET_ENV_NAME = "DOCUMENT_STORE_BUCKET_NAME" diff --git a/lambdas/tests/unit/enums/test_edge_presign_values.py b/lambdas/tests/unit/enums/test_edge_presign_values.py index 0fd2cb914..f32f73b41 100644 --- a/lambdas/tests/unit/enums/test_edge_presign_values.py +++ b/lambdas/tests/unit/enums/test_edge_presign_values.py @@ -1,58 +1,58 @@ -# test_enums.py - from enums.lambda_error import LambdaError +from tests.unit.conftest import MOCKED_LG_BUCKET_URL -ENV = "test" +MOCKED_AUTH_QUERY = ( + "X-Amz-Algorithm=algo&X-Amz-Credential=cred&X-Amz-Date=date" + "&X-Amz-Expires=3600&X-Amz-SignedHeaders=signed" + "&X-Amz-Signature=sig&X-Amz-Security-Token=token" +) +MOCKED_PARTIAL_QUERY = ( + "X-Amz-Algorithm=algo&X-Amz-Credential=cred&X-Amz-Date=date" "&X-Amz-Expires=3600" +) -TABLE_NAME = "CloudFrontEdgeReference" +MOCKED_HEADERS = { + "cloudfront-viewer-country": [{"key": "CloudFront-Viewer-Country", "value": "US"}], + "x-forwarded-for": [{"key": "X-Forwarded-For", "value": "1.2.3.4"}], + "host": [{"key": "Host", "value": MOCKED_LG_BUCKET_URL}], +} -NHS_DOMAIN = "example.gov.uk" +EXPECTED_EDGE_NO_QUERY_MESSAGE = LambdaError.EdgeNoQuery.value["message"] +EXPECTED_EDGE_NO_QUERY_ERROR_CODE = LambdaError.EdgeNoQuery.value["err_code"] +EXPECTED_EDGE_MALFORMED_QUERY_MESSAGE = LambdaError.EdgeRequiredQuery.value["message"] +EXPECTED_EDGE_MALFORMED_QUERY_ERROR_CODE = LambdaError.EdgeRequiredQuery.value[ + "err_code" +] +EXPECTED_EDGE_MALFORMED_HEADER_MESSAGE = LambdaError.EdgeRequiredHeaders.value[ + "message" +] +EXPECTED_EDGE_MALFORMED_HEADER_ERROR_CODE = LambdaError.EdgeRequiredHeaders.value[ + "err_code" +] +EXPECTED_EDGE_NO_ORIGIN_ERROR_MESSAGE = LambdaError.EdgeNoOrigin.value["message"] +EXPECTED_EDGE_NO_ORIGIN_ERROR_CODE = LambdaError.EdgeNoOrigin.value["err_code"] EXPECTED_EDGE_NO_CLIENT_ERROR_MESSAGE = LambdaError.EdgeNoClient.value["message"] - EXPECTED_EDGE_NO_CLIENT_ERROR_CODE = LambdaError.EdgeNoClient.value["err_code"] +EXPECTED_EDGE_MALFORMED_ERROR_MESSAGE = LambdaError.EdgeMalformed.value["message"] +EXPECTED_EDGE_MALFORMED_ERROR_CODE = LambdaError.EdgeMalformed.value["err_code"] -EXPECTED_DYNAMO_DB_CONDITION_EXPRESSION = ( - "attribute_not_exists(IsRequested) OR IsRequested = :false" -) -EXPECTED_DYNAMO_DB_EXPRESSION_ATTRIBUTE_VALUES = {":false": False} - -EXPECTED_SSM_PARAMETER_KEY = "EDGE_REFERENCE_TABLE" - -EXPECTED_SUCCESS_RESPONSE = None -VALID_EVENT_MODEL = { +MOCK_S3_EDGE_EVENT = { "Records": [ { "cf": { "request": { - "headers": { - "authorization": [ - {"key": "Authorization", "value": "Bearer token"} - ], - "host": [{"key": "Host", "value": NHS_DOMAIN}], - }, - "querystring": f"origin=https://test.{NHS_DOMAIN}&other=param", + "headers": MOCKED_HEADERS, + "querystring": MOCKED_AUTH_QUERY, "uri": "/some/path", - } - } - } - ] -} - -MISSING_ORIGIN_EVENT_MODEL = { - "Records": [ - { - "cf": { - "request": { - "headers": { - "authorization": [ - {"key": "Authorization", "value": "Bearer token"} - ], - "host": [{"key": "Host", "value": NHS_DOMAIN}], + "origin": { + "s3": { + "authMethod": "none", + "customHeaders": {}, + "domainName": MOCKED_LG_BUCKET_URL, + "path": "", + } }, - "querystring": "other=param", - "uri": "/some/path", } } } diff --git a/lambdas/tests/unit/handlers/test_edge_presign_handler.py b/lambdas/tests/unit/handlers/test_edge_presign_handler.py index 9a07bc911..26743c69d 100644 --- a/lambdas/tests/unit/handlers/test_edge_presign_handler.py +++ b/lambdas/tests/unit/handlers/test_edge_presign_handler.py @@ -1,96 +1,123 @@ -from unittest.mock import Mock +import copy +import json import pytest -from botocore.exceptions import ClientError -from services.edge_presign_service import EdgePresignService +from handlers.edge_presign_handler import lambda_handler +from tests.unit.conftest import MOCK_TABLE_NAME, MOCKED_LG_BUCKET_URL from tests.unit.enums.test_edge_presign_values import ( - ENV, - EXPECTED_DYNAMO_DB_CONDITION_EXPRESSION, - EXPECTED_DYNAMO_DB_EXPRESSION_ATTRIBUTE_VALUES, - EXPECTED_EDGE_NO_CLIENT_ERROR_CODE, - EXPECTED_EDGE_NO_CLIENT_ERROR_MESSAGE, - EXPECTED_SSM_PARAMETER_KEY, - MISSING_ORIGIN_EVENT_MODEL, - NHS_DOMAIN, - TABLE_NAME, - VALID_EVENT_MODEL, + EXPECTED_EDGE_MALFORMED_HEADER_ERROR_CODE, + EXPECTED_EDGE_MALFORMED_HEADER_MESSAGE, + EXPECTED_EDGE_MALFORMED_QUERY_ERROR_CODE, + EXPECTED_EDGE_MALFORMED_QUERY_MESSAGE, + EXPECTED_EDGE_NO_ORIGIN_ERROR_CODE, + EXPECTED_EDGE_NO_ORIGIN_ERROR_MESSAGE, + EXPECTED_EDGE_NO_QUERY_ERROR_CODE, + EXPECTED_EDGE_NO_QUERY_MESSAGE, + MOCK_S3_EDGE_EVENT, + MOCKED_AUTH_QUERY, + MOCKED_PARTIAL_QUERY, ) -from utils.lambda_exceptions import CloudFrontEdgeException - - -def mock_context(): - context = Mock() - context.aws_request_id = "fake_request_id" - return context @pytest.fixture def valid_event(): - return VALID_EVENT_MODEL - - -@pytest.fixture -def missing_origin_event(): - return MISSING_ORIGIN_EVENT_MODEL + return copy.deepcopy(MOCK_S3_EDGE_EVENT) @pytest.fixture def mock_edge_presign_service(mocker): mock_ssm_service = mocker.patch("services.edge_presign_service.SSMService") mock_ssm_service_instance = mock_ssm_service.return_value - mock_ssm_service_instance.get_ssm_parameter.return_value = TABLE_NAME + mock_ssm_service_instance.get_ssm_parameter.return_value = MOCK_TABLE_NAME mock_dynamo_service = mocker.patch("services.edge_presign_service.DynamoDBService") mock_dynamo_service_instance = mock_dynamo_service.return_value mock_dynamo_service_instance.update_item.return_value = None - return mock_ssm_service_instance, mock_dynamo_service_instance + mock_edge_service = mocker.patch("handlers.edge_presign_handler.EdgePresignService") + mock_edge_service_instance = mock_edge_service.return_value + mock_edge_service_instance.filter_request_values.return_value = { + "uri": "/some/path", + "querystring": MOCKED_AUTH_QUERY, + "headers": {"host": [{"key": "Host", "value": MOCKED_LG_BUCKET_URL}]}, + "domain_name": MOCKED_LG_BUCKET_URL, + } + mock_edge_service_instance.use_presign.return_value = None + mock_edge_service_instance.update_s3_headers.return_value = { + "headers": { + "host": [{"key": "Host", "value": MOCKED_LG_BUCKET_URL}], + } + } + + return mock_edge_service_instance + + +def test_lambda_handler_success(valid_event, context, mock_edge_presign_service): + response = lambda_handler(valid_event, context) + + mock_edge_presign_service.filter_request_values.assert_called_once() + mock_edge_presign_service.use_presign.assert_called_once_with( + mock_edge_presign_service.filter_request_values.return_value + ) + mock_edge_presign_service.update_s3_headers.assert_called_once_with( + valid_event["Records"][0]["cf"]["request"], + mock_edge_presign_service.filter_request_values.return_value, + ) + assert response["headers"]["host"][0]["value"] == MOCKED_LG_BUCKET_URL -def test_attempt_url_update_success(mock_edge_presign_service): - edge_service = EdgePresignService() - uri_hash = "test_uri_hash" - origin_url = f"https://{ENV}.{NHS_DOMAIN}" - edge_service.attempt_url_update(uri_hash, origin_url) +def test_lambda_handler_no_query_params(valid_event, context): + event = copy.deepcopy(valid_event) + event["Records"][0]["cf"]["request"]["querystring"] = "" - mock_ssm_service_instance = mock_edge_presign_service[0] - mock_dynamo_service_instance = mock_edge_presign_service[1] + response = lambda_handler(event, context) - mock_ssm_service_instance.get_ssm_parameter.assert_called_once_with( - EXPECTED_SSM_PARAMETER_KEY - ) + actual_status = response["status"] + actual_response = json.loads(response["body"]) - mock_dynamo_service_instance.update_item.assert_called_once_with( - table_name=f"{ENV}_{TABLE_NAME}", - key=uri_hash, - updated_fields={"IsRequested": True}, - condition_expression=EXPECTED_DYNAMO_DB_CONDITION_EXPRESSION, - expression_attribute_values=EXPECTED_DYNAMO_DB_EXPRESSION_ATTRIBUTE_VALUES, - ) + assert actual_status == 500 + assert actual_response["message"] == EXPECTED_EDGE_NO_QUERY_MESSAGE + assert actual_response["err_code"] == EXPECTED_EDGE_NO_QUERY_ERROR_CODE -def test_attempt_url_update_client_error(mock_edge_presign_service): - edge_service = EdgePresignService() +def test_lambda_handler_missing_query_params(valid_event, context): + event = copy.deepcopy(valid_event) + event["Records"][0]["cf"]["request"]["querystring"] = MOCKED_PARTIAL_QUERY - edge_service.dynamo_service.update_item.side_effect = ClientError( - error_response={"Error": {"Code": "ConditionalCheckFailedException"}}, - operation_name="UpdateItem", - ) + response = lambda_handler(event, context) + + actual_status = response["status"] + actual_response = json.loads(response["body"]) + + assert actual_status == 500 + assert actual_response["message"] == EXPECTED_EDGE_MALFORMED_QUERY_MESSAGE + assert actual_response["err_code"] == EXPECTED_EDGE_MALFORMED_QUERY_ERROR_CODE + + +def test_lambda_handler_missing_headers(valid_event, context): + event = copy.deepcopy(valid_event) + event["Records"][0]["cf"]["request"]["headers"] = {} + + response = lambda_handler(event, context) - with pytest.raises(CloudFrontEdgeException) as exc_info: - edge_service.attempt_url_update("test_uri_hash", f"https://{ENV}.{NHS_DOMAIN}") + actual_status = response["status"] + actual_response = json.loads(response["body"]) - assert exc_info.value.status_code == 400 - assert exc_info.value.message == EXPECTED_EDGE_NO_CLIENT_ERROR_MESSAGE - assert exc_info.value.err_code == EXPECTED_EDGE_NO_CLIENT_ERROR_CODE + assert actual_status == 500 + assert actual_response["message"] == EXPECTED_EDGE_MALFORMED_HEADER_MESSAGE + assert actual_response["err_code"] == EXPECTED_EDGE_MALFORMED_HEADER_ERROR_CODE -def test_attempt_url_update_invalid_origin(mock_edge_presign_service): - edge_service = EdgePresignService() +def test_lambda_handler_missing_origin(valid_event, context): + event = copy.deepcopy(valid_event) + event["Records"][0]["cf"]["request"]["origin"] = {} - result = edge_service.extract_environment_from_url("invalid_url") + response = lambda_handler(event, context) - expected_empty_result = "" + actual_status = response["status"] + actual_response = json.loads(response["body"]) - assert result == expected_empty_result + assert actual_status == 500 + assert actual_response["message"] == EXPECTED_EDGE_NO_ORIGIN_ERROR_MESSAGE + assert actual_response["err_code"] == EXPECTED_EDGE_NO_ORIGIN_ERROR_CODE diff --git a/lambdas/tests/unit/services/test_edge_presign_service.py b/lambdas/tests/unit/services/test_edge_presign_service.py index 56a09ddd4..e7c226d7f 100644 --- a/lambdas/tests/unit/services/test_edge_presign_service.py +++ b/lambdas/tests/unit/services/test_edge_presign_service.py @@ -1,99 +1,147 @@ +import hashlib +from unittest.mock import patch + import pytest from botocore.exceptions import ClientError from services.edge_presign_service import EdgePresignService +from tests.unit.conftest import ( + MOCK_TABLE_NAME, + MOCKED_LG_BUCKET_ENV, + MOCKED_LG_BUCKET_URL, +) from tests.unit.enums.test_edge_presign_values import ( - ENV, - EXPECTED_DYNAMO_DB_CONDITION_EXPRESSION, - EXPECTED_DYNAMO_DB_EXPRESSION_ATTRIBUTE_VALUES, EXPECTED_EDGE_NO_CLIENT_ERROR_CODE, EXPECTED_EDGE_NO_CLIENT_ERROR_MESSAGE, - EXPECTED_SSM_PARAMETER_KEY, - EXPECTED_SUCCESS_RESPONSE, - NHS_DOMAIN, - TABLE_NAME, + EXPECTED_EDGE_NO_ORIGIN_ERROR_CODE, + EXPECTED_EDGE_NO_ORIGIN_ERROR_MESSAGE, + MOCKED_AUTH_QUERY, ) from utils.lambda_exceptions import CloudFrontEdgeException -edge_presign_service = EdgePresignService() - @pytest.fixture -def mock_dynamo_service(mocker): - return mocker.patch.object(edge_presign_service, "dynamo_service", autospec=True) - - -@pytest.fixture -def mock_ssm_service(mocker): - return mocker.patch.object(edge_presign_service, "ssm_service", autospec=True) +@patch("services.edge_presign_service.SSMService") +@patch("services.edge_presign_service.DynamoDBService") +def edge_presign_service(mock_dynamo_service, mock_ssm_service): + mock_ssm_service.get_ssm_parameter.return_value = MOCK_TABLE_NAME + mock_dynamo_service.update_item.return_value = None + return EdgePresignService() @pytest.fixture -def valid_origin_url(): - return f"https://{ENV}.{NHS_DOMAIN}" - +def request_values(): + return { + "uri": "/some/path", + "querystring": MOCKED_AUTH_QUERY, + "domain_name": MOCKED_LG_BUCKET_URL, + } + + +def test_use_presign(edge_presign_service, request_values): + with patch.object( + edge_presign_service, "attempt_presign_ingestion" + ) as mock_attempt_presign_ingestion: + edge_presign_service.use_presign(request_values) + + expected_hash = hashlib.md5( + f"{request_values['uri']}?{request_values['querystring']}".encode("utf-8") + ).hexdigest() + mock_attempt_presign_ingestion.assert_called_once_with( + uri_hash=expected_hash, domain_name=MOCKED_LG_BUCKET_URL + ) -def test_attempt_url_update_success( - mock_dynamo_service, mock_ssm_service, valid_origin_url -): - mock_dynamo_service.update_item.return_value = None - mock_ssm_service.get_ssm_parameter.return_value = TABLE_NAME - uri_hash = "valid_hash" - response = edge_presign_service.attempt_url_update( - uri_hash=uri_hash, origin_url=valid_origin_url - ) +def test_attempt_presign_ingestion_success(edge_presign_service): + edge_presign_service.attempt_presign_ingestion("hashed_uri", MOCKED_LG_BUCKET_URL) - expected_table_name = f"{ENV}_{TABLE_NAME}" - assert response == EXPECTED_SUCCESS_RESPONSE # Success scenario returns None - mock_ssm_service.get_ssm_parameter.assert_called_once_with( - EXPECTED_SSM_PARAMETER_KEY - ) - mock_dynamo_service.update_item.assert_called_once_with( - table_name=expected_table_name, - key=uri_hash, - updated_fields={"IsRequested": True}, - condition_expression=EXPECTED_DYNAMO_DB_CONDITION_EXPRESSION, - expression_attribute_values=EXPECTED_DYNAMO_DB_EXPRESSION_ATTRIBUTE_VALUES, - ) - -def test_attempt_url_update_client_error( - mock_dynamo_service, mock_ssm_service, valid_origin_url -): - mock_dynamo_service.update_item.side_effect = ClientError( +def test_attempt_presign_ingestion_client_error(edge_presign_service): + client_error = ClientError( {"Error": {"Code": "ConditionalCheckFailedException"}}, "UpdateItem" ) - mock_ssm_service.get_ssm_parameter.return_value = TABLE_NAME - uri_hash = "valid_hash" + edge_presign_service.dynamo_service.update_item.side_effect = client_error with pytest.raises(CloudFrontEdgeException) as exc_info: - edge_presign_service.attempt_url_update( - uri_hash=uri_hash, origin_url=valid_origin_url + edge_presign_service.attempt_presign_ingestion( + "hashed_uri", MOCKED_LG_BUCKET_URL ) assert exc_info.value.status_code == 400 - assert exc_info.value.message == EXPECTED_EDGE_NO_CLIENT_ERROR_MESSAGE assert exc_info.value.err_code == EXPECTED_EDGE_NO_CLIENT_ERROR_CODE + assert exc_info.value.message == EXPECTED_EDGE_NO_CLIENT_ERROR_MESSAGE -def test_extract_environment_from_url(): - url = f"https://{ENV}.{NHS_DOMAIN}/path/to/resource" - expected_environment = ENV - actual_environment = edge_presign_service.extract_environment_from_url(url) - assert actual_environment == expected_environment +def test_update_s3_headers(edge_presign_service, request_values): + request = { + "headers": { + "host": [{"key": "Host", "value": MOCKED_LG_BUCKET_URL}], + } + } + + response = edge_presign_service.update_s3_headers(request, request_values) + assert "authorization" not in response["headers"] + assert response["headers"]["host"][0]["value"] == MOCKED_LG_BUCKET_URL + + +def test_filter_request_values_success(edge_presign_service): + request = { + "uri": "/test/uri", + "querystring": MOCKED_AUTH_QUERY, + "headers": {"host": [{"key": "Host", "value": MOCKED_LG_BUCKET_URL}]}, + "origin": {"s3": {"domainName": MOCKED_LG_BUCKET_URL}}, + } + result = edge_presign_service.filter_request_values(request) + assert result["uri"] == "/test/uri" + assert result["querystring"] == MOCKED_AUTH_QUERY + assert result["domain_name"] == MOCKED_LG_BUCKET_URL + + +def test_filter_request_values_missing_component(edge_presign_service): + request = { + "uri": "/test/uri", + "querystring": MOCKED_AUTH_QUERY, + "headers": {"host": [{"key": "Host", "value": MOCKED_LG_BUCKET_URL}]}, + } + with pytest.raises(CloudFrontEdgeException) as exc_info: + edge_presign_service.filter_request_values(request) + + assert exc_info.value.status_code == 500 + assert exc_info.value.err_code == EXPECTED_EDGE_NO_ORIGIN_ERROR_CODE + assert exc_info.value.message == EXPECTED_EDGE_NO_ORIGIN_ERROR_MESSAGE - url_invalid = f"https://{NHS_DOMAIN}/path/to/resource" - expected_empty_result = "" - actual_empty_result = edge_presign_service.extract_environment_from_url(url_invalid) - assert actual_empty_result == expected_empty_result +def test_filter_domain_for_env(edge_presign_service): + # Environments + assert ( + edge_presign_service.filter_domain_for_env("ndra-lloyd-test-test.com") == "ndra" + ) + assert ( + edge_presign_service.filter_domain_for_env("ndr-test-lloyd-test-test.com") + == "ndr-test" + ) + assert ( + edge_presign_service.filter_domain_for_env("pre-prod-lloyd-test-test.com") + == "pre-prod" + ) + # Production + assert ( + edge_presign_service.filter_domain_for_env("prod-lloyd-test-test.com") == "prod" + ) + assert edge_presign_service.filter_domain_for_env("lloyd-test-test.com") == "" + assert edge_presign_service.filter_domain_for_env("invalid.com") == "" -def test_extend_table_name(): - base_table_name = TABLE_NAME - expected_table_with_env = f"{ENV}_{base_table_name}" - actual_table_with_env = edge_presign_service.extend_table_name(base_table_name, ENV) - assert actual_table_with_env == expected_table_with_env - expected_table_no_env = base_table_name - actual_table_no_env = edge_presign_service.extend_table_name(base_table_name, "") - assert actual_table_no_env == expected_table_no_env +def test_extend_table_name(edge_presign_service): + # Environments + assert ( + edge_presign_service.extend_table_name(MOCK_TABLE_NAME, MOCKED_LG_BUCKET_ENV) + == f"{MOCKED_LG_BUCKET_ENV}_{MOCK_TABLE_NAME}" + ) + # Production + assert ( + edge_presign_service.extend_table_name(MOCK_TABLE_NAME, "") == MOCK_TABLE_NAME + ) + assert ( + edge_presign_service.extend_table_name(MOCK_TABLE_NAME, "prod") + == f"prod_{MOCK_TABLE_NAME}" + ) diff --git a/lambdas/utils/decorators/handle_edge_exceptions.py b/lambdas/utils/decorators/handle_edge_exceptions.py index 468f71512..ceda00b64 100644 --- a/lambdas/utils/decorators/handle_edge_exceptions.py +++ b/lambdas/utils/decorators/handle_edge_exceptions.py @@ -1,35 +1,35 @@ from typing import Callable +from enums.lambda_error import LambdaError from utils.audit_logging_setup import LoggingService from utils.edge_response import EdgeResponse from utils.error_response import ErrorResponse -from utils.lambda_exceptions import LambdaException +from utils.lambda_exceptions import CloudFrontEdgeException from utils.request_context import request_context logger = LoggingService(__name__) def handle_edge_exceptions(lambda_func: Callable): - """A decorator for lambda edge handler. - Catch custom Edge Exceptions or AWS ClientError that may be unhandled or raised - - Usage: - @handle_edge_exceptions - def lambda_handler(event, context): - ... - """ - def interceptor(event, context): + interaction_id: str | None = getattr(request_context, "request_id", None) try: return lambda_func(event, context) - except LambdaException as e: + except CloudFrontEdgeException as e: logger.error(str(e)) - - interaction_id = getattr(request_context, "request_id", None) return EdgeResponse( status_code=e.status_code, body=ErrorResponse(e.err_code, e.message, interaction_id).create(), methods=event.get("httpMethod", "GET"), ).create_edge_response() + except Exception as e: + logger.error(f"Unhandled exception: {str(e)}") + err_code: str = LambdaError.EdgeMalformed.value["err_code"] + message: str = LambdaError.EdgeMalformed.value["message"] + return EdgeResponse( + status_code=500, + body=ErrorResponse(err_code, message, interaction_id).create(), + methods=event.get("httpMethod", "GET"), + ).create_edge_response() return interceptor diff --git a/lambdas/utils/decorators/validate_s3_request.py b/lambdas/utils/decorators/validate_s3_request.py new file mode 100644 index 000000000..7b8f01ee3 --- /dev/null +++ b/lambdas/utils/decorators/validate_s3_request.py @@ -0,0 +1,69 @@ +import json +import logging +from functools import wraps +from urllib.parse import parse_qs + +from enums.lambda_error import LambdaError +from utils.lambda_exceptions import CloudFrontEdgeException + +logger = logging.getLogger(__name__) + +REQUIRED_QUERY_PARAMS = [ + "X-Amz-Algorithm", + "X-Amz-Credential", + "X-Amz-Date", + "X-Amz-Expires", + "X-Amz-SignedHeaders", + "X-Amz-Signature", + "X-Amz-Security-Token", +] + +REQUIRED_HEADERS = ["host", "cloudfront-viewer-country", "x-forwarded-for"] + + +def validate_s3_request(lambda_func): + @wraps(lambda_func) + def wrapper(event, context): + request: dict = event["Records"][0]["cf"]["request"] + logger.info(json.dumps(request)) + bad_request: bool = ( + "uri" not in request + or "querystring" not in request + or "headers" not in request + ) + if bad_request: + logger.error( + "Missing required request components: uri, querystring, or headers." + ) + raise CloudFrontEdgeException(500, LambdaError.EdgeMalformed) + + origin: dict = request.get("origin", {}) + if "s3" not in origin or "domainName" not in origin["s3"]: + logger.error("Missing origin domain name.") + raise CloudFrontEdgeException(500, LambdaError.EdgeNoOrigin) + + querystring: str = request["querystring"] + if not querystring: + logger.error(f"Missing query string: {querystring}") + raise CloudFrontEdgeException(500, LambdaError.EdgeNoQuery) + query_params: dict = { + query: value[0] for query, value in parse_qs(querystring).items() + } + missing_query_params: list = [ + param for param in REQUIRED_QUERY_PARAMS if param not in query_params + ] + if missing_query_params: + logger.error(f"Missing required query parameters: {missing_query_params}") + raise CloudFrontEdgeException(500, LambdaError.EdgeRequiredQuery) + + headers: dict = request["headers"] + missing_headers = [ + header for header in REQUIRED_HEADERS if header.lower() not in headers + ] + if missing_headers: + logger.error(f"Missing required headers: {missing_headers}") + raise CloudFrontEdgeException(500, LambdaError.EdgeRequiredHeaders) + + return lambda_func(event, context) + + return wrapper