Skip to content

Commit

Permalink
Update models's fields' _Choices type
Browse files Browse the repository at this point in the history
Django 5 allows model field choices to be callables, mappings or
subclasses of models.Choices. This commit introduces these options in
the stubs.
  • Loading branch information
joaoseckler committed Sep 11, 2024
1 parent 316295f commit a162237
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 3 deletions.
18 changes: 16 additions & 2 deletions django-stubs/db/models/fields/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ from django import forms
from django.core import validators # due to weird mypy.stubtest error
from django.core.checks import CheckMessage
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Model
from django.db.models import Choices, Model
from django.db.models.expressions import Col, Combinable, Expression, Func
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.query_utils import Q, RegisterLookupMixin
from django.forms import Widget
from django.utils.choices import BlankChoiceIterator, _Choice, _ChoiceNamedGroup, _Choices, _ChoicesCallable
from django.utils.choices import BlankChoiceIterator, _Choice, _ChoiceNamedGroup, _ChoicesCallable, _ChoicesMapping
from django.utils.choices import _Choices as _ChoicesSequence
from django.utils.datastructures import DictWrapper
from django.utils.functional import _Getter, _StrOrPromise, cached_property
from typing_extensions import Self, TypeAlias
Expand All @@ -26,6 +27,9 @@ BLANK_CHOICE_DASH: list[tuple[str, str]]

_ChoicesList: TypeAlias = Sequence[_Choice] | Sequence[_ChoiceNamedGroup]
_LimitChoicesTo: TypeAlias = Q | dict[str, Any]
_Choices: TypeAlias = (
_ChoicesSequence | _ChoicesMapping | type[Choices] | Callable[[], _ChoicesSequence | _ChoicesMapping]
)

_F = TypeVar("_F", bound=Field, covariant=True)

Expand Down Expand Up @@ -151,6 +155,7 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
system_check_removed_details: Any | None
system_check_deprecated_details: Any | None
non_db_attrs: tuple[str, ...]

def __init__(
self,
verbose_name: _StrOrPromise | None = None,
Expand Down Expand Up @@ -265,6 +270,7 @@ class DecimalField(Field[_ST, _GT]):
# attributes
max_digits: int
decimal_places: int

def __init__(
self,
verbose_name: _StrOrPromise | None = None,
Expand Down Expand Up @@ -296,6 +302,7 @@ class CharField(Field[_ST, _GT]):
_pyi_private_get_type: str
# objects are converted to string before comparison
_pyi_lookup_exact_type: Any

def __init__(
self,
verbose_name: _StrOrPromise | None = ...,
Expand Down Expand Up @@ -395,6 +402,7 @@ class TextField(Field[_ST, _GT]):
_pyi_private_get_type: str
# objects are converted to string before comparison
_pyi_lookup_exact_type: Any

def __init__(
self,
verbose_name: _StrOrPromise | None = ...,
Expand Down Expand Up @@ -445,6 +453,7 @@ class GenericIPAddressField(Field[_ST, _GT]):
default_error_messages: _ErrorMessagesDict
unpack_ipv4: bool
protocol: str

def __init__(
self,
verbose_name: _StrOrPromise | None = None,
Expand Down Expand Up @@ -478,6 +487,7 @@ class DateField(DateTimeCheckMixin, Field[_ST, _GT]):
_pyi_lookup_exact_type: str | date
auto_now: bool
auto_now_add: bool

def __init__(
self,
verbose_name: _StrOrPromise | None = None,
Expand Down Expand Up @@ -510,6 +520,7 @@ class TimeField(DateTimeCheckMixin, Field[_ST, _GT]):
_pyi_private_get_type: time
auto_now: bool
auto_now_add: bool

def __init__(
self,
verbose_name: _StrOrPromise | None = None,
Expand Down Expand Up @@ -545,6 +556,7 @@ class UUIDField(Field[_ST, _GT]):
_pyi_private_set_type: str | uuid.UUID
_pyi_private_get_type: uuid.UUID
_pyi_lookup_exact_type: uuid.UUID | str

def __init__(
self,
verbose_name: _StrOrPromise | None = None,
Expand Down Expand Up @@ -580,6 +592,7 @@ class FilePathField(Field[_ST, _GT]):
recursive: bool
allow_files: bool
allow_folders: bool

def __init__(
self,
verbose_name: _StrOrPromise | None = None,
Expand Down Expand Up @@ -618,6 +631,7 @@ class DurationField(Field[_ST, _GT]):

class AutoFieldMixin:
db_returning: bool

def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]: ...

class AutoFieldMeta(type): ...
Expand Down
5 changes: 4 additions & 1 deletion django-stubs/utils/choices.pyi
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections.abc import Iterable, Iterator
from collections.abc import Iterable, Iterator, Mapping
from typing import Any, Protocol, TypeVar, type_check_only

from typing_extensions import TypeAlias

_Choice: TypeAlias = tuple[Any, Any]
_ChoiceNamedGroup: TypeAlias = tuple[str, Iterable[_Choice]]
_Choices: TypeAlias = Iterable[_Choice | _ChoiceNamedGroup]
_ChoicesMapping: TypeAlias = Mapping[Any, Any] | Mapping[str, Mapping[Any, Any]] # noqa: PYI047

@type_check_only
class _ChoicesCallable(Protocol):
Expand All @@ -18,10 +19,12 @@ class BaseChoiceIterator:
class BlankChoiceIterator(BaseChoiceIterator):
choices: _Choices
blank_choice: _Choices

def __init__(self, choices: _Choices, blank_choice: _Choices) -> None: ...

class CallableChoiceIterator(BaseChoiceIterator):
func: _ChoicesCallable

def __init__(self, func: _ChoicesCallable) -> None: ...

_V = TypeVar("_V")
Expand Down
89 changes: 89 additions & 0 deletions tests/assert_type/db/models/fields/test_choices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Tuple, TypeVar

from django.db import models
from typing_extensions import assert_type

_T = TypeVar("_T")


def to_named_seq(func: Callable[[], _T]) -> Callable[[], Sequence[Tuple[str, _T]]]:
def inner() -> Sequence[Tuple[str, _T]]:
return [("title", func())]

return inner


def to_named_mapping(func: Callable[[], _T]) -> Callable[[], Mapping[str, _T]]:
def inner() -> Mapping[str, _T]:
return {"title": func()}

return inner


def str_tuple() -> Sequence[Tuple[str, str]]:
return (("foo", "bar"), ("fuzz", "bazz"))


def str_mapping() -> Mapping[str, str]:
return {"foo": "bar", "fuzz": "bazz"}


def int_tuple() -> Sequence[Tuple[int, str]]:
return ((1, "bar"), (2, "bazz"))


def int_mapping() -> Mapping[int, str]:
return {3: "bar", 4: "bazz"}


class TestModel(models.Model):
class TextChoices(models.TextChoices):
FIRST = "foo", "bar"
SECOND = "foo2", "bar"

class IntegerChoices(models.IntegerChoices):
FIRST = 1, "bar"
SECOND = 2, "bar"

char1 = models.CharField[str, str](max_length=5, choices=TextChoices, default="foo")
char2 = models.CharField[str, str](max_length=5, choices=str_tuple, default="foo")
char3 = models.CharField[str, str](max_length=5, choices=str_mapping, default="foo")
char4 = models.CharField[str, str](max_length=5, choices=str_tuple(), default="foo")
char5 = models.CharField[str, str](max_length=5, choices=str_mapping(), default="foo")
char6 = models.CharField[str, str](max_length=5, choices=to_named_seq(str_tuple), default="foo")
char7 = models.CharField[str, str](max_length=5, choices=to_named_mapping(str_mapping), default="foo")
char8 = models.CharField[str, str](max_length=5, choices=to_named_seq(str_tuple)(), default="foo")
char9 = models.CharField[str, str](max_length=5, choices=to_named_mapping(str_mapping)(), default="foo")

int1 = models.IntegerField[int, int](choices=IntegerChoices, default=1)
int2 = models.IntegerField[int, int](choices=int_tuple, default=1)
int3 = models.IntegerField[int, int](choices=int_mapping, default=1)
int4 = models.IntegerField[int, int](choices=int_tuple(), default=1)
int5 = models.IntegerField[int, int](choices=int_mapping(), default=1)
int6 = models.IntegerField[int, int](choices=to_named_seq(int_tuple), default=1)
int7 = models.IntegerField[int, int](choices=to_named_seq(int_mapping), default=1)
int8 = models.IntegerField[int, int](choices=to_named_seq(int_tuple)(), default=1)
int9 = models.IntegerField[int, int](choices=to_named_seq(int_mapping)(), default=1)


instance = TestModel()
assert_type(instance.char1, str)
assert_type(instance.char2, str)
assert_type(instance.char3, str)
assert_type(instance.char4, str)
assert_type(instance.char5, str)
assert_type(instance.char6, str)
assert_type(instance.char7, str)
assert_type(instance.char8, str)
assert_type(instance.char9, str)

assert_type(instance.int1, int)
assert_type(instance.int2, int)
assert_type(instance.int3, int)
assert_type(instance.int4, int)
assert_type(instance.int5, int)
assert_type(instance.int6, int)
assert_type(instance.int7, int)
assert_type(instance.int8, int)
assert_type(instance.int9, int)
159 changes: 159 additions & 0 deletions tests/typecheck/db/models/test_fields.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
- case: db_models_fields_choices
main: |
from collections.abc import Callable, Mapping, Sequence
from datetime import date, time
from decimal import Decimal
from typing import TypeVar
from uuid import UUID
from django.db import models
_T = TypeVar("_T")
def to_named_seq(func: Callable[[], _T]) -> Callable[[], Sequence[tuple[str, _T]]]:
def inner() -> Sequence[tuple[str, _T]]:
return [("title", func())]
return inner
def to_named_mapping(func: Callable[[], _T]) -> Callable[[], Mapping[str, _T]]:
def inner() -> Mapping[str, _T]:
return {"title": func()}
return inner
def str_tuple() -> Sequence[tuple[str, str]]:
return (("foo", "bar"), ("fuzz", "bazz"))
def str_mapping() -> Mapping[str, str]:
return {"foo": "bar", "fuzz": "bazz"}
def int_tuple() -> Sequence[tuple[int, str]]:
return ((1, "bar"), (2, "bazz"))
def int_mapping() -> Mapping[int, str]:
return {3: "bar", 4: "bazz"}
def dec_tuple() -> Sequence[tuple[Decimal, str]]:
return ((Decimal(1), "bar"), (Decimal(2), "bazz"))
def dec_mapping() -> Mapping[Decimal, str]:
return {Decimal(3): "bar", Decimal(4): "bazz"}
def url_tuple() -> Sequence[tuple[str, str]]:
return (("https://python.org", "bar"), ("https://mypy-lang.org", "bazz"))
def url_mapping() -> Mapping[str, str]:
return {"https://python.org": "bar", "https://mypy-lang.org": "bazz"}
def date_tuple() -> Sequence[tuple[date, str]]:
return ((date.today(), "bar"), (date(2024, 1, 1), "bazz"))
def date_mapping() -> Mapping[date, str]:
return {date.today(): "bar", date(2024, 1, 1): "bazz"}
def time_tuple() -> Sequence[tuple[time, str]]:
return ((time(0, 0, 2), "bar"), (time(0, 0, 1), "bazz"))
def time_mapping() -> Mapping[time, str]:
return {time(0, 0, 2): "bar", time(0, 0, 1): "bazz"}
def uuid_tuple() -> Sequence[tuple[UUID, str]]:
return ((UUID(), "bar"), (UUID(), "bazz"))
def uuid_mapping() -> Mapping[UUID, str]:
return {UUID(): "bar", UUID(): "bazz"}
class NewModel(models.Model):
class TextChoices(models.TextChoices):
FIRST = "foo", "bar"
SECOND = "foo2", "bar"
class IntegerChoices(models.IntegerChoices):
FIRST = 1, "bar"
SECOND = 2, "bar"
char1 = models.CharField[str, str](max_length=200, choices=TextChoices)
char2 = models.CharField[str, str](max_length=200, choices=str_tuple)
char3 = models.CharField[str, str](max_length=200, choices=str_mapping)
char4 = models.CharField[str, str](max_length=200, choices=str_tuple())
char5 = models.CharField[str, str](max_length=200, choices=str_mapping())
char6 = models.CharField[str, str](max_length=200, choices=to_named_seq(str_tuple))
char7 = models.CharField[str, str](max_length=200, choices=to_named_seq(str_tuple)())
char8 = models.CharField[str, str](max_length=200, choices=to_named_mapping(str_mapping))
char9 = models.CharField[str, str](max_length=200, choices=to_named_mapping(str_mapping)())
int1 = models.IntegerField[int, int](choices=IntegerChoices)
int2 = models.IntegerField[int, int](choices=int_tuple)
int3 = models.IntegerField[int, int](choices=int_mapping)
int4 = models.IntegerField[int, int](choices=int_tuple())
int5 = models.IntegerField[int, int](choices=int_mapping())
int6 = models.IntegerField[int, int](choices=to_named_seq(str_tuple))
int7 = models.IntegerField[int, int](choices=to_named_seq(str_tuple)())
int8 = models.IntegerField[int, int](choices=to_named_mapping(str_mapping))
int9 = models.IntegerField[int, int](choices=to_named_mapping(str_mapping)())
dec1 = models.DecimalField[Decimal, Decimal](choices=dec_tuple)
dec2 = models.DecimalField[Decimal, Decimal](choices=dec_mapping)
dec3 = models.DecimalField[Decimal, Decimal](choices=dec_tuple())
dec4 = models.DecimalField[Decimal, Decimal](choices=dec_mapping())
slug1 = models.SlugField[str, str](choices=TextChoices)
slug4 = models.SlugField[str, str](choices=str_tuple)
slug5 = models.SlugField[str, str](choices=str_mapping)
slug2 = models.SlugField[str, str](choices=str_tuple())
slug3 = models.SlugField[str, str](choices=str_mapping())
url1 = models.URLField[str, str](choices=str_tuple)
url2 = models.URLField[str, str](choices=str_mapping)
url3 = models.URLField[str, str](choices=str_tuple())
url4 = models.URLField[str, str](choices=str_mapping())
text1 = models.TextField[str, str](choices=TextChoices)
text2 = models.TextField[str, str](choices=str_tuple)
text3 = models.TextField[str, str](choices=str_mapping)
text4 = models.TextField[str, str](choices=str_tuple())
text5 = models.TextField[str, str](choices=str_mapping())
ip1 = models.GenericIPAddressField[int, int](choices=int_tuple)
ip2 = models.GenericIPAddressField[int, int](choices=int_mapping)
ip3 = models.GenericIPAddressField[int, int](choices=int_tuple())
ip4 = models.GenericIPAddressField[int, int](choices=int_mapping())
date1 = models.DateField[date, date](choices=date_tuple)
date2 = models.DateField[date, date](choices=date_mapping)
date3 = models.DateField[date, date](choices=date_tuple())
date4 = models.DateField[date, date](choices=date_mapping())
time1 = models.TimeField[time, time](choices=time_tuple)
time2 = models.TimeField[time, time](choices=time_mapping)
time3 = models.TimeField[time, time](choices=time_tuple())
time4 = models.TimeField[time, time](choices=time_mapping())
uuid1 = models.UUIDField[UUID, UUID](choices=uuid_tuple)
uuid2 = models.UUIDField[UUID, UUID](choices=uuid_mapping)
uuid3 = models.UUIDField[UUID, UUID](choices=uuid_tuple())
uuid4 = models.UUIDField[UUID, UUID](choices=uuid_mapping())
path1 = models.FilePathField[str, str](choices=TextChoices)
path2 = models.FilePathField[str, str](choices=str_tuple)
path3 = models.FilePathField[str, str](choices=str_mapping)
path4 = models.FilePathField[str, str](choices=str_tuple())
path5 = models.FilePathField[str, str](choices=str_mapping())

0 comments on commit a162237

Please sign in to comment.