Skip to content

Commit

Permalink
Update to extract class definitions from method, making intellisense …
Browse files Browse the repository at this point in the history
…actually work
  • Loading branch information
scnerd committed May 22, 2024
1 parent 4ff9776 commit edfcd8c
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 159 deletions.
252 changes: 119 additions & 133 deletions src/pynamodb_single_table/base.py
Original file line number Diff line number Diff line change
@@ -1,170 +1,156 @@
import abc
import itertools
import uuid
from inspect import isabstract

from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import PrivateAttr
from pydantic import computed_field
from pynamodb.attributes import JSONAttribute
from pynamodb.attributes import UnicodeAttribute
from pynamodb.attributes import VersionAttribute
from pynamodb.indexes import KeysOnlyProjection
from pynamodb.indexes import LocalSecondaryIndex
from pynamodb.models import MetaProtocol
from pynamodb.models import Model
from pynamodb_attributes import UUIDAttribute
from typing_extensions import Self


def build_base_model(table_name, host=None) -> type[BaseModel]:
"""Builds a DynamoDB table model for storing a single-table database design.
class StrIndex(LocalSecondaryIndex):
class Meta:
projection = KeysOnlyProjection()

This table uses:
- The concrete table name as a hash key
- A UUID as the row-level primary key
- A string representation for each row that can be used as a secondary primary key
table_name = UnicodeAttribute(hash_key=True)
str_id = UnicodeAttribute(range_key=True)
uid = UUIDAttribute(null=False)

Parameters:
- table_name (str): The name of the DynamoDB table.
- host (str, optional): The host URL of the DynamoDB instance. If not specified,
the default AWS endpoint is used.

Returns:
- type[BaseModel]: A Pydantic base model for creating concrete tables
"""
class RootModelPrototype(Model):
table_name = UnicodeAttribute(hash_key=True)
uid = UUIDAttribute(range_key=True, default_for_new=uuid.uuid4)
str_id = UnicodeAttribute(null=False)
index_by_str = StrIndex()
version = VersionAttribute(null=False, default=1)
data = JSONAttribute(null=False)

class StrIndex(LocalSecondaryIndex):
class Meta:
projection = KeysOnlyProjection()

table_name = UnicodeAttribute(hash_key=True)
str_id = UnicodeAttribute(range_key=True)
uid = UUIDAttribute(null=False)
class SingleTableBaseModel(BaseModel):
_PynamodbMeta: MetaProtocol = PrivateAttr()
__pynamodb_model__: type[RootModelPrototype] = None

class RootMeta:
pass
uid: uuid.UUID | None = None

RootMeta.table_name = table_name
RootMeta.host = host
def __init_subclass__(cls, **kwargs):
if cls.__pynamodb_model__:
assert issubclass(
cls.__pynamodb_model__, RootModelPrototype
) # TODO: Just duck type?
else:
if not cls._PynamodbMeta:
raise TypeError(f"Must define the PynamoDB metadata for {cls}")

class RootModel(Model):
Meta = RootMeta
class RootModel(RootModelPrototype):
Meta = cls._PynamodbMeta

table_name = UnicodeAttribute(hash_key=True)
uid = UUIDAttribute(range_key=True, default_for_new=uuid.uuid4)
str_id = UnicodeAttribute(null=False)
index_by_str = StrIndex()
version = VersionAttribute(null=False, default=1)
data = JSONAttribute(null=False)
cls.__pynamodb_model__ = RootModel

class BaseTableModel(BaseModel, abc.ABC):
__pynamodb_model__: type[RootModel] = RootModel
__table_name__ = None
__str_id__ = None
model_config = ConfigDict(from_attributes=True)
if isabstract(cls) or abc.ABC in cls.__bases__:
return super().__init_subclass__(**kwargs)

uid: uuid.UUID | None = None
if not cls.__table_name__:
raise TypeError(
f"Must define the table name for {cls} (the inner table, not the pynamodb table)"
)

@computed_field
@property
def str_id(self) -> str:
return getattr(self, self.__str_id__)
if not cls.__str_id_field__:
raise TypeError(f"Must define the string ID field for {cls}")

class DoesNotExist(Exception):
pass
@computed_field
@property
def str_id(self) -> str:
return getattr(self, self.__str_id_field__)

class MultipleObjectsFound(Exception):
pass
class DoesNotExist(Exception):
pass

@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.__table_name__ is None:
raise ValueError(
f"Must provide table name ({cls.__name__}.__table_name__) "
f"when subclassing BaseTableModel"
)
if cls.__str_id__ is None:
raise ValueError(
f"Must provide string ID ({cls.__name__}.__str_id__) "
f"when subclassing BaseTableModel"
)

@classmethod
def get_or_create(cls, **kwargs) -> tuple[Self, bool]:
obj = cls.model_validate(kwargs)

if obj.uid is not None:
try:
return cls.get_by_uid(obj.uid), False
except cls.DoesNotExist:
pass
class MultipleObjectsFound(Exception):
pass

@classmethod
def get_or_create(cls, **kwargs) -> tuple[Self, bool]:
obj = cls.model_validate(kwargs)

if obj.uid is not None:
try:
return cls.get_by_str(obj.str_id), False
return cls.get_by_uid(obj.uid), False
except cls.DoesNotExist:
pass

obj.create()
return obj, True

@classmethod
def get_by_str(cls, str_id: str) -> Self:
results = cls.__pynamodb_model__.index_by_str.query(
cls.__table_name__, cls.__pynamodb_model__.str_id == str_id
)
results = list(itertools.islice(results, 2))
if len(results) == 0:
raise cls.DoesNotExist()
if len(results) > 1:
raise cls.MultipleObjectsFound()
uuid_ = results[0].uid

try:
return cls.get_by_uid(uuid_)
except cls.__pynamodb_model__.DoesNotExist as e:
raise cls.DoesNotExist() from e

@classmethod
def get_by_uid(cls, uuid_: uuid.UUID) -> Self:
response = cls.__pynamodb_model__.get(cls.__table_name__, uuid_)
return cls.model_validate(dict(uid=response.uid, **response.data))

def create(self):
item = self.__pynamodb_model__(
self.__table_name__,
str_id=self.str_id,
data=self.model_dump(mode="json", exclude={"uid", "str_id"}),
)
condition = self.__pynamodb_model__._hash_key_attribute().does_not_exist()
rk_attr = self.__pynamodb_model__._range_key_attribute()
if rk_attr:
condition &= rk_attr.does_not_exist()
item.save(condition=condition, add_version_condition=False)
assert item.uid is not None
self.uid = item.uid
return self

def save(self):
item = self.__pynamodb_model__(
self.__table_name__,
uid=self.uid,
str_id=self.str_id,
data=self.model_dump(mode="json", exclude={"uid", "str_id"}),
)
item.save(add_version_condition=False)
assert item.uid is not None
self.uid = item.uid

@classmethod
def count(cls, *args, **kwargs):
return cls.__pynamodb_model__.count(cls.__table_name__, *args, **kwargs)

@classmethod
def query(cls, *args, **kwargs):
return cls.__pynamodb_model__.query(cls.__table_name__, *args, **kwargs)

@classmethod
def scan(cls, *args, **kwargs):
return cls.__pynamodb_model__.scan(cls.__table_name__, *args, **kwargs)
try:
return cls.get_by_str(obj.str_id), False
except cls.DoesNotExist:
pass

return BaseTableModel
obj.create()
return obj, True

@classmethod
def get_by_str(cls, str_id: str) -> Self:
results = cls.__pynamodb_model__.index_by_str.query(
cls.__table_name__, cls.__pynamodb_model__.str_id == str_id
)
results = list(itertools.islice(results, 2))
if len(results) == 0:
raise cls.DoesNotExist()
if len(results) > 1:
raise cls.MultipleObjectsFound()
uuid_ = results[0].uid

try:
return cls.get_by_uid(uuid_)
except cls.__pynamodb_model__.DoesNotExist as e:
raise cls.DoesNotExist() from e

@classmethod
def get_by_uid(cls, uuid_: uuid.UUID) -> Self:
response = cls.__pynamodb_model__.get(cls.__table_name__, uuid_)
return cls.model_validate(dict(uid=response.uid, **response.data))

def create(self):
item = self.__pynamodb_model__(
self.__table_name__,
str_id=self.str_id,
data=self.model_dump(mode="json", exclude={"uid", "str_id"}),
)
condition = self.__pynamodb_model__._hash_key_attribute().does_not_exist()
rk_attr = self.__pynamodb_model__._range_key_attribute()
if rk_attr:
condition &= rk_attr.does_not_exist()
item.save(condition=condition, add_version_condition=False)
assert item.uid is not None
self.uid = item.uid
return self

def save(self):
item = self.__pynamodb_model__(
self.__table_name__,
uid=self.uid,
str_id=self.str_id,
data=self.model_dump(mode="json", exclude={"uid", "str_id"}),
)
item.save(add_version_condition=False)
assert item.uid is not None
self.uid = item.uid

@classmethod
def count(cls, *args, **kwargs):
return cls.__pynamodb_model__.count(cls.__table_name__, *args, **kwargs)

@classmethod
def query(cls, *args, **kwargs):
return cls.__pynamodb_model__.query(cls.__table_name__, *args, **kwargs)

@classmethod
def scan(cls, *args, **kwargs):
return cls.__pynamodb_model__.scan(cls.__table_name__, *args, **kwargs)
62 changes: 36 additions & 26 deletions tests/test_basics.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
import abc
import uuid

import pytest

from pynamodb_single_table.base import build_base_model
from pynamodb_single_table.base import SingleTableBaseModel


@pytest.fixture(scope="module")
def BaseTableModel():
_BaseTableModel = build_base_model(
table_name="PSTTestRoot",
host="http://localhost:8000",
)
class _BaseTableModel(SingleTableBaseModel, abc.ABC):
class _PynamodbMeta:
table_name = "PSTTestRoot"
host = "http://localhost:8000"


class User(_BaseTableModel):
__table_name__ = "user"
__str_id_field__ = "name"
name: str
group_id: uuid.UUID | None = None


class Group(_BaseTableModel):
__table_name__ = "group"
__str_id_field__ = "name"
name: str


@pytest.fixture(scope="function", autouse=True)
def recreate_pynamodb_table() -> type[SingleTableBaseModel]:
_BaseTableModel.__pynamodb_model__.create_table(
wait=True, billing_mode="PAY_PER_REQUEST"
)
Expand All @@ -21,25 +37,7 @@ def BaseTableModel():
_BaseTableModel.__pynamodb_model__.delete_table()


@pytest.fixture(scope="module")
def user_group_models(BaseTableModel):
class User(BaseTableModel):
__table_name__ = "user"
__str_id__ = "name"
name: str
group_id: uuid.UUID | None = None

class Group(BaseTableModel):
__table_name__ = "group"
__str_id__ = "name"
name: str

return User, Group


def test_basic_interface(user_group_models):
User, Group = user_group_models

def test_basic_interface():
group, was_created = Group.get_or_create(name="Admins")

user, was_created = User.get_or_create(
Expand All @@ -56,3 +54,15 @@ def test_basic_interface(user_group_models):
# Check that we have exactly one user and one group
assert Group.count() == 1, list(Group.scan())
assert User.count() == 1, list(User.scan())


def test_duplicate_creation():
group, _ = Group.get_or_create(name="Admins")

user1, user1_was_created = User.get_or_create(name="Joe Shmoe", group_id=group.uid)
user2, user2_was_created = User.get_or_create(name="Joe Shmoe")

assert user1_was_created
assert not user2_was_created
assert user1.uid == user2.uid
assert user1.group_id == user2.group_id

0 comments on commit edfcd8c

Please sign in to comment.