diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ee51b2a..e1ee6a0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,8 @@ repos: (?x)^( docs/.*| tests/.*| - conftest.py + conftest.py| + aiida_restapi/models/json_api.py )$ - repo: https://github.com/PyCQA/pylint @@ -51,6 +52,7 @@ repos: - fastapi~=0.65.1 - uvicorn[standard]>=0.12.0,<0.14.0 - pydantic~=1.8.2 + - pydantic-jsonapi==0.11.0 - python-jose - python-multipart - passlib diff --git a/aiida_restapi/models/__init__.py b/aiida_restapi/models/__init__.py new file mode 100644 index 0000000..b9f1cad --- /dev/null +++ b/aiida_restapi/models/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +"""pydantic models for AiiDA REST API. +""" +# pylint: disable=too-few-public-methods + +from .entities import * +from .responses import * \ No newline at end of file diff --git a/aiida_restapi/models.py b/aiida_restapi/models/entities.py similarity index 97% rename from aiida_restapi/models.py rename to aiida_restapi/models/entities.py index d3da7c1..426639f 100644 --- a/aiida_restapi/models.py +++ b/aiida_restapi/models/entities.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -"""Schemas for AiiDA REST API. +"""ORM entity schemas for AiiDA REST API. Models in this module mirror those in `aiida.backends.djsite.db.models` and `aiida.backends.sqlalchemy.models` @@ -12,6 +12,8 @@ from aiida import orm from pydantic import BaseModel, Field +__all__ = ('AiidaModel', 'Comment', 'User') + # Template type for subclasses of `AiidaModel` ModelType = TypeVar("ModelType", bound="AiidaModel") @@ -94,3 +96,4 @@ class User(AiidaModel): institution: Optional[str] = Field( description="Host institution or workplace of the user" ) + diff --git a/aiida_restapi/models/json_api.py b/aiida_restapi/models/json_api.py new file mode 100644 index 0000000..e1f05e0 --- /dev/null +++ b/aiida_restapi/models/json_api.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +"""Adapted JSON API models. + +Adapting pydantic_jsonapi.response to fit the non-compliant AiiDA REST API responses. +Changes made as comments. +""" +# pylint: disable=missing-class-docstring,redefined-builtin,missing-function-docstring,too-few-public-methods +from typing import Generic, Optional, TypeVar, get_type_hints + +from pydantic.generics import GenericModel +from pydantic_jsonapi.filter import filter_none +from pydantic_jsonapi.relationships import ResponseRelationshipsType +from pydantic_jsonapi.resource_links import ResourceLinks + +# TypeT = TypeVar('TypeT', bound=str) +# AttributesT = TypeVar("AttributesT") + + +# class ResponseDataModel(GenericModel): + +# id: str +# relationships: Optional[ResponseRelationshipsType] +# links: Optional[ResourceLinks] + +# class Config: +# validate_all = True +# extra = "allow" # added + + +DataT = TypeVar("DataT") # , bound=ResponseDataModel) + + +class ResponseModel(GenericModel, Generic[DataT]): + + data: DataT + included: Optional[list] + meta: Optional[dict] + links: Optional[ResourceLinks] + + def dict(self, *, serlialize_none: bool = False, **kwargs): + response = super().dict(**kwargs) + if serlialize_none: + return response + return filter_none(response) + + @classmethod + def resource_object( + cls, + *, + id: str, + attributes: Optional[dict] = None, + relationships: Optional[dict] = None, + links: Optional[dict] = None, + ) -> DataT: + data_type = get_type_hints(cls)["data"] + if getattr(data_type, "__origin__", None) is list: + data_type = data_type.__args__[0] + typename = get_type_hints(data_type)["type"].__args__[0] + return data_type( + id=id, + type=typename, + attributes=attributes or {}, + relationships=relationships, + links=links, + ) diff --git a/aiida_restapi/models/responses.py b/aiida_restapi/models/responses.py new file mode 100644 index 0000000..13a8f9e --- /dev/null +++ b/aiida_restapi/models/responses.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +"""Response schemas for AiiDA REST API. + +Builds upon response schemas from `json_api` module. +""" + +from typing import List, Type, TypeVar + +from pydantic_jsonapi import ErrorResponse + +from . import json_api +from .entities import AiidaModel # pylint: disable=unused-import + +__all__ = ("EntityResponse", "ErrorResponse") + +ModelType = TypeVar("ModelType", bound="AiidaModel") + + +def EntityResponse( + # type_string: str, + attributes_model: ModelType, + *, + use_list: bool = False, +) -> Type[json_api.ResponseModel]: + """Returns entity-specif pydantic response model.""" + response_data_model = attributes_model + type_string = ( + attributes_model._orm_entity.__name__.lower() # pylint: disable=protected-access + ) + if use_list: + response_data_model.__name__ = f"ListResponseData[{type_string}]" + response_model = json_api.ResponseModel[List[response_data_model]] + response_model.__name__ = f"ListResponse[{type_string}]" + else: + response_data_model.__name__ = f"ResponseData[{type_string}]" + response_model = json_api.ResponseModel[response_data_model] + response_model.__name__ = f"Response[{type_string}]" + return response_model diff --git a/aiida_restapi/routers/users.py b/aiida_restapi/routers/users.py index 624b38e..85e7810 100644 --- a/aiida_restapi/routers/users.py +++ b/aiida_restapi/routers/users.py @@ -1,44 +1,50 @@ # -*- coding: utf-8 -*- """Declaration of FastAPI application.""" -from typing import List, Optional - from aiida import orm from aiida.cmdline.utils.decorators import with_dbenv +from aiida.common.exceptions import NotExistent from fastapi import APIRouter, Depends +from fastapi.exceptions import HTTPException -from aiida_restapi.models import User +from aiida_restapi.models import EntityResponse, User from .auth import get_current_active_user +__all__ = ("router",) + router = APIRouter() +SingleUserResponse = EntityResponse(User) +ManyUserResponse = EntityResponse(User, use_list=True) + -@router.get("/users", response_model=List[User]) +@router.get("/users", response_model=ManyUserResponse) @with_dbenv() -async def read_users() -> List[User]: +async def read_users() -> ManyUserResponse: """Get list of all users""" - return [User.from_orm(u) for u in orm.User.objects.find()] + return ManyUserResponse(data=User.get_entities()) -@router.get("/users/{user_id}", response_model=User) +@router.get("/users/{user_id}", response_model=SingleUserResponse) @with_dbenv() -async def read_user(user_id: int) -> Optional[User]: +async def read_user(user_id: int) -> SingleUserResponse: """Get user by id.""" - orm_user = orm.User.objects.get(id=user_id) + try: + orm_user = orm.User.objects.get(id=user_id) + except NotExistent as exc: + raise HTTPException(status_code=404, detail="User not found") from exc - if orm_user: - return User.from_orm(orm_user) + return SingleUserResponse(user=User.from_orm(orm_user)) - return None - -@router.post("/users", response_model=User) +@router.post("/users", response_model=SingleUserResponse) +@with_dbenv() async def create_user( user: User, current_user: User = Depends( get_current_active_user ), # pylint: disable=unused-argument -) -> User: +) -> SingleUserResponse: """Create new AiiDA user.""" orm_user = orm.User(**user.dict(exclude_unset=True)).store() - return User.from_orm(orm_user) + return SingleUserResponse(data=User.from_orm(orm_user)) diff --git a/setup.json b/setup.json index 84c6947..9cfe5ec 100644 --- a/setup.json +++ b/setup.json @@ -30,7 +30,8 @@ "sqlalchemy<1.4", "fastapi~=0.65.1", "uvicorn[standard]>=0.12.0,<0.14.0", - "pydantic~=1.8.2" + "pydantic~=1.8.2", + "pydantic-jsonapi==0.11.0" ], "extras_require": { "testing": [ diff --git a/tests/test_users.py b/tests/test_users.py index 56a333f..349948a 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -8,7 +8,7 @@ def test_get_single_user(default_users, client): # pylint: disable=unused-argum """Test retrieving a single user.""" response = client.get("/users/1") assert response.status_code == 200 - assert response.json()["first_name"] == "Giuseppe" + assert response.json()["data"]["first_name"] == "Giuseppe" def test_get_users(default_users, client): # pylint: disable=unused-argument @@ -19,7 +19,7 @@ def test_get_users(default_users, client): # pylint: disable=unused-argument """ response = client.get("/users") assert response.status_code == 200 - assert len(response.json()) == 2 + 1 + assert len(response.json()["data"]) == 2 + 1 def test_create_user(client, authenticate): # pylint: disable=unused-argument @@ -30,5 +30,5 @@ def test_create_user(client, authenticate): # pylint: disable=unused-argument assert response.status_code == 200, response.content response = client.get("/users") - first_names = [user["first_name"] for user in response.json()] + first_names = [user["first_name"] for user in response.json()["data"]] assert "New" in first_names