From fe97eaf67209d530589c4eacee245d7d5b0c6e97 Mon Sep 17 00:00:00 2001 From: Jukka Hassinen Date: Tue, 3 Dec 2024 15:49:59 +0200 Subject: [PATCH] fix: related model inheritance fixes and string format reference support --- django2pydantic/handlers/relational.py | 145 ++++++------ .../test_inheriting_and_subclassing_works.py | 221 ++++++++++++++++++ 2 files changed, 294 insertions(+), 72 deletions(-) create mode 100644 tests/test_inheriting_and_subclassing_works.py diff --git a/django2pydantic/handlers/relational.py b/django2pydantic/handlers/relational.py index 9731c12..ac24621 100644 --- a/django2pydantic/handlers/relational.py +++ b/django2pydantic/handlers/relational.py @@ -1,7 +1,9 @@ """Handlers for relational fields.""" +from abc import ABC from typing import Annotated, Any, override +from django.apps import apps from django.db import models from django.db.models.base import ModelBase from django.db.models.fields.related import RelatedField @@ -11,13 +13,77 @@ from django2pydantic.registry import FieldTypeRegistry -class RelatedFieldHandler(DjangoFieldHandler[RelatedField[Any, Any]]): +class RelatedFieldHandler[TDjangoModel: RelatedField[Any, Any]]( + DjangoFieldHandler[TDjangoModel], ABC +): """Base handler for Related fields.""" - @property @override - def examples(self) -> list[Any] | None: - """Return the example value(s) of the field.""" + def get_pydantic_type_raw(self): + """Return the Pydantic type of the field.""" + return ( + FieldTypeRegistry.instance() + .get_handler(self._get_target_field()) + .get_pydantic_type() + ) + + def _get_target_field(self) -> models.Field[Any, Any]: + """Return the target field of the relation.""" + if getattr(self.field_obj, "to_field", False) and isinstance( + self.field_obj.to_field, + ModelBase, + ): + target_field = self.field_obj.related_model._meta.get_field( + self.field_obj.to_field, + ) + elif isinstance(self.field_obj.related_model, str): + if "." in self.field_obj.related_model: + app_label = self.field_obj.related_model.split(".")[0] + model_name = self.field_obj.related_model.split(".")[1] + target_field = apps.get_model( # noqa: SLF001 + app_label=app_label, model_name=model_name + )._meta.pk + elif self.field_obj.related_model != "self": + app_label = self.field_obj.model._meta.app_label # type: ignore[unreachable] # noqa: SLF001, WPS437 + model_name = self.field_obj.related_model + try: + target_field = apps.get_model( # noqa: SLF001 + app_label=app_label, model_name=model_name + )._meta.pk + except LookupError as lookup_exception: + msg = ( + f"{lookup_exception}\n" + f"For field '{self.field_obj.name}'" + f" in model '{self.field_obj.model.__name__}'," + f" did not find related model {self.field_obj.related_model}. " + "If you are using string references to models," + " use the format 'app_label.ModelName'.\n" + "More information: https://docs.djangoproject.com/en/5.1/ref/models/fields/#absolute-relationships" + ) + raise ValueError( + msg, + ) from lookup_exception + + if self.field_obj.related_model == "self": + target_field = self.field_obj.model._meta.pk + else: + target_field = self.field_obj.related_model._meta.pk + if not target_field: + msg = f"Related model {self.field_obj.related_model} does not have a primary key field." + raise ValueError( + msg, + ) # This should never happen, but just in case, we raise an error here. + return target_field + + @property + def _examples(self) -> list[Any] | None: # PRAGMA: NO COVER + """Return the example value(s) of the field. + + Currently disabled as marked as protected method. + + There might be some use case to provide example values for the fields, + by using limit_choices_to or default values. + """ lct = self.field_obj.get_limit_choices_to() if self.field_obj.choices: return list( @@ -34,7 +100,7 @@ def examples(self) -> list[Any] | None: return None -class ForeignKeyHandler(DjangoFieldHandler[models.ForeignKey[models.Model]]): +class ForeignKeyHandler(RelatedFieldHandler[models.ForeignKey[models.Model]]): """Handler for ForeignKey fields.""" @override @@ -42,28 +108,8 @@ class ForeignKeyHandler(DjangoFieldHandler[models.ForeignKey[models.Model]]): def field(cls) -> type[models.ForeignKey[models.Model]]: return models.ForeignKey - @override - def get_pydantic_type_raw(self): - if hasattr(self.field_obj, "to_field") and isinstance( - self.field_obj.to_field, - ModelBase, - ): - target_field = self.field_obj.related_model._meta.get_field( - self.field_obj.to_field, - ) - else: - target_field = self.field_obj.related_model._meta.pk - if not target_field: - msg = f"Related model {self.field_obj.related_model} does not have a primary key field." - raise ValueError( - msg, - ) # This should never happen, but just in case, we raise an error here. - return ( - FieldTypeRegistry.instance().get_handler(target_field).get_pydantic_type() - ) - -class OneToOneFieldHandler(DjangoFieldHandler[models.OneToOneField[models.Model]]): +class OneToOneFieldHandler(RelatedFieldHandler[models.OneToOneField[models.Model]]): """Handler for OneToOne fields.""" @override @@ -71,29 +117,9 @@ class OneToOneFieldHandler(DjangoFieldHandler[models.OneToOneField[models.Model] def field(cls) -> type[models.OneToOneField[models.Model]]: return models.OneToOneField - @override - def get_pydantic_type_raw(self): - if hasattr(self.field_obj, "to_field") and isinstance( - self.field_obj.to_field, - ModelBase, - ): - target_field = self.field_obj.related_model._meta.get_field( - self.field_obj.to_field, - ) - else: - target_field = self.field_obj.related_model._meta.pk - if not target_field: - msg = f"Related model {self.field_obj.related_model} does not have a primary key field." - raise ValueError( - msg, - ) # This should never happen, but just in case, we raise an error here. - return ( - FieldTypeRegistry.instance().get_handler(target_field).get_pydantic_type() - ) - class ManyToManyFieldHandler( - DjangoFieldHandler[models.ManyToManyField[models.Model, models.Model]], + RelatedFieldHandler[models.ManyToManyField[models.Model, models.Model]], ): """Handler for ManyToMany fields.""" @@ -102,31 +128,6 @@ class ManyToManyFieldHandler( def field(cls) -> type[models.ManyToManyField[models.Model, models.Model]]: return models.ManyToManyField - def _get_target_field(self) -> models.Field[Any, Any]: - if hasattr(self.field_obj, "to_field") and isinstance( - self.field_obj.to_field, - ModelBase, - ): - target_field = self.field_obj.related_model._meta.get_field( - self.field_obj.to_field, - ) - else: - target_field = self.field_obj.related_model._meta.pk - if not target_field: - msg = f"Related model {self.field_obj.related_model} does not have a primary key field." - raise ValueError( - msg, - ) # This should never happen, but just in case, we raise an error here. - return target_field - - @override - def get_pydantic_type_raw(self): - return ( - FieldTypeRegistry.instance() - .get_handler(self._get_target_field()) - .get_pydantic_type() - ) - @override def get_pydantic_type(self) -> type[list[Annotated[Any, Any]]]: """Return the Pydantic type of the field.""" diff --git a/tests/test_inheriting_and_subclassing_works.py b/tests/test_inheriting_and_subclassing_works.py new file mode 100644 index 0000000..0f0ed1d --- /dev/null +++ b/tests/test_inheriting_and_subclassing_works.py @@ -0,0 +1,221 @@ +"""Test that inheriting and subclassing works.""" + +@pytest.mark.skip(reason="Placeholder test") +def test_models_can_have_abstract_base_classes() -> None: + """Test that models can have abstract base classes.""" + + class Base(models.Model): + id = models.AutoField(primary_key=True) + + class Meta: + abstract = True + + class ModelA(Base): + name = models.CharField(max_length=100) + + class SchemaA(Schema): + """SchemaA class.""" + + class Meta(Schema.Meta): + """Meta class.""" + + model = ModelA + fields: ClassVar[ModelFields] = { + "id": Infer, + "name": Infer, + } + + openapi_schema = SchemaA.model_json_schema() + assert openapi_schema["properties"]["name"]["type"] == "string" + + +@pytest.mark.skip(reason="Placeholder test") +def test_foreign_key_fields_can_have_string_reference_to_related_model() -> None: + """Test that foreign key fields can have a string reference to the related model.""" + + class ModelA(models.Model): + id = models.AutoField(primary_key=True) + name = models.CharField(max_length=100) + + class ModelB(models.Model): + id = models.AutoField(primary_key=True) + model_a = models.ForeignKey( + "ModelA", on_delete=models.CASCADE + ) # <-- string reference + + class SchemaA(Schema): + """SchemaA class.""" + + class Meta(Schema.Meta): + """Meta class.""" + + model = ModelA + fields: ClassVar[ModelFields] = { + "id": Infer, + "name": Infer, + } + + class SchemaB(Schema): + """SchemaB class.""" + + class Meta(Schema.Meta): + """Meta class.""" + + model = ModelB + fields: ClassVar[ModelFields] = { + "id": Infer, + "model_a": Infer, + } + + openapi_schema = SchemaB.model_json_schema() + assert ( + openapi_schema["properties"]["model_a"]["$ref"] == "#/components/schemas/ModelA" + ) + + +@pytest.mark.skip(reason="Placeholder test") +def test_many_to_many_fields_can_have_string_reference_to_related_model() -> None: + """Test that many to many fields can have a string reference to the related model.""" + + class ModelA(models.Model): + id = models.AutoField(primary_key=True) + name = models.CharField(max_length=100) + + class ModelB(models.Model): + id = models.AutoField(primary_key=True) + model_a = models.ManyToManyField("ModelA") + + class SchemaA(Schema): + """SchemaA class.""" + + class Meta(Schema.Meta): + """Meta class.""" + + model = ModelA + fields: ClassVar[ModelFields] = { + "id": Infer, + "name": Infer, + } + + class SchemaB(Schema): + """SchemaB class.""" + + class Meta(Schema.Meta): + """Meta class.""" + + model = ModelB + fields: ClassVar[ModelFields] = { + "id": Infer, + "model_a": Infer, + } + + openapi_schema = SchemaB.model_json_schema() + assert ( + openapi_schema["properties"]["model_a"]["items"]["$ref"] + == "#/components/schemas/ModelA" + ) + + +@pytest.mark.skip(reason="Placeholder test") +def test_one_to_one_fields_can_have_string_reference_to_related_model() -> None: + """Test that one to one fields can have a string reference to the related model.""" + + class ModelA(models.Model): + id = models.AutoField(primary_key=True) + name = models.CharField(max_length=100) + + class ModelB(models.Model): + id = models.AutoField(primary_key=True) + model_a = models.OneToOneField("ModelA", on_delete=models.CASCADE) + + class SchemaA(Schema): + """SchemaA class.""" + + class Meta(Schema.Meta): + """Meta class.""" + + model = ModelA + fields: ClassVar[ModelFields] = { + "id": Infer, + "name": Infer, + } + + class SchemaB(Schema): + """SchemaB class.""" + + class Meta(Schema.Meta): + """Meta class.""" + + model = ModelB + fields: ClassVar[ModelFields] = { + "id": Infer, + "model_a": Infer, + } + + openapi_schema = SchemaB.model_json_schema() + assert ( + openapi_schema["properties"]["model_a"]["$ref"] == "#/components/schemas/ModelA" + ) + + +@pytest.mark.skip(reason="Placeholder test") +def test_there_can_be_multiple_schemas_for_one_model() -> None: + """Test that there can be multiple schemas for one model.""" + + class ModelA(models.Model): + id = models.AutoField(primary_key=True) + name = models.CharField(max_length=100) + + class SchemaA(Schema): + """SchemaA class.""" + + class Meta(Schema.Meta): + """Meta class.""" + + model = ModelA + fields: ClassVar[ModelFields] = { + "id": Infer, + "name": Infer, + } + + class SchemaB(Schema): + """SchemaB class.""" + + class Meta(Schema.Meta): + """Meta class.""" + + model = ModelA + fields: ClassVar[ModelFields] = { + "id": Infer, + } + + openapi_schema_a = SchemaA.model_json_schema() + openapi_schema_b = SchemaB.model_json_schema() + assert openapi_schema_a["properties"]["name"]["type"] == "string" + assert "name" not in openapi_schema_b["properties"] + + +@pytest.mark.skip(reason="Placeholder test") +def test_there_can_be_self_referencing_fields() -> None: + """Test that there can be self-referencing fields.""" + + class ModelA(models.Model): + id = models.AutoField(primary_key=True) + parent = models.ForeignKey("self", on_delete=models.CASCADE, null=True) + + class SchemaA(Schema): + """SchemaA class.""" + + class Meta(Schema.Meta): + """Meta class.""" + + model = ModelA + fields: ClassVar[ModelFields] = { + "id": Infer, + "parent": Infer, + } + + openapi_schema = SchemaA.model_json_schema() + assert ( + openapi_schema["properties"]["parent"]["$ref"] == "#/components/schemas/ModelA" + )