From 84cf5f16d754d073b32f5d34b43cf4b694e07696 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Mon, 22 Apr 2024 19:54:30 +0100 Subject: [PATCH] Add support for many-to-many relationships (#868) --- admin-js/src/App.js | 6 +- admin-js/tests/relationships.test.js | 68 ++++++++++++++ aiohttp_admin/backends/abc.py | 130 +++++++++++++++++++++------ aiohttp_admin/backends/sqlalchemy.py | 88 +++++++++++++----- aiohttp_admin/routes.py | 8 +- aiohttp_admin/security.py | 7 +- aiohttp_admin/types.py | 2 +- examples/relationships.py | 6 +- tests/_resources.py | 30 ++++--- tests/conftest.py | 4 +- tests/test_backends_abc.py | 10 +++ tests/test_backends_sqlalchemy.py | 4 +- tests/test_security.py | 6 +- tests/test_views.py | 58 +++++++++--- 14 files changed, 333 insertions(+), 94 deletions(-) diff --git a/admin-js/src/App.js b/admin-js/src/App.js index 15eb6be4..9086e433 100644 --- a/admin-js/src/App.js +++ b/admin-js/src/App.js @@ -142,11 +142,7 @@ const dataProvider = { deleteMany: (resource, params) => dataRequest(resource, "delete_many", params), getList: (resource, params) => dataRequest(resource, "get_list", params), getMany: (resource, params) => dataRequest(resource, "get_many", params), - getManyReference: (resource, params) => { - // filter object is reused across requests, so clone it before modifying. - params["filter"] = {...params["filter"], [params["target"]]: params["id"]}; - return dataRequest(resource, "get_list", params); - }, + getManyReference: (resource, params) => dataRequest(resource, "get_many_ref", params), getOne: (resource, params) => dataRequest(resource, "get_one", params), update: (resource, params) => dataRequest(resource, "update", params), updateMany: (resource, params) => dataRequest(resource, "update_many", params) diff --git a/admin-js/tests/relationships.test.js b/admin-js/tests/relationships.test.js index bd04917f..b1b4d874 100644 --- a/admin-js/tests/relationships.test.js +++ b/admin-js/tests/relationships.test.js @@ -12,6 +12,7 @@ test("datagrid works", async () => { await userEvent.keyboard("[Escape]"); const grid = await within(table).findByRole("table"); + await sleep(0.1); const childHeaders = within(grid).getAllByRole("columnheader"); expect(childHeaders.slice(1).map((e) => e.textContent)).toEqual(["Id", "Name", "Value"]); const childRows = within(grid).getAllByRole("row"); @@ -61,6 +62,73 @@ test("onetomany child displays", async () => { expect(childCells.map((e) => e.textContent)).toEqual(["Bar", "2"]); }); +test("onetoone parents display", async () => { + await userEvent.click(await screen.findByRole("button", {"name": "Open menu"})); + await userEvent.click(await screen.findByText("Onetoone parents")); + + await waitFor(() => screen.getByRole("heading", {"name": "Onetoone parents"})); + await sleep(1); + + const table = screen.getByRole("table"); + await userEvent.click(within(table).getAllByRole("row")[1]); + + await waitFor(() => screen.getByRole("heading", {"name": "Onetoone parent Foo"})); + const grid = await screen.findByRole("table"); + const childHeaders = within(grid).getAllByRole("columnheader"); + expect(childHeaders.map((e) => e.textContent)).toEqual(["Id", "Name", "Value"]); + const childRows = within(grid).getAllByRole("row"); + expect(childRows.length).toBe(2); + const childCells = within(childRows[1]).getAllByRole("cell"); + expect(childCells.map((e) => e.textContent)).toEqual(["2", "Child Bar", "2"]); +}); + +test("manytomany left displays", async () => { + await userEvent.click(await screen.findByRole("button", {"name": "Open menu"})); + await userEvent.click(await screen.findByText("Manytomany lefts")); + + await waitFor(() => screen.getByRole("heading", {"name": "Manytomany lefts"})); + await sleep(1); + + await userEvent.click(screen.getByRole("button", {"name": "Columns"})); + // TODO: Remove when fixed: https://github.com/marmelab/react-admin/issues/9587 + await userEvent.click(within(screen.getByRole("presentation")).getByLabelText("Children")); + await userEvent.keyboard("[Escape]"); + + const table = screen.getAllByRole("table")[0]; + const headers = within(table.querySelector("thead")).getAllByRole("columnheader"); + expect(headers.slice(1, -1).map((e) => e.textContent)).toEqual(["Id", "Name", "Value", "Children"]); + + const rows = within(table).getAllByRole("row").filter((e) => e.parentElement.parentElement === table); + const firstCells = within(rows[1]).getAllByRole("cell").filter((e) => e.parentElement === rows[1]); + expect(firstCells.slice(1, -2).map((e) => e.textContent)).toEqual(["1", "Foo", "2"]); + const secondCells = within(rows[2]).getAllByRole("cell").filter((e) => e.parentElement === rows[2]); + expect(secondCells.slice(1, -2).map((e) => e.textContent)).toEqual(["2", "Bar", "3"]); + + const firstGrid = await within(firstCells.at(-2)).findByRole("table"); + const firstHeaders = within(firstGrid).getAllByRole("columnheader"); + await waitFor(() => firstHeaders[1].textContent.trim() != ""); + expect(firstHeaders.slice(1).map((e) => e.textContent)).toEqual(["Id", "Name", "Value"]); + const firstRows = within(firstGrid).getAllByRole("row"); + expect(firstRows.length).toBe(3); + let cells = within(firstRows[1]).getAllByRole("cell"); + expect(cells.slice(1).map((e) => e.textContent)).toEqual(["3", "Bar Child", "6"]); + cells = within(firstRows[2]).getAllByRole("cell"); + expect(cells.slice(1).map((e) => e.textContent)).toEqual(["1", "Foo Child", "5"]); + + const secondGrid = within(secondCells.at(-2)).getByRole("table"); + const secondHeaders = within(secondGrid).getAllByRole("columnheader"); + await waitFor(() => secondHeaders[0].textContent.trim() != ""); + expect(secondHeaders.slice(1).map((e) => e.textContent)).toEqual(["Id", "Name", "Value"]); + const secondRows = within(secondGrid).getAllByRole("row"); + expect(secondRows.length).toBe(4); + cells = within(secondRows[1]).getAllByRole("cell"); + expect(cells.slice(1).map((e) => e.textContent)).toEqual(["3", "Bar Child", "6"]); + cells = within(secondRows[2]).getAllByRole("cell"); + expect(cells.slice(1).map((e) => e.textContent)).toEqual(["2", "Baz Child", "7"]); + cells = within(secondRows[3]).getAllByRole("cell"); + expect(cells.slice(1).map((e) => e.textContent)).toEqual(["1", "Foo Child", "5"]); +}); + test("composite foreign key child displays table", async () => { await userEvent.click(await screen.findByRole("button", {"name": "Open menu"})); await userEvent.click(await screen.findByText("Composite foreign key children")); diff --git a/aiohttp_admin/backends/abc.py b/aiohttp_admin/backends/abc.py index 51f4db8f..d83cce79 100644 --- a/aiohttp_admin/backends/abc.py +++ b/aiohttp_admin/backends/abc.py @@ -14,7 +14,7 @@ from pydantic import Json from ..security import check, permissions_as_dict -from ..types import ComponentState, InputState, fk +from ..types import ComponentState, InputState, fk, resources_key if sys.version_info >= (3, 10): from typing import TypeAlias @@ -26,7 +26,7 @@ else: from typing_extensions import TypedDict -_ID = TypeVar("_ID") +_ID = TypeVar("_ID", bound=tuple[object, ...]) Record = dict[str, object] Meta = Optional[dict[str, object]] @@ -87,6 +87,22 @@ class GetManyParams(_Params): ids: Json[tuple[str, ...]] +class GetManyRefAPIParams(_Params): + target: str + id: str + pagination: Json[_Pagination] + sort: Json[_Sort] + filter: Json[dict[str, object]] + + +class GetManyRefParams(_Params): + target: tuple[str, ...] + id: tuple[object, ...] + pagination: Json[_Pagination] + sort: Json[_Sort] + filter: Json[dict[str, object]] + + class _CreateData(TypedDict): """Id will not be included for create calls.""" data: Record @@ -116,6 +132,11 @@ class DeleteManyParams(_Params): ids: Json[tuple[str, ...]] +class _ListQuery(TypedDict): + sort: _Sort + filter: dict[str, object] + + class AbstractAdminResource(ABC, Generic[_ID]): name: str fields: dict[str, ComponentState] @@ -155,6 +176,10 @@ async def get_one(self, record_id: _ID, meta: Meta) -> Record: async def get_many(self, record_ids: Sequence[_ID], meta: Meta) -> list[Record]: """Return the matching records.""" + @abstractmethod + async def get_many_ref(self, params: GetManyRefParams) -> tuple[list[Record], int]: + """Return list of records and total count available (when not paginating).""" + @abstractmethod async def update(self, record_id: _ID, data: Record, previous_data: Record, meta: Meta) -> Record: @@ -176,38 +201,30 @@ async def delete(self, record_id: _ID, previous_data: Record, meta: Meta) -> Rec async def delete_many(self, record_ids: Sequence[_ID], meta: Meta) -> list[_ID]: """Delete the matching records and return their IDs.""" + async def get_many_ref_name(self, target: str, meta: Meta) -> str: + """Return the resource name for the reference. + + This can be used to change which resource should be returned by get_many_ref(). + + For example, if we have an SQLAlchemy model called 'parent' with a relationship + called children, then a normal get_many_ref_name() call would go to the 'child' + model with the details from the parent, and the default behaviour would work. + + However, the SQLAlchemy backend uses the meta to switch this and send the request + to the 'parent' model instead and then use the children ORM attribute to fetch + the referenced resources, thus requiring this method to return 'child'. + This allows the SQLAlchemy backend to support complex relationships (e.g. + many-to-many) without needing react-admin to know the details. + """ + return self.name + # https://marmelab.com/react-admin/DataProviderWriting.html @final async def _get_list(self, request: web.Request) -> web.Response: await check_permission(request, f"admin.{self.name}.view", context=(request, None)) query = check(GetListParams, request.query) - - # When sort order refers to "id", this should be translated to primary key. - if query["sort"]["field"] == "id": - query["sort"]["field"] = self.primary_key[0] - else: - query["sort"]["field"] = query["sort"]["field"].removeprefix("data.") - - query["filter"].update(check(dict[str, object], query["filter"].pop("data", {}))) # type: ignore[type-var] - - merged_filter = {} - for k, v in query["filter"].items(): - if k.startswith("fk_"): - v = check(str, v) - for c, cv in zip(k.removeprefix("fk_").split("__"), v.split("|")): - merged_filter[c] = check(self._raw_record_type[c], cv) - else: - merged_filter[k] = check(self._raw_record_type[k], v) - query["filter"] = merged_filter - - # Add filters from advanced permissions. - # The permissions will be cached on the request from a previous permissions check. - permissions = permissions_as_dict(request["aiohttpadmin_permissions"]) - filters = permissions.get(f"admin.{self.name}.view", - permissions.get(f"admin.{self.name}.*", {})) - for k, v in filters.items(): - query["filter"][k] = v + self._process_list_query(query, request) raw_results, total = await self.get_list(query) results = [await self._convert_record(r, request) for r in raw_results @@ -239,6 +256,33 @@ async def _get_many(self, request: web.Request) -> web.Response: if await permits(request, f"admin.{self.name}.view", context=(request, r))] return json_response({"data": results}) + @final + async def _get_many_ref(self, request: web.Request) -> web.Response: + query = check(GetManyRefAPIParams, request.query) + meta = query["filter"].pop("__meta__", None) + if meta is not None: + query["meta"] = check(dict[str, object], meta) + reference = await self.get_many_ref_name(query["target"], query.get("meta")) + ref_model = request.app[resources_key][reference] + + await check_permission(request, f"admin.{ref_model.name}.view", context=(request, None)) + + ref_model._process_list_query(query, request) + + if query["target"].startswith("fk_"): + target = tuple(query["target"].removeprefix("fk_").split("__")) + record_id = tuple(check(self._raw_record_type[k], v) + for k, v in zip(target, query["id"].split("|"))) + else: + target = (query["target"],) + record_id = check(self._id_type, query["id"].split("|")) + + raw_results, total = await self.get_many_ref({**query, "target": target, "id": record_id}) + + results = [await ref_model._convert_record(r, request) for r in raw_results + if await permits(request, f"admin.{ref_model.name}.view", context=(request, r))] + return json_response({"data": results, "total": total}) + @final async def _create(self, request: web.Request) -> web.Response: query = check(CreateParams, request.query) @@ -350,7 +394,7 @@ async def _delete_many(self, request: web.Request) -> web.Response: @final def _check_record(self, record: Record) -> Record: """Check and convert input record.""" - return check(self._record_type, record) # type: ignore[no-any-return] + return check(self._record_type, record) @final async def _convert_record(self, record: Record, request: web.Request) -> APIRecord: @@ -371,6 +415,33 @@ def _convert_ids(self, ids: Sequence[_ID]) -> tuple[str, ...]: """Convert IDs to correct output format.""" return tuple(str(i) for i in ids) + def _process_list_query(self, query: _ListQuery, request: web.Request) -> None: + # When sort order refers to "id", this should be translated to primary key. + if query["sort"]["field"] == "id": + query["sort"]["field"] = self.primary_key[0] + else: + query["sort"]["field"] = query["sort"]["field"].removeprefix("data.") + + query["filter"].update(check(dict[str, object], query["filter"].pop("data", {}))) + + merged_filter = {} + for k, v in query["filter"].items(): + if k.startswith("fk_"): + v = check(str, v) + for c, cv in zip(k.removeprefix("fk_").split("__"), v.split("|")): + merged_filter[c] = check(self._raw_record_type[c], cv) + else: + merged_filter[k] = check(self._raw_record_type[k], v) + query["filter"] = merged_filter + + # Add filters from advanced permissions. + # The permissions will be cached on the request from a previous permissions check. + permissions = permissions_as_dict(request["aiohttpadmin_permissions"]) + filters = permissions.get(f"admin.{self.name}.view", + permissions.get(f"admin.{self.name}.*", {})) + for k, v in filters.items(): + query["filter"][k] = v + @cached_property def routes(self) -> tuple[web.RouteDef, ...]: """Routes to act on this resource. @@ -382,6 +453,7 @@ def routes(self) -> tuple[web.RouteDef, ...]: web.get(url + "/list", self._get_list, name=self.name + "_get_list"), web.get(url + "/one", self._get_one, name=self.name + "_get_one"), web.get(url, self._get_many, name=self.name + "_get_many"), + web.get(url + "/ref", self._get_many_ref, name=self.name + "_get_many_ref"), web.post(url, self._create, name=self.name + "_create"), web.put(url + "/update", self._update, name=self.name + "_update"), web.put(url + "/update_many", self._update_many, name=self.name + "_update_many"), diff --git a/aiohttp_admin/backends/sqlalchemy.py b/aiohttp_admin/backends/sqlalchemy.py index be1121e2..90c804b3 100644 --- a/aiohttp_admin/backends/sqlalchemy.py +++ b/aiohttp_admin/backends/sqlalchemy.py @@ -9,10 +9,11 @@ import sqlalchemy as sa from aiohttp import web -from sqlalchemy.ext.asyncio import AsyncEngine -from sqlalchemy.orm import DeclarativeBase, DeclarativeBaseNoMeta, Mapper, QueryableAttribute +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from sqlalchemy.orm import (DeclarativeBase, DeclarativeBaseNoMeta, Mapper, + QueryableAttribute, selectinload) -from .abc import AbstractAdminResource, GetListParams, Meta, Record +from .abc import AbstractAdminResource, GetListParams, GetManyRefParams, Meta, Record from ..types import FunctionState, comp, data, fk, func, regex if sys.version_info >= (3, 10): @@ -27,6 +28,7 @@ Union[_FValues, Sequence[_FValues]]] _ModelOrTable = Union[sa.Table, type[DeclarativeBase], type[DeclarativeBaseNoMeta]] _SABoolExpression = sa.sql.roles.ExpressionElementRole[bool] +# _RelationshipAttr = InstrumentedAttribute[Union[DeclarativeBase, DeclarativeBaseNoMeta]] logger = logging.getLogger(__name__) @@ -160,7 +162,9 @@ def create_filters(columns: sa.ColumnCollection[str, sa.Column[object]], # ID is based on PK, which we can't infer from types, so must use Any here. -class SAResource(AbstractAdminResource[Any]): +class SAResource(AbstractAdminResource[tuple[Any, ...]]): + _model: Union[type[DeclarativeBase], type[DeclarativeBaseNoMeta], None] = None + def __init__(self, db: AsyncEngine, model_or_table: _ModelOrTable): if isinstance(model_or_table, sa.Table): table = model_or_table @@ -168,8 +172,17 @@ def __init__(self, db: AsyncEngine, model_or_table: _ModelOrTable): if not isinstance(model_or_table.__table__, sa.Table): raise ValueError("Non-table mappings are not supported.") table = model_or_table.__table__ + self._model = model_or_table + self._db = db + self._table = table self.name = table.name + self.primary_key = tuple(filter(lambda c: table.c[c].primary_key, self._table.c.keys())) + if not self.primary_key: + self.primary_key = tuple(self._table.c.keys()) + pk_types = tuple(table.c[pk].type.python_type for pk in self.primary_key) + self._id_type = tuple.__class_getitem__(pk_types) # type: ignore[assignment] + self.fields = {} self.inputs = {} self.omit_fields = set() @@ -256,6 +269,10 @@ def __init__(self, db: AsyncEngine, model_or_table: _ModelOrTable): props["link"] = "show" elif relationship.uselist: t = "ReferenceManyField" + props["reference"] = self.name + props["target"] = name + props["source"] = "id" + props["filter"] = {"__meta__": {"orm": True}} else: t = "ReferenceOneField" props["link"] = "show" @@ -277,15 +294,6 @@ def __init__(self, db: AsyncEngine, model_or_table: _ModelOrTable): self.fields[name] = comp(t, props) self.omit_fields.add(name) - self._db = db - self._table = table - - self.primary_key = tuple(filter(lambda c: table.c[c].primary_key, self._table.c.keys())) - if not self.primary_key: - raise ValueError("No primary key found.") - pk_types = tuple(table.c[pk].type.python_type for pk in self.primary_key) - self._id_type = tuple.__class_getitem__(pk_types) # type: ignore[assignment] - super().__init__(record_type) @handle_errors @@ -315,19 +323,53 @@ async def get_entities() -> list[Record]: return await asyncio.gather(get_entities(), get_count()) @handle_errors - async def get_one(self, record_id: tuple[Any], meta: Meta) -> Record: + async def get_one(self, record_id: tuple[Any, ...], meta: Meta) -> Record: async with self._db.connect() as conn: stmt = sa.select(self._table).where(*self._cmp_pk(record_id)) result = await conn.execute(stmt) return result.one()._asdict() @handle_errors - async def get_many(self, record_ids: Sequence[tuple[Any]], meta: Meta) -> list[Record]: + async def get_many(self, record_ids: Sequence[tuple[Any, ...]], meta: Meta) -> list[Record]: async with self._db.connect() as conn: stmt = sa.select(self._table).where(self._cmp_pk_many(record_ids)) result = await conn.execute(stmt) return [r._asdict() for r in result] + async def get_many_ref_name(self, target: str, meta: Meta) -> str: + if meta and meta.get("orm", False): + # TODO(pydantic): arbitrary_types_allowed=True check(_RelationshipAttr, ...) + relationship = getattr(self._model, target) + return relationship.entity.persist_selectable.name # type: ignore[no-any-return] + + return self.name + + @handle_errors + async def get_many_ref(self, params: GetManyRefParams) -> tuple[list[Record], int]: + meta = params.get("meta") + if meta and meta.get("orm", False): + if self._model is None: + raise web.HTTPBadRequest(reason="Not an ORM model.") + + # Use an ORM relationship to get the records (essentially the inverse of a + # normal manyReference request). This makes it easy to support complex + # relationships (such as many-to-many) without react-admin needing the details + target = params["target"][0] + reverse = params["sort"]["order"] == "DESC" + # TODO(pydantic): arbitrary_types_allowed=True check(_RelationshipAttr, ...) + relationship = getattr(self._model, target) + async with AsyncSession(self._db) as sess: + result = await sess.get(self._model, params["id"], + options=(selectinload(relationship),)) + records = [{c.name: getattr(r, c.name) for c in r.__table__.c} + for r in getattr(result, target)] + records.sort(key=lambda r: r[params["sort"]["field"]], reverse=reverse) + return records, len(records) + + for k, v in zip(params["target"], params["id"]): + params["filter"][k] = v + return await self.get_list(params) + @handle_errors async def create(self, data: Record, meta: Meta) -> Record: async with self._db.begin() as conn: @@ -340,7 +382,7 @@ async def create(self, data: Record, meta: Meta) -> Record: return row.one()._asdict() @handle_errors - async def update(self, record_id: tuple[Any], data: Record, previous_data: Record, + async def update(self, record_id: tuple[Any, ...], data: Record, previous_data: Record, meta: Meta) -> Record: async with self._db.begin() as conn: stmt = sa.update(self._table).where(*self._cmp_pk(record_id)) @@ -349,31 +391,33 @@ async def update(self, record_id: tuple[Any], data: Record, previous_data: Recor return row.one()._asdict() @handle_errors - async def update_many(self, record_ids: Sequence[tuple[Any]], data: Record, - meta: Meta) -> list[Any]: + async def update_many(self, record_ids: Sequence[tuple[Any, ...]], data: Record, + meta: Meta) -> list[tuple[Any, ...]]: async with self._db.begin() as conn: stmt = sa.update(self._table).where(self._cmp_pk_many(record_ids)) stmt = stmt.values(data).returning(*(self._table.c[pk] for pk in self.primary_key)) return list(await conn.scalars(stmt)) @handle_errors - async def delete(self, record_id: tuple[Any], previous_data: Record, meta: Meta) -> Record: + async def delete(self, record_id: tuple[Any, ...], previous_data: Record, + meta: Meta) -> Record: async with self._db.begin() as conn: stmt = sa.delete(self._table).where(*self._cmp_pk(record_id)) row = await conn.execute(stmt.returning(*self._table.c)) return row.one()._asdict() @handle_errors - async def delete_many(self, record_ids: Sequence[tuple[Any]], meta: Meta) -> list[Any]: + async def delete_many(self, record_ids: Sequence[tuple[Any, ...]], + meta: Meta) -> list[tuple[Any, ...]]: async with self._db.begin() as conn: stmt = sa.delete(self._table).where(self._cmp_pk_many(record_ids)) r = await conn.scalars(stmt.returning(*(self._table.c[pk] for pk in self.primary_key))) return list(r) - def _cmp_pk(self, record_id: tuple[Any]) -> Iterator[_SABoolExpression]: + def _cmp_pk(self, record_id: tuple[Any, ...]) -> Iterator[_SABoolExpression]: return (self._table.c[pk] == r_id for pk, r_id in zip(self.primary_key, record_id)) - def _cmp_pk_many(self, record_ids: Sequence[tuple[Any]]) -> _SABoolExpression: + def _cmp_pk_many(self, record_ids: Sequence[tuple[Any, ...]]) -> _SABoolExpression: return sa.tuple_(*(self._table.c[pk] for pk in self.primary_key)).in_(record_ids) def _get_validators(self, table: sa.Table, c: sa.Column[object]) -> list[FunctionState]: diff --git a/aiohttp_admin/routes.py b/aiohttp_admin/routes.py index 8c2adf1b..7c4082b3 100644 --- a/aiohttp_admin/routes.py +++ b/aiohttp_admin/routes.py @@ -6,15 +6,15 @@ from aiohttp import web from . import views +from .backends.abc import AbstractAdminResource from .types import Schema, _ResourceState, data, resources_key, state_key def setup_resources(admin: web.Application, schema: Schema) -> None: - admin[resources_key] = [] - + resources: dict[str, AbstractAdminResource[tuple[object, ...]]] = {} for r in schema["resources"]: m = r["model"] - admin[resources_key].append(m) + resources[m.name] = m admin.router.add_routes(m.routes) try: @@ -24,6 +24,7 @@ def setup_resources(admin: web.Application, schema: Schema) -> None: else: if not all(f in m.fields for f in r["display"]): raise ValueError(f"Display includes non-existent field {r['display']}") + # TODO: Use label: https://github.com/marmelab/react-admin/issues/9587 omit_fields = tuple(m.fields[f]["props"].get("source") for f in omit_fields) repr_field = r.get("repr", data(m.primary_key[0])) @@ -54,6 +55,7 @@ def setup_resources(admin: web.Application, schema: Schema) -> None: "bulk_update": r.get("bulk_update", {}), "urls": {}, "show_actions": r.get("show_actions", ())} admin[state_key]["resources"][m.name] = state + admin[resources_key] = resources def setup_routes(admin: web.Application) -> None: diff --git a/aiohttp_admin/security.py b/aiohttp_admin/security.py index 9648cc03..5355a088 100644 --- a/aiohttp_admin/security.py +++ b/aiohttp_admin/security.py @@ -1,5 +1,5 @@ import json -from collections.abc import Collection, Hashable, Mapping, Sequence +from collections.abc import Collection, Mapping, Sequence from enum import Enum from functools import lru_cache from typing import Optional, Type, TypeVar, Union @@ -11,7 +11,7 @@ from .types import IdentityDict, Schema, UserDetails -_T = TypeVar("_T", bound=Hashable) +_T = TypeVar("_T") @lru_cache # https://github.com/python/typeshed/issues/6347 @@ -21,7 +21,8 @@ def _get_schema(t: Type[_T]) -> TypeAdapter[_T]: # type: ignore[misc] def check(t: Type[_T], value: object) -> _T: """Validate value is of static type t.""" - return _get_schema(t).validate_python(value) # type: ignore[no-any-return] + # https://github.com/python/mypy/issues/11470 + return _get_schema(t).validate_python(value) # type: ignore[arg-type,no-any-return] class Permissions(str, Enum): diff --git a/aiohttp_admin/types.py b/aiohttp_admin/types.py index fa9dc556..cbc5e0f2 100644 --- a/aiohttp_admin/types.py +++ b/aiohttp_admin/types.py @@ -172,5 +172,5 @@ def regex(value: str) -> RegexState: check_credentials_key = AppKey[Callable[[str, str], Awaitable[bool]]]("check_credentials") permission_re_key = AppKey("permission_re", re.Pattern[str]) -resources_key = AppKey("resources", list[Any]) # TODO(pydantic): AbstractAdminResource +resources_key = AppKey("resources", dict[str, Any]) # TODO(pydantic): AbstractAdminResource state_key = AppKey("state", State) diff --git a/examples/relationships.py b/examples/relationships.py index 9e4f2bc9..50212be7 100644 --- a/examples/relationships.py +++ b/examples/relationships.py @@ -160,10 +160,12 @@ async def create_app() -> web.Application: manytomany_p2 = ManyToManyParent(name="Bar", value=3) manytomany_c1 = ManyToManyChild(name="Foo Child", value=5) manytomany_c2 = ManyToManyChild(name="Bar Child", value=6) + manytomany_c3 = ManyToManyChild(name="Baz Child", value=7) manytomany_p1.children.append(manytomany_c1) manytomany_p1.children.append(manytomany_c2) manytomany_p2.children.append(manytomany_c1) manytomany_p2.children.append(manytomany_c2) + manytomany_p2.children.append(manytomany_c3) sess.add(manytomany_p1) sess.add(manytomany_p2) sess.add(manytomany_c1) @@ -199,8 +201,8 @@ async def create_app() -> web.Application: {"model": SAResource(engine, ManyToOneChild)}, {"model": SAResource(engine, OneToOneParent), "repr": aiohttp_admin.data("name")}, {"model": SAResource(engine, OneToOneChild)}, - # {"model": SAResource(engine, ManyToManyParent)}, - # {"model": SAResource(engine, ManyToManyChild)}, + {"model": SAResource(engine, ManyToManyParent)}, + {"model": SAResource(engine, ManyToManyChild)}, {"model": SAResource(engine, CompositeForeignKeyChild), "repr": aiohttp_admin.data("description")}, {"model": SAResource(engine, CompositeForeignKeyParent)} diff --git a/tests/_resources.py b/tests/_resources.py index 57eb222b..1d29d482 100644 --- a/tests/_resources.py +++ b/tests/_resources.py @@ -1,10 +1,11 @@ from typing import Sequence -from aiohttp_admin.backends.abc import AbstractAdminResource, GetListParams, Meta, Record +from aiohttp_admin.backends.abc import (AbstractAdminResource, GetListParams, + GetManyRefParams, Meta, Record) from aiohttp_admin.types import ComponentState, InputState -class DummyResource(AbstractAdminResource[str]): +class DummyResource(AbstractAdminResource[tuple[str]]): def __init__(self, name: str, fields: dict[str, ComponentState], inputs: dict[str, InputState], primary_key: str): self.name = name @@ -12,30 +13,39 @@ def __init__(self, name: str, fields: dict[str, ComponentState], self.inputs = inputs self.primary_key = (primary_key,) self.omit_fields = set() - self._id_type = str + self._id_type = tuple[str] # type: ignore[assignment] self._foreign_rows = set() super().__init__() - async def get_list(self, params: GetListParams) -> tuple[list[Record], int]: # pragma: no cover # noqa: B950 + async def get_list(self, params: GetListParams) -> tuple[list[Record], int]: # pragma: no cover raise NotImplementedError() - async def get_one(self, record_id: str, meta: Meta) -> Record: # pragma: no cover + async def get_one(self, record_id: tuple[str], meta: Meta) -> Record: # pragma: no cover raise NotImplementedError() - async def get_many(self, record_ids: Sequence[str], meta: Meta) -> list[Record]: # pragma: no cover # noqa: B950 + async def get_many(self, record_ids: Sequence[tuple[str]], meta: Meta) -> list[Record]: # pragma: no cover raise NotImplementedError() - async def update(self, record_id: str, data: Record, previous_data: Record, meta: Meta) -> Record: # pragma: no cover # noqa: B950 + async def get_many_ref(self, params: GetManyRefParams) -> tuple[list[Record], int]: # pragma: no cover raise NotImplementedError() - async def update_many(self, record_ids: Sequence[str], data: Record, meta: Meta) -> list[str]: # pragma: no cover # noqa: B950 + async def update( # pragma: no cover + self, record_id: tuple[str], data: Record, previous_data: Record, meta: Meta + ) -> Record: + raise NotImplementedError() + + async def update_many( # pragma: no cover + self, record_ids: Sequence[tuple[str]], data: Record, meta: Meta + ) -> list[tuple[str]]: raise NotImplementedError() async def create(self, data: Record, meta: Meta) -> Record: # pragma: no cover raise NotImplementedError() - async def delete(self, record_id: str, previous_data: Record, meta: Meta) -> Record: # pragma: no cover # noqa: B950 + async def delete(self, record_id: tuple[str], previous_data: Record, meta: Meta) -> Record: # pragma: no cover raise NotImplementedError() - async def delete_many(self, record_ids: Sequence[str], meta: Meta) -> list[str]: # pragma: no cover # noqa: B950 + async def delete_many( # pragma: no cover + self, record_ids: Sequence[tuple[str]], meta: Meta + ) -> list[tuple[str]]: raise NotImplementedError() diff --git a/tests/conftest.py b/tests/conftest.py index f1dbdb3e..5313a53c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ from aiohttp.test_utils import TestClient from sqlalchemy.ext.asyncio import (AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine) -from sqlalchemy.orm import DeclarativeBaseNoMeta, Mapped, mapped_column +from sqlalchemy.orm import DeclarativeBaseNoMeta, Mapped, mapped_column, relationship import aiohttp_admin from _auth import check_credentials @@ -26,6 +26,8 @@ class DummyModel(Base): id: Mapped[int] = mapped_column(primary_key=True) + foreigns: Mapped[list["ForeignModel"]] = relationship() + class Dummy2Model(Base): __tablename__ = "dummy2" diff --git a/tests/test_backends_abc.py b/tests/test_backends_abc.py index 2493fb68..36b696e8 100644 --- a/tests/test_backends_abc.py +++ b/tests/test_backends_abc.py @@ -16,3 +16,13 @@ async def test_create_with_null(admin_client: TestClient, login: _Login) -> None async with admin_client.post(url, params=p, headers=h) as resp: assert resp.status == 200, await resp.text() assert await resp.json() == {"data": {"id": "4", "data": {"id": 4, "msg": None}}} + + +async def test_invalid_field(admin_client: TestClient, login: _Login) -> None: + h = await login(admin_client) + assert admin_client.app + url = admin_client.app[admin].router["dummy2_create"].url_for() + p = {"data": json.dumps({"data": {"incorrect": "foo"}})} + async with admin_client.post(url, params=p, headers=h) as resp: + assert resp.status == 400, await resp.text() + assert "Invalid field 'incorrect'" in await resp.text() diff --git a/tests/test_backends_sqlalchemy.py b/tests/test_backends_sqlalchemy.py index 671cec8a..34080507 100644 --- a/tests/test_backends_sqlalchemy.py +++ b/tests/test_backends_sqlalchemy.py @@ -227,8 +227,8 @@ class TestOne(base): # type: ignore[misc,valid-type] {"children": comp("Datagrid", { "rowClick": "show", "children": [comp("NumberField", {"source": data("id")})], "bulkActionButtons": comp("BulkDeleteButton", {"mutationMode": "pessimistic"})}), - "label": "Ones", "reference": "one", "source": fk("id"), "target": fk("many_id"), - "sortable": False, "key": "ones"}) + "label": "Ones", "reference": "many", "source": "id", "target": "ones", + "sortable": False, "key": "ones", "filter": {"__meta__": {"orm": True}}}) assert "ones" not in r.inputs r = SAResource(mock_engine, TestOne) diff --git a/tests/test_security.py b/tests/test_security.py index 0a3e69bb..bed70970 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -41,7 +41,7 @@ async def test_valid_login_logout(admin_client: TestClient) -> None: h = {"Authorization": token} async with admin_client.get(get_one_url, params=p, headers=h) as resp: assert resp.status == 200 - assert await resp.json() == {"data": {"id": "1", "data": {"id": 1}}} + assert await resp.json() == {"data": {"id": "1", "fk_id": "1", "data": {"id": 1}}} # Continue to test logout logout_url = admin_client.app[admin].router["logout"].url_for() @@ -134,7 +134,7 @@ async def identity_callback(identity: Optional[str]) -> UserDetails: h = await login(admin_client) async with admin_client.get(url, params={"id": "1"}, headers=h) as resp: assert resp.status == 200 - assert await resp.json() == {"data": {"id": "1", "data": {"id": 1}}} + assert await resp.json() == {"data": {"id": "1", "fk_id": "1", "data": {"id": 1}}} async def test_get_fk_with_permission(create_admin_client: _CreateClient, login: _Login) -> None: @@ -182,7 +182,7 @@ async def identity_callback(identity: Optional[str]) -> UserDetails: h = await login(admin_client) async with admin_client.get(url, params={"id": "1"}, headers=h) as resp: assert resp.status == 200 - assert await resp.json() == {"data": {"id": "1", "data": {"id": 1}}} + assert await resp.json() == {"data": {"id": "1", "fk_id": "1", "data": {"id": 1}}} async def test_get_resource_with_negative_permission(create_admin_client: _CreateClient, diff --git a/tests/test_views.py b/tests/test_views.py index 3dc72614..cc5e6536 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -32,11 +32,13 @@ async def test_admin_view(admin_client: TestClient) -> None: state = json.loads(m.group(1)) r = state["resources"]["dummy"] - assert r["list_omit"] == [] - assert r["fields"] == {"id": comp("NumberField", {"source": data("id"), "key": "id"})} + # TODO: https://github.com/marmelab/react-admin/issues/9587 + assert r["list_omit"] == ["id"] + assert r["fields"].keys() == {"id", "foreigns"} + assert r["fields"]["id"] == comp("NumberField", {"source": data("id"), "key": "id"}) assert r["inputs"] == { "id": comp("NumberInput", - {"source": data("id"), "key": "id", "alwaysOn": "alwaysOn", + {"source": data("id"), "key": "id", # "alwaysOn": "alwaysOn", "validate": [func("required", [])]}) | {"show_create": False}} assert r["repr"] == data("id") @@ -86,9 +88,10 @@ async def test_list_filtering_by_pk(admin_client: TestClient, login: _Login) -> url = admin_client.app[admin].router["dummy_get_list"].url_for() p = {"pagination": '{"page": 1, "perPage": 10}', "sort": '{"field": "id", "order": "ASC"}', "filter": '{"id": 3}'} + exp_rec = {"id": "3", "fk_id": "3", "data": {"id": 3}} async with admin_client.get(url, params=p, headers=h) as resp: assert resp.status == 200 - assert await resp.json() == {"data": [{"id": "3", "data": {"id": 3}}], "total": 1} + assert await resp.json() == {"data": [exp_rec], "total": 1} @pytest.mark.xfail(reason="Need to implement #668 to make this work properly") @@ -114,7 +117,7 @@ async def test_get_one(admin_client: TestClient, login: _Login) -> None: async with admin_client.get(url, params={"id": 1}, headers=h) as resp: assert resp.status == 200 - assert await resp.json() == {"data": {"id": "1", "data": {"id": 1}}} + assert await resp.json() == {"data": {"id": "1", "fk_id": "1", "data": {"id": 1}}} async def test_get_one_not_exists(admin_client: TestClient, login: _Login) -> None: @@ -137,9 +140,9 @@ async def test_get_many(admin_client: TestClient, login: _Login) -> None: p = {"ids": '["3", "7", "12"]'} async with admin_client.get(url, params=p, headers=h) as resp: assert resp.status == 200 - assert await resp.json() == {"data": [{"id": "3", "data": {"id": 3}}, - {"id": "7", "data": {"id": 7}}, - {"id": "12", "data": {"id": 12}}]} + assert await resp.json() == {"data": [{"id": "3", "fk_id": "3", "data": {"id": 3}}, + {"id": "7", "fk_id": "7", "data": {"id": 7}}, + {"id": "12", "fk_id": "12", "data": {"id": 12}}]} async def test_get_many_not_exists(admin_client: TestClient, login: _Login) -> None: @@ -153,14 +156,43 @@ async def test_get_many_not_exists(admin_client: TestClient, login: _Login) -> N p = {"ids": '["3", "4", "8"]'} async with admin_client.get(url, params=p, headers=h) as resp: assert resp.status == 200 - assert await resp.json() == {"data": [{"id": "3", "data": {"id": 3}}, - {"id": "4", "data": {"id": 4}}]} + assert await resp.json() == {"data": [{"id": "3", "fk_id": "3", "data": {"id": 3}}, + {"id": "4", "fk_id": "4", "data": {"id": 4}}]} p = {"ids": '["9", "10", "11"]'} async with admin_client.get(url, params=p, headers=h) as resp: assert resp.status == 404 +async def test_get_many_ref(admin_client: TestClient, login: _Login) -> None: + h = await login(admin_client) + assert admin_client.app + + url = admin_client.app[admin].router["foreign_get_many_ref"].url_for() + page = json.dumps({"page": 1, "perPage": 10}) + sort = json.dumps({"field": "id", "order": "DESC"}) + p = {"target": "dummy", "id": "1", "pagination": page, "sort": sort, "filter": "{}"} + expected_record = {"id": "1", "fk_dummy": "1", "data": {"id": 1, "dummy": 1}} + async with admin_client.get(url, params=p, headers=h) as resp: + assert resp.status == 200, await resp.text() + assert await resp.json() == {"data": [expected_record], "total": 1} + + +async def test_get_many_ref_orm(admin_client: TestClient, login: _Login) -> None: + h = await login(admin_client) + assert admin_client.app + + url = admin_client.app[admin].router["dummy_get_many_ref"].url_for() + page = json.dumps({"page": 1, "perPage": 10}) + sort = json.dumps({"field": "id", "order": "DESC"}) + f = json.dumps({"__meta__": {"orm": True}}) + p = {"target": "foreigns", "id": "1", "pagination": page, "sort": sort, "filter": f} + expected_record = {"id": "1", "fk_dummy": "1", "data": {"id": 1, "dummy": 1}} + async with admin_client.get(url, params=p, headers=h) as resp: + assert resp.status == 200, await resp.text() + assert await resp.json() == {"data": [expected_record], "total": 1} + + async def test_create(admin_client: TestClient, login: _Login) -> None: h = await login(admin_client) assert admin_client.app @@ -168,7 +200,7 @@ async def test_create(admin_client: TestClient, login: _Login) -> None: p = {"data": json.dumps({"data": {}})} async with admin_client.post(url, params=p, headers=h) as resp: assert resp.status == 200 - assert await resp.json() == {"data": {"id": "2", "data": {"id": 2}}} + assert await resp.json() == {"data": {"id": "2", "fk_id": "2", "data": {"id": 2}}} async with admin_client.app[db]() as sess: r = await sess.get(admin_client.app[model], 2) @@ -193,7 +225,7 @@ async def test_update(admin_client: TestClient, login: _Login) -> None: "previousData": json.dumps({"id": "1", "data": {"id": 1}})} async with admin_client.put(url, params=p, headers=h) as resp: assert resp.status == 200 - assert await resp.json() == {"data": {"id": "4", "data": {"id": 4}}} + assert await resp.json() == {"data": {"id": "4", "fk_id": "4", "data": {"id": 4}}} async with admin_client.app[db]() as sess: r = await sess.get(admin_client.app[model], 4) @@ -269,7 +301,7 @@ async def test_delete(admin_client: TestClient, login: _Login) -> None: p = {"id": "1", "previousData": '{"id": "1", "data": {"id": 1}}'} async with admin_client.delete(url, params=p, headers=h) as resp: assert resp.status == 200 - assert await resp.json() == {"data": {"id": "1", "data": {"id": 1}}} + assert await resp.json() == {"data": {"id": "1", "fk_id": "1", "data": {"id": 1}}} async with admin_client.app[db]() as sess: assert await sess.get(admin_client.app[model], 1) is None