Skip to content

Commit

Permalink
prmdr-327 bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
NogaNHS committed Oct 18, 2023
1 parent 77027df commit 149cd0c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 24 deletions.
2 changes: 1 addition & 1 deletion lambdas/handlers/search_patient_details_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
def get_pds_service():
return (
PdsApiService
if not bool(os.getenv("PDS_FHIR_IS_STUBBED"))
if (os.getenv("PDS_FHIR_IS_STUBBED") == 'false')
else MockPdsApiService
)

Expand Down
27 changes: 14 additions & 13 deletions lambdas/services/pds_api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def fetch_patient_details(
nhs_number: str,
) -> PatientDetails:
try:
logger.info("Using real pds service")
validate_id(nhs_number)
response = self.pds_request(nhs_number, retry_on_expired=True)
return self.handle_response(response, nhs_number)
Expand Down Expand Up @@ -73,19 +74,19 @@ def pds_request(self, nshNumber: str, retry_on_expired: bool):

url_endpoint = endpoint + "Patient/" + nshNumber
pds_response = requests.get(url=url_endpoint, headers=authorization_header)

if pds_response.status_code == 401 & retry_on_expired:
return self.pds_request(nshNumber, retry_on_expired=False)
return pds_response

def get_new_access_token(self):
logger.info("Getting new PDS access token")
try:
access_token_ssm_parameter = self.get_parameters_for_new_access_token()
jwt_token = self.create_jwt_token_for_new_access_token_request(
access_token_ssm_parameter
)
nhs_oauth_endpoint = access_token_ssm_parameter[
SSMParameter.NHS_OAUTH_ENDPOINT
SSMParameter.NHS_OAUTH_ENDPOINT.value
]
nhs_oauth_response = self.request_new_access_token(
jwt_token, nhs_oauth_endpoint
Expand All @@ -100,15 +101,15 @@ def get_new_access_token(self):

def get_parameters_for_new_access_token(self):
parameters = [
SSMParameter.NHS_OAUTH_ENDPOINT,
SSMParameter.PDS_KID,
SSMParameter.NHS_OAUTH_KEY,
SSMParameter.PDS_API_KEY,
SSMParameter.NHS_OAUTH_ENDPOINT.value,
SSMParameter.PDS_KID.value,
SSMParameter.NHS_OAUTH_KEY.value,
SSMParameter.PDS_API_KEY.value,
]
return self.ssm_service.get_ssm_parameters(parameters, with_decryption=True)

def update_access_token_ssm(self, parameter_value: str):
parameter_key = SSMParameter.PDS_API_ACCESS_TOKEN
parameter_key = SSMParameter.PDS_API_ACCESS_TOKEN.value
self.ssm_service.update_ssm_parameter(
parameter_key=parameter_key,
parameter_value=parameter_value,
Expand All @@ -117,8 +118,8 @@ def update_access_token_ssm(self, parameter_value: str):

def get_parameters_for_pds_api_request(self):
parameters = [
SSMParameter.PDS_API_ENDPOINT,
SSMParameter.PDS_API_ACCESS_TOKEN,
SSMParameter.PDS_API_ENDPOINT.value,
SSMParameter.PDS_API_ACCESS_TOKEN.value,
]
ssm_response = self.ssm_service.get_ssm_parameters(
parameters_keys=parameters, with_decryption=True
Expand All @@ -129,11 +130,11 @@ def create_jwt_token_for_new_access_token_request(
self, access_token_ssm_parameters
):
nhs_oauth_endpoint = access_token_ssm_parameters[
SSMParameter.NHS_OAUTH_ENDPOINT
SSMParameter.NHS_OAUTH_ENDPOINT.value
]
kid = access_token_ssm_parameters[SSMParameter.PDS_KID]
nhs_key = access_token_ssm_parameters[SSMParameter.NHS_OAUTH_KEY]
pds_key = access_token_ssm_parameters[SSMParameter.PDS_API_KEY]
kid = access_token_ssm_parameters[SSMParameter.PDS_KID.value]
nhs_key = access_token_ssm_parameters[SSMParameter.NHS_OAUTH_KEY.value]
pds_key = access_token_ssm_parameters[SSMParameter.PDS_API_KEY.value]
payload = {
"iss": nhs_key,
"sub": nhs_key,
Expand Down
20 changes: 10 additions & 10 deletions lambdas/tests/unit/services/test_pds_api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ def test_request_new_token_is_call_with_correct_data(mocker):

def test_create_jwt_for_new_access_token(mocker):
access_token_parameters = {
SSMParameter.NHS_OAUTH_ENDPOINT: "api.endpoint/mock",
SSMParameter.PDS_KID: "test_string_pds_kid",
SSMParameter.NHS_OAUTH_KEY: "test_string_key_oauth",
SSMParameter.PDS_API_KEY: "test_string_key_pds",
SSMParameter.NHS_OAUTH_ENDPOINT.value: "api.endpoint/mock",
SSMParameter.PDS_KID.value: "test_string_pds_kid",
SSMParameter.NHS_OAUTH_KEY.value: "test_string_key_oauth",
SSMParameter.PDS_API_KEY.value: "test_string_key_pds",
}
expected_payload = {
"iss": "test_string_key_oauth",
Expand All @@ -121,7 +121,7 @@ def test_create_jwt_for_new_access_token(mocker):
)

def test_get_parameters_for_pds_api_request():
ssm_parameters_expected = (f"test_value_{SSMParameter.PDS_API_ENDPOINT}", f"test_value_{SSMParameter.PDS_API_ACCESS_TOKEN}")
ssm_parameters_expected = (f"test_value_{SSMParameter.PDS_API_ENDPOINT.value}", f"test_value_{SSMParameter.PDS_API_ACCESS_TOKEN.value}")
actual = pds_service.get_parameters_for_pds_api_request()
assert ssm_parameters_expected == actual

Expand All @@ -130,14 +130,14 @@ def test_update_access_token_ssm(mocker):

pds_service.update_access_token_ssm("test_string")

fake_ssm_service.update_ssm_parameter.assert_called_with(parameter_key=SSMParameter.PDS_API_ACCESS_TOKEN, parameter_value="test_string", parameter_type="SecureString")
fake_ssm_service.update_ssm_parameter.assert_called_with(parameter_key=SSMParameter.PDS_API_ACCESS_TOKEN.value, parameter_value="test_string", parameter_type="SecureString")

def test_get_parameters_for_new_access_token(mocker):
parameters = [
SSMParameter.NHS_OAUTH_ENDPOINT,
SSMParameter.PDS_KID,
SSMParameter.NHS_OAUTH_KEY,
SSMParameter.PDS_API_KEY,
SSMParameter.NHS_OAUTH_ENDPOINT.value,
SSMParameter.PDS_KID.value,
SSMParameter.NHS_OAUTH_KEY.value,
SSMParameter.PDS_API_KEY.value,
]
fake_ssm_service.get_ssm_parameters = mocker.MagicMock()
pds_service.get_parameters_for_new_access_token()
Expand Down

0 comments on commit 149cd0c

Please sign in to comment.