Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Mypy's strict mode #937

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1f3f315
Use Mypy's strict mode
adamchainz Aug 15, 2022
7ee3a90
Finish hints for cache module
adamchainz Aug 22, 2022
0072a12
Finish hints for utils
adamchainz Aug 22, 2022
273b76c
Finish hints for operations
adamchainz Aug 22, 2022
7d8beef
Finish hints for locks
adamchainz Aug 22, 2022
b99b21e
Finish hints for management commands
adamchainz Aug 22, 2022
447dcc3
Finish hints for status
adamchainz Aug 22, 2022
a885a6b
Stricter signature for contribute_to_class
adamchainz Aug 22, 2022
586531a
Fix JSONExtract output_field arg
adamchainz Aug 22, 2022
7639367
Fix types for IndexLookup.as_sql()
adamchainz Aug 22, 2022
107e55b
Fix type for GroupConcat arg 1
adamchainz Aug 22, 2022
e4449c7
Fix some as_sql() signatures
adamchainz Aug 26, 2022
a6f7929
assert
adamchainz Aug 27, 2022
facdc62
mute isinstance
adamchainz Aug 28, 2022
7aeb183
more fixes
adamchainz Aug 30, 2022
8758ed1
Improve function types
adamchainz Oct 18, 2022
42f84b0
Add cast
adamchainz Oct 18, 2022
1d0bcd9
Fix AsType signature
adamchainz Oct 19, 2022
0d70001
Fix name of pytest fixture to avoid collision
adamchainz Oct 19, 2022
1ce99b7
Fix some errors in cache tests
adamchainz Oct 19, 2022
dd2c058
Fix return type of IndexLookup.as_sql
adamchainz Oct 19, 2022
a2c8e60
make formfield() methods return Any
adamchainz Oct 19, 2022
fc8af0a
Add extra model type asserts for deserialization tests
adamchainz Sep 9, 2024
873b6da
Some extra hints
adamchainz Sep 9, 2024
2533519
Fix formfield() methods
adamchainz Sep 9, 2024
a0874a2
Allow bad arg types for tests checking that
adamchainz Sep 9, 2024
842cbcf
Pass connection to db_type()
adamchainz Sep 9, 2024
07f44d0
Correct mysql_connections()
adamchainz Sep 9, 2024
de983de
Correct types of source expression functions
adamchainz Sep 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@ repos:
rev: v1.11.2
hooks:
- id: mypy
additional_dependencies:
- django-stubs==5.0.4
- mysqlclient
- pytest==8.3.2
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,19 @@ enable_error_code = [
"redundant-expr",
"truthy-bool",
]
check_untyped_defs = true
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
mypy_path = "src/"
namespace_packages = false
no_implicit_optional = true
plugins = [ "mypy_django_plugin.main" ]
strict = true
warn_unreachable = true
warn_unused_ignores = true

[[tool.mypy.overrides]]
module = "tests.*"
allow_untyped_defs = true

[tool.django-stubs]
django_settings_module = "tests.settings"

[tool.rstcheck]
ignore_directives = [
"automodule",
Expand Down
18 changes: 12 additions & 6 deletions src/django_mysql/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from django.core.cache.backends.base import default_key_func
from django.db import connections
from django.db import router
from django.db.models import Model
from django.utils.encoding import force_bytes
from django.utils.module_loading import import_string

Expand Down Expand Up @@ -62,7 +63,9 @@ def __init__(self, table: str, params: dict[str, Any]) -> None:
super().__init__(params)
self._table = table

class CacheEntry:
CacheEntry: type[Model] # force Mypy to accept duck typing

class CacheEntry: # type: ignore [no-redef]
_meta = Options(table)

self.cache_model_class = CacheEntry
Expand Down Expand Up @@ -183,7 +186,7 @@ def get_many(
self, keys: Iterable[str], version: int | None = None
) -> dict[str, Any]:
made_key_to_key = {self.make_key(key, version=version): key for key in keys}
made_keys = list(made_key_to_key.keys())
made_keys: list[Any] = list(made_key_to_key.keys())
for key in made_keys:
self.validate_key(key)

Expand Down Expand Up @@ -266,7 +269,7 @@ def _base_set(
return True
else: # mode = 'add'
# Use a special code in the add query for "did insert"
insert_id = cursor.lastrowid
insert_id: int = cursor.lastrowid
return insert_id != 444

_set_many_query = collapse_spaces(
Expand Down Expand Up @@ -416,7 +419,8 @@ def _base_delta(
raise ValueError("Key '%s' not found, or not an integer" % key)

# New value stored in insert_id
return cursor.lastrowid
result: int = cursor.lastrowid
return result

# Looks a bit tangled to turn the blob back into an int for updating, but
# it works. Stores the new value for insert_id() with LAST_INSERT_ID
Expand Down Expand Up @@ -448,7 +452,7 @@ def touch(
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)
with connections[db].cursor() as cursor:
affected_rows = cursor.execute(
affected_rows: int = cursor.execute(
self._touch_query.format(table=table), [exp, key, self._now()]
)
return affected_rows > 0
Expand Down Expand Up @@ -612,18 +616,20 @@ def delete_with_prefix(self, prefix: str, version: int | None = None) -> int:
prefix = self.make_key(prefix + "%", version=version)

with connections[db].cursor() as cursor:
return cursor.execute(
result: int = cursor.execute(
"""DELETE FROM {table}
WHERE cache_key LIKE %s""".format(
table=table
),
(prefix,),
)
return result

def cull(self) -> int:
db = router.db_for_write(self.cache_model_class)
table = connections[db].ops.quote_name(self._table)

num_deleted: int
with connections[db].cursor() as cursor:
# First, try just deleting expired keys
num_deleted = cursor.execute(
Expand Down
2 changes: 2 additions & 0 deletions src/django_mysql/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import TypeVar
from typing import cast

__all__ = ("cache",)

if sys.version_info >= (3, 9):
from functools import cache
else:
Expand Down
6 changes: 5 additions & 1 deletion src/django_mysql/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.db import connections
from django.db.backends.utils import CursorWrapper
from django.db.models import Model
from django.db.transaction import Atomic
from django.db.transaction import TransactionManagementError
from django.db.transaction import atomic
from django.db.utils import DEFAULT_DB_ALIAS
Expand Down Expand Up @@ -77,7 +78,8 @@ def is_held(self) -> bool:
def holding_connection_id(self) -> int | None:
with self.get_cursor() as cursor:
cursor.execute("SELECT IS_USED_LOCK(%s)", (self.name,))
return cursor.fetchone()[0]
result: int | None = cursor.fetchone()[0]
return result

@classmethod
def held_with_prefix(
Expand Down Expand Up @@ -108,6 +110,7 @@ def __init__(
self.read: list[str] = self._process_names(read)
self.write: list[str] = self._process_names(write)
self.db = DEFAULT_DB_ALIAS if using is None else using
self._atomic: Atomic | None = None

def _process_names(self, names: list[str | type[Model]] | None) -> list[str]:
"""
Expand Down Expand Up @@ -170,6 +173,7 @@ def release(
) -> None:
connection = connections[self.db]
with connection.cursor() as cursor:
assert self._atomic is not None
self._atomic.__exit__(exc_type, exc_value, exc_traceback)
self._atomic = None
cursor.execute("UNLOCK TABLES")
7 changes: 4 additions & 3 deletions src/django_mysql/management/commands/cull_mysql_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None:
help="Specify the cache alias(es) to cull.",
)

def handle(
self, *args: Any, verbosity: int, aliases: list[str], **options: Any
) -> None:
def handle(self, *args: Any, **options: Any) -> None:
verbosity: int = options["verbosity"]
aliases: list[str] = options["aliases"]

if not aliases:
aliases = list(settings.CACHES)

Expand Down
8 changes: 5 additions & 3 deletions src/django_mysql/management/commands/dbparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None:
"pt-online-schema-change $(./manage.py dbparams --dsn)",
)

def handle(
self, *args: Any, alias: str, show_mysql: bool, show_dsn: bool, **options: Any
) -> None:
def handle(self, *args: Any, **options: Any) -> None:
alias: str = options["alias"]
show_mysql: bool = options["show_mysql"]
show_dsn: bool = options["show_dsn"]

try:
connection = connections[alias]
except ConnectionDoesNotExist:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None:
help="Specify the cache alias(es) to create migrations for.",
)

def handle(self, *args: Any, aliases: list[str], **options: Any) -> None:
def handle(self, *args: Any, **options: Any) -> None:
aliases: list[str] = options["aliases"]
if not aliases:
aliases = list(settings.CACHES)

Expand Down
30 changes: 29 additions & 1 deletion src/django_mysql/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django_mysql.models.aggregates import BitOr
from django_mysql.models.aggregates import BitXor
from django_mysql.models.aggregates import GroupConcat
from django_mysql.models.base import Model # noqa
from django_mysql.models.base import Model
from django_mysql.models.expressions import ListF
from django_mysql.models.expressions import SetF
from django_mysql.models.fields import Bit1BooleanField
Expand All @@ -25,3 +25,31 @@
from django_mysql.models.query import SmartIterator
from django_mysql.models.query import add_QuerySetMixin
from django_mysql.models.query import pt_visual_explain

__all__ = (
"add_QuerySetMixin",
"ApproximateInt",
"Bit1BooleanField",
"BitAnd",
"BitOr",
"BitXor",
"DynamicField",
"EnumField",
"FixedCharField",
"GroupConcat",
"ListCharField",
"ListF",
"ListTextField",
"Model",
"NullBit1BooleanField",
"pt_visual_explain",
"QuerySet",
"QuerySetMixin",
"SetCharField",
"SetF",
"SetTextField",
"SizedBinaryField",
"SizedTextField",
"SmartChunkedIterator",
"SmartIterator",
)
3 changes: 1 addition & 2 deletions src/django_mysql/models/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Aggregate
from django.db.models import CharField
from django.db.models import Expression
from django.db.models.sql.compiler import SQLCompiler


Expand All @@ -29,7 +28,7 @@ class GroupConcat(Aggregate):

def __init__(
self,
expression: Expression,
expression: Any,
distinct: bool = False,
separator: str | None = None,
ordering: str | None = None,
Expand Down
16 changes: 9 additions & 7 deletions src/django_mysql/models/expressions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from typing import Any
from typing import Iterable
from typing import Sequence

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import F
from django.db.models import Value
from django.db.models.expressions import BaseExpression
from django.db.models.expressions import Combinable
from django.db.models.expressions import Expression
from django.db.models.sql.compiler import SQLCompiler

from django_mysql.utils import collapse_spaces
Expand All @@ -18,10 +20,10 @@ def __init__(self, lhs: BaseExpression, rhs: BaseExpression) -> None:
self.lhs = lhs
self.rhs = rhs

def get_source_expressions(self) -> list[BaseExpression]:
def get_source_expressions(self) -> list[Expression]:
return [self.lhs, self.rhs]

def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
self.lhs, self.rhs = exprs


Expand Down Expand Up @@ -138,10 +140,10 @@ def __init__(self, lhs: BaseExpression) -> None:
super().__init__()
self.lhs = lhs

def get_source_expressions(self) -> list[BaseExpression]:
def get_source_expressions(self) -> list[Expression]:
return [self.lhs]

def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
(self.lhs,) = exprs

def as_sql(
Expand Down Expand Up @@ -170,10 +172,10 @@ def __init__(self, lhs: BaseExpression) -> None:
super().__init__()
self.lhs = lhs

def get_source_expressions(self) -> list[BaseExpression]:
def get_source_expressions(self) -> list[Expression]:
return [self.lhs]

def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
(self.lhs,) = exprs

def as_sql(
Expand Down
4 changes: 2 additions & 2 deletions src/django_mysql/models/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from django_mysql.models.fields.sizes import SizedBinaryField
from django_mysql.models.fields.sizes import SizedTextField

__all__ = [
__all__ = (
"Bit1BooleanField",
"DynamicField",
"EnumField",
Expand All @@ -24,4 +24,4 @@
"SetTextField",
"SizedBinaryField",
"SizedTextField",
]
)
14 changes: 7 additions & 7 deletions src/django_mysql/models/fields/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Union
from typing import cast

from django import forms
from django.core import checks
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import DateField
Expand All @@ -22,12 +23,11 @@
from django.db.models import TimeField
from django.db.models import Transform
from django.db.models.sql.compiler import SQLCompiler
from django.forms import Field as FormField
from django.utils.translation import gettext_lazy as _

from django_mysql.checks import mysql_connections
from django_mysql.models.lookups import DynColHasKey
from django_mysql.typing import DeconstructResult
from django_mysql.utils import mysql_connections

try:
import mariadb_dyncol
Expand Down Expand Up @@ -89,7 +89,7 @@ def check(self, **kwargs: Any) -> list[checks.CheckMessage]:
return errors

def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []
if mariadb_dyncol is None:
errors.append(
checks.Error(
Expand All @@ -102,7 +102,7 @@ def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]:
return errors

def _check_mariadb_version(self) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []

any_conn_works = any(
(conn.vendor == "mysql" and conn.mysql_is_mariadb)
Expand All @@ -121,7 +121,7 @@ def _check_mariadb_version(self) -> list[checks.CheckMessage]:
return errors

def _check_character_set(self) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []

conn = None
for _alias, check_conn in mysql_connections():
Expand Down Expand Up @@ -153,7 +153,7 @@ def _check_character_set(self) -> list[checks.CheckMessage]:
def _check_spec_recursively(
self, spec: Any, path: str = ""
) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []

if not isinstance(spec, dict):
errors.append(
Expand Down Expand Up @@ -292,7 +292,7 @@ def deconstruct(self) -> DeconstructResult:
kwargs["blank"] = False
return name, path, args, kwargs

def formfield(self, *args: Any, **kwargs: Any) -> FormField | None:
def formfield(self, *args: Any, **kwargs: Any) -> forms.Field | None:
"""
Disabled in forms - there is no sensible way of editing this
"""
Expand Down
Loading
Loading