Skip to content

Commit

Permalink
feat: allow Problem subclasses to define static response headers
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniszewski committed Nov 17, 2020
1 parent 359852f commit a69ee34
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
25 changes: 24 additions & 1 deletion examples/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,41 @@

from fastapi import FastAPI

from fastapi_rfc7807.middleware import register
from fastapi_rfc7807.middleware import register, Problem


app = FastAPI()
register(app)


class AuthenticationError(Problem):
"""An example of how to create a custom subclass of the Problem error.
This class also defines additional headers which should be sent with
the error response.
"""

headers = {
'WWW-Authenticate': 'Bearer',
}

def __init__(self, msg: str) -> None:
super(AuthenticationError, self).__init__(
status=401,
detail=msg,
)


@app.get('/')
async def root():
return {'message': 'Hello World'}


@app.get('/auth')
async def custom():
raise AuthenticationError('user is unauthenticated')


@app.get('/error')
async def error():
raise ValueError('something went wrong')
Expand Down
13 changes: 12 additions & 1 deletion fastapi_rfc7807/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import asyncio
import http
import json
from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, Union
from typing import (Any, Awaitable, Callable, Dict, Mapping, Optional,
Sequence, Union)

from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
Expand All @@ -31,6 +32,13 @@ def __init__(self, *args, debug: bool = False, **kwargs) -> None:
self.debug: bool = debug
super(ProblemResponse, self).__init__(*args, **kwargs)

def init_headers(self, headers: Mapping[str, str] = None) -> None:
h = dict(headers) if headers else {}
if hasattr(self, 'problem') and self.problem.headers:
h.update(self.problem.headers)

super(ProblemResponse, self).init_headers(h)

def render(self, content: Any) -> bytes:
"""Render the provided content as an RFC-7807 Problem JSON-serialized bytes."""
if isinstance(content, Problem):
Expand All @@ -57,6 +65,7 @@ def render(self, content: Any) -> bytes:
# the status code of the Problem.
self.status_code = p.status

self.problem = p
return p.to_bytes()


Expand All @@ -76,6 +85,8 @@ class Problem(Exception):
more granular control over how/when values are set.
"""

headers: Dict[str, str] = {}

def __init__(
self,
type: Optional[str] = None,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,26 @@ def test_render_other(self):
'content': "['some', 'other', 'data']",
}

def test_render_with_headers(self):
class CustomErr(middleware.Problem):
headers = {
'custom-header': 'testing',
}

resp = middleware.ProblemResponse(
CustomErr(),
)

assert resp.media_type == 'application/problem+json'
assert resp.debug is False
assert resp.status_code == 500
assert resp.headers['custom-header'] == 'testing'
assert json.loads(resp.body) == {
'type': 'about:blank',
'status': 500,
'title': 'Internal Server Error',
}


class TestProblem:

Expand Down

0 comments on commit a69ee34

Please sign in to comment.