From 88597338122ad8aeba83727189ddf286040ed875 Mon Sep 17 00:00:00 2001 From: David Maxson Date: Mon, 3 Jun 2024 11:17:18 -0700 Subject: [PATCH] Full code coverage --- src/pynamodb_single_table/__main__.py | 13 ---- src/pynamodb_single_table/base.py | 46 +++++++++----- tests/test_basics.py | 86 +++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 28 deletions(-) delete mode 100644 src/pynamodb_single_table/__main__.py diff --git a/src/pynamodb_single_table/__main__.py b/src/pynamodb_single_table/__main__.py deleted file mode 100644 index 2017c27..0000000 --- a/src/pynamodb_single_table/__main__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Command-line interface.""" - -import click - - -@click.command() -@click.version_option() -def main() -> None: - """PynamoDB Single Table.""" - - -if __name__ == "__main__": - main(prog_name="pynamodb_single_table") # pragma: no cover diff --git a/src/pynamodb_single_table/base.py b/src/pynamodb_single_table/base.py index 141cd8a..b14c6f8 100644 --- a/src/pynamodb_single_table/base.py +++ b/src/pynamodb_single_table/base.py @@ -7,8 +7,8 @@ from typing import Type from pydantic import BaseModel -from pydantic import PrivateAttr from pydantic import computed_field +from pydantic.fields import ModelPrivateAttr from pynamodb.attributes import JSONAttribute from pynamodb.attributes import UnicodeAttribute from pynamodb.attributes import VersionAttribute @@ -39,7 +39,7 @@ class RootModelPrototype(Model): class SingleTableBaseModel(BaseModel): - _PynamodbMeta: MetaProtocol = PrivateAttr() + _PynamodbMeta: MetaProtocol = None __pynamodb_model__: Type[RootModelPrototype] = None uid: Optional[uuid.UUID] = None @@ -50,7 +50,7 @@ def __init_subclass__(cls, **kwargs): cls.__pynamodb_model__, RootModelPrototype ) # TODO: Just duck type? else: - if not cls._PynamodbMeta: + if isinstance(cls._PynamodbMeta, ModelPrivateAttr): raise TypeError(f"Must define the PynamoDB metadata for {cls}") class RootModel(RootModelPrototype): @@ -61,12 +61,12 @@ class RootModel(RootModelPrototype): if isabstract(cls) or abc.ABC in cls.__bases__: return super().__init_subclass__(**kwargs) - if not cls.__table_name__: + if not getattr(cls, "__table_name__", None): raise TypeError( f"Must define the table name for {cls} (the inner table, not the pynamodb table)" ) - if not cls.__str_id_field__: + if not getattr(cls, "__str_id_field__", None): raise TypeError(f"Must define the string ID field for {cls}") @computed_field @@ -110,15 +110,19 @@ def get_by_str(cls, str_id: str) -> Self: raise cls.MultipleObjectsFound() uuid_ = results[0].uid + return cls.get_by_uid(uuid_) + + @classmethod + def get_by_uid(cls, uuid_: uuid.UUID) -> Self: try: - return cls.get_by_uid(uuid_) + response = cls.__pynamodb_model__.get(cls.__table_name__, uuid_) except cls.__pynamodb_model__.DoesNotExist as e: raise cls.DoesNotExist() from e + return cls._from_item(response) @classmethod - def get_by_uid(cls, uuid_: uuid.UUID) -> Self: - response = cls.__pynamodb_model__.get(cls.__table_name__, uuid_) - return cls.model_validate(dict(uid=response.uid, **response.data)) + def _from_item(cls, item) -> Self: + return cls(uid=item.uid, **item.data) def create(self): item = self.__pynamodb_model__( @@ -126,10 +130,12 @@ def create(self): str_id=self.str_id, data=self.model_dump(mode="json", exclude={"uid", "str_id"}), ) - condition = self.__pynamodb_model__._hash_key_attribute().does_not_exist() - rk_attr = self.__pynamodb_model__._range_key_attribute() - if rk_attr: - condition &= rk_attr.does_not_exist() + if self.uid is not None: + item.uid = self.uid + condition = ( + self.__pynamodb_model__.table_name.does_not_exist() + & self.__pynamodb_model__.uid.does_not_exist() + ) item.save(condition=condition, add_version_condition=False) assert item.uid is not None self.uid = item.uid @@ -152,8 +158,18 @@ def count(cls, *args, **kwargs): @classmethod def query(cls, *args, **kwargs): - return cls.__pynamodb_model__.query(cls.__table_name__, *args, **kwargs) + return ( + cls._from_item(item) + for item in cls.__pynamodb_model__.query( + cls.__table_name__, *args, **kwargs + ) + ) @classmethod def scan(cls, *args, **kwargs): - return cls.__pynamodb_model__.scan(cls.__table_name__, *args, **kwargs) + return ( + cls._from_item(item) + for item in cls.__pynamodb_model__.scan( + cls.__pynamodb_model__.table_name == cls.__table_name__, *args, **kwargs + ) + ) diff --git a/tests/test_basics.py b/tests/test_basics.py index 3d1c170..26ebd64 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -73,3 +73,89 @@ def test_duplicate_creation(): assert not user2_was_created assert user1.uid == user2.uid assert user1.group_id == user2.group_id + + +def test_error_no_metadata(): + with pytest.raises(TypeError): + + class BadBaseTableModel(SingleTableBaseModel): + pass + + with pytest.raises(TypeError): + + class BadSingleTableNoTableName(_BaseTableModel): + # Missing __table_name__ + pass + + with pytest.raises(TypeError): + + class BadSingleTableEmptyTableName(_BaseTableModel): + # Empty __table_name__ + __table_name__ = "" + + with pytest.raises(TypeError): + + class BadSingleTableNoStrIdField(_BaseTableModel): + __table_name__ = "table" + # Missing __str_id_field__ + + with pytest.raises(TypeError): + + class BadSingleTableEmptyStrIdField(_BaseTableModel): + # Empty __str_id_field__ + __table_name__ = "table" + __str_id_field__ = "" + + +def test_preexisting_uid(): + uid = uuid.uuid4() + user, was_created = User.get_or_create(name="Joe Shmoe", uid=uid) + assert user.uid == uid + assert was_created + + user, was_created = User.get_or_create(name="John Doe", uid=uid) + assert user.uid == uid + assert not was_created + assert user.name == "Joe Shmoe" + + +def test_preexisting_str_id(): + user, was_created = User.get_or_create(name="Joe Shmoe") + assert was_created + + user, was_created = User.get_or_create(name="Joe Shmoe") + assert not was_created + assert user.name == "Joe Shmoe" + + +def test_duplicate_str_id(): + user1 = User(name="Username", uid=uuid.uuid4()) + user2 = User(name="Username", uid=uuid.uuid4()) + user1.save() + user2.save() + with pytest.raises(User.MultipleObjectsFound): + User.get_by_str("Username") + + +def test_query(): + user1, _ = User.get_or_create(name="John Doe") + user2, _ = User.get_or_create(name="Joe Schmoe") + + group1, _ = Group.get_or_create(name="Admins") + + users = list(User.query()) + assert len(users) == 2 + groups = list(Group.query()) + assert len(groups) == 1 + + +def test_scan(): + user1, _ = User.get_or_create(name="John Doe") + user2, _ = User.get_or_create(name="Joe Schmoe") + + group1, _ = Group.get_or_create(name="Admins") + + users = list(User.scan()) + assert len(users) == 2 + groups = list(Group.scan()) + assert len(groups) == 1