Skip to content

Commit

Permalink
Merge branch 'main' into Feature/11_combine_profile_update_and_fi_ass…
Browse files Browse the repository at this point in the history
…ociation
  • Loading branch information
guffee23 committed Sep 26, 2023
2 parents e16b624 + 3a3d65f commit 1057704
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 61 deletions.
20 changes: 4 additions & 16 deletions src/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,12 @@
from entities.engine import get_session
from entities.repos import institutions_repo as repo

OPEN_DOMAIN_REQUESTS = {
"/v1/admin/me": {"GET"},
"/v1/institutions": {"GET"},
"/v1/institutions/domains/allowed": {"GET"},
}


async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
if request_needs_domain_check(request):
if not request.user.is_authenticated:
raise HTTPException(status_code=HTTPStatus.FORBIDDEN)
if await email_domain_denied(session, request.user.email):
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied")


def request_needs_domain_check(request: Request) -> bool:
path = request.scope["path"].rstrip("/")
return not (path in OPEN_DOMAIN_REQUESTS and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path])
if not request.user.is_authenticated:
raise HTTPException(status_code=HTTPStatus.FORBIDDEN)
if await email_domain_denied(session, request.user.email):
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied")


async def email_domain_denied(session: AsyncSession, email: str) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"DeniedDomainDao",
"DeniedDomainDto",
"UserProfile",
"AuthenticatedUser",
]

from .dao import (
Expand All @@ -24,4 +25,5 @@
FinancialInsitutionDomainCreate,
DeniedDomainDto,
UserProfile,
AuthenticatedUser,
)
44 changes: 44 additions & 0 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Dict, Any, Set
from pydantic import BaseModel
from starlette.authentication import BaseUser


class FinancialInsitutionDomainBase(BaseModel):
Expand Down Expand Up @@ -46,3 +47,46 @@ class UserProfile(BaseModel):

def get_user(self) -> Dict[str, Any]:
return {"firstName": self.firstName, "lastName": self.lastName}


class AuthenticatedUser(BaseUser, BaseModel):
claims: Dict[str, Any]
name: str
username: str
email: str
id: str
institutions: List[str]

@classmethod
def from_claim(cls, claims: Dict[str, Any]) -> "AuthenticatedUser":
return cls(
claims=claims,
name=claims.get("name", ""),
username=claims.get("preferred_username", ""),
email=claims.get("email", ""),
id=claims.get("sub", ""),
institutions=cls.parse_institutions(claims.get("institutions")),
)

@classmethod
def parse_institutions(cls, institutions: List[str] | None) -> List[str]:
"""
Parse out the list of institutions returned by Keycloak
Args:
institutions(List[str]): list of full institution paths provided by keycloak,
it is possible to have nested paths, though we may not use the feature.
e.g. ["/ROOT_INSTITUTION/CHILD_INSTITUTION/GRAND_CHILD_INSTITUTION"]
Returns:
List[str]: List of cleaned up institutions.
e.g. ["GRAND_CHILD_INSTITUTION"]
"""
if institutions:
return [institution.split("/")[-1] for institution in institutions]
else:
return []

@property
def is_authenticated(self) -> bool:
return True
5 changes: 2 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@
import logging
import env # noqa: F401
from http import HTTPStatus
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2AuthorizationCodeBearer
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
from dependencies import check_domain

from routers import admin_router, institutions_router

from oauth2 import BearerTokenAuthBackend

log = logging.getLogger()

app = FastAPI(dependencies=[Depends(check_domain)])
app = FastAPI()


@app.exception_handler(HTTPException)
Expand Down
4 changes: 2 additions & 2 deletions src/oauth2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["oauth2_admin", "BearerTokenAuthBackend", "AuthenticatedUser"]
__all__ = ["oauth2_admin", "BearerTokenAuthBackend"]

from .oauth2_admin import oauth2_admin
from .oauth2_backend import BearerTokenAuthBackend, AuthenticatedUser
from .oauth2_backend import BearerTokenAuthBackend
25 changes: 2 additions & 23 deletions src/oauth2/oauth2_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from typing import Coroutine, Any, Dict, List, Tuple
from fastapi import HTTPException
from pydantic import BaseModel
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
Expand All @@ -11,33 +10,13 @@
from fastapi.security import OAuth2AuthorizationCodeBearer
from starlette.requests import HTTPConnection

from entities.models import AuthenticatedUser

from .oauth2_admin import oauth2_admin

log = logging.getLogger(__name__)


class AuthenticatedUser(BaseUser, BaseModel):
claims: Dict[str, Any]
name: str | None
username: str | None
email: str | None
id: str | None

@classmethod
def from_claim(cls, claims: Dict[str, Any]) -> "AuthenticatedUser":
return cls(
claims=claims,
name=claims.get("name"),
username=claims.get("preferred_username"),
email=claims.get("email"),
id=claims.get("sub"),
)

@property
def is_authenticated(self) -> bool:
return True


class BearerTokenAuthBackend(AuthenticationBackend):
def __init__(self, token_bearer: OAuth2AuthorizationCodeBearer) -> None:
self.token_bearer = token_bearer
Expand Down
17 changes: 6 additions & 11 deletions src/routers/admin.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from http import HTTPStatus
from typing import Set
from fastapi import Request
from fastapi import Depends, Request
from starlette.authentication import requires
from dependencies import check_domain
from util import Router
from entities.models import UserProfile

from oauth2 import AuthenticatedUser, oauth2_admin
from entities.models import AuthenticatedUser
from oauth2 import oauth2_admin

router = Router()

Expand All @@ -16,22 +18,15 @@ async def get_me(request: Request):
return request.user


@router.put("/me/", status_code=HTTPStatus.ACCEPTED)
@router.put("/me/", status_code=HTTPStatus.ACCEPTED, dependencies=[Depends(check_domain)])
@requires("manage-account")
async def update_me(request: Request, user: UserProfile):
oauth2_admin.update_user(request.user.id, user.get_user())
if user.leis:
oauth2_admin.associate_to_lei_set(request.user.id, user.leis)


@router.put("/me/groups/", status_code=HTTPStatus.ACCEPTED)
@requires("manage-account")
async def associate_group(request: Request, groups: Set[str]):
for group in groups:
oauth2_admin.associate_to_group(request.user.id, group)


@router.put("/me/institutions/", status_code=HTTPStatus.ACCEPTED)
@router.put("/me/institutions/", status_code=HTTPStatus.ACCEPTED, dependencies=[Depends(check_domain)])
@requires("manage-account")
async def associate_lei(request: Request, leis: Set[str]):
oauth2_admin.associate_to_lei_set(request.user.id, leis)
6 changes: 3 additions & 3 deletions src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from http import HTTPStatus
from oauth2 import oauth2_admin
from util import Router
from dependencies import parse_leis
from dependencies import check_domain, parse_leis
from typing import Annotated, List, Tuple
from entities.engine import get_session
from entities.repos import institutions_repo as repo
Expand Down Expand Up @@ -35,7 +35,7 @@ async def get_institutions(
return await repo.get_institutions(request.state.db_session, leis, domain, page, count)


@router.post("/", response_model=Tuple[str, FinancialInstitutionDto])
@router.post("/", response_model=Tuple[str, FinancialInstitutionDto], dependencies=[Depends(check_domain)])
@requires(["query-groups", "manage-users"])
async def create_institution(
request: Request,
Expand All @@ -58,7 +58,7 @@ async def get_institution(
return res


@router.post("/{lei}/domains/", response_model=List[FinancialInsitutionDomainDto])
@router.post("/{lei}/domains/", response_model=List[FinancialInsitutionDomainDto], dependencies=[Depends(check_domain)])
@requires(["query-groups", "manage-users"])
async def add_domains(
request: Request,
Expand Down
2 changes: 1 addition & 1 deletion tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytest_mock import MockerFixture
from starlette.authentication import AuthCredentials, UnauthenticatedUser

from oauth2.oauth2_backend import AuthenticatedUser
from entities.models import AuthenticatedUser


@pytest.fixture
Expand Down
20 changes: 19 additions & 1 deletion tests/api/routers/test_admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytest_mock import MockerFixture
from starlette.authentication import AuthCredentials

from oauth2.oauth2_backend import AuthenticatedUser
from entities.models import AuthenticatedUser


class TestAdminApi:
Expand All @@ -19,6 +19,24 @@ def test_get_me_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed
res = client.get("/v1/admin/me")
assert res.status_code == 200
assert res.json().get("name") == "test"
assert res.json().get("institutions") == []

def test_get_me_authed_with_institutions(self, app_fixture: FastAPI, auth_mock: Mock):
claims = {
"name": "test",
"preferred_username": "test_user",
"email": "test@local.host",
"sub": "testuser123",
"institutions": ["/TEST1LEI", "/TEST2LEI/TEST2CHILDLEI"],
}
auth_mock.return_value = (
AuthCredentials(["authenticated"]),
AuthenticatedUser.from_claim(claims),
)
client = TestClient(app_fixture)
res = client.get("/v1/admin/me")
assert res.status_code == 200
assert res.json().get("institutions") == ["TEST1LEI", "TEST2CHILDLEI"]

def test_update_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
client = TestClient(app_fixture)
Expand Down
20 changes: 19 additions & 1 deletion tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import Mock
from unittest.mock import Mock, ANY

from fastapi import FastAPI
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -81,6 +81,15 @@ def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAP
assert res.status_code == 200
assert res.json().get("name") == "Test Bank 123"

def test_get_institution_not_exists(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution")
get_institution_mock.return_value = None
client = TestClient(app_fixture)
lei_path = "testLeiPath"
res = client.get(f"/v1/institutions/{lei_path}")
get_institution_mock.assert_called_once_with(ANY, lei_path)
assert res.status_code == 404

def test_add_domains_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
client = TestClient(app_fixture)

Expand Down Expand Up @@ -124,3 +133,12 @@ def test_add_domains_authed_with_denied_email_domain(
res = client.post(f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}])
assert res.status_code == 403
assert "domain denied" in res.json()["detail"]

def test_check_domain_allowed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
domain_allowed_mock = mocker.patch("entities.repos.institutions_repo.is_email_domain_allowed")
domain_allowed_mock.return_value = True
domain_to_check = "local.host"
client = TestClient(app_fixture)
res = client.get(f"/v1/institutions/domains/allowed?domain={domain_to_check}")
domain_allowed_mock.assert_called_once_with(ANY, domain_to_check)
assert res.json() is True

0 comments on commit 1057704

Please sign in to comment.