Skip to content

Commit

Permalink
Fix type errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer committed Jul 30, 2023
1 parent c331ba9 commit 727756c
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 135 deletions.
6 changes: 3 additions & 3 deletions aiohttp_admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from aiohttp import web
from aiohttp.typedefs import Handler
from aiohttp_session.cookie_storage import EncryptedCookieStorage
from pydantic import ValidationError, parse_obj_as
from pydantic import ValidationError

from .routes import setup_resources, setup_routes
from .security import AdminAuthorizationPolicy, Permissions, TokenIdentityPolicy
from .security import AdminAuthorizationPolicy, Permissions, TokenIdentityPolicy, check
from .types import Schema, UserDetails

__all__ = ("Permissions", "Schema", "UserDetails", "setup")
Expand Down Expand Up @@ -67,7 +67,7 @@ def value(r: web.RouteDef) -> tuple[str, str]:
m = res["model"]
admin["state"]["resources"][m.name]["urls"] = {key(r): value(r) for r in m.routes}

schema = parse_obj_as(Schema, schema)
schema = check(Schema, schema)
if secret is None:
secret = secrets.token_bytes()

Expand Down
30 changes: 15 additions & 15 deletions aiohttp_admin/backends/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

from aiohttp import web
from aiohttp_security import check_permission, permits
from pydantic import Json, parse_obj_as
from pydantic import Json

from ..security import permissions_as_dict
from ..security import check, permissions_as_dict
from ..types import ComponentState, InputState

if sys.version_info >= (3, 12):
Expand Down Expand Up @@ -155,7 +155,7 @@ async def delete_many(self, params: DeleteManyParams) -> list[Union[int, str]]:

async def _get_list(self, request: web.Request) -> web.Response:
await check_permission(request, f"admin.{self.name}.view", context=(request, None))
query = parse_obj_as(GetListParams, request.query)
query = check(GetListParams, request.query)

# When sort order refers to "id", this should be translated to primary key.
if query["sort"]["field"] == "id":
Expand All @@ -180,7 +180,7 @@ async def _get_list(self, request: web.Request) -> web.Response:

async def _get_one(self, request: web.Request) -> web.Response:
await check_permission(request, f"admin.{self.name}.view", context=(request, None))
query = parse_obj_as(GetOneParams, request.query)
query = check(GetOneParams, request.query)

result = await self.get_one(query)
if not await permits(request, f"admin.{self.name}.view", context=(request, result)):
Expand All @@ -191,7 +191,7 @@ async def _get_one(self, request: web.Request) -> web.Response:

async def _get_many(self, request: web.Request) -> web.Response:
await check_permission(request, f"admin.{self.name}.view", context=(request, None))
query = parse_obj_as(GetManyParams, request.query)
query = check(GetManyParams, request.query)

results = await self.get_many(query)
if not results:
Expand All @@ -204,12 +204,12 @@ async def _get_many(self, request: web.Request) -> web.Response:
return json_response({"data": results})

async def _create(self, request: web.Request) -> web.Response:
query = parse_obj_as(CreateParams, request.query)
query = check(CreateParams, request.query)
# TODO(Pydantic): Dissallow extra arguments
for k in query["data"]:
if k not in self.inputs and k != "id":
raise web.HTTPBadRequest(reason=f"Invalid field '{k}'")
query["data"] = parse_obj_as(self._record_type, query["data"])
query["data"] = check(self._record_type, query["data"])
await check_permission(request, f"admin.{self.name}.add", context=(request, query["data"]))
for k, v in query["data"].items():
if v is not None:
Expand All @@ -223,13 +223,13 @@ async def _create(self, request: web.Request) -> web.Response:

async def _update(self, request: web.Request) -> web.Response:
await check_permission(request, f"admin.{self.name}.edit", context=(request, None))
query = parse_obj_as(UpdateParams, request.query)
query = check(UpdateParams, request.query)
# TODO(Pydantic): Dissallow extra arguments
for k in query["data"]:
if k not in self.inputs and k != "id":
raise web.HTTPBadRequest(reason=f"Invalid field '{k}'")
query["data"] = parse_obj_as(self._record_type, query["data"])
query["previousData"] = parse_obj_as(self._record_type, query["previousData"])
query["data"] = check(self._record_type, query["data"])
query["previousData"] = check(self._record_type, query["previousData"])

if self.primary_key != "id":
query["data"].pop("id", None)
Expand Down Expand Up @@ -257,12 +257,12 @@ async def _update(self, request: web.Request) -> web.Response:

async def _update_many(self, request: web.Request) -> web.Response:
await check_permission(request, f"admin.{self.name}.edit", context=(request, None))
query = parse_obj_as(UpdateManyParams, request.query)
query = check(UpdateManyParams, request.query)
# TODO(Pydantic): Dissallow extra arguments
for k in query["data"]:
if k not in self.inputs and k != "id":
raise web.HTTPBadRequest(reason=f"Invalid field '{k}'")
query["data"] = parse_obj_as(self._record_type, query["data"])
query["data"] = check(self._record_type, query["data"])

# Check original records are allowed by permission filters.
originals = await self.get_many({"ids": query["ids"]})
Expand All @@ -284,8 +284,8 @@ async def _update_many(self, request: web.Request) -> web.Response:

async def _delete(self, request: web.Request) -> web.Response:
await check_permission(request, f"admin.{self.name}.delete", context=(request, None))
query = parse_obj_as(DeleteParams, request.query)
query["previousData"] = parse_obj_as(self._record_type, query["previousData"])
query = check(DeleteParams, request.query)
query["previousData"] = check(self._record_type, query["previousData"])

original = await self.get_one({"id": query["id"]})
if not await permits(request, f"admin.{self.name}.delete", context=(request, original)):
Expand All @@ -298,7 +298,7 @@ async def _delete(self, request: web.Request) -> web.Response:

async def _delete_many(self, request: web.Request) -> web.Response:
await check_permission(request, f"admin.{self.name}.delete", context=(request, None))
query = parse_obj_as(DeleteManyParams, request.query)
query = check(DeleteManyParams, request.query)

originals = await self.get_many(query)
allowed = await asyncio.gather(*(permits(request, f"admin.{self.name}.delete",
Expand Down
12 changes: 5 additions & 7 deletions aiohttp_admin/backends/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .abc import (
AbstractAdminResource, CreateParams, DeleteManyParams, DeleteParams, GetListParams,
GetManyParams, GetOneParams, Record, UpdateManyParams, UpdateParams)
from ..types import comp, func, regex
from ..types import FunctionState, comp, func, regex

if sys.version_info >= (3, 10):
from typing import ParamSpec
Expand Down Expand Up @@ -192,7 +192,7 @@ def __init__(self, db: AsyncEngine, model_or_table: Union[sa.Table, type[Declara
props = props.copy()
show = c is not table.autoincrement_column
props["validate"] = self._get_validators(table, c)
self.inputs[c.name] = comp(inp, props)
self.inputs[c.name] = comp(inp, props) # type: ignore[assignment]
self.inputs[c.name]["show_create"] = show

if not isinstance(model_or_table, sa.Table):
Expand Down Expand Up @@ -228,7 +228,7 @@ def __init__(self, db: AsyncEngine, model_or_table: Union[sa.Table, type[Declara
c_props["source"] = kc.name
children.append(comp(field, c_props))
container = "Datagrid" if t == "ReferenceManyField" else "DatagridSingle"
datagrid = comp(container, {"children": children, "rowClick": show})
datagrid = comp(container, {"children": children, "rowClick": "show"})
props["children"] = (datagrid,)

self.fields[name] = comp(t, props)
Expand Down Expand Up @@ -322,10 +322,8 @@ async def delete_many(self, params: DeleteManyParams) -> list[Union[str, int]]:
r = await conn.scalars(stmt.returning(self._table.c[self.primary_key]))
return list(r)

def _get_validators(
self, table: sa.Table, c: sa.Column[object]
) -> list[tuple[Union[str, int], ...]]:
validators: list[tuple[Union[str, int], ...]] = []
def _get_validators(self, table: sa.Table, c: sa.Column[object]) -> list[FunctionState]:
validators: list[FunctionState] = []
if c.default is None and c.server_default is None and not c.nullable:
validators.append(func("required", ()))
max_length = getattr(c.type, "length", None)
Expand Down
21 changes: 17 additions & 4 deletions aiohttp_admin/security.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import json
from collections.abc import Collection, Mapping, Sequence
from collections.abc import Collection, Hashable, Mapping, Sequence
from enum import Enum
from typing import Optional, Union
from functools import lru_cache
from typing import Optional, Type, TypeVar, Union

from aiohttp import web
from aiohttp_security import AbstractAuthorizationPolicy, SessionIdentityPolicy
from cryptography.fernet import Fernet, InvalidToken
from pydantic import Json, ValidationError, parse_obj_as
from pydantic import Json, TypeAdapter, ValidationError

from .types import IdentityDict, Schema, UserDetails

_T = TypeVar("_T", bound=Hashable)


@lru_cache
def _get_schema(t: Type[_T]) -> TypeAdapter[_T]:
return TypeAdapter(t)


def check(t: Type[_T], value: object) -> _T:
"""Validate value is of static type t."""
return _get_schema(t).validate_python(value)


class Permissions(str, Enum):
view = "admin.view"
Expand Down Expand Up @@ -101,7 +114,7 @@ async def identify(self, request: web.Request) -> Optional[str]:
# Validate JS token
hdr = request.headers.get("Authorization")
try:
identity_data = parse_obj_as(Json[IdentityDict], hdr)
identity_data = check(Json[IdentityDict], hdr)
except ValidationError:
return None

Expand Down
6 changes: 3 additions & 3 deletions aiohttp_admin/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
from collections.abc import Callable, Collection, Sequence
from typing import Any, Awaitable, Literal, Optional, Union
from typing import Any, Awaitable, Literal, Mapping, Optional, Union

if sys.version_info >= (3, 12):
from typing import TypedDict

Check warning on line 6 in aiohttp_admin/types.py

View check run for this annotation

Codecov / codecov/patch

aiohttp_admin/types.py#L6

Added line #L6 was not covered by tests
Expand Down Expand Up @@ -124,9 +124,9 @@ class State(TypedDict):
js_module: Optional[str]


def comp(t: str, props: Optional[dict[str, object]] = None) -> ComponentState:
def comp(t: str, props: Optional[Mapping[str, object]] = None) -> ComponentState:
"""Use a component of type t with the given props."""
return {"__type__": "component", "type": t, "props": props or {}}
return {"__type__": "component", "type": t, "props": dict(props or {})}


def func(name: str, args: Optional[Sequence[object]] = None) -> FunctionState:
Expand Down
6 changes: 4 additions & 2 deletions aiohttp_admin/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from aiohttp import web
from aiohttp_security import forget, remember
from pydantic import Json, parse_obj_as
from pydantic import Json

from .security import check

if sys.version_info >= (3, 12):
from typing import TypedDict

Check warning on line 12 in aiohttp_admin/views.py

View check run for this annotation

Codecov / codecov/patch

aiohttp_admin/views.py#L12

Added line #L12 was not covered by tests
Expand Down Expand Up @@ -52,7 +54,7 @@ async def index(request: web.Request) -> web.Response:

async def token(request: web.Request) -> web.Response:
"""Validate user credentials and log the user in."""
data = parse_obj_as(Json[_Login], await request.read())
data = check(Json[_Login], await request.read())

check_credentials = request.app["check_credentials"]
if not await check_credentials(data["username"], data["password"]):
Expand Down
4 changes: 2 additions & 2 deletions tests/_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from aiohttp_admin.backends.abc import (
AbstractAdminResource, CreateParams, DeleteManyParams, DeleteParams, GetListParams,
GetManyParams, GetOneParams, Record, UpdateManyParams, UpdateParams)
from aiohttp_admin.types import FieldState, InputState
from aiohttp_admin.types import ComponentState, InputState


class DummyResource(AbstractAdminResource):
def __init__(self, name: str, fields: dict[str, FieldState],
def __init__(self, name: str, fields: dict[str, ComponentState],
inputs: dict[str, InputState], primary_key: str):
self.name = name
self.fields = fields
Expand Down
47 changes: 23 additions & 24 deletions tests/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import aiohttp_admin
from _auth import check_credentials
from _resources import DummyResource
from aiohttp_admin.types import comp, func


def test_path() -> None:
Expand Down Expand Up @@ -39,23 +40,25 @@ def test_no_js_module() -> None:

def test_validators() -> None:
dummy = DummyResource(
"dummy", {"id": {"type": "NumberField", "props": {}}},
{"id": {"type": "NumberInput", "props": {}, "show_create": True,
"validators": (("required",),)}}, "id")
"dummy",
{"id": {"__type__": "component", "type": "NumberField", "props": {}}},
{"id": {"__type__": "component", "type": "NumberInput",
"props": {"validate": ({"__type__": "function", "name": "required", "args": ()},)},
"show_create": True}},
"id")
app = web.Application()
schema: aiohttp_admin.Schema = {"security": {"check_credentials": check_credentials},
"resources": ({"model": dummy,
"validators": {"id": (("minValue", 3),)}},)}
schema: aiohttp_admin.Schema = {
"security": {"check_credentials": check_credentials},
"resources": ({"model": dummy, "validators": {"id": (func("minValue", (3,)),)}},)}
admin = aiohttp_admin.setup(app, schema)
validators = admin["state"]["resources"]["dummy"]["inputs"]["id"]["validators"]
# TODO(Pydantic2): Should be int 3 in both lines.
assert validators == (("required",), ("minValue", "3"))
assert ("minValue", "3") not in dummy.inputs["id"]["validators"]
validators = admin["state"]["resources"]["dummy"]["inputs"]["id"]["props"]["validate"]
assert validators == (func("required", ()), func("minValue", (3,)))
assert ("minValue", 3) not in dummy.inputs["id"]["props"]["validate"] # type: ignore[operator]


def test_re() -> None:
test_re = DummyResource("testre", {"id": {"type": "NumberField", "props": {}},
"value": {"type": "TextField", "props": {}}}, {}, "id")
test_re = DummyResource(
"testre", {"id": comp("NumberField"), "value": comp("TextField")}, {}, "id")

app = web.Application()
schema: aiohttp_admin.Schema = {"security": {"check_credentials": check_credentials},
Expand Down Expand Up @@ -90,10 +93,9 @@ def test_display() -> None:
app = web.Application()
model = DummyResource(
"test",
{"id": {"type": "TextField", "props": {}}, "foo": {"type": "TextField", "props": {}}},
{"id": {"type": "TextInput", "props": {}, "show_create": False,
"validators": (("required",),)},
"foo": {"type": "TextInput", "props": {}, "show_create": True, "validators": ()}},
{"id": comp("TextField"), "foo": comp("TextField")},
{"id": comp("TextInput", {"validate": (func("required", ()),)}) | {"show_create": False}, # type: ignore[dict-item]
"foo": comp("TextInput") | {"show_create": True}}, # type: ignore[dict-item]
"id")
schema: aiohttp_admin.Schema = {"security": {"check_credentials": check_credentials},
"resources": ({"model": model, "display": ("foo",)},)}
Expand All @@ -102,14 +104,13 @@ def test_display() -> None:

test_state = admin["state"]["resources"]["test"]
assert test_state["list_omit"] == ("id",)
assert test_state["inputs"]["id"]["props"] == {}
assert test_state["inputs"]["id"]["props"] == {"validate": (func("required", ()),)}
assert test_state["inputs"]["foo"]["props"] == {"alwaysOn": "alwaysOn"}


def test_display_invalid() -> None:
app = web.Application()
model = DummyResource("test", {"id": {"type": "TextField", "props": {}},
"foo": {"type": "TextField", "props": {}}}, {}, "id")
model = DummyResource("test", {"id": comp("TextField"), "foo": comp("TextField")}, {}, "id")
schema: aiohttp_admin.Schema = {"security": {"check_credentials": check_credentials},
"resources": ({"model": model, "display": ("bar",)},)}

Expand All @@ -121,9 +122,8 @@ def test_extra_props() -> None:
app = web.Application()
model = DummyResource(
"test",
{"id": {"type": "TextField", "props": {"textAlign": "right", "placeholder": "foo"}}},
{"id": {"type": "TextInput", "props": {"resettable": False, "type": "text"},
"show_create": False, "validators": ()}},
{"id": comp("TextField", {"textAlign": "right", "placeholder": "foo"})},
{"id": comp("TextInput", {"resettable": False, "type": "text"}) | {"show_create": False}}, # type: ignore[dict-item]
"id")
schema: aiohttp_admin.Schema = {
"security": {"check_credentials": check_credentials},
Expand All @@ -144,8 +144,7 @@ def test_extra_props() -> None:

def test_invalid_repr() -> None:
app = web.Application()
model = DummyResource("test", {"id": {"type": "TextField", "props": {}},
"foo": {"type": "TextField", "props": {}}}, {}, "id")
model = DummyResource("test", {"id": comp("TextField"), "foo": comp("TextField")}, {}, "id")
schema: aiohttp_admin.Schema = {"security": {"check_credentials": check_credentials},
"resources": ({"model": model, "repr": "bar"},)}

Expand Down
Loading

0 comments on commit 727756c

Please sign in to comment.