From edfcd8cfdf395593dfc7dc1904e3b81611b7550f Mon Sep 17 00:00:00 2001 From: David Maxson Date: Wed, 22 May 2024 12:17:03 -0700 Subject: [PATCH] Update to extract class definitions from method, making intellisense actually work --- src/pynamodb_single_table/base.py | 252 ++++++++++++++---------------- tests/test_basics.py | 62 +++++--- 2 files changed, 155 insertions(+), 159 deletions(-) diff --git a/src/pynamodb_single_table/base.py b/src/pynamodb_single_table/base.py index cc44312..aa9278a 100644 --- a/src/pynamodb_single_table/base.py +++ b/src/pynamodb_single_table/base.py @@ -1,170 +1,156 @@ import abc import itertools import uuid +from inspect import isabstract from pydantic import BaseModel -from pydantic import ConfigDict +from pydantic import PrivateAttr from pydantic import computed_field from pynamodb.attributes import JSONAttribute from pynamodb.attributes import UnicodeAttribute from pynamodb.attributes import VersionAttribute from pynamodb.indexes import KeysOnlyProjection from pynamodb.indexes import LocalSecondaryIndex +from pynamodb.models import MetaProtocol from pynamodb.models import Model from pynamodb_attributes import UUIDAttribute from typing_extensions import Self -def build_base_model(table_name, host=None) -> type[BaseModel]: - """Builds a DynamoDB table model for storing a single-table database design. +class StrIndex(LocalSecondaryIndex): + class Meta: + projection = KeysOnlyProjection() - This table uses: - - The concrete table name as a hash key - - A UUID as the row-level primary key - - A string representation for each row that can be used as a secondary primary key + table_name = UnicodeAttribute(hash_key=True) + str_id = UnicodeAttribute(range_key=True) + uid = UUIDAttribute(null=False) - Parameters: - - table_name (str): The name of the DynamoDB table. - - host (str, optional): The host URL of the DynamoDB instance. If not specified, - the default AWS endpoint is used. - Returns: - - type[BaseModel]: A Pydantic base model for creating concrete tables - """ +class RootModelPrototype(Model): + table_name = UnicodeAttribute(hash_key=True) + uid = UUIDAttribute(range_key=True, default_for_new=uuid.uuid4) + str_id = UnicodeAttribute(null=False) + index_by_str = StrIndex() + version = VersionAttribute(null=False, default=1) + data = JSONAttribute(null=False) - class StrIndex(LocalSecondaryIndex): - class Meta: - projection = KeysOnlyProjection() - table_name = UnicodeAttribute(hash_key=True) - str_id = UnicodeAttribute(range_key=True) - uid = UUIDAttribute(null=False) +class SingleTableBaseModel(BaseModel): + _PynamodbMeta: MetaProtocol = PrivateAttr() + __pynamodb_model__: type[RootModelPrototype] = None - class RootMeta: - pass + uid: uuid.UUID | None = None - RootMeta.table_name = table_name - RootMeta.host = host + def __init_subclass__(cls, **kwargs): + if cls.__pynamodb_model__: + assert issubclass( + cls.__pynamodb_model__, RootModelPrototype + ) # TODO: Just duck type? + else: + if not cls._PynamodbMeta: + raise TypeError(f"Must define the PynamoDB metadata for {cls}") - class RootModel(Model): - Meta = RootMeta + class RootModel(RootModelPrototype): + Meta = cls._PynamodbMeta - table_name = UnicodeAttribute(hash_key=True) - uid = UUIDAttribute(range_key=True, default_for_new=uuid.uuid4) - str_id = UnicodeAttribute(null=False) - index_by_str = StrIndex() - version = VersionAttribute(null=False, default=1) - data = JSONAttribute(null=False) + cls.__pynamodb_model__ = RootModel - class BaseTableModel(BaseModel, abc.ABC): - __pynamodb_model__: type[RootModel] = RootModel - __table_name__ = None - __str_id__ = None - model_config = ConfigDict(from_attributes=True) + if isabstract(cls) or abc.ABC in cls.__bases__: + return super().__init_subclass__(**kwargs) - uid: uuid.UUID | None = None + if not cls.__table_name__: + raise TypeError( + f"Must define the table name for {cls} (the inner table, not the pynamodb table)" + ) - @computed_field - @property - def str_id(self) -> str: - return getattr(self, self.__str_id__) + if not cls.__str_id_field__: + raise TypeError(f"Must define the string ID field for {cls}") - class DoesNotExist(Exception): - pass + @computed_field + @property + def str_id(self) -> str: + return getattr(self, self.__str_id_field__) - class MultipleObjectsFound(Exception): - pass + class DoesNotExist(Exception): + pass - @classmethod - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - if cls.__table_name__ is None: - raise ValueError( - f"Must provide table name ({cls.__name__}.__table_name__) " - f"when subclassing BaseTableModel" - ) - if cls.__str_id__ is None: - raise ValueError( - f"Must provide string ID ({cls.__name__}.__str_id__) " - f"when subclassing BaseTableModel" - ) - - @classmethod - def get_or_create(cls, **kwargs) -> tuple[Self, bool]: - obj = cls.model_validate(kwargs) - - if obj.uid is not None: - try: - return cls.get_by_uid(obj.uid), False - except cls.DoesNotExist: - pass + class MultipleObjectsFound(Exception): + pass + + @classmethod + def get_or_create(cls, **kwargs) -> tuple[Self, bool]: + obj = cls.model_validate(kwargs) + if obj.uid is not None: try: - return cls.get_by_str(obj.str_id), False + return cls.get_by_uid(obj.uid), False except cls.DoesNotExist: pass - obj.create() - return obj, True - - @classmethod - def get_by_str(cls, str_id: str) -> Self: - results = cls.__pynamodb_model__.index_by_str.query( - cls.__table_name__, cls.__pynamodb_model__.str_id == str_id - ) - results = list(itertools.islice(results, 2)) - if len(results) == 0: - raise cls.DoesNotExist() - if len(results) > 1: - raise cls.MultipleObjectsFound() - uuid_ = results[0].uid - - try: - return cls.get_by_uid(uuid_) - except cls.__pynamodb_model__.DoesNotExist as e: - raise cls.DoesNotExist() from e - - @classmethod - def get_by_uid(cls, uuid_: uuid.UUID) -> Self: - response = cls.__pynamodb_model__.get(cls.__table_name__, uuid_) - return cls.model_validate(dict(uid=response.uid, **response.data)) - - def create(self): - item = self.__pynamodb_model__( - self.__table_name__, - str_id=self.str_id, - data=self.model_dump(mode="json", exclude={"uid", "str_id"}), - ) - condition = self.__pynamodb_model__._hash_key_attribute().does_not_exist() - rk_attr = self.__pynamodb_model__._range_key_attribute() - if rk_attr: - condition &= rk_attr.does_not_exist() - item.save(condition=condition, add_version_condition=False) - assert item.uid is not None - self.uid = item.uid - return self - - def save(self): - item = self.__pynamodb_model__( - self.__table_name__, - uid=self.uid, - str_id=self.str_id, - data=self.model_dump(mode="json", exclude={"uid", "str_id"}), - ) - item.save(add_version_condition=False) - assert item.uid is not None - self.uid = item.uid - - @classmethod - def count(cls, *args, **kwargs): - return cls.__pynamodb_model__.count(cls.__table_name__, *args, **kwargs) - - @classmethod - def query(cls, *args, **kwargs): - return cls.__pynamodb_model__.query(cls.__table_name__, *args, **kwargs) - - @classmethod - def scan(cls, *args, **kwargs): - return cls.__pynamodb_model__.scan(cls.__table_name__, *args, **kwargs) + try: + return cls.get_by_str(obj.str_id), False + except cls.DoesNotExist: + pass - return BaseTableModel + obj.create() + return obj, True + + @classmethod + def get_by_str(cls, str_id: str) -> Self: + results = cls.__pynamodb_model__.index_by_str.query( + cls.__table_name__, cls.__pynamodb_model__.str_id == str_id + ) + results = list(itertools.islice(results, 2)) + if len(results) == 0: + raise cls.DoesNotExist() + if len(results) > 1: + raise cls.MultipleObjectsFound() + uuid_ = results[0].uid + + try: + return cls.get_by_uid(uuid_) + except cls.__pynamodb_model__.DoesNotExist as e: + raise cls.DoesNotExist() from e + + @classmethod + def get_by_uid(cls, uuid_: uuid.UUID) -> Self: + response = cls.__pynamodb_model__.get(cls.__table_name__, uuid_) + return cls.model_validate(dict(uid=response.uid, **response.data)) + + def create(self): + item = self.__pynamodb_model__( + self.__table_name__, + str_id=self.str_id, + data=self.model_dump(mode="json", exclude={"uid", "str_id"}), + ) + condition = self.__pynamodb_model__._hash_key_attribute().does_not_exist() + rk_attr = self.__pynamodb_model__._range_key_attribute() + if rk_attr: + condition &= rk_attr.does_not_exist() + item.save(condition=condition, add_version_condition=False) + assert item.uid is not None + self.uid = item.uid + return self + + def save(self): + item = self.__pynamodb_model__( + self.__table_name__, + uid=self.uid, + str_id=self.str_id, + data=self.model_dump(mode="json", exclude={"uid", "str_id"}), + ) + item.save(add_version_condition=False) + assert item.uid is not None + self.uid = item.uid + + @classmethod + def count(cls, *args, **kwargs): + return cls.__pynamodb_model__.count(cls.__table_name__, *args, **kwargs) + + @classmethod + def query(cls, *args, **kwargs): + return cls.__pynamodb_model__.query(cls.__table_name__, *args, **kwargs) + + @classmethod + def scan(cls, *args, **kwargs): + return cls.__pynamodb_model__.scan(cls.__table_name__, *args, **kwargs) diff --git a/tests/test_basics.py b/tests/test_basics.py index 785ccd8..1214897 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -1,16 +1,32 @@ +import abc import uuid import pytest -from pynamodb_single_table.base import build_base_model +from pynamodb_single_table.base import SingleTableBaseModel -@pytest.fixture(scope="module") -def BaseTableModel(): - _BaseTableModel = build_base_model( - table_name="PSTTestRoot", - host="http://localhost:8000", - ) +class _BaseTableModel(SingleTableBaseModel, abc.ABC): + class _PynamodbMeta: + table_name = "PSTTestRoot" + host = "http://localhost:8000" + + +class User(_BaseTableModel): + __table_name__ = "user" + __str_id_field__ = "name" + name: str + group_id: uuid.UUID | None = None + + +class Group(_BaseTableModel): + __table_name__ = "group" + __str_id_field__ = "name" + name: str + + +@pytest.fixture(scope="function", autouse=True) +def recreate_pynamodb_table() -> type[SingleTableBaseModel]: _BaseTableModel.__pynamodb_model__.create_table( wait=True, billing_mode="PAY_PER_REQUEST" ) @@ -21,25 +37,7 @@ def BaseTableModel(): _BaseTableModel.__pynamodb_model__.delete_table() -@pytest.fixture(scope="module") -def user_group_models(BaseTableModel): - class User(BaseTableModel): - __table_name__ = "user" - __str_id__ = "name" - name: str - group_id: uuid.UUID | None = None - - class Group(BaseTableModel): - __table_name__ = "group" - __str_id__ = "name" - name: str - - return User, Group - - -def test_basic_interface(user_group_models): - User, Group = user_group_models - +def test_basic_interface(): group, was_created = Group.get_or_create(name="Admins") user, was_created = User.get_or_create( @@ -56,3 +54,15 @@ def test_basic_interface(user_group_models): # Check that we have exactly one user and one group assert Group.count() == 1, list(Group.scan()) assert User.count() == 1, list(User.scan()) + + +def test_duplicate_creation(): + group, _ = Group.get_or_create(name="Admins") + + user1, user1_was_created = User.get_or_create(name="Joe Shmoe", group_id=group.uid) + user2, user2_was_created = User.get_or_create(name="Joe Shmoe") + + assert user1_was_created + assert not user2_was_created + assert user1.uid == user2.uid + assert user1.group_id == user2.group_id