Skip to content

Commit

Permalink
Authenticate with CIR using google ID token (#1288)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Mebin Abraham <35296336+MebinAbraham@users.noreply.github.com>
Co-authored-by: Mebin Abraham <mebin95+work@gmail.com>
  • Loading branch information
3 people authored Jan 18, 2024
1 parent 67ce919 commit bc3d5f6
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 11 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,10 @@ The following env variables can be used
| ONS_URL | `https://www.ons.gov.uk` | The URL of the ONS website where static content is sourced, e.g. accessibility info |
| SDS_API_BASE_URL | | The base URL of the SDS API used for fetching supplementary data |
| CIR_API_BASE_URL | | The base URL of the CIR API used for fetching collection instruments |
| OIDC_TOKEN_BACKEND | gcp | The backend to use when fetching the Open ID Connect token |
| OIDC_TOKEN_LEEWAY_IN_SECONDS | 300 | The leeway to use when validating OIDC tokens |
| SDS_OAUTH2_CLIENT_ID | | The OAuth2 Client ID used when setting up IAP on the SDS |
| CIR_OAUTH2_CLIENT_ID | | The OAuth2 Client ID used when setting up IAP on the CIR |

The following env variables can be used when running tests

Expand Down
9 changes: 2 additions & 7 deletions app/services/supplementary_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from structlog import get_logger

from app.keys import KEY_PURPOSE_SDS
from app.oidc.oidc import OIDCCredentialsService
from app.settings import SDS_OAUTH2_CLIENT_ID
from app.utilities.credentials import fetch_and_apply_oidc_credentials
from app.utilities.request_session import get_retryable_session
from app.utilities.supplementary_data_parser import validate_supplementary_data_v1

Expand Down Expand Up @@ -66,13 +66,8 @@ def get_supplementary_data_v1(
backoff_factor=SUPPLEMENTARY_DATA_REQUEST_BACKOFF_FACTOR,
)

# Type ignore: oidc_credentials_service is a singleton of this application
oidc_credentials_service: OIDCCredentialsService = current_app.eq["oidc_credentials_service"] # type: ignore
# Type ignore: SDS_OAUTH2_CLIENT_ID is an env var which must exist as it is verified in setup.py
credentials = oidc_credentials_service.get_credentials(
iap_client_id=SDS_OAUTH2_CLIENT_ID # type: ignore
)
credentials.apply(headers=session.headers)
fetch_and_apply_oidc_credentials(session=session, client_id=SDS_OAUTH2_CLIENT_ID) # type: ignore

try:
response = session.get(
Expand Down
1 change: 1 addition & 0 deletions app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def utcoffset_or_fail(date_value, key):
OIDC_TOKEN_LEEWAY_IN_SECONDS = int(os.getenv("OIDC_TOKEN_LEEWAY_IN_SECONDS", "300"))

SDS_OAUTH2_CLIENT_ID = os.getenv("SDS_OAUTH2_CLIENT_ID")
CIR_OAUTH2_CLIENT_ID = os.getenv("CIR_OAUTH2_CLIENT_ID")

ACCOUNT_SERVICE_BASE_URL = os.getenv(
"ACCOUNT_SERVICE_BASE_URL", "https://surveys.ons.gov.uk"
Expand Down
7 changes: 5 additions & 2 deletions app/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,15 +387,18 @@ def setup_task_client(application):


def setup_oidc(application):
def sds_client_id_exists():
def client_ids_exist():
if not application.config.get("SDS_OAUTH2_CLIENT_ID"):
raise MissingEnvironmentVariable("Setting SDS_OAUTH2_CLIENT_ID Missing")

if not application.config.get("CIR_OAUTH2_CLIENT_ID"):
raise MissingEnvironmentVariable("Setting CIR_OAUTH2_CLIENT_ID Missing")

if not (oidc_token_backend := application.config.get("OIDC_TOKEN_BACKEND")):
raise MissingEnvironmentVariable("Setting OIDC_TOKEN_BACKEND Missing")

if oidc_token_backend == "gcp":
sds_client_id_exists()
client_ids_exist()
application.eq["oidc_credentials_service"] = OIDCCredentialsServiceGCP()

elif oidc_token_backend == "local":
Expand Down
12 changes: 12 additions & 0 deletions app/utilities/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import requests
from flask import current_app

from app.oidc.oidc import OIDCCredentialsService


def fetch_and_apply_oidc_credentials(session: requests.Session, client_id: str) -> None:
# Type ignore: oidc_credentials_service is a singleton of this application
oidc_credentials_service: OIDCCredentialsService = current_app.eq["oidc_credentials_service"] # type: ignore

credentials = oidc_credentials_service.get_credentials(iap_client_id=client_id)
credentials.apply(headers=session.headers)
12 changes: 10 additions & 2 deletions app/utilities/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
DEFAULT_LANGUAGE_CODE,
QuestionnaireSchema,
)
from app.settings import CIR_OAUTH2_CLIENT_ID
from app.utilities.credentials import fetch_and_apply_oidc_credentials
from app.utilities.json import json_load, json_loads
from app.utilities.request_session import get_retryable_session

Expand Down Expand Up @@ -141,7 +143,7 @@ def load_schema_from_instrument_id(
) -> QuestionnaireSchema:
parameters = {"guid": cir_instrument_id}
cir_url = f"{current_app.config['CIR_API_BASE_URL']}{CIR_RETRIEVE_COLLECTION_INSTRUMENT_URL}?{urlencode(parameters)}"
return load_schema_from_url(url=cir_url, language_code=language_code)
return load_schema_from_url(url=cir_url, language_code=language_code, is_cir=True)


@lru_cache(maxsize=None)
Expand Down Expand Up @@ -194,7 +196,9 @@ def _load_schema_file(schema_name: str, language_code: str) -> Any:


@lru_cache(maxsize=None)
def load_schema_from_url(url: str, *, language_code: str | None) -> QuestionnaireSchema:
def load_schema_from_url(
url: str, *, language_code: str | None, is_cir: bool = False
) -> QuestionnaireSchema:
"""
Fetches a schema from the provided url.
The caller is responsible for including any required query parameters in the url
Expand All @@ -214,6 +218,10 @@ def load_schema_from_url(url: str, *, language_code: str | None) -> Questionnair
backoff_factor=SCHEMA_REQUEST_BACKOFF_FACTOR,
)

if is_cir:
# Type ignore: CIR_OAUTH2_CLIENT_ID is an env var which must exist as it is verified in setup.py
fetch_and_apply_oidc_credentials(session=session, client_id=CIR_OAUTH2_CLIENT_ID) # type: ignore

try:
req = session.get(url, timeout=SCHEMA_REQUEST_TIMEOUT)
except RequestException as exc:
Expand Down
35 changes: 35 additions & 0 deletions tests/app/utilities/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from urllib3.connectionpool import HTTPConnectionPool
from urllib3.response import HTTPResponse

from app.oidc.gcp_oidc import OIDCCredentialsServiceGCP
from app.questionnaire import QuestionnaireSchema
from app.setup import create_app
from app.utilities.schema import (
Expand Down Expand Up @@ -374,3 +375,37 @@ def test_load_schema_from_url_max_retries(mocker):

assert str(exc.value) == "schema request failed"
assert mocked_make_request.call_count == 3


@responses.activate
def test_load_schema_from_metadata_cir_with_gcp_authentication(
app, metadata_with_cir_instrument_id, mocker
):
load_schema_from_url.cache_clear()
mock_schema = QuestionnaireSchema({}, language_code="cy")

mock_oidc_service = Mock(spec=OIDCCredentialsServiceGCP)
mocker.patch.dict(
"app.services.supplementary_data.current_app.eq",
{"oidc_credentials_service": mock_oidc_service},
)

responses.add(
responses.GET,
f"{TEST_CIR_URL}{CIR_RETRIEVE_COLLECTION_INSTRUMENT_URL}",
json=mock_schema.json,
status=200,
)

with app.app_context():
app.config["CIR_API_BASE_URL"] = TEST_CIR_URL
loaded_schema = load_schema_from_metadata(
metadata=metadata_with_cir_instrument_id, language_code="cy"
)

mock_oidc_service.get_credentials.assert_called_once_with(
iap_client_id=app.config["CIR_OAUTH2_CLIENT_ID"]
)

assert loaded_schema.json == mock_schema.json
assert loaded_schema.language_code == mock_schema.language_code
14 changes: 14 additions & 0 deletions tests/integration/test_app_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def test_setup_oidc_service_gcp(self):
# Given
self._setting_overrides["OIDC_TOKEN_BACKEND"] = "gcp"
self._setting_overrides["SDS_OAUTH2_CLIENT_ID"] = "1234567890"
self._setting_overrides["CIR_OAUTH2_CLIENT_ID"] = "1234567890"

# When
application = create_app(self._setting_overrides)
Expand Down Expand Up @@ -449,3 +450,16 @@ def test_sds_oauth_2_client_id_missing_raises_exception(self):

# Then
assert "Setting SDS_OAUTH2_CLIENT_ID Missing" in str(ex.exception)

def test_cir_oauth_2_client_id_missing_raises_exception(self):
# Given
self._setting_overrides["OIDC_TOKEN_BACKEND"] = "gcp"
self._setting_overrides["SDS_OAUTH2_CLIENT_ID"] = "123456789"
self._setting_overrides["CIR_OAUTH2_CLIENT_ID"] = ""

# When
with self.assertRaises(MissingEnvironmentVariable) as ex:
create_app(self._setting_overrides)

# Then
assert "Setting CIR_OAUTH2_CLIENT_ID Missing" in str(ex.exception)

0 comments on commit bc3d5f6

Please sign in to comment.