From a69ee342c2fa7e0099949ce410591893b3824091 Mon Sep 17 00:00:00 2001 From: Erick Daniszewski Date: Tue, 17 Nov 2020 11:42:38 -0500 Subject: [PATCH] feat: allow Problem subclasses to define static response headers --- examples/basic.py | 25 ++++++++++++++++++++++++- fastapi_rfc7807/middleware.py | 13 ++++++++++++- tests/test_middleware.py | 20 ++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/examples/basic.py b/examples/basic.py index 60cb25e..5655a47 100644 --- a/examples/basic.py +++ b/examples/basic.py @@ -7,18 +7,41 @@ from fastapi import FastAPI -from fastapi_rfc7807.middleware import register +from fastapi_rfc7807.middleware import register, Problem app = FastAPI() register(app) +class AuthenticationError(Problem): + """An example of how to create a custom subclass of the Problem error. + + This class also defines additional headers which should be sent with + the error response. + """ + + headers = { + 'WWW-Authenticate': 'Bearer', + } + + def __init__(self, msg: str) -> None: + super(AuthenticationError, self).__init__( + status=401, + detail=msg, + ) + + @app.get('/') async def root(): return {'message': 'Hello World'} +@app.get('/auth') +async def custom(): + raise AuthenticationError('user is unauthenticated') + + @app.get('/error') async def error(): raise ValueError('something went wrong') diff --git a/fastapi_rfc7807/middleware.py b/fastapi_rfc7807/middleware.py index 624c855..22715f8 100644 --- a/fastapi_rfc7807/middleware.py +++ b/fastapi_rfc7807/middleware.py @@ -6,7 +6,8 @@ import asyncio import http import json -from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, Union +from typing import (Any, Awaitable, Callable, Dict, Mapping, Optional, + Sequence, Union) from fastapi import FastAPI from fastapi.exceptions import RequestValidationError @@ -31,6 +32,13 @@ def __init__(self, *args, debug: bool = False, **kwargs) -> None: self.debug: bool = debug super(ProblemResponse, self).__init__(*args, **kwargs) + def init_headers(self, headers: Mapping[str, str] = None) -> None: + h = dict(headers) if headers else {} + if hasattr(self, 'problem') and self.problem.headers: + h.update(self.problem.headers) + + super(ProblemResponse, self).init_headers(h) + def render(self, content: Any) -> bytes: """Render the provided content as an RFC-7807 Problem JSON-serialized bytes.""" if isinstance(content, Problem): @@ -57,6 +65,7 @@ def render(self, content: Any) -> bytes: # the status code of the Problem. self.status_code = p.status + self.problem = p return p.to_bytes() @@ -76,6 +85,8 @@ class Problem(Exception): more granular control over how/when values are set. """ + headers: Dict[str, str] = {} + def __init__( self, type: Optional[str] = None, diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 4e5bac1..3866e36 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -115,6 +115,26 @@ def test_render_other(self): 'content': "['some', 'other', 'data']", } + def test_render_with_headers(self): + class CustomErr(middleware.Problem): + headers = { + 'custom-header': 'testing', + } + + resp = middleware.ProblemResponse( + CustomErr(), + ) + + assert resp.media_type == 'application/problem+json' + assert resp.debug is False + assert resp.status_code == 500 + assert resp.headers['custom-header'] == 'testing' + assert json.loads(resp.body) == { + 'type': 'about:blank', + 'status': 500, + 'title': 'Internal Server Error', + } + class TestProblem: