diff --git a/src/biodm/components/controllers/controller.py b/src/biodm/components/controllers/controller.py index 6a8e1bc..325e510 100644 --- a/src/biodm/components/controllers/controller.py +++ b/src/biodm/components/controllers/controller.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, List, TYPE_CHECKING, Optional -from marshmallow.schema import Schema, EXCLUDE, INCLUDE +from marshmallow.schema import Schema from marshmallow.exceptions import ValidationError from sqlalchemy.exc import MissingGreenlet @@ -84,14 +84,12 @@ def validate(cls, data: bytes) -> (Any | list | dict | None): try: json_data = json.load(io.BytesIO(data)) cls.schema.many = isinstance(json_data, list) - cls.schema.unknown = EXCLUDE return cls.schema.loads(json_data=data) + except ValidationError as e: raise PayloadValidationError(e) from e except json.JSONDecodeError as e: raise PayloadJSONDecodingError(e) from e - except Exception as e: - raise e @classmethod def serialize(cls, data: dict | Base | List[Base], many: bool, only: Optional[List[str]]=None) -> str: @@ -101,20 +99,20 @@ def serialize(cls, data: dict | Base | List[Base], many: bool, only: Optional[Li :type data: dict, class:`biodm.components.Base`, List[class:`biodm.components.Base`] :param many: plurality flag, essential to marshmallow :type data: bool + :param only: List of fields to restrict serialization on, optional, defaults to None + :type only: List[str] """ try: - # cls.schema.only = only - # cls.schema.partial = True - cls.schema.unknown = INCLUDE + # Save and plug in restristed fields. dump_fields = cls.schema.dump_fields if only: cls.schema.dump_fields = { k:v for k, v in dump_fields.items() if k in only } serialized = cls.schema.dump(data, many=many) - #  cls.schema.only = None - # cls.schema.partial = None + # Restore to full afterwards. cls.schema.dump_fields = dump_fields return json.dumps(serialized, indent=cls.app.config.INDENT) + except MissingGreenlet as e: raise AsyncDBError(e) from e diff --git a/src/biodm/components/controllers/resourcecontroller.py b/src/biodm/components/controllers/resourcecontroller.py index b295f3c..bbdfdde 100644 --- a/src/biodm/components/controllers/resourcecontroller.py +++ b/src/biodm/components/controllers/resourcecontroller.py @@ -2,6 +2,7 @@ from functools import partial from typing import TYPE_CHECKING, List, Any +from marshmallow.schema import EXCLUDE from starlette.routing import Mount, Route from starlette.requests import Request from starlette.responses import Response @@ -69,9 +70,11 @@ def __init__(self, app: Api, entity: str=None, table: Base=None, schema: Schema= super().__init__(app=app) self.resource = entity if entity else self._infer_entity_name() self.table = table if table else self._infer_table() + self.pk = tuple(self.table.pk()) self.svc = self._infer_svc()(app=self.app, table=self.table) - self.__class__.schema = schema() if schema else self._infer_schema() + # schema = schema if schema else self._infer_schema() + self.__class__.schema = (schema if schema else self._infer_schema())(unknown=EXCLUDE) def _infer_entity_name(self) -> str: """Infer entity name from controller name.""" @@ -117,7 +120,7 @@ def _infer_schema(self) -> Schema: """Tries to import from instance module reference.""" isn = f"{self.resource}Schema" try: - return self.app.schemas.__dict__[isn]() + return self.app.schemas.__dict__[isn] except Exception as e: raise ValueError( f"{self.__class__.__name__} could not find {isn} Schema. " diff --git a/src/biodm/components/services/dbservice.py b/src/biodm/components/services/dbservice.py index 11866ce..d072bac 100644 --- a/src/biodm/components/services/dbservice.py +++ b/src/biodm/components/services/dbservice.py @@ -4,10 +4,10 @@ from sqlalchemy import select, update, delete from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Bundle, Load, load_only, joinedload +from sqlalchemy.orm import load_only, joinedload from sqlalchemy.sql import Insert, Update, Delete, Select -from biodm.utils.utils import unevalled_all, unevalled_or, to_it, DictBundle, it_to, partition +from biodm.utils.utils import unevalled_all, unevalled_or, to_it, it_to, partition from biodm.component import CRUDApiComponent from biodm.components import Base from biodm.managers import DatabaseManager @@ -38,26 +38,6 @@ async def _select(self, stmt: Select, session: AsyncSession) -> (Any | None): return row raise FailedRead("Select returned no result.") - @DatabaseManager.in_session - async def _query(self, stmt: Any, filter, joins: List[Any], session: AsyncSession) -> (Any | None): - """QUERY one from database. - - Temporary measure - - > TODO: change when new sqlalchemy release comes out. - Query is part of sqlalchemy legacy API, it is not supported in async mode. - select has a bug for Bundle features. Refer to: - - https://github.com/sqlalchemy/sqlalchemy/discussions/11345 - """ - def sync_inner(session, stmt, filter, joins): - stmt = session.query(stmt).filter(filter) - for join in joins: - stmt = stmt.join(join) - row = stmt.one() - if row: - return it_to(row) - raise FailedRead("Query returned no result.") - return await session.run_sync(sync_inner, stmt, filter, joins) - @DatabaseManager.in_session async def _select_many(self, stmt: Select, session: AsyncSession) -> List[Any]: """SELECT many from database.""" @@ -167,21 +147,80 @@ def _parse_int_operators(self, attr) -> Tuple[str, str]: f"Expecting either 'field=v1,v2' pairs or integrer" f" operators 'field.op(v)' op in {SUPPORTED_INT_OPERATORS}") - async def filter(self, query_params: dict, **kwargs) -> List[Base]: + def _filter_process_attr(self, stmt: Select, attr: List[str]): + """Iterates over attribute parts (e.g. table.attr.x.y.z) joining tables along the way. + + :param stmt: select statement in construction + :type stmt: Select + :param attr: attribute name parts of the querystring + :type attr: List[str] + :raises ValueError: When name is incorrect. + :return: Resulting statement and handles to column object and its type + :rtype: Tuple[Select, Tuple[Column, type]] + """ + table = self.table + for nested in attr[:-1]: + jtn = table.target_table(nested) + if jtn is None: + raise ValueError(f"Invalid nested entity name {nested}.") + jtable = jtn.decl_class + stmt = stmt.join(jtable) + table = jtable + + return stmt, table.colinfo(attr[-1]) + + def _restrict_select_on_fields( + self, + stmt: Select, + fields: List[str], + nested: List[str], + serializer: Callable=None + ) -> Select: + """set load_only options of a select(table) statement given a list of fields. + + :param stmt: _description_ + :type stmt: Select + :param fields: _description_ + :type fields: List[str] + :param nested: _description_ + :type nested: List[str] + :param serializer: _description_, defaults to None + :type serializer: Callable, optional + :return: _description_ + :rtype: Select + """ + # Restrict serializer fields so that it doesn't trigger any lazy loading. + serializer = partial(serializer, only=fields + nested) if serializer else None + return ( + stmt.options( + load_only( + *[getattr(self.table, f) for f in fields] + ), + *[ + joinedload(getattr(self.table, n)) + for n in nested + ] + ), + serializer + ) + + async def filter(self, query_params: dict, serializer: Callable=None, **kwargs) -> List[Base]: """READ rows filted on query parameters.""" # Get special parameters fields = query_params.pop('fields', None) offset = query_params.pop('start', None) limit = query_params.pop('end', None) - # reverse = query_params.pop('reverse') - - # TODO: apply same as read - # if fields: - # fields = fields.split(',') if fields else None - # stmt = - + reverse = query_params.pop('reverse', None) + # TODO: apply limit to nested lists as well. stmt = select(self.table) + if fields: + fields = fields.split(",") + nested, fields = partition(fields, lambda x: x in self.table.relationships()) + stmt, serializer = self._restrict_select_on_fields( + stmt, fields, nested, serializer + ) + for dskey, csval in query_params.items(): attr, values = dskey.split("."), csval.split(",") # exclude = False @@ -192,19 +231,7 @@ async def filter(self, query_params: dict, **kwargs) -> List[Base]: operator = None if csval else self._parse_int_operators(attr.pop()) # elif any(op in dskey for op in SUPPORTED_INT_OPERATORS): # raise ValueError("'field.op()=value' type of query is not yet supported.") - - # For every nested entity of the attribute, join table. - table = self.table - for nested in attr[:-1]: - jtn = table.target_table(nested) - if jtn is None: - raise ValueError(f"Invalid nested entity name {nested}.") - jtable = jtn.decl_class - stmt = stmt.join(jtable) - table = jtable - - # Get field info from last joined table. - col, ctype = table.colinfo(attr[-1]) + stmt, (col, ctype) = self._filter_process_attr(stmt, attr) # Numerical operators. if operator: @@ -241,7 +268,7 @@ async def filter(self, query_params: dict, **kwargs) -> List[Base]: # if exclude: # stmt = select(self.table.not_in(stmt)) stmt = stmt.offset(offset).limit(limit) - return await self._select_many(stmt, **kwargs) + return await self._select_many(stmt, serializer=serializer, **kwargs) async def read(self, pk_val: List[Any], @@ -258,24 +285,12 @@ async def read(self, :return: SQLAlchemy result item. :rtype: Base """ - fields = fields or [] - rels = self.table.relationships() - nested, fields = partition(fields, lambda x: x in rels) - + nested, fields = partition(fields or [], lambda x: x in self.table.relationships()) stmt = select(self.table) if fields or nested: - # Restrict fields and joined tables. - stmt = stmt.options( - load_only( - *[getattr(self.table, f) for f in fields] - ), - *[ - joinedload(getattr(self.table, n)) - for n in nested - ] + stmt, serializer = self._restrict_select_on_fields( + stmt, fields, nested, serializer ) - # Restrict serializer fields so that it doesn't trigger any lazy loading. - serializer = partial(serializer, only=fields + nested) if serializer else None stmt = stmt.where(self.gen_cond(pk_val)) return await self._select(stmt, serializer=serializer, **kwargs) @@ -289,10 +304,12 @@ async def update(self, pk_val, data: dict, **kwargs) -> Base: :return: _description_ :rtype: Base """ - stmt = update(self.table)\ - .where(self.gen_cond(pk_val))\ - .values(**data)\ - .returning(self.table) + stmt = ( + update(self.table) + .where(self.gen_cond(pk_val)) + .values(**data) + .returning(self.table) + ) return await self._update(stmt, **kwargs) async def delete(self, pk_val, **kwargs) -> Any: