diff --git a/fastagency/api/openapi/security.py b/fastagency/api/openapi/security.py index 5f8b40d6..dd41bc89 100644 --- a/fastagency/api/openapi/security.py +++ b/fastagency/api/openapi/security.py @@ -1,12 +1,16 @@ import logging -from typing import Any, ClassVar, Literal, Optional, Protocol, Union +from typing import Any, ClassVar, Literal, Optional, Protocol +import requests from pydantic import BaseModel, model_validator +from typing_extensions import TypeAlias # Get the logger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +BaseSecurityType: TypeAlias = type["BaseSecurity"] + class BaseSecurity(BaseModel): """Base class for security classes.""" @@ -36,24 +40,30 @@ def accept(self, security_params: "BaseSecurityParameters") -> bool: return isinstance(self, security_params.get_security_class()) @classmethod - def is_supported(cls, type: str, in_value: Union[str, dict[str, Any]]) -> bool: - return type == cls.type and in_value == cls.in_value + def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool: + return cls.type == type and cls.in_value == schema_parameters.get("in") @classmethod - def get_security_class(cls, type: str, in_value: str) -> Optional[str]: + def get_security_class( + cls, type: str, schema_parameters: dict[str, Any] + ) -> BaseSecurityType: sub_classes = cls.__subclasses__() for sub_class in sub_classes: - if sub_class.is_supported(type, in_value): - return sub_class.__name__ + if sub_class.is_supported(type, schema_parameters): + return sub_class else: logger.error( - f"Unsupported type '{type}' and in_value '{in_value}' combination" + f"Unsupported type '{type}' and schema_parameters '{schema_parameters}' combination" ) raise ValueError( - f"Unsupported type '{type}' and in_value '{in_value}' combination" + f"Unsupported type '{type}' and schema_parameters '{schema_parameters}' combination" ) + @classmethod + def get_security_parameters(cls, schema_parameters: dict[str, Any]) -> str: + return f"{cls.__name__}(name=\"{schema_parameters.get('name')}\")" + class BaseSecurityParameters(Protocol): """Base class for security parameters.""" @@ -102,6 +112,13 @@ class APIKeyQuery(BaseSecurity): type: ClassVar[Literal["apiKey"]] = "apiKey" in_value: ClassVar[Literal["query"]] = "query" + @classmethod + def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool: + return ( + super().is_supported(type, schema_parameters) + and "name" in schema_parameters + ) + class Parameters(BaseModel): # BaseSecurityParameters """API Key Query security parameters class.""" @@ -180,10 +197,17 @@ class OAuth2PasswordBearer(BaseSecurity): type: ClassVar[Literal["oauth2"]] = "oauth2" in_value: ClassVar[Literal["bearer"]] = "bearer" + token_url: str @classmethod - def is_supported(cls, type: str, in_value: Union[str, dict[str, Any]]) -> bool: - return type == cls.type and isinstance(in_value, dict) + def is_supported(cls, type: str, schema_parameters: dict[str, Any]) -> bool: + return type == cls.type and "password" in schema_parameters.get("flows", {}) + + @classmethod + def get_security_parameters(cls, schema_parameters: dict[str, Any]) -> str: + name = schema_parameters.get("name") + token_url = f'{schema_parameters.get("server_url")}/{schema_parameters["flows"]["password"]["tokenUrl"]}' + return f'{cls.__name__}(name="{name}", token_url="{token_url}")' class Parameters(BaseModel): # BaseSecurityParameters """OAuth2 Password Bearer security class.""" @@ -191,6 +215,7 @@ class Parameters(BaseModel): # BaseSecurityParameters username: Optional[str] = None password: Optional[str] = None bearer_token: Optional[str] = None + token_url: Optional[str] = None @model_validator(mode="before") def check_credentials(cls, values: dict[str, Any]) -> Any: # noqa @@ -206,6 +231,19 @@ def check_credentials(cls, values: dict[str, Any]) -> Any: # noqa return values + def get_token(self, token_url: str) -> str: + # Get the token + request = requests.post( + token_url, + data={ + "username": self.username, + "password": self.password, + }, + timeout=5, + ) + request.raise_for_status() + return request.json()["access_token"] # type: ignore + def apply( self, q_params: dict[str, Any], @@ -213,8 +251,9 @@ def apply( security: BaseSecurity, ) -> None: if not self.bearer_token: - # request token from the tokenUrl with username and password - raise NotImplementedError() + if security.token_url is None: # type: ignore + raise ValueError("Token URL is not defined") + self.bearer_token = self.get_token(security.token_url) # type: ignore if "headers" not in body_dict: body_dict["headers"] = {} diff --git a/fastagency/api/openapi/security_schema_visitor.py b/fastagency/api/openapi/security_schema_visitor.py index 581b9fba..c2ff6990 100644 --- a/fastagency/api/openapi/security_schema_visitor.py +++ b/fastagency/api/openapi/security_schema_visitor.py @@ -10,27 +10,21 @@ def custom_visitor(parser: OpenAPIParser, model_path: Path) -> dict[str, object] if "securitySchemes" not in parser.raw_obj["components"]: return {} security_schemes = parser.raw_obj["components"]["securitySchemes"] - - # for k, v in security_schemes.items(): - # security_schemes[k]["in_value"] = security_schemes[k].pop("in") + server_url = parser.raw_obj["servers"][0]["url"] security_classes = [] security_parameters = {} for k, v in security_schemes.items(): - if "in" not in v and v["type"] == "http": - in_value = v.get("scheme", None) - if "in" not in v and v["type"] == "oauth2": - in_value = v.get("flows", None) - else: - in_value = v["in"] + v["server_url"] = server_url security_class = BaseSecurity.get_security_class( - type=v["type"], in_value=in_value + type=v["type"], schema_parameters=v + ) + + security_classes.append(security_class.__name__) + + security_parameters[k] = security_class.get_security_parameters( + schema_parameters=v ) - if security_class is None: - continue - security_classes.append(security_class) - name = v.get("name", None) - security_parameters[k] = f'{security_class}(name="{name}")' return { "security_schemes": security_schemes, diff --git a/pyproject.toml b/pyproject.toml index 97b2ac10..b7293671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,6 +137,7 @@ testing = [ "PyYAML==6.0.2", "watchfiles==0.24.0", "email-validator==2.2.0", + "python-multipart>=0.0.12", ] dev = [ diff --git a/tests/api/openapi/security/test_oauth_client.py b/tests/api/openapi/security/test_oauth_client.py index bad1c456..90e9baa3 100644 --- a/tests/api/openapi/security/test_oauth_client.py +++ b/tests/api/openapi/security/test_oauth_client.py @@ -5,6 +5,7 @@ import requests from fastapi import Depends, FastAPI, HTTPException, status from fastapi.security import OAuth2PasswordBearer as FastAPIOAuth2PasswordBearer +from fastapi.security import OAuth2PasswordRequestForm from fastagency.api.openapi import OpenAPI from fastagency.api.openapi.security import OAuth2PasswordBearer @@ -20,6 +21,19 @@ def create_oauth2_fastapi_app(host: str, port: int) -> FastAPI: oauth2_scheme = FastAPIOAuth2PasswordBearer(tokenUrl="token") + @app.post("/token") + async def login( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + ) -> dict[str, str]: + if ( + form_data.username != "user" + or form_data.password != "password" # pragma: allowlist secret + ): + raise HTTPException( + status_code=400, detail="Incorrect username or password" + ) + return {"access_token": "token123", "token_type": "bearer"} + @app.post("/low", summary="Low Level") async def post_oauth( message: str, token: Annotated[str, Depends(oauth2_scheme)] @@ -37,9 +51,49 @@ def openapi_oauth2_schema() -> dict[str, Any]: "openapi": "3.1.0", "info": {"title": "OAuth2", "version": "0.1.0"}, "servers": [ - {"url": "http://127.0.0.1:43465", "description": "Local development server"} + {"url": "http://127.0.0.1:60473", "description": "Local development server"} ], "paths": { + "/token": { + "post": { + "summary": "Login", + "operationId": "login_token_post", + "requestBody": { + "content": { + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Body_login_token_post" + } + } + }, + "required": True, + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": {"type": "string"}, + "type": "object", + "title": "Response Login Token Post", + } + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + }, "/low": { "post": { "summary": "Low Level", @@ -78,10 +132,35 @@ def openapi_oauth2_schema() -> dict[str, Any]: }, }, } - } + }, }, "components": { "schemas": { + "Body_login_token_post": { + "properties": { + "grant_type": { + "anyOf": [ + {"type": "string", "pattern": "password"}, + {"type": "null"}, + ], + "title": "Grant Type", + }, + "username": {"type": "string", "title": "Username"}, + "password": {"type": "string", "title": "Password"}, + "scope": {"type": "string", "title": "Scope", "default": ""}, + "client_id": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "Client Id", + }, + "client_secret": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "Client Secret", + }, + }, + "type": "object", + "required": ["username", "password"], + "title": "Body_login_token_post", + }, "HTTPValidationError": { "properties": { "detail": { @@ -144,13 +223,41 @@ def test_oauth2_fastapi_app( [(create_oauth2_fastapi_app)], indirect=["fastapi_openapi_url"], ) -def test_generate_oauth2_client(fastapi_openapi_url: str) -> None: - api_client = OpenAPI.create(openapi_url=fastapi_openapi_url) +def test_generate_oauth2_client_token(fastapi_openapi_url: str) -> None: + api_client = OpenAPI.create( + openapi_url=fastapi_openapi_url, + ) api_client.set_security_params( OAuth2PasswordBearer.Parameters(bearer_token="token123") ) - expected = ["post_oauth_low_post"] + expected = ["post_oauth_low_post", "login_token_post"] + + functions = list(api_client._get_functions_to_register()) + assert [f.__name__ for f in functions] == expected + + post_oauth_f = functions[0] + + response = post_oauth_f(message="message") + + assert response == {"message": "message"} + + +@pytest.mark.parametrize( + "fastapi_openapi_url", + [(create_oauth2_fastapi_app)], + indirect=["fastapi_openapi_url"], +) +def test_generate_oauth2_client_password(fastapi_openapi_url: str) -> None: + api_client = OpenAPI.create(openapi_url=fastapi_openapi_url) + api_client.set_security_params( + OAuth2PasswordBearer.Parameters( + username="user", + password="password", # pragma: allowlist secret + ) + ) + + expected = ["post_oauth_low_post", "login_token_post"] functions = list(api_client._get_functions_to_register()) assert [f.__name__ for f in functions] == expected @@ -160,3 +267,33 @@ def test_generate_oauth2_client(fastapi_openapi_url: str) -> None: response = post_oauth_f(message="message") assert response == {"message": "message"} + + +@pytest.mark.parametrize( + "fastapi_openapi_url", + [(create_oauth2_fastapi_app)], + indirect=["fastapi_openapi_url"], +) +def test_generate_oauth2_client_wrong_password(fastapi_openapi_url: str) -> None: + api_client = OpenAPI.create(openapi_url=fastapi_openapi_url) + api_client.set_security_params( + OAuth2PasswordBearer.Parameters( + username="user", + password="password123", # pragma: allowlist secret + ) + ) + + expected = ["post_oauth_low_post", "login_token_post"] + + functions = list(api_client._get_functions_to_register()) + assert [f.__name__ for f in functions] == expected + + post_oauth_f = functions[0] + + with pytest.raises(requests.exceptions.HTTPError) as e: + post_oauth_f(message="message") + + assert ( + str(e.value) + == f'400 Client Error: Bad Request for url: {fastapi_openapi_url.split("/openapi.json")[0]}/token' + ) diff --git a/tests/api/openapi/security/test_security.py b/tests/api/openapi/security/test_security.py index 73283955..31a32f20 100644 --- a/tests/api/openapi/security/test_security.py +++ b/tests/api/openapi/security/test_security.py @@ -57,7 +57,6 @@ def test_generate_client(secure_fastapi_url: str) -> None: with expected_models_gen_path.open() as f: expected_models_gen = f.readlines()[4:] - # print(actual_main_gen_txt) assert actual_main_gen_txt == expected_main_gen_txt assert actual_models_gen == expected_models_gen @@ -106,7 +105,7 @@ def test_import_and_call_generate_client(secure_fastapi_url: str) -> None: assert client_resp == {"api_key": api_key} -def test__get_matching_security(secure_fastapi_url: str) -> None: +def test_get_matching_security(secure_fastapi_url: str) -> None: with tempfile.TemporaryDirectory() as temp_dir: td = Path(temp_dir) / "gen"