Skip to content

Commit

Permalink
Merge pull request #357 from mesozoic/orm_meta
Browse files Browse the repository at this point in the history
Do not call orm.Model.Meta callables during `__init_subclass__`
  • Loading branch information
mesozoic authored Mar 22, 2024
2 parents 769d924 + c289427 commit cfd304b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 17 deletions.
36 changes: 24 additions & 12 deletions pyairtable/orm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,24 +142,36 @@ def __init__(self, **fields: Any):
setattr(self, key, value)

@classmethod
def _get_meta(cls, name: str, default: Any = None, required: bool = False) -> Any:
def _get_meta(
cls, name: str, default: Any = None, required: bool = False, call: bool = True
) -> Any:
"""
Retrieves the value of a Meta attribute.
Args:
default: The default value to return if the attribute is not set.
required: Raise an exception if the attribute is not set.
call: If the value is callable, call it before returning a result.
"""
if not hasattr(cls, "Meta"):
raise AttributeError(f"{cls.__name__}.Meta must be defined")
if required and not hasattr(cls.Meta, name):
raise ValueError(f"{cls.__name__}.Meta.{name} must be defined")
value = getattr(cls.Meta, name, default)
if callable(value):
if not hasattr(cls.Meta, name):
if required:
raise ValueError(f"{cls.__name__}.Meta.{name} must be defined")
return default
value = getattr(cls.Meta, name)
if call and callable(value):
value = value()
if required and value is None:
raise ValueError(f"{cls.__name__}.Meta.{name} cannot be None")
return value

@classmethod
def _validate_class(cls) -> None:
# Verify required Meta attributes were set
assert cls._get_meta("api_key", required=True)
assert cls._get_meta("base_id", required=True)
assert cls._get_meta("table_name", required=True)
# Verify required Meta attributes were set (but don't call any callables)
assert cls._get_meta("api_key", required=True, call=False)
assert cls._get_meta("base_id", required=True, call=False)
assert cls._get_meta("table_name", required=True, call=False)

model_attributes = [a for a in cls.__dict__.keys() if not a.startswith("__")]
overridden = set(model_attributes).intersection(Model.__dict__.keys())
Expand All @@ -174,17 +186,17 @@ def _validate_class(cls) -> None:
@lru_cache
def get_api(cls) -> Api:
return Api(
api_key=cls._get_meta("api_key"),
api_key=cls._get_meta("api_key", required=True),
timeout=cls._get_meta("timeout"),
)

@classmethod
def get_base(cls) -> Base:
return cls.get_api().base(cls._get_meta("base_id"))
return cls.get_api().base(cls._get_meta("base_id", required=True))

@classmethod
def get_table(cls) -> Table:
return cls.get_base().table(cls._get_meta("table_name"))
return cls.get_base().table(cls._get_meta("table_name", required=True))

@classmethod
def _typecast(cls) -> bool:
Expand Down
27 changes: 22 additions & 5 deletions tests/test_orm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ class Address(Model):
street = f.TextField("Street")


def test_model_empty_meta_with_callable():
"""
Test that we throw an exception when a required Meta attribute is
defined as a callable which returns None.
"""

class Address(Model):
Meta = fake_meta(api_key=lambda: None)
street = f.TextField("Street")

with mock.patch("pyairtable.Table.first", return_value=fake_record()) as m:
with pytest.raises(ValueError):
Address.first()
m.assert_not_called()


@pytest.mark.parametrize("name", ("exists", "id"))
def test_model_overlapping(name):
"""
Expand Down Expand Up @@ -197,7 +213,8 @@ def test_passthrough(methodname):
def test_dynamic_model_meta():
"""
Test that we can provide callables in our Meta class to provide
the access token, base ID, and table name at runtime.
the access token, base ID, and table name at runtime. Also ensure
that callable Meta attributes don't get called until they're needed.
"""
data = {
"api_key": "FakeApiKey",
Expand All @@ -209,12 +226,12 @@ class Fake(Model):
class Meta:
api_key = lambda: data["api_key"] # noqa
base_id = partial(data.get, "base_id")

@staticmethod
def table_name():
return data["table_name"]
table_name = mock.Mock(return_value=data["table_name"])

f = Fake()
Fake.Meta.table_name.assert_not_called()

assert f._get_meta("api_key") == data["api_key"]
assert f._get_meta("base_id") == data["base_id"]
assert f._get_meta("table_name") == data["table_name"]
Fake.Meta.table_name.assert_called_once()

0 comments on commit cfd304b

Please sign in to comment.