Skip to content

Commit

Permalink
Add support for many-to-many relationships (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer committed Apr 22, 2024
1 parent 2d3c293 commit 84cf5f1
Show file tree
Hide file tree
Showing 14 changed files with 333 additions and 94 deletions.
6 changes: 1 addition & 5 deletions admin-js/src/App.js
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,7 @@ const dataProvider = {
deleteMany: (resource, params) => dataRequest(resource, "delete_many", params),
getList: (resource, params) => dataRequest(resource, "get_list", params),
getMany: (resource, params) => dataRequest(resource, "get_many", params),
getManyReference: (resource, params) => {
// filter object is reused across requests, so clone it before modifying.
params["filter"] = {...params["filter"], [params["target"]]: params["id"]};
return dataRequest(resource, "get_list", params);
},
getManyReference: (resource, params) => dataRequest(resource, "get_many_ref", params),
getOne: (resource, params) => dataRequest(resource, "get_one", params),
update: (resource, params) => dataRequest(resource, "update", params),
updateMany: (resource, params) => dataRequest(resource, "update_many", params)
Expand Down
68 changes: 68 additions & 0 deletions admin-js/tests/relationships.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ test("datagrid works", async () => {
await userEvent.keyboard("[Escape]");

const grid = await within(table).findByRole("table");
await sleep(0.1);
const childHeaders = within(grid).getAllByRole("columnheader");
expect(childHeaders.slice(1).map((e) => e.textContent)).toEqual(["Id", "Name", "Value"]);
const childRows = within(grid).getAllByRole("row");
Expand Down Expand Up @@ -61,6 +62,73 @@ test("onetomany child displays", async () => {
expect(childCells.map((e) => e.textContent)).toEqual(["Bar", "2"]);
});

test("onetoone parents display", async () => {
await userEvent.click(await screen.findByRole("button", {"name": "Open menu"}));
await userEvent.click(await screen.findByText("Onetoone parents"));

await waitFor(() => screen.getByRole("heading", {"name": "Onetoone parents"}));
await sleep(1);

const table = screen.getByRole("table");
await userEvent.click(within(table).getAllByRole("row")[1]);

await waitFor(() => screen.getByRole("heading", {"name": "Onetoone parent Foo"}));
const grid = await screen.findByRole("table");
const childHeaders = within(grid).getAllByRole("columnheader");
expect(childHeaders.map((e) => e.textContent)).toEqual(["Id", "Name", "Value"]);
const childRows = within(grid).getAllByRole("row");
expect(childRows.length).toBe(2);
const childCells = within(childRows[1]).getAllByRole("cell");
expect(childCells.map((e) => e.textContent)).toEqual(["2", "Child Bar", "2"]);
});

test("manytomany left displays", async () => {
await userEvent.click(await screen.findByRole("button", {"name": "Open menu"}));
await userEvent.click(await screen.findByText("Manytomany lefts"));

await waitFor(() => screen.getByRole("heading", {"name": "Manytomany lefts"}));
await sleep(1);

await userEvent.click(screen.getByRole("button", {"name": "Columns"}));
// TODO: Remove when fixed: https://github.com/marmelab/react-admin/issues/9587
await userEvent.click(within(screen.getByRole("presentation")).getByLabelText("Children"));
await userEvent.keyboard("[Escape]");

const table = screen.getAllByRole("table")[0];
const headers = within(table.querySelector("thead")).getAllByRole("columnheader");
expect(headers.slice(1, -1).map((e) => e.textContent)).toEqual(["Id", "Name", "Value", "Children"]);

const rows = within(table).getAllByRole("row").filter((e) => e.parentElement.parentElement === table);
const firstCells = within(rows[1]).getAllByRole("cell").filter((e) => e.parentElement === rows[1]);
expect(firstCells.slice(1, -2).map((e) => e.textContent)).toEqual(["1", "Foo", "2"]);
const secondCells = within(rows[2]).getAllByRole("cell").filter((e) => e.parentElement === rows[2]);
expect(secondCells.slice(1, -2).map((e) => e.textContent)).toEqual(["2", "Bar", "3"]);

const firstGrid = await within(firstCells.at(-2)).findByRole("table");
const firstHeaders = within(firstGrid).getAllByRole("columnheader");
await waitFor(() => firstHeaders[1].textContent.trim() != "");
expect(firstHeaders.slice(1).map((e) => e.textContent)).toEqual(["Id", "Name", "Value"]);
const firstRows = within(firstGrid).getAllByRole("row");
expect(firstRows.length).toBe(3);
let cells = within(firstRows[1]).getAllByRole("cell");
expect(cells.slice(1).map((e) => e.textContent)).toEqual(["3", "Bar Child", "6"]);
cells = within(firstRows[2]).getAllByRole("cell");
expect(cells.slice(1).map((e) => e.textContent)).toEqual(["1", "Foo Child", "5"]);

const secondGrid = within(secondCells.at(-2)).getByRole("table");
const secondHeaders = within(secondGrid).getAllByRole("columnheader");
await waitFor(() => secondHeaders[0].textContent.trim() != "");
expect(secondHeaders.slice(1).map((e) => e.textContent)).toEqual(["Id", "Name", "Value"]);
const secondRows = within(secondGrid).getAllByRole("row");
expect(secondRows.length).toBe(4);
cells = within(secondRows[1]).getAllByRole("cell");
expect(cells.slice(1).map((e) => e.textContent)).toEqual(["3", "Bar Child", "6"]);
cells = within(secondRows[2]).getAllByRole("cell");
expect(cells.slice(1).map((e) => e.textContent)).toEqual(["2", "Baz Child", "7"]);
cells = within(secondRows[3]).getAllByRole("cell");
expect(cells.slice(1).map((e) => e.textContent)).toEqual(["1", "Foo Child", "5"]);
});

test("composite foreign key child displays table", async () => {
await userEvent.click(await screen.findByRole("button", {"name": "Open menu"}));
await userEvent.click(await screen.findByText("Composite foreign key children"));
Expand Down
130 changes: 101 additions & 29 deletions aiohttp_admin/backends/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pydantic import Json

from ..security import check, permissions_as_dict
from ..types import ComponentState, InputState, fk
from ..types import ComponentState, InputState, fk, resources_key

if sys.version_info >= (3, 10):
from typing import TypeAlias
Expand All @@ -26,7 +26,7 @@
else:
from typing_extensions import TypedDict

_ID = TypeVar("_ID")
_ID = TypeVar("_ID", bound=tuple[object, ...])
Record = dict[str, object]
Meta = Optional[dict[str, object]]

Expand Down Expand Up @@ -87,6 +87,22 @@ class GetManyParams(_Params):
ids: Json[tuple[str, ...]]


class GetManyRefAPIParams(_Params):
target: str
id: str
pagination: Json[_Pagination]
sort: Json[_Sort]
filter: Json[dict[str, object]]


class GetManyRefParams(_Params):
target: tuple[str, ...]
id: tuple[object, ...]
pagination: Json[_Pagination]
sort: Json[_Sort]
filter: Json[dict[str, object]]


class _CreateData(TypedDict):
"""Id will not be included for create calls."""
data: Record
Expand Down Expand Up @@ -116,6 +132,11 @@ class DeleteManyParams(_Params):
ids: Json[tuple[str, ...]]


class _ListQuery(TypedDict):
sort: _Sort
filter: dict[str, object]


class AbstractAdminResource(ABC, Generic[_ID]):
name: str
fields: dict[str, ComponentState]
Expand Down Expand Up @@ -155,6 +176,10 @@ async def get_one(self, record_id: _ID, meta: Meta) -> Record:
async def get_many(self, record_ids: Sequence[_ID], meta: Meta) -> list[Record]:
"""Return the matching records."""

@abstractmethod
async def get_many_ref(self, params: GetManyRefParams) -> tuple[list[Record], int]:
"""Return list of records and total count available (when not paginating)."""

@abstractmethod
async def update(self, record_id: _ID, data: Record, previous_data: Record,
meta: Meta) -> Record:
Expand All @@ -176,38 +201,30 @@ async def delete(self, record_id: _ID, previous_data: Record, meta: Meta) -> Rec
async def delete_many(self, record_ids: Sequence[_ID], meta: Meta) -> list[_ID]:
"""Delete the matching records and return their IDs."""

async def get_many_ref_name(self, target: str, meta: Meta) -> str:
"""Return the resource name for the reference.
This can be used to change which resource should be returned by get_many_ref().
For example, if we have an SQLAlchemy model called 'parent' with a relationship
called children, then a normal get_many_ref_name() call would go to the 'child'
model with the details from the parent, and the default behaviour would work.
However, the SQLAlchemy backend uses the meta to switch this and send the request
to the 'parent' model instead and then use the children ORM attribute to fetch
the referenced resources, thus requiring this method to return 'child'.
This allows the SQLAlchemy backend to support complex relationships (e.g.
many-to-many) without needing react-admin to know the details.
"""
return self.name

# https://marmelab.com/react-admin/DataProviderWriting.html

@final
async def _get_list(self, request: web.Request) -> web.Response:
await check_permission(request, f"admin.{self.name}.view", context=(request, None))
query = check(GetListParams, request.query)

# When sort order refers to "id", this should be translated to primary key.
if query["sort"]["field"] == "id":
query["sort"]["field"] = self.primary_key[0]
else:
query["sort"]["field"] = query["sort"]["field"].removeprefix("data.")

query["filter"].update(check(dict[str, object], query["filter"].pop("data", {}))) # type: ignore[type-var]

merged_filter = {}
for k, v in query["filter"].items():
if k.startswith("fk_"):
v = check(str, v)
for c, cv in zip(k.removeprefix("fk_").split("__"), v.split("|")):
merged_filter[c] = check(self._raw_record_type[c], cv)
else:
merged_filter[k] = check(self._raw_record_type[k], v)
query["filter"] = merged_filter

# Add filters from advanced permissions.
# The permissions will be cached on the request from a previous permissions check.
permissions = permissions_as_dict(request["aiohttpadmin_permissions"])
filters = permissions.get(f"admin.{self.name}.view",
permissions.get(f"admin.{self.name}.*", {}))
for k, v in filters.items():
query["filter"][k] = v
self._process_list_query(query, request)

raw_results, total = await self.get_list(query)
results = [await self._convert_record(r, request) for r in raw_results
Expand Down Expand Up @@ -239,6 +256,33 @@ async def _get_many(self, request: web.Request) -> web.Response:
if await permits(request, f"admin.{self.name}.view", context=(request, r))]
return json_response({"data": results})

@final
async def _get_many_ref(self, request: web.Request) -> web.Response:
query = check(GetManyRefAPIParams, request.query)
meta = query["filter"].pop("__meta__", None)
if meta is not None:
query["meta"] = check(dict[str, object], meta)
reference = await self.get_many_ref_name(query["target"], query.get("meta"))
ref_model = request.app[resources_key][reference]

await check_permission(request, f"admin.{ref_model.name}.view", context=(request, None))

ref_model._process_list_query(query, request)

if query["target"].startswith("fk_"):
target = tuple(query["target"].removeprefix("fk_").split("__"))
record_id = tuple(check(self._raw_record_type[k], v)
for k, v in zip(target, query["id"].split("|")))
else:
target = (query["target"],)
record_id = check(self._id_type, query["id"].split("|"))

raw_results, total = await self.get_many_ref({**query, "target": target, "id": record_id})

results = [await ref_model._convert_record(r, request) for r in raw_results
if await permits(request, f"admin.{ref_model.name}.view", context=(request, r))]
return json_response({"data": results, "total": total})

@final
async def _create(self, request: web.Request) -> web.Response:
query = check(CreateParams, request.query)
Expand Down Expand Up @@ -350,7 +394,7 @@ async def _delete_many(self, request: web.Request) -> web.Response:
@final
def _check_record(self, record: Record) -> Record:
"""Check and convert input record."""
return check(self._record_type, record) # type: ignore[no-any-return]
return check(self._record_type, record)

@final
async def _convert_record(self, record: Record, request: web.Request) -> APIRecord:
Expand All @@ -371,6 +415,33 @@ def _convert_ids(self, ids: Sequence[_ID]) -> tuple[str, ...]:
"""Convert IDs to correct output format."""
return tuple(str(i) for i in ids)

def _process_list_query(self, query: _ListQuery, request: web.Request) -> None:
# When sort order refers to "id", this should be translated to primary key.
if query["sort"]["field"] == "id":
query["sort"]["field"] = self.primary_key[0]
else:
query["sort"]["field"] = query["sort"]["field"].removeprefix("data.")

query["filter"].update(check(dict[str, object], query["filter"].pop("data", {})))

merged_filter = {}
for k, v in query["filter"].items():
if k.startswith("fk_"):
v = check(str, v)
for c, cv in zip(k.removeprefix("fk_").split("__"), v.split("|")):
merged_filter[c] = check(self._raw_record_type[c], cv)
else:
merged_filter[k] = check(self._raw_record_type[k], v)
query["filter"] = merged_filter

# Add filters from advanced permissions.
# The permissions will be cached on the request from a previous permissions check.
permissions = permissions_as_dict(request["aiohttpadmin_permissions"])
filters = permissions.get(f"admin.{self.name}.view",
permissions.get(f"admin.{self.name}.*", {}))
for k, v in filters.items():
query["filter"][k] = v

@cached_property
def routes(self) -> tuple[web.RouteDef, ...]:
"""Routes to act on this resource.
Expand All @@ -382,6 +453,7 @@ def routes(self) -> tuple[web.RouteDef, ...]:
web.get(url + "/list", self._get_list, name=self.name + "_get_list"),
web.get(url + "/one", self._get_one, name=self.name + "_get_one"),
web.get(url, self._get_many, name=self.name + "_get_many"),
web.get(url + "/ref", self._get_many_ref, name=self.name + "_get_many_ref"),
web.post(url, self._create, name=self.name + "_create"),
web.put(url + "/update", self._update, name=self.name + "_update"),
web.put(url + "/update_many", self._update_many, name=self.name + "_update_many"),
Expand Down
Loading

0 comments on commit 84cf5f1

Please sign in to comment.