-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update to extract class definitions from method, making intellisense …
…actually work
- Loading branch information
Showing
2 changed files
with
155 additions
and
159 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,170 +1,156 @@ | ||
import abc | ||
import itertools | ||
import uuid | ||
from inspect import isabstract | ||
|
||
from pydantic import BaseModel | ||
from pydantic import ConfigDict | ||
from pydantic import PrivateAttr | ||
from pydantic import computed_field | ||
from pynamodb.attributes import JSONAttribute | ||
from pynamodb.attributes import UnicodeAttribute | ||
from pynamodb.attributes import VersionAttribute | ||
from pynamodb.indexes import KeysOnlyProjection | ||
from pynamodb.indexes import LocalSecondaryIndex | ||
from pynamodb.models import MetaProtocol | ||
from pynamodb.models import Model | ||
from pynamodb_attributes import UUIDAttribute | ||
from typing_extensions import Self | ||
|
||
|
||
def build_base_model(table_name, host=None) -> type[BaseModel]: | ||
"""Builds a DynamoDB table model for storing a single-table database design. | ||
class StrIndex(LocalSecondaryIndex): | ||
class Meta: | ||
projection = KeysOnlyProjection() | ||
|
||
This table uses: | ||
- The concrete table name as a hash key | ||
- A UUID as the row-level primary key | ||
- A string representation for each row that can be used as a secondary primary key | ||
table_name = UnicodeAttribute(hash_key=True) | ||
str_id = UnicodeAttribute(range_key=True) | ||
uid = UUIDAttribute(null=False) | ||
|
||
Parameters: | ||
- table_name (str): The name of the DynamoDB table. | ||
- host (str, optional): The host URL of the DynamoDB instance. If not specified, | ||
the default AWS endpoint is used. | ||
|
||
Returns: | ||
- type[BaseModel]: A Pydantic base model for creating concrete tables | ||
""" | ||
class RootModelPrototype(Model): | ||
table_name = UnicodeAttribute(hash_key=True) | ||
uid = UUIDAttribute(range_key=True, default_for_new=uuid.uuid4) | ||
str_id = UnicodeAttribute(null=False) | ||
index_by_str = StrIndex() | ||
version = VersionAttribute(null=False, default=1) | ||
data = JSONAttribute(null=False) | ||
|
||
class StrIndex(LocalSecondaryIndex): | ||
class Meta: | ||
projection = KeysOnlyProjection() | ||
|
||
table_name = UnicodeAttribute(hash_key=True) | ||
str_id = UnicodeAttribute(range_key=True) | ||
uid = UUIDAttribute(null=False) | ||
class SingleTableBaseModel(BaseModel): | ||
_PynamodbMeta: MetaProtocol = PrivateAttr() | ||
__pynamodb_model__: type[RootModelPrototype] = None | ||
|
||
class RootMeta: | ||
pass | ||
uid: uuid.UUID | None = None | ||
|
||
RootMeta.table_name = table_name | ||
RootMeta.host = host | ||
def __init_subclass__(cls, **kwargs): | ||
if cls.__pynamodb_model__: | ||
assert issubclass( | ||
cls.__pynamodb_model__, RootModelPrototype | ||
) # TODO: Just duck type? | ||
else: | ||
if not cls._PynamodbMeta: | ||
raise TypeError(f"Must define the PynamoDB metadata for {cls}") | ||
|
||
class RootModel(Model): | ||
Meta = RootMeta | ||
class RootModel(RootModelPrototype): | ||
Meta = cls._PynamodbMeta | ||
|
||
table_name = UnicodeAttribute(hash_key=True) | ||
uid = UUIDAttribute(range_key=True, default_for_new=uuid.uuid4) | ||
str_id = UnicodeAttribute(null=False) | ||
index_by_str = StrIndex() | ||
version = VersionAttribute(null=False, default=1) | ||
data = JSONAttribute(null=False) | ||
cls.__pynamodb_model__ = RootModel | ||
|
||
class BaseTableModel(BaseModel, abc.ABC): | ||
__pynamodb_model__: type[RootModel] = RootModel | ||
__table_name__ = None | ||
__str_id__ = None | ||
model_config = ConfigDict(from_attributes=True) | ||
if isabstract(cls) or abc.ABC in cls.__bases__: | ||
return super().__init_subclass__(**kwargs) | ||
|
||
uid: uuid.UUID | None = None | ||
if not cls.__table_name__: | ||
raise TypeError( | ||
f"Must define the table name for {cls} (the inner table, not the pynamodb table)" | ||
) | ||
|
||
@computed_field | ||
@property | ||
def str_id(self) -> str: | ||
return getattr(self, self.__str_id__) | ||
if not cls.__str_id_field__: | ||
raise TypeError(f"Must define the string ID field for {cls}") | ||
|
||
class DoesNotExist(Exception): | ||
pass | ||
@computed_field | ||
@property | ||
def str_id(self) -> str: | ||
return getattr(self, self.__str_id_field__) | ||
|
||
class MultipleObjectsFound(Exception): | ||
pass | ||
class DoesNotExist(Exception): | ||
pass | ||
|
||
@classmethod | ||
def __init_subclass__(cls, **kwargs): | ||
super().__init_subclass__(**kwargs) | ||
if cls.__table_name__ is None: | ||
raise ValueError( | ||
f"Must provide table name ({cls.__name__}.__table_name__) " | ||
f"when subclassing BaseTableModel" | ||
) | ||
if cls.__str_id__ is None: | ||
raise ValueError( | ||
f"Must provide string ID ({cls.__name__}.__str_id__) " | ||
f"when subclassing BaseTableModel" | ||
) | ||
|
||
@classmethod | ||
def get_or_create(cls, **kwargs) -> tuple[Self, bool]: | ||
obj = cls.model_validate(kwargs) | ||
|
||
if obj.uid is not None: | ||
try: | ||
return cls.get_by_uid(obj.uid), False | ||
except cls.DoesNotExist: | ||
pass | ||
class MultipleObjectsFound(Exception): | ||
pass | ||
|
||
@classmethod | ||
def get_or_create(cls, **kwargs) -> tuple[Self, bool]: | ||
obj = cls.model_validate(kwargs) | ||
|
||
if obj.uid is not None: | ||
try: | ||
return cls.get_by_str(obj.str_id), False | ||
return cls.get_by_uid(obj.uid), False | ||
except cls.DoesNotExist: | ||
pass | ||
|
||
obj.create() | ||
return obj, True | ||
|
||
@classmethod | ||
def get_by_str(cls, str_id: str) -> Self: | ||
results = cls.__pynamodb_model__.index_by_str.query( | ||
cls.__table_name__, cls.__pynamodb_model__.str_id == str_id | ||
) | ||
results = list(itertools.islice(results, 2)) | ||
if len(results) == 0: | ||
raise cls.DoesNotExist() | ||
if len(results) > 1: | ||
raise cls.MultipleObjectsFound() | ||
uuid_ = results[0].uid | ||
|
||
try: | ||
return cls.get_by_uid(uuid_) | ||
except cls.__pynamodb_model__.DoesNotExist as e: | ||
raise cls.DoesNotExist() from e | ||
|
||
@classmethod | ||
def get_by_uid(cls, uuid_: uuid.UUID) -> Self: | ||
response = cls.__pynamodb_model__.get(cls.__table_name__, uuid_) | ||
return cls.model_validate(dict(uid=response.uid, **response.data)) | ||
|
||
def create(self): | ||
item = self.__pynamodb_model__( | ||
self.__table_name__, | ||
str_id=self.str_id, | ||
data=self.model_dump(mode="json", exclude={"uid", "str_id"}), | ||
) | ||
condition = self.__pynamodb_model__._hash_key_attribute().does_not_exist() | ||
rk_attr = self.__pynamodb_model__._range_key_attribute() | ||
if rk_attr: | ||
condition &= rk_attr.does_not_exist() | ||
item.save(condition=condition, add_version_condition=False) | ||
assert item.uid is not None | ||
self.uid = item.uid | ||
return self | ||
|
||
def save(self): | ||
item = self.__pynamodb_model__( | ||
self.__table_name__, | ||
uid=self.uid, | ||
str_id=self.str_id, | ||
data=self.model_dump(mode="json", exclude={"uid", "str_id"}), | ||
) | ||
item.save(add_version_condition=False) | ||
assert item.uid is not None | ||
self.uid = item.uid | ||
|
||
@classmethod | ||
def count(cls, *args, **kwargs): | ||
return cls.__pynamodb_model__.count(cls.__table_name__, *args, **kwargs) | ||
|
||
@classmethod | ||
def query(cls, *args, **kwargs): | ||
return cls.__pynamodb_model__.query(cls.__table_name__, *args, **kwargs) | ||
|
||
@classmethod | ||
def scan(cls, *args, **kwargs): | ||
return cls.__pynamodb_model__.scan(cls.__table_name__, *args, **kwargs) | ||
try: | ||
return cls.get_by_str(obj.str_id), False | ||
except cls.DoesNotExist: | ||
pass | ||
|
||
return BaseTableModel | ||
obj.create() | ||
return obj, True | ||
|
||
@classmethod | ||
def get_by_str(cls, str_id: str) -> Self: | ||
results = cls.__pynamodb_model__.index_by_str.query( | ||
cls.__table_name__, cls.__pynamodb_model__.str_id == str_id | ||
) | ||
results = list(itertools.islice(results, 2)) | ||
if len(results) == 0: | ||
raise cls.DoesNotExist() | ||
if len(results) > 1: | ||
raise cls.MultipleObjectsFound() | ||
uuid_ = results[0].uid | ||
|
||
try: | ||
return cls.get_by_uid(uuid_) | ||
except cls.__pynamodb_model__.DoesNotExist as e: | ||
raise cls.DoesNotExist() from e | ||
|
||
@classmethod | ||
def get_by_uid(cls, uuid_: uuid.UUID) -> Self: | ||
response = cls.__pynamodb_model__.get(cls.__table_name__, uuid_) | ||
return cls.model_validate(dict(uid=response.uid, **response.data)) | ||
|
||
def create(self): | ||
item = self.__pynamodb_model__( | ||
self.__table_name__, | ||
str_id=self.str_id, | ||
data=self.model_dump(mode="json", exclude={"uid", "str_id"}), | ||
) | ||
condition = self.__pynamodb_model__._hash_key_attribute().does_not_exist() | ||
rk_attr = self.__pynamodb_model__._range_key_attribute() | ||
if rk_attr: | ||
condition &= rk_attr.does_not_exist() | ||
item.save(condition=condition, add_version_condition=False) | ||
assert item.uid is not None | ||
self.uid = item.uid | ||
return self | ||
|
||
def save(self): | ||
item = self.__pynamodb_model__( | ||
self.__table_name__, | ||
uid=self.uid, | ||
str_id=self.str_id, | ||
data=self.model_dump(mode="json", exclude={"uid", "str_id"}), | ||
) | ||
item.save(add_version_condition=False) | ||
assert item.uid is not None | ||
self.uid = item.uid | ||
|
||
@classmethod | ||
def count(cls, *args, **kwargs): | ||
return cls.__pynamodb_model__.count(cls.__table_name__, *args, **kwargs) | ||
|
||
@classmethod | ||
def query(cls, *args, **kwargs): | ||
return cls.__pynamodb_model__.query(cls.__table_name__, *args, **kwargs) | ||
|
||
@classmethod | ||
def scan(cls, *args, **kwargs): | ||
return cls.__pynamodb_model__.scan(cls.__table_name__, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters