From 4947c9133406b7dafa66291f5633fefe81ee9439 Mon Sep 17 00:00:00 2001 From: Jesus Lara Date: Thu, 19 Dec 2024 17:03:47 +0100 Subject: [PATCH] manage enum on serialization of ModelViews --- examples/test_model.py | 34 ++++++++++++++++-- navigator/libs/json.pyx | 8 +++++ navigator/version.py | 2 +- navigator/views/abstract.py | 70 ++++++++++++++++++++++++------------- navigator/views/model.py | 34 +++++++++++------- 5 files changed, 107 insertions(+), 41 deletions(-) diff --git a/examples/test_model.py b/examples/test_model.py index 2e8b5e7a..50823c44 100644 --- a/examples/test_model.py +++ b/examples/test_model.py @@ -1,5 +1,6 @@ from typing import Union import asyncio +from enum import Enum from datetime import datetime from aiohttp import web from navconfig.logging import logging @@ -9,14 +10,33 @@ from navigator import Application from navigator.responses import HTMLResponse from navigator.views import ModelView +from navigator.conf import PG_USER, PG_PWD, PG_HOST, PG_PORT + +# Example DSN: +dsn = f'postgresql://{PG_USER}:{PG_PWD}@{PG_HOST}:{PG_PORT}/pruebas' + +class AirportType(Enum): + """ + Enum for Airport Types. + """ + CITY = 1 + INTERNATIONAL = 2 + DOMESTIC = 3 class Country(Model): country_code: str = Column(primary_key=True) country: str + class Airport(Model): iata: str = Column(primary_key=True, required=True, label='IATA Code') airport: str = Column(required=True, label="Airport Name") + airport_type: AirportType = Column( + required=True, + label='Airport Type', + choices=AirportType, + default=AirportType.CITY + ) city: str country: str created_by: int @@ -40,10 +60,19 @@ async def hola(request: web.Request) -> web.Response: class AirportHandler(ModelView): model: Model = Airport pk: Union[str, list] = ['iata'] + dsn: str = dsn - async def _get_created_by(self, value, column, **kwargs): + async def _set_created_by(self, value, column, **kwargs): return await self.get_userid(session=self._session) + async def _put_callback(self, response: web.Response, result, *args, **kwargs): + print('RESULT > ', result) + print('RESPONSE > ', response) + print('PUT CALLBACK') + return response + + _post_callback = _put_callback + async def on_startup(self, *args, **kwargs): print(args, kwargs) print('THIS CODE RUN ON STARTUP') @@ -67,6 +96,7 @@ async def start_example(db): iata character varying(3), airport character varying(60), city character varying(20), + airport_type integer, country character varying(30), created_by integer, created_at timestamp with time zone NOT NULL DEFAULT now(), @@ -122,7 +152,7 @@ async def end_example(db): "password": "12345678", "host": "127.0.0.1", "port": "5432", - "database": "navigator", + "database": "pruebas", "DEBUG": True, } kwargs = { diff --git a/navigator/libs/json.pyx b/navigator/libs/json.pyx index e79b1f09..3ef33f38 100644 --- a/navigator/libs/json.pyx +++ b/navigator/libs/json.pyx @@ -13,6 +13,7 @@ from psycopg2 import Binary # Import Binary from psycopg2 from typing import Any, Union from pathlib import PosixPath, PurePath, Path from decimal import Decimal +from enum import Enum, EnumType from ..exceptions.exceptions cimport ValidationError import orjson @@ -52,6 +53,13 @@ cdef class JSONContent: return [obj.lower, up] elif hasattr(obj, 'tolist'): # numpy array return obj.tolist() + elif isinstance(obj, Enum): # Handle Enum serialization + if obj is None: + return None + return obj.value if hasattr(obj, 'value') else obj.name + elif isinstance(obj, type) and issubclass(obj, Enum): + return [{'value': e.value, 'name': e.name} for e in obj] + # return [e.name for e in obj] # Serialize the names of the Enum class members elif isinstance(obj, _MISSING_TYPE): return None elif obj == MISSING: diff --git a/navigator/version.py b/navigator/version.py index e1ecbdb4..48f2ea44 100644 --- a/navigator/version.py +++ b/navigator/version.py @@ -4,7 +4,7 @@ __description__ = ( "Navigator Web Framework based on aiohttp, " "with batteries included." ) -__version__ = "2.12.4" +__version__ = "2.12.5" __author__ = "Jesus Lara" __author_email__ = "jesuslarag@gmail.com" __license__ = "BSD" diff --git a/navigator/views/abstract.py b/navigator/views/abstract.py index dde977d9..1e694ee7 100644 --- a/navigator/views/abstract.py +++ b/navigator/views/abstract.py @@ -1,10 +1,9 @@ -from typing import Optional, Union, Any, TypeVar -from collections.abc import Callable +from typing import Optional, Union, Any, TypeVar, Type +from collections.abc import Awaitable, Callable import asyncio -import copy -from aiohttp import web, hdrs import traceback from functools import wraps +from aiohttp import web, hdrs try: import babel BABEL_INSTALLED = True @@ -82,13 +81,15 @@ async def __aenter__(self): async def default_connection(self, request: web.Request): if self._dbname in request.app: return request.app[self._dbname] - kwargs = { - "server_settings": { - 'client_min_messages': 'notice', - 'max_parallel_workers': '24', - 'tcp_keepalives_idle': '30' + kwargs = {} + if self.driver == 'pg': + kwargs = { + "server_settings": { + 'client_min_messages': 'notice', + 'max_parallel_workers': '24', + 'tcp_keepalives_idle': '30' + } } - } pool = AsyncPool( self.driver, dsn=default_dsn, @@ -148,37 +149,37 @@ class AbstractModel(BaseView): in: Model type: BaseModel required: true + description: DataModel to be used. + - name: get_model + in: Model + type: BaseModel + required: false + description: DataModel to be used. """ - model: BaseModel = None - get_model: BaseModel = None + model: Type[BaseModel] = None + get_model: Type[BaseModel] = None # Signal for startup method for this ModelView on_startup: Optional[Callable] = None on_shutdown: Optional[Callable] = None model_kwargs: dict = {} name: str = "Model" + # Connection parameters + driver: str = 'pg' + dsn: str = None + credentials: dict = None + dbname: str = 'nav.model' + handler: ConnectionHandler def __init__(self, request, *args, **kwargs): self.__name__ = self.model.__name__ self._session = None - driver = kwargs.pop('driver', 'pg') - dsn = kwargs.pop('dsn', None) - credentials = kwargs.pop('credentials', {}) - dbname = kwargs.pop('dbname', 'nav.model') ## getting get Model: if not self.get_model: self.get_model = self.model super().__init__(request, *args, **kwargs) - # Database Connection Handler - self.handler = ConnectionHandler( - driver, - dsn=dsn, - dbname=dbname, - credentials=credentials, - model_kwargs=self.model_kwargs - ) @classmethod - def configure(cls, app: WebApp, path: str = None) -> WebApp: + def configure(cls, app: WebApp, path: str = None, **kwargs) -> WebApp: """configure. @@ -221,6 +222,25 @@ def configure(cls, app: WebApp, path: str = None) -> WebApp: app.router.add_view( r"{url}{{meta:(:.*)?}}".format(url=url), cls ) + # Use kwargs to reconfigure the connection handler if needed + if 'driver' in kwargs: + cls.driver = kwargs['driver'] + if 'dsn' in kwargs: + cls.dsn = kwargs['dsn'] + if 'credentials' in kwargs: + cls.credentials = kwargs['credentials'] + if 'dbname' in kwargs: + cls.dbname = kwargs['dbname'] + if 'model_kwargs' in kwargs: + cls.model_kwargs = kwargs['model_kwargs'] + # Database Connection Handler + cls.handler = ConnectionHandler( + cls.driver, + dsn=cls.dsn, + dbname=cls.dbname, + credentials=cls.credentials, + model_kwargs=cls.model_kwargs + ) async def validate_payload(self, data: Optional[Union[dict, list]] = None): """Get information for usage in Form.""" diff --git a/navigator/views/model.py b/navigator/views/model.py index 84f6783a..8452d51b 100644 --- a/navigator/views/model.py +++ b/navigator/views/model.py @@ -1,5 +1,5 @@ -from collections.abc import Awaitable -from typing import Optional, Union, Any +from collections.abc import Iterable +from typing import Optional, Union, Any, Awaitable, Callable import importlib import asyncio from aiohttp import web @@ -51,6 +51,9 @@ async def load_model(tablename: str, schema: str, connection: Any) -> Model: ) +CallbackType = Optional[Callable[[web.Response, BaseModel], Awaitable[None]]] + + class ModelView(AbstractModel): """ModelView. @@ -70,16 +73,16 @@ class ModelView(AbstractModel): get_model: BaseModel = None model_name: str = None # Override the current model with other. path: str = None - pk: Union[str, list] = None + pk: Optional[Iterable] = None _required: list = [] _primaries: list = [] _hidden: list = [] # New Callables to be used on response: - _get_callback: Optional[Awaitable] = None - _put_callback: Optional[Awaitable] = None - _post_callback: Optional[Awaitable] = None - _patch_callback: Optional[Awaitable] = None - _delete_callback: Optional[Awaitable] = None + _get_callback: CallbackType = None + _put_callback: CallbackType = None + _post_callback: CallbackType = None + _patch_callback: CallbackType = None + _delete_callback: CallbackType = None def __init__(self, request, *args, **kwargs): if self.model_name is not None: @@ -371,7 +374,12 @@ async def _get_filters(): if len(res) == 1: return res[0] return res - args = {self.pk: _primary} + elif isinstance(self.pk, str): + args = {self.pk: _primary} # pylint: disable=E1143 + else: + raise ValueError( + f"Invalid PK definition for {self.__name__}: {self.pk}" + ) args = {**_filter, **args} return await self.get_model.get(**args) elif len(qp) > 0: @@ -1271,7 +1279,7 @@ def _del_primary(self, args: dict = None) -> Any: try: _args = {} paramlist = [ - item.strip() for item in args["id"].split("/") if item.strip() + item.strip() for item in args.get('id', '').split("/") if item.strip() ] if not paramlist: return None @@ -1306,8 +1314,8 @@ def _del_primary(self, args: dict = None) -> Any: # TODO: use validation from datamodel # evaluate the corrected type for fields: val = paramlist.pop(0) - args[key] = val - return args + _args[key] = val + return _args except KeyError: pass else: @@ -1352,7 +1360,7 @@ async def delete(self): if isinstance(objid, list): data = [] for entry in objid: - args = {self.pk: entry} + args = {self.pk: entry} # noqa obj = await self.model.get(**args) data.append(await obj.delete()) else: