Skip to content

Commit

Permalink
307 add username and password logic to oauth security settings (#328)
Browse files Browse the repository at this point in the history
* WIP

* Fix failing tests

* Add missing python-multipart dependency

* Rebuild docs

* Implement code suggesitions
  • Loading branch information
sternakt authored Oct 8, 2024
1 parent dab7763 commit cccb21b
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 34 deletions.
63 changes: 51 additions & 12 deletions fastagency/api/openapi/security.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
from typing import Any, ClassVar, Literal, Optional, Protocol, Union
from typing import Any, ClassVar, Literal, Optional, Protocol

import requests
from pydantic import BaseModel, model_validator
from typing_extensions import TypeAlias

# Get the logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

BaseSecurityType: TypeAlias = type["BaseSecurity"]


class BaseSecurity(BaseModel):
"""Base class for security classes."""
Expand Down Expand Up @@ -36,24 +40,30 @@ def accept(self, security_params: "BaseSecurityParameters") -> bool:
return isinstance(self, security_params.get_security_class())

@classmethod
def is_supported(cls, type: str, in_value: Union[str, dict[str, Any]]) -> bool:
return type == cls.type and in_value == cls.in_value
def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool:
return cls.type == type and cls.in_value == schema_parameters.get("in")

@classmethod
def get_security_class(cls, type: str, in_value: str) -> Optional[str]:
def get_security_class(
cls, type: str, schema_parameters: dict[str, Any]
) -> BaseSecurityType:
sub_classes = cls.__subclasses__()

for sub_class in sub_classes:
if sub_class.is_supported(type, in_value):
return sub_class.__name__
if sub_class.is_supported(type, schema_parameters):
return sub_class
else:
logger.error(
f"Unsupported type '{type}' and in_value '{in_value}' combination"
f"Unsupported type '{type}' and schema_parameters '{schema_parameters}' combination"
)
raise ValueError(
f"Unsupported type '{type}' and in_value '{in_value}' combination"
f"Unsupported type '{type}' and schema_parameters '{schema_parameters}' combination"
)

@classmethod
def get_security_parameters(cls, schema_parameters: dict[str, Any]) -> str:
return f"{cls.__name__}(name=\"{schema_parameters.get('name')}\")"


class BaseSecurityParameters(Protocol):
"""Base class for security parameters."""
Expand Down Expand Up @@ -102,6 +112,13 @@ class APIKeyQuery(BaseSecurity):
type: ClassVar[Literal["apiKey"]] = "apiKey"
in_value: ClassVar[Literal["query"]] = "query"

@classmethod
def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool:
return (
super().is_supported(type, schema_parameters)
and "name" in schema_parameters
)

class Parameters(BaseModel): # BaseSecurityParameters
"""API Key Query security parameters class."""

Expand Down Expand Up @@ -180,17 +197,25 @@ class OAuth2PasswordBearer(BaseSecurity):

type: ClassVar[Literal["oauth2"]] = "oauth2"
in_value: ClassVar[Literal["bearer"]] = "bearer"
token_url: str

@classmethod
def is_supported(cls, type: str, in_value: Union[str, dict[str, Any]]) -> bool:
return type == cls.type and isinstance(in_value, dict)
def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool:
return type == cls.type and "password" in schema_parameters.get("flows", {})

@classmethod
def get_security_parameters(cls, schema_parameters: dict[str, Any]) -> str:
name = schema_parameters.get("name")
token_url = f'{schema_parameters.get("server_url")}/{schema_parameters["flows"]["password"]["tokenUrl"]}'
return f'{cls.__name__}(name="{name}", token_url="{token_url}")'

class Parameters(BaseModel): # BaseSecurityParameters
"""OAuth2 Password Bearer security class."""

username: Optional[str] = None
password: Optional[str] = None
bearer_token: Optional[str] = None
token_url: Optional[str] = None

@model_validator(mode="before")
def check_credentials(cls, values: dict[str, Any]) -> Any: # noqa
Expand All @@ -206,15 +231,29 @@ def check_credentials(cls, values: dict[str, Any]) -> Any: # noqa

return values

def get_token(self, token_url: str) -> str:
# Get the token
request = requests.post(
token_url,
data={
"username": self.username,
"password": self.password,
},
timeout=5,
)
request.raise_for_status()
return request.json()["access_token"] # type: ignore

def apply(
self,
q_params: dict[str, Any],
body_dict: dict[str, Any],
security: BaseSecurity,
) -> None:
if not self.bearer_token:
# request token from the tokenUrl with username and password
raise NotImplementedError()
if security.token_url is None: # type: ignore
raise ValueError("Token URL is not defined")
self.bearer_token = self.get_token(security.token_url) # type: ignore

if "headers" not in body_dict:
body_dict["headers"] = {}
Expand Down
24 changes: 9 additions & 15 deletions fastagency/api/openapi/security_schema_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,21 @@ def custom_visitor(parser: OpenAPIParser, model_path: Path) -> dict[str, object]
if "securitySchemes" not in parser.raw_obj["components"]:
return {}
security_schemes = parser.raw_obj["components"]["securitySchemes"]

# for k, v in security_schemes.items():
# security_schemes[k]["in_value"] = security_schemes[k].pop("in")
server_url = parser.raw_obj["servers"][0]["url"]

security_classes = []
security_parameters = {}
for k, v in security_schemes.items():
if "in" not in v and v["type"] == "http":
in_value = v.get("scheme", None)
if "in" not in v and v["type"] == "oauth2":
in_value = v.get("flows", None)
else:
in_value = v["in"]
v["server_url"] = server_url
security_class = BaseSecurity.get_security_class(
type=v["type"], in_value=in_value
type=v["type"], schema_parameters=v
)

security_classes.append(security_class.__name__)

security_parameters[k] = security_class.get_security_parameters(
schema_parameters=v
)
if security_class is None:
continue
security_classes.append(security_class)
name = v.get("name", None)
security_parameters[k] = f'{security_class}(name="{name}")'

return {
"security_schemes": security_schemes,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ testing = [
"PyYAML==6.0.2",
"watchfiles==0.24.0",
"email-validator==2.2.0",
"python-multipart>=0.0.12",
]

dev = [
Expand Down
147 changes: 142 additions & 5 deletions tests/api/openapi/security/test_oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import requests
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer as FastAPIOAuth2PasswordBearer
from fastapi.security import OAuth2PasswordRequestForm

from fastagency.api.openapi import OpenAPI
from fastagency.api.openapi.security import OAuth2PasswordBearer
Expand All @@ -20,6 +21,19 @@ def create_oauth2_fastapi_app(host: str, port: int) -> FastAPI:

oauth2_scheme = FastAPIOAuth2PasswordBearer(tokenUrl="token")

@app.post("/token")
async def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> dict[str, str]:
if (
form_data.username != "user"
or form_data.password != "password" # pragma: allowlist secret
):
raise HTTPException(
status_code=400, detail="Incorrect username or password"
)
return {"access_token": "token123", "token_type": "bearer"}

@app.post("/low", summary="Low Level")
async def post_oauth(
message: str, token: Annotated[str, Depends(oauth2_scheme)]
Expand All @@ -37,9 +51,49 @@ def openapi_oauth2_schema() -> dict[str, Any]:
"openapi": "3.1.0",
"info": {"title": "OAuth2", "version": "0.1.0"},
"servers": [
{"url": "http://127.0.0.1:43465", "description": "Local development server"}
{"url": "http://127.0.0.1:60473", "description": "Local development server"}
],
"paths": {
"/token": {
"post": {
"summary": "Login",
"operationId": "login_token_post",
"requestBody": {
"content": {
"application/x-www-form-urlencoded": {
"schema": {
"$ref": "#/components/schemas/Body_login_token_post"
}
}
},
"required": True,
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"additionalProperties": {"type": "string"},
"type": "object",
"title": "Response Login Token Post",
}
}
},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
}
},
"/low": {
"post": {
"summary": "Low Level",
Expand Down Expand Up @@ -78,10 +132,35 @@ def openapi_oauth2_schema() -> dict[str, Any]:
},
},
}
}
},
},
"components": {
"schemas": {
"Body_login_token_post": {
"properties": {
"grant_type": {
"anyOf": [
{"type": "string", "pattern": "password"},
{"type": "null"},
],
"title": "Grant Type",
},
"username": {"type": "string", "title": "Username"},
"password": {"type": "string", "title": "Password"},
"scope": {"type": "string", "title": "Scope", "default": ""},
"client_id": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"title": "Client Id",
},
"client_secret": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"title": "Client Secret",
},
},
"type": "object",
"required": ["username", "password"],
"title": "Body_login_token_post",
},
"HTTPValidationError": {
"properties": {
"detail": {
Expand Down Expand Up @@ -144,13 +223,41 @@ def test_oauth2_fastapi_app(
[(create_oauth2_fastapi_app)],
indirect=["fastapi_openapi_url"],
)
def test_generate_oauth2_client(fastapi_openapi_url: str) -> None:
api_client = OpenAPI.create(openapi_url=fastapi_openapi_url)
def test_generate_oauth2_client_token(fastapi_openapi_url: str) -> None:
api_client = OpenAPI.create(
openapi_url=fastapi_openapi_url,
)
api_client.set_security_params(
OAuth2PasswordBearer.Parameters(bearer_token="token123")
)

expected = ["post_oauth_low_post"]
expected = ["post_oauth_low_post", "login_token_post"]

functions = list(api_client._get_functions_to_register())
assert [f.__name__ for f in functions] == expected

post_oauth_f = functions[0]

response = post_oauth_f(message="message")

assert response == {"message": "message"}


@pytest.mark.parametrize(
"fastapi_openapi_url",
[(create_oauth2_fastapi_app)],
indirect=["fastapi_openapi_url"],
)
def test_generate_oauth2_client_password(fastapi_openapi_url: str) -> None:
api_client = OpenAPI.create(openapi_url=fastapi_openapi_url)
api_client.set_security_params(
OAuth2PasswordBearer.Parameters(
username="user",
password="password", # pragma: allowlist secret
)
)

expected = ["post_oauth_low_post", "login_token_post"]

functions = list(api_client._get_functions_to_register())
assert [f.__name__ for f in functions] == expected
Expand All @@ -160,3 +267,33 @@ def test_generate_oauth2_client(fastapi_openapi_url: str) -> None:
response = post_oauth_f(message="message")

assert response == {"message": "message"}


@pytest.mark.parametrize(
"fastapi_openapi_url",
[(create_oauth2_fastapi_app)],
indirect=["fastapi_openapi_url"],
)
def test_generate_oauth2_client_wrong_password(fastapi_openapi_url: str) -> None:
api_client = OpenAPI.create(openapi_url=fastapi_openapi_url)
api_client.set_security_params(
OAuth2PasswordBearer.Parameters(
username="user",
password="password123", # pragma: allowlist secret
)
)

expected = ["post_oauth_low_post", "login_token_post"]

functions = list(api_client._get_functions_to_register())
assert [f.__name__ for f in functions] == expected

post_oauth_f = functions[0]

with pytest.raises(requests.exceptions.HTTPError) as e:
post_oauth_f(message="message")

assert (
str(e.value)
== f'400 Client Error: Bad Request for url: {fastapi_openapi_url.split("/openapi.json")[0]}/token'
)
3 changes: 1 addition & 2 deletions tests/api/openapi/security/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def test_generate_client(secure_fastapi_url: str) -> None:
with expected_models_gen_path.open() as f:
expected_models_gen = f.readlines()[4:]

# print(actual_main_gen_txt)
assert actual_main_gen_txt == expected_main_gen_txt
assert actual_models_gen == expected_models_gen

Expand Down Expand Up @@ -106,7 +105,7 @@ def test_import_and_call_generate_client(secure_fastapi_url: str) -> None:
assert client_resp == {"api_key": api_key}


def test__get_matching_security(secure_fastapi_url: str) -> None:
def test_get_matching_security(secure_fastapi_url: str) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
td = Path(temp_dir) / "gen"

Expand Down

0 comments on commit cccb21b

Please sign in to comment.