Skip to content

Commit

Permalink
Full code coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
scnerd committed Jun 3, 2024
1 parent 7603d99 commit 8859733
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 28 deletions.
13 changes: 0 additions & 13 deletions src/pynamodb_single_table/__main__.py

This file was deleted.

46 changes: 31 additions & 15 deletions src/pynamodb_single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Type

from pydantic import BaseModel
from pydantic import PrivateAttr
from pydantic import computed_field
from pydantic.fields import ModelPrivateAttr
from pynamodb.attributes import JSONAttribute
from pynamodb.attributes import UnicodeAttribute
from pynamodb.attributes import VersionAttribute
Expand Down Expand Up @@ -39,7 +39,7 @@ class RootModelPrototype(Model):


class SingleTableBaseModel(BaseModel):
_PynamodbMeta: MetaProtocol = PrivateAttr()
_PynamodbMeta: MetaProtocol = None
__pynamodb_model__: Type[RootModelPrototype] = None

uid: Optional[uuid.UUID] = None
Expand All @@ -50,7 +50,7 @@ def __init_subclass__(cls, **kwargs):
cls.__pynamodb_model__, RootModelPrototype
) # TODO: Just duck type?
else:
if not cls._PynamodbMeta:
if isinstance(cls._PynamodbMeta, ModelPrivateAttr):
raise TypeError(f"Must define the PynamoDB metadata for {cls}")

class RootModel(RootModelPrototype):
Expand All @@ -61,12 +61,12 @@ class RootModel(RootModelPrototype):
if isabstract(cls) or abc.ABC in cls.__bases__:
return super().__init_subclass__(**kwargs)

if not cls.__table_name__:
if not getattr(cls, "__table_name__", None):
raise TypeError(
f"Must define the table name for {cls} (the inner table, not the pynamodb table)"
)

if not cls.__str_id_field__:
if not getattr(cls, "__str_id_field__", None):
raise TypeError(f"Must define the string ID field for {cls}")

@computed_field
Expand Down Expand Up @@ -110,26 +110,32 @@ def get_by_str(cls, str_id: str) -> Self:
raise cls.MultipleObjectsFound()
uuid_ = results[0].uid

return cls.get_by_uid(uuid_)

@classmethod
def get_by_uid(cls, uuid_: uuid.UUID) -> Self:
try:
return cls.get_by_uid(uuid_)
response = cls.__pynamodb_model__.get(cls.__table_name__, uuid_)
except cls.__pynamodb_model__.DoesNotExist as e:
raise cls.DoesNotExist() from e
return cls._from_item(response)

@classmethod
def get_by_uid(cls, uuid_: uuid.UUID) -> Self:
response = cls.__pynamodb_model__.get(cls.__table_name__, uuid_)
return cls.model_validate(dict(uid=response.uid, **response.data))
def _from_item(cls, item) -> Self:
return cls(uid=item.uid, **item.data)

def create(self):
item = self.__pynamodb_model__(
self.__table_name__,
str_id=self.str_id,
data=self.model_dump(mode="json", exclude={"uid", "str_id"}),
)
condition = self.__pynamodb_model__._hash_key_attribute().does_not_exist()
rk_attr = self.__pynamodb_model__._range_key_attribute()
if rk_attr:
condition &= rk_attr.does_not_exist()
if self.uid is not None:
item.uid = self.uid
condition = (
self.__pynamodb_model__.table_name.does_not_exist()
& self.__pynamodb_model__.uid.does_not_exist()
)
item.save(condition=condition, add_version_condition=False)
assert item.uid is not None
self.uid = item.uid
Expand All @@ -152,8 +158,18 @@ def count(cls, *args, **kwargs):

@classmethod
def query(cls, *args, **kwargs):
return cls.__pynamodb_model__.query(cls.__table_name__, *args, **kwargs)
return (
cls._from_item(item)
for item in cls.__pynamodb_model__.query(
cls.__table_name__, *args, **kwargs
)
)

@classmethod
def scan(cls, *args, **kwargs):
return cls.__pynamodb_model__.scan(cls.__table_name__, *args, **kwargs)
return (
cls._from_item(item)
for item in cls.__pynamodb_model__.scan(
cls.__pynamodb_model__.table_name == cls.__table_name__, *args, **kwargs
)
)
86 changes: 86 additions & 0 deletions tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,89 @@ def test_duplicate_creation():
assert not user2_was_created
assert user1.uid == user2.uid
assert user1.group_id == user2.group_id


def test_error_no_metadata():
with pytest.raises(TypeError):

class BadBaseTableModel(SingleTableBaseModel):
pass

with pytest.raises(TypeError):

class BadSingleTableNoTableName(_BaseTableModel):
# Missing __table_name__
pass

with pytest.raises(TypeError):

class BadSingleTableEmptyTableName(_BaseTableModel):
# Empty __table_name__
__table_name__ = ""

with pytest.raises(TypeError):

class BadSingleTableNoStrIdField(_BaseTableModel):
__table_name__ = "table"
# Missing __str_id_field__

with pytest.raises(TypeError):

class BadSingleTableEmptyStrIdField(_BaseTableModel):
# Empty __str_id_field__
__table_name__ = "table"
__str_id_field__ = ""


def test_preexisting_uid():
uid = uuid.uuid4()
user, was_created = User.get_or_create(name="Joe Shmoe", uid=uid)
assert user.uid == uid
assert was_created

user, was_created = User.get_or_create(name="John Doe", uid=uid)
assert user.uid == uid
assert not was_created
assert user.name == "Joe Shmoe"


def test_preexisting_str_id():
user, was_created = User.get_or_create(name="Joe Shmoe")
assert was_created

user, was_created = User.get_or_create(name="Joe Shmoe")
assert not was_created
assert user.name == "Joe Shmoe"


def test_duplicate_str_id():
user1 = User(name="Username", uid=uuid.uuid4())
user2 = User(name="Username", uid=uuid.uuid4())
user1.save()
user2.save()
with pytest.raises(User.MultipleObjectsFound):
User.get_by_str("Username")


def test_query():
user1, _ = User.get_or_create(name="John Doe")
user2, _ = User.get_or_create(name="Joe Schmoe")

group1, _ = Group.get_or_create(name="Admins")

users = list(User.query())
assert len(users) == 2
groups = list(Group.query())
assert len(groups) == 1


def test_scan():
user1, _ = User.get_or_create(name="John Doe")
user2, _ = User.get_or_create(name="Joe Schmoe")

group1, _ = Group.get_or_create(name="Admins")

users = list(User.scan())
assert len(users) == 2
groups = list(Group.scan())
assert len(groups) == 1

0 comments on commit 8859733

Please sign in to comment.