From 149cd0c6b8a2dd5271e037f0236a84ff2234edb8 Mon Sep 17 00:00:00 2001 From: NogaNHS Date: Wed, 18 Oct 2023 11:23:57 +0100 Subject: [PATCH] prmdr-327 bug fix --- .../search_patient_details_handler.py | 2 +- lambdas/services/pds_api_service.py | 27 ++++++++++--------- .../unit/services/test_pds_api_service.py | 20 +++++++------- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/lambdas/handlers/search_patient_details_handler.py b/lambdas/handlers/search_patient_details_handler.py index 0196705a9..b0cf298de 100644 --- a/lambdas/handlers/search_patient_details_handler.py +++ b/lambdas/handlers/search_patient_details_handler.py @@ -23,7 +23,7 @@ def get_pds_service(): return ( PdsApiService - if not bool(os.getenv("PDS_FHIR_IS_STUBBED")) + if (os.getenv("PDS_FHIR_IS_STUBBED") == 'false') else MockPdsApiService ) diff --git a/lambdas/services/pds_api_service.py b/lambdas/services/pds_api_service.py index becd0e271..ee2092e84 100644 --- a/lambdas/services/pds_api_service.py +++ b/lambdas/services/pds_api_service.py @@ -30,6 +30,7 @@ def fetch_patient_details( nhs_number: str, ) -> PatientDetails: try: + logger.info("Using real pds service") validate_id(nhs_number) response = self.pds_request(nhs_number, retry_on_expired=True) return self.handle_response(response, nhs_number) @@ -73,19 +74,19 @@ def pds_request(self, nshNumber: str, retry_on_expired: bool): url_endpoint = endpoint + "Patient/" + nshNumber pds_response = requests.get(url=url_endpoint, headers=authorization_header) - if pds_response.status_code == 401 & retry_on_expired: return self.pds_request(nshNumber, retry_on_expired=False) return pds_response def get_new_access_token(self): + logger.info("Getting new PDS access token") try: access_token_ssm_parameter = self.get_parameters_for_new_access_token() jwt_token = self.create_jwt_token_for_new_access_token_request( access_token_ssm_parameter ) nhs_oauth_endpoint = access_token_ssm_parameter[ - SSMParameter.NHS_OAUTH_ENDPOINT + SSMParameter.NHS_OAUTH_ENDPOINT.value ] nhs_oauth_response = self.request_new_access_token( jwt_token, nhs_oauth_endpoint @@ -100,15 +101,15 @@ def get_new_access_token(self): def get_parameters_for_new_access_token(self): parameters = [ - SSMParameter.NHS_OAUTH_ENDPOINT, - SSMParameter.PDS_KID, - SSMParameter.NHS_OAUTH_KEY, - SSMParameter.PDS_API_KEY, + SSMParameter.NHS_OAUTH_ENDPOINT.value, + SSMParameter.PDS_KID.value, + SSMParameter.NHS_OAUTH_KEY.value, + SSMParameter.PDS_API_KEY.value, ] return self.ssm_service.get_ssm_parameters(parameters, with_decryption=True) def update_access_token_ssm(self, parameter_value: str): - parameter_key = SSMParameter.PDS_API_ACCESS_TOKEN + parameter_key = SSMParameter.PDS_API_ACCESS_TOKEN.value self.ssm_service.update_ssm_parameter( parameter_key=parameter_key, parameter_value=parameter_value, @@ -117,8 +118,8 @@ def update_access_token_ssm(self, parameter_value: str): def get_parameters_for_pds_api_request(self): parameters = [ - SSMParameter.PDS_API_ENDPOINT, - SSMParameter.PDS_API_ACCESS_TOKEN, + SSMParameter.PDS_API_ENDPOINT.value, + SSMParameter.PDS_API_ACCESS_TOKEN.value, ] ssm_response = self.ssm_service.get_ssm_parameters( parameters_keys=parameters, with_decryption=True @@ -129,11 +130,11 @@ def create_jwt_token_for_new_access_token_request( self, access_token_ssm_parameters ): nhs_oauth_endpoint = access_token_ssm_parameters[ - SSMParameter.NHS_OAUTH_ENDPOINT + SSMParameter.NHS_OAUTH_ENDPOINT.value ] - kid = access_token_ssm_parameters[SSMParameter.PDS_KID] - nhs_key = access_token_ssm_parameters[SSMParameter.NHS_OAUTH_KEY] - pds_key = access_token_ssm_parameters[SSMParameter.PDS_API_KEY] + kid = access_token_ssm_parameters[SSMParameter.PDS_KID.value] + nhs_key = access_token_ssm_parameters[SSMParameter.NHS_OAUTH_KEY.value] + pds_key = access_token_ssm_parameters[SSMParameter.PDS_API_KEY.value] payload = { "iss": nhs_key, "sub": nhs_key, diff --git a/lambdas/tests/unit/services/test_pds_api_service.py b/lambdas/tests/unit/services/test_pds_api_service.py index 7f9960b18..a9882173d 100644 --- a/lambdas/tests/unit/services/test_pds_api_service.py +++ b/lambdas/tests/unit/services/test_pds_api_service.py @@ -96,10 +96,10 @@ def test_request_new_token_is_call_with_correct_data(mocker): def test_create_jwt_for_new_access_token(mocker): access_token_parameters = { - SSMParameter.NHS_OAUTH_ENDPOINT: "api.endpoint/mock", - SSMParameter.PDS_KID: "test_string_pds_kid", - SSMParameter.NHS_OAUTH_KEY: "test_string_key_oauth", - SSMParameter.PDS_API_KEY: "test_string_key_pds", + SSMParameter.NHS_OAUTH_ENDPOINT.value: "api.endpoint/mock", + SSMParameter.PDS_KID.value: "test_string_pds_kid", + SSMParameter.NHS_OAUTH_KEY.value: "test_string_key_oauth", + SSMParameter.PDS_API_KEY.value: "test_string_key_pds", } expected_payload = { "iss": "test_string_key_oauth", @@ -121,7 +121,7 @@ def test_create_jwt_for_new_access_token(mocker): ) def test_get_parameters_for_pds_api_request(): - ssm_parameters_expected = (f"test_value_{SSMParameter.PDS_API_ENDPOINT}", f"test_value_{SSMParameter.PDS_API_ACCESS_TOKEN}") + ssm_parameters_expected = (f"test_value_{SSMParameter.PDS_API_ENDPOINT.value}", f"test_value_{SSMParameter.PDS_API_ACCESS_TOKEN.value}") actual = pds_service.get_parameters_for_pds_api_request() assert ssm_parameters_expected == actual @@ -130,14 +130,14 @@ def test_update_access_token_ssm(mocker): pds_service.update_access_token_ssm("test_string") - fake_ssm_service.update_ssm_parameter.assert_called_with(parameter_key=SSMParameter.PDS_API_ACCESS_TOKEN, parameter_value="test_string", parameter_type="SecureString") + fake_ssm_service.update_ssm_parameter.assert_called_with(parameter_key=SSMParameter.PDS_API_ACCESS_TOKEN.value, parameter_value="test_string", parameter_type="SecureString") def test_get_parameters_for_new_access_token(mocker): parameters = [ - SSMParameter.NHS_OAUTH_ENDPOINT, - SSMParameter.PDS_KID, - SSMParameter.NHS_OAUTH_KEY, - SSMParameter.PDS_API_KEY, + SSMParameter.NHS_OAUTH_ENDPOINT.value, + SSMParameter.PDS_KID.value, + SSMParameter.NHS_OAUTH_KEY.value, + SSMParameter.PDS_API_KEY.value, ] fake_ssm_service.get_ssm_parameters = mocker.MagicMock() pds_service.get_parameters_for_new_access_token()