Skip to content

Commit

Permalink
Merge pull request #254 from phenobarbital/dev
Browse files Browse the repository at this point in the history
new decorator for validate payload using dataclasses or basemodels
  • Loading branch information
phenobarbital authored May 13, 2024
2 parents 821ef1f + 3d699d3 commit 6571fac
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 11 deletions.
27 changes: 27 additions & 0 deletions examples/test_client_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import aiohttp
import asyncio
import json
from navigator.conf import APP_HOST, APP_PORT

async def test_validate_payload():
url = f'http://{APP_HOST}:{APP_PORT}/api/v1/animal'
headers = {'Content-Type': 'application/json'}

payload = {
"Lion": {"name": "Lion", "specie": "Panthera leo", "age": 5},
"Elephant": {"name": "Elephant", "habitat": "Savannah", "is_wild": True},
"Snake": {"name": "Snake", "specie": "Reptilia", "age": 2}
}

async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload) as response:
content_type = response.headers.get('Content-Type', '')
if 'application/json' in content_type:
response_data = await response.json()
else:
response_data = await response.text()
print(f"Status: {response.status}")
print(f"Response: {response_data}")

if __name__ == '__main__':
asyncio.run(test_validate_payload())
2 changes: 2 additions & 0 deletions navigator/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#### BASIC Configuration
APP_NAME = config.get("APP_NAME", fallback="Navigator")
APP_HOST = config.get('APP_HOST', fallback="localhost")
APP_PORT = config.get('APP_PORT', fallback=5000)
APP_TITLE = config.get("APP_TITLE", fallback="NAVIGATOR").upper()
APP_LOGNAME = config.get("APP_LOGNAME", fallback="Navigator")
logging.debug(f"::: STARTING APP: {APP_NAME} ::: ")
Expand Down
203 changes: 203 additions & 0 deletions navigator/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from typing import Union, Any
from collections.abc import Callable
from functools import wraps
import asyncio
import inspect
from dataclasses import dataclass, is_dataclass
from aiohttp import web
from aiohttp.abc import AbstractView
from aiohttp.web_exceptions import HTTPError
from datamodel import BaseModel
from datamodel.exceptions import ValidationError
from navigator_auth.conf import exclude_list


"""
Useful decorators for the navigator app.
"""
def allow_anonymous(func):
@wraps(func)
async def wrapper(*args, **kwargs):
request = args[-1]
path = request.path
exclude_list.add(path) # Add this path to exclude_list to bypass auth
return await func(*args, **kwargs)
return wrapper


async def validate_model(request: web.Request, model: Union[dataclass, BaseModel]) -> tuple:
"""
validate_model.
Description: Validate a model using a dataclass or BaseModel.
Args:
request (web.Request): aiohttp Request object.
model (Union[dataclass,BaseModel]): Model can be a dataclass or BaseModel.
Returns:
tuple: data, errors (if any)
"""
errors: dict = {}
data = None
if request.method in ('OPTIONS', 'HEAD'):
# There is no validation for OPTIONS/HEAD methods:
return (True, None)
elif request.method in ("POST", "PUT", "PATCH"):
if request.content_type == "application/json":
# getting data from POST
data = await request.json()
else:
data = await request.post()
elif request.method == "GET":
data = {key: val for (key, val) in request.query.items()}
else:
raise web.HTTPNotImplemented(
reason=f"{request.method} Method not Implemented for Data Validation.",
content_type="application/json",
)
if data is None:
raise web.HTTPNotFound(
reason="There is no content for validation.",
content_type="application/json",
)

async def validate_data(data):
valid = None
errors = {}
if issubclass(model, BaseModel):
try:
valid = model(**data)
except (TypeError, ValueError, AttributeError) as exc:
errors = {
"error": f"Invalid Data: {exc}"
}
except ValidationError as exc:
errors = {
"error": f"Invalid Data: {exc}",
"payload": exc.payload
}
elif is_dataclass(model):
try:
valid = model(**data)
except Exception as err:
errors = {"error": f"Invalid Data: {err}"}
else:
errors = {"error": "Invalid Model Type"}
return valid, errors

errors = {}
valid = {}

if isinstance(data, dict):
if isinstance(list(data.values())[0], dict):
for k, v in data.items():
item_valid, item_error = await validate_data(v)
if item_valid:
valid[k] = item_valid
if item_error:
errors.update(item_error)
if valid:
return valid, errors
valid, error = await validate_data(data)
if error:
errors.update(error)
return valid, errors

elif isinstance(data, list):
valid = []
for item in data:
item_valid, item_error = await validate_data(item)
if item_valid:
valid.append(item_valid)
if item_error:
errors.update(item_error)
return valid, errors
else:
return data, {
"error": "Invalid type for Data Input, expecting a Dict or List."
}


def validate_payload(*models: Union[type[BaseModel], type[dataclass]]) -> Callable:
"""validate_payload.
Description: Validate Request payload using dataclasses or Datamodels.
Args:
models (Union[dataclass,BaseModel]): List of models can be used for validation.
kwargs: Any other data passed as arguments to function.
Returns:
Callable: Decorator function adding validated data to handler.
"""
def _validation(func: Callable) -> Callable:
@wraps(func)
async def _wrap(*args: Any, **kwargs) -> web.StreamResponse:
## building arguments:
# Supports class based views see web.View
if isinstance(args[0], AbstractView):
request = args[0].request
elif isinstance(args[0], web.View):
request = args[0].request
else:
request = args[-1]

content_type = request.headers.get('Content-Type')

sig = inspect.signature(func)
bound_args = sig.bind_partial(*args, **kwargs)
bound_args.apply_defaults()

# Dictionary to hold validation results
validated_data = {}
errors = {}

# Validate payload using the model
for model in models:
try:
data, model_errors = await validate_model(
request, model
)
model_name = model.__name__.lower()
validated_data[model_name] = data
if model_errors:
errors[model_name] = model_errors
except Exception as err:
if content_type == "application/json":
return web.json_response(
{
"error": f"Error during validation of model {model.__name__}: {err}"
}, status=400
)
raise web.HTTPBadRequest(
reason=f"Error during validation of model {model.__name__}: {err}",
content_type="application/json"
)

# Assign validated data to respective function arguments
for param_name, param in sig.parameters.items():
model_name = param_name.lower()
if model_name in validated_data:
bound_args.arguments[param_name] = validated_data[model_name]

bound_args.arguments['errors'] = errors

# Call the original function with new arguments
try:
if asyncio.iscoroutinefunction(func):
response = await func(*bound_args.args, **bound_args.kwargs)
else:
response = func(*bound_args.args, **bound_args.kwargs)
return response
except HTTPError as ex:
return ex
except Exception as err:
if content_type == "application/json":
return web.json_response(
{"error": str(err)}, status=500
)
raise web.HTTPInternalServerError(
reason=f"Error Calling Function {func.__name__}: {err}"
) from err

return _wrap

return _validation
2 changes: 1 addition & 1 deletion navigator/libs/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def decode(self, passphrase: Any) -> bytes:
msg = codecs.decode(passphrase, "hex")
if self.type == "AES":
try:
return self.cipher.decrypt(msg)[len(self.iv) :].decode("utf-8")
return self.cipher.decrypt(msg)[len(self.iv):].decode("utf-8")
except Exception as e:
print(e)
raise (e)
Expand Down
4 changes: 2 additions & 2 deletions navigator/libs/mutables.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def keys(self) -> list:

def set(self, key, value) -> None:
self.mapping[key] = value
if not key in self._columns:
if key not in self._columns:
self._columns.append(key)

### Section: Simple magic methods
Expand All @@ -76,7 +76,7 @@ def __delitem__(self, key) -> None:

def __setitem__(self, key, value):
self.mapping[key] = value
if not key in self._columns:
if key not in self._columns:
self._columns.append(key)

def __getitem__(self, key: Union[str, int]) -> Any:
Expand Down
21 changes: 14 additions & 7 deletions navigator/navigator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
except FileExistsError:
print('Error: Missing ENV directory for Navconfig.')

from navigator_auth.conf import exclude_list
from .exceptions.handlers import nav_exception_handler, shutdown
from .handlers import BaseAppHandler
from .functions import cPrint
Expand Down Expand Up @@ -247,10 +248,17 @@ def _decorator(func):
def Response(self, content: Any) -> web.Response:
return web.Response(text=content)

def get(self, route: str):
@property
def router(self):
return self.get_app().router

def get(self, route: str, allow_anonymous: bool = False):
app = self.get_app()

def _decorator(func):
if allow_anonymous is True:
# add this route to the exclude list:
exclude_list.append(route)
r = app.router.add_get(route, func)

@wraps(func)
Expand All @@ -267,12 +275,11 @@ async def _wrap(request, *args, **kwargs):

return _decorator

@property
def router(self):
return self.get_app().router

def post(self, route: str):
def post(self, route: str, allow_anonymous: bool = False):
def _decorator(func):
if allow_anonymous is True:
# add this route to the exclude list:
exclude_list.append(route)
self.get_app().router.add_post(route, func)

@wraps(func)
Expand Down Expand Up @@ -348,7 +355,7 @@ async def _wrap(*args: Any) -> web.StreamResponse:

def validate(self, model: Union[dataclass, BaseModel], **kwargs) -> web.Response:
"""validate.
Description: Validate Request input using a Datamodel
Description: Validate Request input using a dataclass or Datamodel.
Args:
model (Union[dataclass,BaseModel]): Model can be a dataclass or BaseModel.
kwargs: Any other data passed as arguments to function.
Expand Down
2 changes: 1 addition & 1 deletion navigator/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
__description__ = (
"Navigator Web Framework based on aiohttp, " "with batteries included."
)
__version__ = "2.8.44"
__version__ = "2.8.45"
__author__ = "Jesus Lara"
__author_email__ = "jesuslarag@gmail.com"
__license__ = "BSD"

0 comments on commit 6571fac

Please sign in to comment.