Skip to content
This repository has been archived by the owner on Nov 19, 2023. It is now read-only.

Commit

Permalink
feat: adding support for request validation
Browse files Browse the repository at this point in the history
  • Loading branch information
maticardenas committed Oct 14, 2023
1 parent 08d9449 commit 4b5f6fb
Show file tree
Hide file tree
Showing 14 changed files with 381 additions and 37 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ from openapi_tester import SchemaTester
schema_tester = SchemaTester(schema_file_path="./schemas/publishedSpecs.yaml")
```

Once you've instantiated a tester, you can use it to test responses:
Once you've instantiated a tester, you can use it to test responses and request bodies:

```python
from openapi_tester.schema_tester import SchemaTester
Expand All @@ -53,6 +53,12 @@ def test_response_documentation(client):
response = client.get('api/v1/test/1')
assert response.status_code == 200
schema_tester.validate_response(response=response)


def test_request_documentation(client):
response = client.get('api/v1/test/1')
assert response.status_code == 200
schema_tester.validate_request(response=response)
```

If you are using the Django testing framework, you can create a base `APITestCase` that incorporates schema validation:
Expand Down Expand Up @@ -188,11 +194,11 @@ In case of issues with the schema itself, the validator will raise the appropria

The library includes an `OpenAPIClient`, which extends Django REST framework's
[`APIClient` class](https://www.django-rest-framework.org/api-guide/testing/#apiclient).
If you wish to validate each response against OpenAPI schema when writing
If you wish to validate each request and response against OpenAPI schema when writing
unit tests - `OpenAPIClient` is what you need!

To use `OpenAPIClient` simply pass `SchemaTester` instance that should be used
to validate responses and then use it like regular Django testing client:
to validate requests and responses and then use it like regular Django testing client:

```python
schema_tester = SchemaTester()
Expand Down
6 changes: 6 additions & 0 deletions openapi_tester/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@ def __init__(
def request(self, **kwargs) -> Response: # type: ignore[override]
"""Validate fetched response against given OpenAPI schema."""
response = super().request(**kwargs)
if self._is_successful_response(response):
self.schema_tester.validate_request(response)
self.schema_tester.validate_response(response)
return response

@staticmethod
def _is_successful_response(response: Response) -> bool:
return response.status_code < 400

@staticmethod
def _schema_tester_factory() -> SchemaTester:
"""Factory of default ``SchemaTester`` instances."""
Expand Down
12 changes: 6 additions & 6 deletions openapi_tester/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
INVALID_PATTERN_ERROR = "String pattern is not valid regex: {pattern}"
VALIDATE_ENUM_ERROR = "Expected: a member of the enum {enum}\n\nReceived: {received}"
VALIDATE_TYPE_ERROR = 'Expected: {article} "{type}" type value\n\nReceived: {received}'
VALIDATE_MULTIPLE_OF_ERROR = "The response value {data} should be a multiple of {multiple}"
VALIDATE_MINIMUM_ERROR = "The response value {data} is lower than the specified minimum of {minimum}"
VALIDATE_MAXIMUM_ERROR = "The response value {data} exceeds the maximum allowed value of {maximum}"
VALIDATE_MULTIPLE_OF_ERROR = "The value {data} should be a multiple of {multiple}"
VALIDATE_MINIMUM_ERROR = "The value {data} is lower than the specified minimum of {minimum}"
VALIDATE_MAXIMUM_ERROR = "The value {data} exceeds the maximum allowed value of {maximum}"
VALIDATE_MIN_LENGTH_ERROR = 'The length of "{data}" is shorter than the specified minimum length of {min_length}'
VALIDATE_MAX_LENGTH_ERROR = 'The length of "{data}" exceeds the specified maximum length of {max_length}'
VALIDATE_MIN_ARRAY_LENGTH_ERROR = (
Expand All @@ -32,9 +32,9 @@
)
VALIDATE_UNIQUE_ITEMS_ERROR = "The array {data} must contain unique items only"
VALIDATE_NONE_ERROR = "Received a null value for a non-nullable schema object"
VALIDATE_MISSING_RESPONSE_KEY_ERROR = 'The following property is missing in the response data: "{missing_key}"'
VALIDATE_EXCESS_RESPONSE_KEY_ERROR = (
'The following property was found in the response, but is missing from the schema definition: "{excess_key}"'
VALIDATE_MISSING_KEY_ERROR = 'The following property is missing in the {http_message} data: "{missing_key}"'
VALIDATE_EXCESS_KEY_ERROR = (
'The following property was found in the {http_message}, but is missing from the schema definition: "{excess_key}"'
)
VALIDATE_WRITE_ONLY_RESPONSE_KEY_ERROR = (
'The following property was found in the response, but is documented as being "writeOnly": "{write_only_key}"'
Expand Down
2 changes: 2 additions & 0 deletions openapi_tester/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def set_schema(self, schema: dict) -> None:
"""
de_referenced_schema = self.de_reference_schema(schema)
self.validate_schema(de_referenced_schema)

self.schema = self.normalize_schema_paths(de_referenced_schema)

@cached_property
Expand Down Expand Up @@ -245,6 +246,7 @@ class StaticSchemaLoader(BaseSchemaLoader):

def __init__(self, path: str, field_key_map: dict[str, str] | None = None):
super().__init__(field_key_map=field_key_map)

self.path = path if not isinstance(path, pathlib.PosixPath) else str(path)

def load_schema(self) -> dict[str, Any]:
Expand Down
107 changes: 100 additions & 7 deletions openapi_tester/schema_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
INIT_ERROR,
UNDOCUMENTED_SCHEMA_SECTION_ERROR,
VALIDATE_ANY_OF_ERROR,
VALIDATE_EXCESS_RESPONSE_KEY_ERROR,
VALIDATE_MISSING_RESPONSE_KEY_ERROR,
VALIDATE_EXCESS_KEY_ERROR,
VALIDATE_MISSING_KEY_ERROR,
VALIDATE_NONE_ERROR,
VALIDATE_ONE_OF_ERROR,
VALIDATE_WRITE_ONLY_RESPONSE_KEY_ERROR,
Expand Down Expand Up @@ -135,6 +135,7 @@ def get_response_schema_section(self, response: Response) -> dict[str, Any]:
:return dict
"""
schema = self.loader.get_schema()

response_method = response.request["REQUEST_METHOD"].lower() # type: ignore
parameterized_path, _ = self.loader.resolve_path(
response.request["PATH_INFO"], method=response_method # type: ignore
Expand Down Expand Up @@ -198,6 +199,61 @@ def get_response_schema_section(self, response: Response) -> dict[str, Any]:
)
return {}

def get_request_body_schema_section(self, request: dict[str, Any]) -> dict[str, Any]:
"""
Fetches the request section of a schema.
:param response: DRF Request Instance
:return dict
"""
schema = self.loader.get_schema()
request_method = request["REQUEST_METHOD"].lower()

parameterized_path, _ = self.loader.resolve_path(request["PATH_INFO"], method=request_method)
paths_object = self.get_key_value(schema, "paths")

route_object = self.get_key_value(
paths_object,
parameterized_path,
f"\n\nUndocumented route {parameterized_path}.\n\nDocumented routes: " + "\n\t• ".join(paths_object.keys()),
)

method_object = self.get_key_value(
route_object,
request_method,
(
f"\n\nUndocumented method: {request_method}.\n\nDocumented methods: "
f"{[method.lower() for method in route_object.keys() if method.lower() != 'parameters']}."
),
)

if all(key in request for key in ["CONTENT_LENGTH", "CONTENT_TYPE", "wsgi.input"]):
if request["CONTENT_TYPE"] != "application/json":
return {}

request_body_object = self.get_key_value(
method_object,
"requestBody",
f"\n\nNo request body documented for method: {request_method}, path: {parameterized_path}",
)
content_object = self.get_key_value(
request_body_object,
"content",
f"\n\nNo content documented for method: {request_method}, path: {parameterized_path}",
)
json_object = self.get_key_value(
content_object,
r"^application\/.*json$",
(
"\n\nNo `application/json` requests documented for method: "
f"{request_method}, path: {parameterized_path}"
),
use_regex=True,
)
return self.get_key_value(json_object, "schema")

return {}

def handle_one_of(self, schema_section: dict, data: Any, reference: str, **kwargs: Any) -> None:
matches = 0
passed_schema_section_formats = set()
Expand Down Expand Up @@ -226,6 +282,9 @@ def handle_any_of(self, schema_section: dict, data: Any, reference: str, **kwarg
continue
raise DocumentationError(f"{VALIDATE_ANY_OF_ERROR}\n\nReference: {reference}.anyOf")

def is_openapi_schema(self) -> bool:
return self.loader.get_schema().get("openapi") is not None

@staticmethod
def test_is_nullable(schema_item: dict) -> bool:
"""
Expand Down Expand Up @@ -338,6 +397,7 @@ def test_openapi_object(
reference: str,
case_tester: Callable[[str], None] | None = None,
ignore_case: list[str] | None = None,
http_message: str = "response",
) -> None:
"""
1. Validate that casing is correct for both response and schema
Expand All @@ -358,16 +418,17 @@ def test_openapi_object(
self.test_key_casing(key, case_tester, ignore_case)
if key in required_keys and key not in response_keys:
raise DocumentationError(
f"{VALIDATE_MISSING_RESPONSE_KEY_ERROR.format(missing_key=key)}\n\nReference: {reference}."
f"object:key:{key}\n\nHint: Remove the key from your"
" OpenAPI docs, or include it in your API response"
f"{VALIDATE_MISSING_KEY_ERROR.format(missing_key=key, http_message=http_message)}\n\nReference:"
f" {reference}.object:key:{key}\n\nHint: Remove the key from your OpenAPI docs, or include it in"
" your API response"
)
for key in response_keys:
self.test_key_casing(key, case_tester, ignore_case)
if key not in properties and not additional_properties_allowed:
raise DocumentationError(
f"{VALIDATE_EXCESS_RESPONSE_KEY_ERROR.format(excess_key=key)}\n\nReference: {reference}.object:key:"
f"{key}\n\nHint: Remove the key from your API response, or include it in your OpenAPI docs"
f"{VALIDATE_EXCESS_KEY_ERROR.format(excess_key=key, http_message=http_message)}\n\nReference:"
f" {reference}.object:key:{key}\n\nHint: Remove the key from your API response, or include it in"
" your OpenAPI docs"
)
if key in write_only_properties:
raise DocumentationError(
Expand Down Expand Up @@ -403,6 +464,37 @@ def test_openapi_array(self, schema_section: dict[str, Any], data: dict, referen
**kwargs,
)

def validate_request(
self,
response: Response,
case_tester: Callable[[str], None] | None = None,
ignore_case: list[str] | None = None,
validators: list[Callable[[dict[str, Any], Any], str | None]] | None = None,
) -> None:
"""
Verifies that an OpenAPI schema definition matches an API request body.
:param request: The HTTP request
:param case_tester: Optional Callable that checks a string's casing
:param ignore_case: Optional list of keys to ignore in case testing
:param validators: Optional list of validator functions
:param **kwargs: Request keyword arguments
:raises: ``openapi_tester.exceptions.DocumentationError`` for inconsistencies in the API response and schema.
``openapi_tester.exceptions.CaseError`` for case errors.
"""
if self.is_openapi_schema():
# TODO: Implement for other schema types
request_body_schema = self.get_request_body_schema_section(response.request) # type: ignore
if request_body_schema:
self.test_schema_section(
schema_section=request_body_schema,
data=response.renderer_context["request"].data, # type: ignore
case_tester=case_tester or self.case_tester,
ignore_case=ignore_case,
validators=validators,
http_message="request",
)

def validate_response(
self,
response: Response,
Expand All @@ -427,4 +519,5 @@ def validate_response(
case_tester=case_tester or self.case_tester,
ignore_case=ignore_case,
validators=validators,
http_message="response",
)
5 changes: 5 additions & 0 deletions test_project/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ class Meta:
vehicle_type = serializers.CharField(max_length=10)


class PetsSerializer(serializers.Serializer):
name = serializers.CharField(max_length=254)
tag = serializers.CharField(max_length=254, required=False)


class ItemSerializer(serializers.Serializer):
item_type = serializers.CharField(max_length=10)

Expand Down
7 changes: 7 additions & 0 deletions test_project/api/views/pets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from rest_framework.status import HTTP_200_OK
from rest_framework.views import APIView

from test_project.api.serializers import PetsSerializer

if TYPE_CHECKING:
from rest_framework.request import Request

Expand All @@ -14,3 +16,8 @@ class Pet(APIView):
def get(self, request: Request, petId: int) -> Response:
pet = {"name": "doggie", "category": {"id": 1, "name": "Dogs"}, "photoUrls": [], "status": "available"}
return Response(pet, HTTP_200_OK)

def post(self, request) -> Response:
serializer = PetsSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
return Response({"id": 1, "name": request.data["name"]}, 201)
1 change: 1 addition & 0 deletions test_project/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
path("api/<str:version>/snake-case/", SnakeCasedResponse.as_view()),
# ^trailing slash is here on purpose
path("api/<str:version>/router_generated/", include(router.urls)),
path("api/pets", Pet.as_view(), name="get-pets"),
re_path(r"api/pet/(?P<petId>\d+)", Pet.as_view(), name="get-pet"),
]

Expand Down
77 changes: 77 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

from copy import deepcopy
from typing import TYPE_CHECKING, Callable
from unittest.mock import MagicMock

import pytest
from rest_framework.response import Response

from tests.schema_converter import SchemaToPythonConverter
from tests.utils import TEST_ROOT

if TYPE_CHECKING:
from pathlib import Path


@pytest.fixture()
def pets_api_schema() -> Path:
return TEST_ROOT / "schemas" / "openapi_v3_reference_schema.yaml"


@pytest.fixture()
def pets_post_request():
request_body = MagicMock()
request_body.read.return_value = b'{"name": "doggie", "tag": "dog"}'
return {
"PATH_INFO": "/api/pets",
"REQUEST_METHOD": "POST",
"SERVER_PORT": "80",
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": "70",
"CONTENT_TYPE": "application/json",
"wsgi.input": request_body,
"QUERY_STRING": "",
}


@pytest.fixture()
def invalid_pets_post_request():
request_body = MagicMock()
request_body.read.return_value = b'{"surname": "doggie", "species": "dog"}'
return {
"PATH_INFO": "/api/pets",
"REQUEST_METHOD": "POST",
"SERVER_PORT": "80",
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": "70",
"CONTENT_TYPE": "application/json",
"wsgi.input": request_body,
"QUERY_STRING": "",
}


@pytest.fixture()
def response_factory() -> Callable:
def response(
schema: dict | None,
url_fragment: str,
method: str,
status_code: int | str = 200,
response_body: dict | None = None,
) -> Response:
converted_schema = None
if schema:
converted_schema = SchemaToPythonConverter(deepcopy(schema)).result
response = Response(status=int(status_code), data=converted_schema)
response.request = {"REQUEST_METHOD": method, "PATH_INFO": url_fragment} # type: ignore
if schema:
response.json = lambda: converted_schema # type: ignore
elif response_body:
response.request["CONTENT_LENGTH"] = len(response_body) # type: ignore
response.request["CONTENT_TYPE"] = "application/json" # type: ignore
response.request["wsgi.input"] = response_body # type: ignore
response.renderer_context = {"request": MagicMock(data=response_body)} # type: ignore
return response

return response
Loading

0 comments on commit 4b5f6fb

Please sign in to comment.