Skip to content

Commit

Permalink
support fields for filter as well + cleanup, private functions to red…
Browse files Browse the repository at this point in the history
…uce internals
  • Loading branch information
Etienne Jodry authored and Etienne Jodry committed May 8, 2024
1 parent c3349c7 commit a462439
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 74 deletions.
16 changes: 7 additions & 9 deletions src/biodm/components/controllers/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
from typing import Any, List, TYPE_CHECKING, Optional

from marshmallow.schema import Schema, EXCLUDE, INCLUDE
from marshmallow.schema import Schema
from marshmallow.exceptions import ValidationError
from sqlalchemy.exc import MissingGreenlet

Expand Down Expand Up @@ -84,14 +84,12 @@ def validate(cls, data: bytes) -> (Any | list | dict | None):
try:
json_data = json.load(io.BytesIO(data))
cls.schema.many = isinstance(json_data, list)
cls.schema.unknown = EXCLUDE
return cls.schema.loads(json_data=data)

except ValidationError as e:
raise PayloadValidationError(e) from e
except json.JSONDecodeError as e:
raise PayloadJSONDecodingError(e) from e
except Exception as e:
raise e

@classmethod
def serialize(cls, data: dict | Base | List[Base], many: bool, only: Optional[List[str]]=None) -> str:
Expand All @@ -101,20 +99,20 @@ def serialize(cls, data: dict | Base | List[Base], many: bool, only: Optional[Li
:type data: dict, class:`biodm.components.Base`, List[class:`biodm.components.Base`]
:param many: plurality flag, essential to marshmallow
:type data: bool
:param only: List of fields to restrict serialization on, optional, defaults to None
:type only: List[str]
"""
try:
# cls.schema.only = only
# cls.schema.partial = True
cls.schema.unknown = INCLUDE
# Save and plug in restristed fields.
dump_fields = cls.schema.dump_fields
if only:
cls.schema.dump_fields = {
k:v for k, v in dump_fields.items() if k in only
}
serialized = cls.schema.dump(data, many=many)
#  cls.schema.only = None
# cls.schema.partial = None
# Restore to full afterwards.
cls.schema.dump_fields = dump_fields
return json.dumps(serialized, indent=cls.app.config.INDENT)

except MissingGreenlet as e:
raise AsyncDBError(e) from e
7 changes: 5 additions & 2 deletions src/biodm/components/controllers/resourcecontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial
from typing import TYPE_CHECKING, List, Any

from marshmallow.schema import EXCLUDE
from starlette.routing import Mount, Route
from starlette.requests import Request
from starlette.responses import Response
Expand Down Expand Up @@ -69,9 +70,11 @@ def __init__(self, app: Api, entity: str=None, table: Base=None, schema: Schema=
super().__init__(app=app)
self.resource = entity if entity else self._infer_entity_name()
self.table = table if table else self._infer_table()

self.pk = tuple(self.table.pk())
self.svc = self._infer_svc()(app=self.app, table=self.table)
self.__class__.schema = schema() if schema else self._infer_schema()
# schema = schema if schema else self._infer_schema()
self.__class__.schema = (schema if schema else self._infer_schema())(unknown=EXCLUDE)

def _infer_entity_name(self) -> str:
"""Infer entity name from controller name."""
Expand Down Expand Up @@ -117,7 +120,7 @@ def _infer_schema(self) -> Schema:
"""Tries to import from instance module reference."""
isn = f"{self.resource}Schema"
try:
return self.app.schemas.__dict__[isn]()
return self.app.schemas.__dict__[isn]
except Exception as e:
raise ValueError(
f"{self.__class__.__name__} could not find {isn} Schema. "
Expand Down
143 changes: 80 additions & 63 deletions src/biodm/components/services/dbservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from sqlalchemy import select, update, delete
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Bundle, Load, load_only, joinedload
from sqlalchemy.orm import load_only, joinedload
from sqlalchemy.sql import Insert, Update, Delete, Select

from biodm.utils.utils import unevalled_all, unevalled_or, to_it, DictBundle, it_to, partition
from biodm.utils.utils import unevalled_all, unevalled_or, to_it, it_to, partition
from biodm.component import CRUDApiComponent
from biodm.components import Base
from biodm.managers import DatabaseManager
Expand Down Expand Up @@ -38,26 +38,6 @@ async def _select(self, stmt: Select, session: AsyncSession) -> (Any | None):
return row
raise FailedRead("Select returned no result.")

@DatabaseManager.in_session
async def _query(self, stmt: Any, filter, joins: List[Any], session: AsyncSession) -> (Any | None):
"""QUERY one from database.
Temporary measure -
> TODO: change when new sqlalchemy release comes out.
Query is part of sqlalchemy legacy API, it is not supported in async mode.
select has a bug for Bundle features. Refer to:
- https://github.com/sqlalchemy/sqlalchemy/discussions/11345
"""
def sync_inner(session, stmt, filter, joins):
stmt = session.query(stmt).filter(filter)
for join in joins:
stmt = stmt.join(join)
row = stmt.one()
if row:
return it_to(row)
raise FailedRead("Query returned no result.")
return await session.run_sync(sync_inner, stmt, filter, joins)

@DatabaseManager.in_session
async def _select_many(self, stmt: Select, session: AsyncSession) -> List[Any]:
"""SELECT many from database."""
Expand Down Expand Up @@ -167,21 +147,80 @@ def _parse_int_operators(self, attr) -> Tuple[str, str]:
f"Expecting either 'field=v1,v2' pairs or integrer"
f" operators 'field.op(v)' op in {SUPPORTED_INT_OPERATORS}")

async def filter(self, query_params: dict, **kwargs) -> List[Base]:
def _filter_process_attr(self, stmt: Select, attr: List[str]):
"""Iterates over attribute parts (e.g. table.attr.x.y.z) joining tables along the way.
:param stmt: select statement in construction
:type stmt: Select
:param attr: attribute name parts of the querystring
:type attr: List[str]
:raises ValueError: When name is incorrect.
:return: Resulting statement and handles to column object and its type
:rtype: Tuple[Select, Tuple[Column, type]]
"""
table = self.table
for nested in attr[:-1]:
jtn = table.target_table(nested)
if jtn is None:
raise ValueError(f"Invalid nested entity name {nested}.")
jtable = jtn.decl_class
stmt = stmt.join(jtable)
table = jtable

return stmt, table.colinfo(attr[-1])

def _restrict_select_on_fields(
self,
stmt: Select,
fields: List[str],
nested: List[str],
serializer: Callable=None
) -> Select:
"""set load_only options of a select(table) statement given a list of fields.
:param stmt: _description_
:type stmt: Select
:param fields: _description_
:type fields: List[str]
:param nested: _description_
:type nested: List[str]
:param serializer: _description_, defaults to None
:type serializer: Callable, optional
:return: _description_
:rtype: Select
"""
# Restrict serializer fields so that it doesn't trigger any lazy loading.
serializer = partial(serializer, only=fields + nested) if serializer else None
return (
stmt.options(
load_only(
*[getattr(self.table, f) for f in fields]
),
*[
joinedload(getattr(self.table, n))
for n in nested
]
),
serializer
)

async def filter(self, query_params: dict, serializer: Callable=None, **kwargs) -> List[Base]:
"""READ rows filted on query parameters."""
# Get special parameters
fields = query_params.pop('fields', None)
offset = query_params.pop('start', None)
limit = query_params.pop('end', None)
# reverse = query_params.pop('reverse')

# TODO: apply same as read
# if fields:
# fields = fields.split(',') if fields else None
# stmt =

reverse = query_params.pop('reverse', None)
# TODO: apply limit to nested lists as well.

stmt = select(self.table)
if fields:
fields = fields.split(",")
nested, fields = partition(fields, lambda x: x in self.table.relationships())
stmt, serializer = self._restrict_select_on_fields(
stmt, fields, nested, serializer
)

for dskey, csval in query_params.items():
attr, values = dskey.split("."), csval.split(",")
# exclude = False
Expand All @@ -192,19 +231,7 @@ async def filter(self, query_params: dict, **kwargs) -> List[Base]:
operator = None if csval else self._parse_int_operators(attr.pop())
# elif any(op in dskey for op in SUPPORTED_INT_OPERATORS):
# raise ValueError("'field.op()=value' type of query is not yet supported.")

# For every nested entity of the attribute, join table.
table = self.table
for nested in attr[:-1]:
jtn = table.target_table(nested)
if jtn is None:
raise ValueError(f"Invalid nested entity name {nested}.")
jtable = jtn.decl_class
stmt = stmt.join(jtable)
table = jtable

# Get field info from last joined table.
col, ctype = table.colinfo(attr[-1])
stmt, (col, ctype) = self._filter_process_attr(stmt, attr)

# Numerical operators.
if operator:
Expand Down Expand Up @@ -241,7 +268,7 @@ async def filter(self, query_params: dict, **kwargs) -> List[Base]:
# if exclude:
# stmt = select(self.table.not_in(stmt))
stmt = stmt.offset(offset).limit(limit)
return await self._select_many(stmt, **kwargs)
return await self._select_many(stmt, serializer=serializer, **kwargs)

async def read(self,
pk_val: List[Any],
Expand All @@ -258,24 +285,12 @@ async def read(self,
:return: SQLAlchemy result item.
:rtype: Base
"""
fields = fields or []
rels = self.table.relationships()
nested, fields = partition(fields, lambda x: x in rels)

nested, fields = partition(fields or [], lambda x: x in self.table.relationships())
stmt = select(self.table)
if fields or nested:
# Restrict fields and joined tables.
stmt = stmt.options(
load_only(
*[getattr(self.table, f) for f in fields]
),
*[
joinedload(getattr(self.table, n))
for n in nested
]
stmt, serializer = self._restrict_select_on_fields(
stmt, fields, nested, serializer
)
# Restrict serializer fields so that it doesn't trigger any lazy loading.
serializer = partial(serializer, only=fields + nested) if serializer else None
stmt = stmt.where(self.gen_cond(pk_val))
return await self._select(stmt, serializer=serializer, **kwargs)

Expand All @@ -289,10 +304,12 @@ async def update(self, pk_val, data: dict, **kwargs) -> Base:
:return: _description_
:rtype: Base
"""
stmt = update(self.table)\
.where(self.gen_cond(pk_val))\
.values(**data)\
.returning(self.table)
stmt = (
update(self.table)
.where(self.gen_cond(pk_val))
.values(**data)
.returning(self.table)
)
return await self._update(stmt, **kwargs)

async def delete(self, pk_val, **kwargs) -> Any:
Expand Down

0 comments on commit a462439

Please sign in to comment.