Skip to content

Commit

Permalink
fix: related model inheritance fixes and string format reference support
Browse files Browse the repository at this point in the history
  • Loading branch information
jhassine committed Dec 3, 2024
1 parent db3e1ba commit fe97eaf
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 72 deletions.
145 changes: 73 additions & 72 deletions django2pydantic/handlers/relational.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Handlers for relational fields."""

from abc import ABC
from typing import Annotated, Any, override

from django.apps import apps
from django.db import models
from django.db.models.base import ModelBase
from django.db.models.fields.related import RelatedField
Expand All @@ -11,13 +13,77 @@
from django2pydantic.registry import FieldTypeRegistry


class RelatedFieldHandler(DjangoFieldHandler[RelatedField[Any, Any]]):
class RelatedFieldHandler[TDjangoModel: RelatedField[Any, Any]](
DjangoFieldHandler[TDjangoModel], ABC
):
"""Base handler for Related fields."""

@property
@override
def examples(self) -> list[Any] | None:
"""Return the example value(s) of the field."""
def get_pydantic_type_raw(self):
"""Return the Pydantic type of the field."""
return (
FieldTypeRegistry.instance()
.get_handler(self._get_target_field())
.get_pydantic_type()
)

def _get_target_field(self) -> models.Field[Any, Any]:
"""Return the target field of the relation."""
if getattr(self.field_obj, "to_field", False) and isinstance(
self.field_obj.to_field,
ModelBase,
):
target_field = self.field_obj.related_model._meta.get_field(
self.field_obj.to_field,
)
elif isinstance(self.field_obj.related_model, str):
if "." in self.field_obj.related_model:
app_label = self.field_obj.related_model.split(".")[0]
model_name = self.field_obj.related_model.split(".")[1]
target_field = apps.get_model( # noqa: SLF001
app_label=app_label, model_name=model_name
)._meta.pk
elif self.field_obj.related_model != "self":
app_label = self.field_obj.model._meta.app_label # type: ignore[unreachable] # noqa: SLF001, WPS437
model_name = self.field_obj.related_model
try:
target_field = apps.get_model( # noqa: SLF001
app_label=app_label, model_name=model_name
)._meta.pk
except LookupError as lookup_exception:
msg = (
f"{lookup_exception}\n"
f"For field '{self.field_obj.name}'"
f" in model '{self.field_obj.model.__name__}',"
f" did not find related model {self.field_obj.related_model}. "
"If you are using string references to models,"
" use the format 'app_label.ModelName'.\n"
"More information: https://docs.djangoproject.com/en/5.1/ref/models/fields/#absolute-relationships"
)
raise ValueError(
msg,
) from lookup_exception

if self.field_obj.related_model == "self":
target_field = self.field_obj.model._meta.pk
else:
target_field = self.field_obj.related_model._meta.pk
if not target_field:
msg = f"Related model {self.field_obj.related_model} does not have a primary key field."
raise ValueError(
msg,
) # This should never happen, but just in case, we raise an error here.
return target_field

@property
def _examples(self) -> list[Any] | None: # PRAGMA: NO COVER
"""Return the example value(s) of the field.
Currently disabled as marked as protected method.
There might be some use case to provide example values for the fields,
by using limit_choices_to or default values.
"""
lct = self.field_obj.get_limit_choices_to()
if self.field_obj.choices:
return list(
Expand All @@ -34,66 +100,26 @@ def examples(self) -> list[Any] | None:
return None


class ForeignKeyHandler(DjangoFieldHandler[models.ForeignKey[models.Model]]):
class ForeignKeyHandler(RelatedFieldHandler[models.ForeignKey[models.Model]]):
"""Handler for ForeignKey fields."""

@override
@classmethod
def field(cls) -> type[models.ForeignKey[models.Model]]:
return models.ForeignKey

@override
def get_pydantic_type_raw(self):
if hasattr(self.field_obj, "to_field") and isinstance(
self.field_obj.to_field,
ModelBase,
):
target_field = self.field_obj.related_model._meta.get_field(
self.field_obj.to_field,
)
else:
target_field = self.field_obj.related_model._meta.pk
if not target_field:
msg = f"Related model {self.field_obj.related_model} does not have a primary key field."
raise ValueError(
msg,
) # This should never happen, but just in case, we raise an error here.
return (
FieldTypeRegistry.instance().get_handler(target_field).get_pydantic_type()
)


class OneToOneFieldHandler(DjangoFieldHandler[models.OneToOneField[models.Model]]):
class OneToOneFieldHandler(RelatedFieldHandler[models.OneToOneField[models.Model]]):
"""Handler for OneToOne fields."""

@override
@classmethod
def field(cls) -> type[models.OneToOneField[models.Model]]:
return models.OneToOneField

@override
def get_pydantic_type_raw(self):
if hasattr(self.field_obj, "to_field") and isinstance(
self.field_obj.to_field,
ModelBase,
):
target_field = self.field_obj.related_model._meta.get_field(
self.field_obj.to_field,
)
else:
target_field = self.field_obj.related_model._meta.pk
if not target_field:
msg = f"Related model {self.field_obj.related_model} does not have a primary key field."
raise ValueError(
msg,
) # This should never happen, but just in case, we raise an error here.
return (
FieldTypeRegistry.instance().get_handler(target_field).get_pydantic_type()
)


class ManyToManyFieldHandler(
DjangoFieldHandler[models.ManyToManyField[models.Model, models.Model]],
RelatedFieldHandler[models.ManyToManyField[models.Model, models.Model]],
):
"""Handler for ManyToMany fields."""

Expand All @@ -102,31 +128,6 @@ class ManyToManyFieldHandler(
def field(cls) -> type[models.ManyToManyField[models.Model, models.Model]]:
return models.ManyToManyField

def _get_target_field(self) -> models.Field[Any, Any]:
if hasattr(self.field_obj, "to_field") and isinstance(
self.field_obj.to_field,
ModelBase,
):
target_field = self.field_obj.related_model._meta.get_field(
self.field_obj.to_field,
)
else:
target_field = self.field_obj.related_model._meta.pk
if not target_field:
msg = f"Related model {self.field_obj.related_model} does not have a primary key field."
raise ValueError(
msg,
) # This should never happen, but just in case, we raise an error here.
return target_field

@override
def get_pydantic_type_raw(self):
return (
FieldTypeRegistry.instance()
.get_handler(self._get_target_field())
.get_pydantic_type()
)

@override
def get_pydantic_type(self) -> type[list[Annotated[Any, Any]]]:
"""Return the Pydantic type of the field."""
Expand Down
Loading

0 comments on commit fe97eaf

Please sign in to comment.