diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ad875596..dd5bffd6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -73,3 +73,7 @@ repos: rev: v1.11.2 hooks: - id: mypy + additional_dependencies: + - django-stubs==5.0.4 + - mysqlclient + - pytest==8.3.2 diff --git a/pyproject.toml b/pyproject.toml index 1fc4b11f..85e24dcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,20 +92,19 @@ enable_error_code = [ "redundant-expr", "truthy-bool", ] -check_untyped_defs = true -disallow_any_generics = true -disallow_incomplete_defs = true -disallow_untyped_defs = true mypy_path = "src/" namespace_packages = false -no_implicit_optional = true +plugins = [ "mypy_django_plugin.main" ] +strict = true warn_unreachable = true -warn_unused_ignores = true [[tool.mypy.overrides]] module = "tests.*" allow_untyped_defs = true +[tool.django-stubs] +django_settings_module = "tests.settings" + [tool.rstcheck] ignore_directives = [ "automodule", diff --git a/src/django_mysql/cache.py b/src/django_mysql/cache.py index 6cbcf146..230b5e78 100644 --- a/src/django_mysql/cache.py +++ b/src/django_mysql/cache.py @@ -19,6 +19,7 @@ from django.core.cache.backends.base import default_key_func from django.db import connections from django.db import router +from django.db.models import Model from django.utils.encoding import force_bytes from django.utils.module_loading import import_string @@ -62,7 +63,9 @@ def __init__(self, table: str, params: dict[str, Any]) -> None: super().__init__(params) self._table = table - class CacheEntry: + CacheEntry: type[Model] # force Mypy to accept duck typing + + class CacheEntry: # type: ignore [no-redef] _meta = Options(table) self.cache_model_class = CacheEntry @@ -183,7 +186,7 @@ def get_many( self, keys: Iterable[str], version: int | None = None ) -> dict[str, Any]: made_key_to_key = {self.make_key(key, version=version): key for key in keys} - made_keys = list(made_key_to_key.keys()) + made_keys: list[Any] = list(made_key_to_key.keys()) for key in made_keys: self.validate_key(key) @@ -266,7 +269,7 @@ def _base_set( return True else: # mode = 'add' # Use a special code in the add query for "did insert" - insert_id = cursor.lastrowid + insert_id: int = cursor.lastrowid return insert_id != 444 _set_many_query = collapse_spaces( @@ -416,7 +419,8 @@ def _base_delta( raise ValueError("Key '%s' not found, or not an integer" % key) # New value stored in insert_id - return cursor.lastrowid + result: int = cursor.lastrowid + return result # Looks a bit tangled to turn the blob back into an int for updating, but # it works. Stores the new value for insert_id() with LAST_INSERT_ID @@ -448,7 +452,7 @@ def touch( db = router.db_for_write(self.cache_model_class) table = connections[db].ops.quote_name(self._table) with connections[db].cursor() as cursor: - affected_rows = cursor.execute( + affected_rows: int = cursor.execute( self._touch_query.format(table=table), [exp, key, self._now()] ) return affected_rows > 0 @@ -612,18 +616,20 @@ def delete_with_prefix(self, prefix: str, version: int | None = None) -> int: prefix = self.make_key(prefix + "%", version=version) with connections[db].cursor() as cursor: - return cursor.execute( + result: int = cursor.execute( """DELETE FROM {table} WHERE cache_key LIKE %s""".format( table=table ), (prefix,), ) + return result def cull(self) -> int: db = router.db_for_write(self.cache_model_class) table = connections[db].ops.quote_name(self._table) + num_deleted: int with connections[db].cursor() as cursor: # First, try just deleting expired keys num_deleted = cursor.execute( diff --git a/src/django_mysql/compat.py b/src/django_mysql/compat.py index a072e556..97a0b089 100644 --- a/src/django_mysql/compat.py +++ b/src/django_mysql/compat.py @@ -6,6 +6,8 @@ from typing import TypeVar from typing import cast +__all__ = ("cache",) + if sys.version_info >= (3, 9): from functools import cache else: diff --git a/src/django_mysql/locks.py b/src/django_mysql/locks.py index 9470df1f..39d8b0f1 100644 --- a/src/django_mysql/locks.py +++ b/src/django_mysql/locks.py @@ -6,6 +6,7 @@ from django.db import connections from django.db.backends.utils import CursorWrapper from django.db.models import Model +from django.db.transaction import Atomic from django.db.transaction import TransactionManagementError from django.db.transaction import atomic from django.db.utils import DEFAULT_DB_ALIAS @@ -77,7 +78,8 @@ def is_held(self) -> bool: def holding_connection_id(self) -> int | None: with self.get_cursor() as cursor: cursor.execute("SELECT IS_USED_LOCK(%s)", (self.name,)) - return cursor.fetchone()[0] + result: int | None = cursor.fetchone()[0] + return result @classmethod def held_with_prefix( @@ -108,6 +110,7 @@ def __init__( self.read: list[str] = self._process_names(read) self.write: list[str] = self._process_names(write) self.db = DEFAULT_DB_ALIAS if using is None else using + self._atomic: Atomic | None = None def _process_names(self, names: list[str | type[Model]] | None) -> list[str]: """ @@ -170,6 +173,7 @@ def release( ) -> None: connection = connections[self.db] with connection.cursor() as cursor: + assert self._atomic is not None self._atomic.__exit__(exc_type, exc_value, exc_traceback) self._atomic = None cursor.execute("UNLOCK TABLES") diff --git a/src/django_mysql/management/commands/cull_mysql_caches.py b/src/django_mysql/management/commands/cull_mysql_caches.py index ae8d6d81..27e2683a 100644 --- a/src/django_mysql/management/commands/cull_mysql_caches.py +++ b/src/django_mysql/management/commands/cull_mysql_caches.py @@ -31,9 +31,10 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None: help="Specify the cache alias(es) to cull.", ) - def handle( - self, *args: Any, verbosity: int, aliases: list[str], **options: Any - ) -> None: + def handle(self, *args: Any, **options: Any) -> None: + verbosity: int = options["verbosity"] + aliases: list[str] = options["aliases"] + if not aliases: aliases = list(settings.CACHES) diff --git a/src/django_mysql/management/commands/dbparams.py b/src/django_mysql/management/commands/dbparams.py index fa6791f9..4de2c3d7 100644 --- a/src/django_mysql/management/commands/dbparams.py +++ b/src/django_mysql/management/commands/dbparams.py @@ -51,9 +51,11 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None: "pt-online-schema-change $(./manage.py dbparams --dsn)", ) - def handle( - self, *args: Any, alias: str, show_mysql: bool, show_dsn: bool, **options: Any - ) -> None: + def handle(self, *args: Any, **options: Any) -> None: + alias: str = options["alias"] + show_mysql: bool = options["show_mysql"] + show_dsn: bool = options["show_dsn"] + try: connection = connections[alias] except ConnectionDoesNotExist: diff --git a/src/django_mysql/management/commands/mysql_cache_migration.py b/src/django_mysql/management/commands/mysql_cache_migration.py index c73008cf..9222c23d 100644 --- a/src/django_mysql/management/commands/mysql_cache_migration.py +++ b/src/django_mysql/management/commands/mysql_cache_migration.py @@ -30,7 +30,8 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None: help="Specify the cache alias(es) to create migrations for.", ) - def handle(self, *args: Any, aliases: list[str], **options: Any) -> None: + def handle(self, *args: Any, **options: Any) -> None: + aliases: list[str] = options["aliases"] if not aliases: aliases = list(settings.CACHES) diff --git a/src/django_mysql/models/__init__.py b/src/django_mysql/models/__init__.py index 5b074b13..36ea1f52 100644 --- a/src/django_mysql/models/__init__.py +++ b/src/django_mysql/models/__init__.py @@ -4,7 +4,7 @@ from django_mysql.models.aggregates import BitOr from django_mysql.models.aggregates import BitXor from django_mysql.models.aggregates import GroupConcat -from django_mysql.models.base import Model # noqa +from django_mysql.models.base import Model from django_mysql.models.expressions import ListF from django_mysql.models.expressions import SetF from django_mysql.models.fields import Bit1BooleanField @@ -25,3 +25,31 @@ from django_mysql.models.query import SmartIterator from django_mysql.models.query import add_QuerySetMixin from django_mysql.models.query import pt_visual_explain + +__all__ = ( + "add_QuerySetMixin", + "ApproximateInt", + "Bit1BooleanField", + "BitAnd", + "BitOr", + "BitXor", + "DynamicField", + "EnumField", + "FixedCharField", + "GroupConcat", + "ListCharField", + "ListF", + "ListTextField", + "Model", + "NullBit1BooleanField", + "pt_visual_explain", + "QuerySet", + "QuerySetMixin", + "SetCharField", + "SetF", + "SetTextField", + "SizedBinaryField", + "SizedTextField", + "SmartChunkedIterator", + "SmartIterator", +) diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index 6370831c..58170761 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -5,7 +5,6 @@ from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import Aggregate from django.db.models import CharField -from django.db.models import Expression from django.db.models.sql.compiler import SQLCompiler @@ -29,7 +28,7 @@ class GroupConcat(Aggregate): def __init__( self, - expression: Expression, + expression: Any, distinct: bool = False, separator: str | None = None, ordering: str | None = None, diff --git a/src/django_mysql/models/expressions.py b/src/django_mysql/models/expressions.py index 464314ba..35c8ff51 100644 --- a/src/django_mysql/models/expressions.py +++ b/src/django_mysql/models/expressions.py @@ -1,12 +1,14 @@ from __future__ import annotations from typing import Any -from typing import Iterable +from typing import Sequence from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import F from django.db.models import Value from django.db.models.expressions import BaseExpression +from django.db.models.expressions import Combinable +from django.db.models.expressions import Expression from django.db.models.sql.compiler import SQLCompiler from django_mysql.utils import collapse_spaces @@ -18,10 +20,10 @@ def __init__(self, lhs: BaseExpression, rhs: BaseExpression) -> None: self.lhs = lhs self.rhs = rhs - def get_source_expressions(self) -> list[BaseExpression]: + def get_source_expressions(self) -> list[Expression]: return [self.lhs, self.rhs] - def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None: + def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None: self.lhs, self.rhs = exprs @@ -138,10 +140,10 @@ def __init__(self, lhs: BaseExpression) -> None: super().__init__() self.lhs = lhs - def get_source_expressions(self) -> list[BaseExpression]: + def get_source_expressions(self) -> list[Expression]: return [self.lhs] - def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None: + def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None: (self.lhs,) = exprs def as_sql( @@ -170,10 +172,10 @@ def __init__(self, lhs: BaseExpression) -> None: super().__init__() self.lhs = lhs - def get_source_expressions(self) -> list[BaseExpression]: + def get_source_expressions(self) -> list[Expression]: return [self.lhs] - def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None: + def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None: (self.lhs,) = exprs def as_sql( diff --git a/src/django_mysql/models/fields/__init__.py b/src/django_mysql/models/fields/__init__.py index f413947d..f2d0e011 100644 --- a/src/django_mysql/models/fields/__init__.py +++ b/src/django_mysql/models/fields/__init__.py @@ -12,7 +12,7 @@ from django_mysql.models.fields.sizes import SizedBinaryField from django_mysql.models.fields.sizes import SizedTextField -__all__ = [ +__all__ = ( "Bit1BooleanField", "DynamicField", "EnumField", @@ -24,4 +24,4 @@ "SetTextField", "SizedBinaryField", "SizedTextField", -] +) diff --git a/src/django_mysql/models/fields/dynamic.py b/src/django_mysql/models/fields/dynamic.py index db6891b3..63efbcc1 100644 --- a/src/django_mysql/models/fields/dynamic.py +++ b/src/django_mysql/models/fields/dynamic.py @@ -10,6 +10,7 @@ from typing import Union from typing import cast +from django import forms from django.core import checks from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import DateField @@ -22,12 +23,11 @@ from django.db.models import TimeField from django.db.models import Transform from django.db.models.sql.compiler import SQLCompiler -from django.forms import Field as FormField from django.utils.translation import gettext_lazy as _ -from django_mysql.checks import mysql_connections from django_mysql.models.lookups import DynColHasKey from django_mysql.typing import DeconstructResult +from django_mysql.utils import mysql_connections try: import mariadb_dyncol @@ -89,7 +89,7 @@ def check(self, **kwargs: Any) -> list[checks.CheckMessage]: return errors def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] if mariadb_dyncol is None: errors.append( checks.Error( @@ -102,7 +102,7 @@ def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]: return errors def _check_mariadb_version(self) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] any_conn_works = any( (conn.vendor == "mysql" and conn.mysql_is_mariadb) @@ -121,7 +121,7 @@ def _check_mariadb_version(self) -> list[checks.CheckMessage]: return errors def _check_character_set(self) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] conn = None for _alias, check_conn in mysql_connections(): @@ -153,7 +153,7 @@ def _check_character_set(self) -> list[checks.CheckMessage]: def _check_spec_recursively( self, spec: Any, path: str = "" ) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] if not isinstance(spec, dict): errors.append( @@ -292,7 +292,7 @@ def deconstruct(self) -> DeconstructResult: kwargs["blank"] = False return name, path, args, kwargs - def formfield(self, *args: Any, **kwargs: Any) -> FormField | None: + def formfield(self, *args: Any, **kwargs: Any) -> forms.Field | None: """ Disabled in forms - there is no sensible way of editing this """ diff --git a/src/django_mysql/models/fields/lists.py b/src/django_mysql/models/fields/lists.py index 5d25f01f..52732f8e 100644 --- a/src/django_mysql/models/fields/lists.py +++ b/src/django_mysql/models/fields/lists.py @@ -2,9 +2,9 @@ from typing import Any from typing import Callable -from typing import Iterable from typing import cast +from django import forms from django.core import checks from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import CharField @@ -14,7 +14,7 @@ from django.db.models import Model from django.db.models import TextField from django.db.models.expressions import BaseExpression -from django.forms import Field as FormField +from django.db.models.sql.compiler import SQLCompiler from django.utils.translation import gettext_lazy as _ from django_mysql.forms import SimpleListField @@ -74,7 +74,7 @@ def check(self, **kwargs: Any) -> list[checks.CheckMessage]: return errors @property - def description(self) -> Any: + def description(self) -> str: return _("List of %(base_description)s") % { "base_description": self.base_field.description } @@ -154,7 +154,7 @@ def value_to_string(self, obj: Any) -> str: vals = self.value_from_object(obj) return self.get_prep_value(vals) - def formfield(self, **kwargs: Any) -> FormField: + def formfield(self, **kwargs: Any) -> forms.Field | None: defaults = { "form_class": SimpleListField, "base_field": self.base_field.formfield(), @@ -163,8 +163,10 @@ def formfield(self, **kwargs: Any) -> FormField: defaults.update(kwargs) return super().formfield(**defaults) - def contribute_to_class(self, cls: type[Model], name: str, **kwargs: Any) -> None: - super().contribute_to_class(cls, name, **kwargs) + def contribute_to_class( + self, cls: type[Model], name: str, private_only: bool = False + ) -> None: + super().contribute_to_class(cls, name, private_only=private_only) self.base_field.model = cls @@ -228,11 +230,11 @@ def __init__(self, index: int, *args: Any, **kwargs: Any) -> None: self.index = index def as_sql( - self, qn: Callable[[str], str], connection: BaseDatabaseWrapper - ) -> tuple[str, Iterable[Any]]: - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) - params = tuple(lhs_params) + tuple(rhs_params) + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper + ) -> tuple[str, list[str | int]]: + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) + params = list(lhs_params) + list(rhs_params) # Put rhs on the left since that's the order FIND_IN_SET uses return f"(FIND_IN_SET({rhs}, {lhs}) = {self.index})", params diff --git a/src/django_mysql/models/fields/sets.py b/src/django_mysql/models/fields/sets.py index b7c0d6d4..24a29c8d 100644 --- a/src/django_mysql/models/fields/sets.py +++ b/src/django_mysql/models/fields/sets.py @@ -3,6 +3,7 @@ from typing import Any from typing import cast +from django import forms from django.core import checks from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import CharField @@ -11,7 +12,6 @@ from django.db.models import Model from django.db.models import TextField from django.db.models.expressions import BaseExpression -from django.forms import Field as FormField from django.utils.translation import gettext_lazy as _ from django_mysql.forms import SimpleSetField @@ -136,7 +136,7 @@ def value_to_string(self, obj: Any) -> str: vals = self.value_from_object(obj) return self.get_prep_value(vals) - def formfield(self, **kwargs: Any) -> FormField: + def formfield(self, **kwargs: Any) -> forms.Field | None: defaults = { "form_class": SimpleSetField, "base_field": self.base_field.formfield(), @@ -145,8 +145,10 @@ def formfield(self, **kwargs: Any) -> FormField: defaults.update(kwargs) return super().formfield(**defaults) - def contribute_to_class(self, cls: type[Model], name: str, **kwargs: Any) -> None: - super().contribute_to_class(cls, name, **kwargs) + def contribute_to_class( + self, cls: type[Model], name: str, private_only: bool = False + ) -> None: + super().contribute_to_class(cls, name, private_only=private_only) self.base_field.model = cls @@ -167,6 +169,7 @@ def check(self, **kwargs: Any) -> list[checks.CheckMessage]: and isinstance(self.base_field, CharField) and self.size ): + assert self.base_field.max_length is not None max_size = ( # The chars used (self.size * (self.base_field.max_length)) diff --git a/src/django_mysql/models/functions.py b/src/django_mysql/models/functions.py index 62623f32..fa252190 100644 --- a/src/django_mysql/models/functions.py +++ b/src/django_mysql/models/functions.py @@ -169,7 +169,8 @@ def get(cls, using: str = DEFAULT_DB_ALIAS) -> int: # database connections in Django, and the reason was not clear with connections[using].cursor() as cursor: cursor.execute("SELECT LAST_INSERT_ID()") - return cursor.fetchone()[0] + id_: int = cursor.fetchone()[0] + return id_ # JSON Functions @@ -182,7 +183,7 @@ def __init__( self, expression: ExpressionArgument, *paths: ExpressionArgument, - output_field: type[DjangoField] | None = None, + output_field: DjangoField | None = None, ) -> None: exprs = [expression] for path in paths: @@ -258,7 +259,11 @@ def as_sql( if connection.vendor != "mysql": # pragma: no cover raise AssertionError("JSONValue only supports MySQL/MariaDB") json_string = json.dumps(self._data, allow_nan=False) - if connection.vendor == "mysql" and connection.mysql_is_mariadb: + if ( + connection.vendor == "mysql" + # type narrowed by vendor check + and connection.mysql_is_mariadb # type: ignore [attr-defined] + ): # MariaDB doesn't support explicit cast to JSON. return "JSON_EXTRACT(%s, '$')", (json_string,) else: @@ -270,7 +275,7 @@ def __init__( self, expression: ExpressionArgument, data: dict[ - str, + ExpressionArgument, ( ExpressionArgument | None @@ -288,12 +293,12 @@ def __init__( exprs = [expression] for path, value in data.items(): - if not hasattr(path, "resolve_expression"): + if not isinstance(path, Expression): path = Value(path) exprs.append(path) - if not hasattr(value, "resolve_expression"): + if not isinstance(value, Expression): value = JSONValue(value) exprs.append(value) @@ -373,7 +378,11 @@ class AsType(Func): function = "" template = "%(expressions)s AS %(data_type)s" - def __init__(self, expression: ExpressionArgument, data_type: str) -> None: + def __init__( + self, + expression: Expression | str | float | int | dt.date | dt.time | dt.datetime, + data_type: str, + ) -> None: from django_mysql.models.fields.dynamic import KeyTransform if not hasattr(expression, "resolve_expression"): @@ -392,7 +401,7 @@ def __init__( self, expression: ExpressionArgument, to_add: dict[ - str, + ExpressionArgument, ExpressionArgument | float | int | dt.date | dt.time | dt.datetime, ], ) -> None: @@ -400,12 +409,12 @@ def __init__( expressions = [expression] for name, value in to_add.items(): - if not hasattr(name, "resolve_expression"): + if not isinstance(name, Expression): name = Value(name) if isinstance(value, dict): raise ValueError("ColumnAdd with nested values is not supported") - if not hasattr(value, "resolve_expression"): + if not isinstance(value, Expression): value = Value(value) expressions.extend((name, value)) diff --git a/src/django_mysql/models/lookups.py b/src/django_mysql/models/lookups.py index 5b67b842..8eb922af 100644 --- a/src/django_mysql/models/lookups.py +++ b/src/django_mysql/models/lookups.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any -from typing import Callable from typing import Iterable from django.db.backends.base.base import BaseDatabaseWrapper @@ -24,11 +23,11 @@ class SoundsLike(Lookup): def as_sql( self, - qn: Callable[[str], str], + compiler: SQLCompiler, connection: BaseDatabaseWrapper, ) -> tuple[str, Iterable[Any]]: - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(lhs_params) + tuple(rhs_params) return f"{lhs} SOUNDS LIKE {rhs}", params @@ -66,10 +65,10 @@ def get_prep_lookup(self) -> Any: return super().get_prep_lookup() def as_sql( - self, qn: Callable[[str], str], connection: BaseDatabaseWrapper + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper ) -> tuple[str, Iterable[Any]]: - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) # Put rhs (and params) on the left since that's the order FIND_IN_SET uses params = tuple(rhs_params) + tuple(lhs_params) return f"FIND_IN_SET({rhs}, {lhs})", params @@ -86,9 +85,9 @@ class DynColHasKey(Lookup): lookup_name = "has_key" def as_sql( - self, qn: Callable[[str], str], connection: BaseDatabaseWrapper + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper ) -> tuple[str, Iterable[Any]]: - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(lhs_params) + tuple(rhs_params) return f"COLUMN_EXISTS({lhs}, {rhs})", params diff --git a/src/django_mysql/models/query.py b/src/django_mysql/models/query.py index 61a16837..bcb67854 100644 --- a/src/django_mysql/models/query.py +++ b/src/django_mysql/models/query.py @@ -698,7 +698,7 @@ def approx_count(queryset: models.QuerySet) -> int: ) # N.B. when we support more complex QuerySets they should be estimated # with 'EXPLAIN SELECT' - approx_count = cursor.fetchone()[0] + approx_count: int = cursor.fetchone()[0] return approx_count diff --git a/src/django_mysql/operations.py b/src/django_mysql/operations.py index ea1ff5b4..c5003bd5 100644 --- a/src/django_mysql/operations.py +++ b/src/django_mysql/operations.py @@ -2,7 +2,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.migrations.operations.base import Operation -from django.db.migrations.state import ModelState +from django.db.migrations.state import ProjectState from django.utils.functional import cached_property @@ -15,15 +15,15 @@ def __init__(self, name: str, soname: str) -> None: self.name = name self.soname = soname - def state_forwards(self, app_label: str, state: ModelState) -> None: + def state_forwards(self, app_label: str, state: ProjectState) -> None: pass # pragma: no cover def database_forwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: if not self.plugin_installed(schema_editor): schema_editor.execute( @@ -34,8 +34,8 @@ def database_backwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: if self.plugin_installed(schema_editor): schema_editor.execute("UNINSTALL PLUGIN %s" % self.name) @@ -48,7 +48,7 @@ def plugin_installed(self, schema_editor: BaseDatabaseSchemaEditor) -> bool: WHERE PLUGIN_NAME LIKE %s""", (self.name,), ) - count = cursor.fetchone()[0] + count: int = cursor.fetchone()[0] return count > 0 def describe(self) -> str: @@ -63,15 +63,15 @@ class InstallSOName(Operation): def __init__(self, soname: str) -> None: self.soname = soname - def state_forwards(self, app_label: str, state: ModelState) -> None: + def state_forwards(self, app_label: str, state: ProjectState) -> None: pass # pragma: no cover def database_forwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: schema_editor.execute("INSTALL SONAME %s", (self.soname,)) @@ -79,8 +79,8 @@ def database_backwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: schema_editor.execute("UNINSTALL SONAME %s", (self.soname,)) @@ -96,19 +96,17 @@ def __init__( self.engine = to_engine self.from_engine = from_engine - @property - def reversible(self) -> bool: - return self.from_engine is not None + self.reversible = self.from_engine is not None - def state_forwards(self, app_label: str, state: ModelState) -> None: + def state_forwards(self, app_label: str, state: ProjectState) -> None: pass def database_forwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_state: ModelState, - to_state: ModelState, + from_state: ProjectState, + to_state: ProjectState, ) -> None: self._change_engine(app_label, schema_editor, to_state, engine=self.engine) @@ -116,8 +114,8 @@ def database_backwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_state: ModelState, - to_state: ModelState, + from_state: ProjectState, + to_state: ProjectState, ) -> None: if self.from_engine is None: raise NotImplementedError("You cannot reverse this operation") @@ -128,7 +126,7 @@ def _change_engine( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - to_state: ModelState, + to_state: ProjectState, engine: str, ) -> None: new_model = to_state.apps.get_model(app_label, self.name) diff --git a/src/django_mysql/status.py b/src/django_mysql/status.py index fea900ae..9fd76db2 100644 --- a/src/django_mysql/status.py +++ b/src/django_mysql/status.py @@ -6,7 +6,6 @@ from django.db import connections from django.db.backends.utils import CursorWrapper from django.db.utils import DEFAULT_DB_ALIAS -from django.utils.functional import SimpleLazyObject from django_mysql.exceptions import TimeoutError @@ -16,6 +15,7 @@ class BaseStatus: Base class for the status classes """ + __slots__ = ("db",) query = "" def __init__(self, using: str | None = None) -> None: @@ -57,7 +57,7 @@ def get_many(self, names: Iterable[str]) -> dict[str, int | float | bool | str]: ] ) - cursor.execute(query, names) + cursor.execute(query, tuple(names)) return {name: self._cast(value) for name, value in cursor.fetchall()} @@ -127,5 +127,5 @@ class SessionStatus(BaseStatus): query = "SHOW SESSION STATUS" -global_status = SimpleLazyObject(GlobalStatus) -session_status = SimpleLazyObject(SessionStatus) +global_status = GlobalStatus() +session_status = SessionStatus() diff --git a/src/django_mysql/utils.py b/src/django_mysql/utils.py index 6a98540f..5d5de374 100644 --- a/src/django_mysql/utils.py +++ b/src/django_mysql/utils.py @@ -132,7 +132,9 @@ def collapse_spaces(string: str) -> str: return " ".join(filter(None, bits)) -def index_name(model: Model, *field_names: str, using: str = DEFAULT_DB_ALIAS) -> str: +def index_name( + model: type[Model], *field_names: str, using: str = DEFAULT_DB_ALIAS +) -> str: """ Returns the name of the index existing on field_names, or raises KeyError if no such index exists. @@ -162,7 +164,7 @@ def index_name(model: Model, *field_names: str, using: str = DEFAULT_DB_ALIAS) - ), (model._meta.db_table,) + column_names, ) - indexes = defaultdict(list) + indexes: defaultdict[str, list[str]] = defaultdict(list) for index_name, _, column_name in cursor.fetchall(): indexes[index_name].append(column_name) @@ -177,7 +179,7 @@ def get_list_sql(sequence: list[str] | tuple[str, ...]) -> str: return "({})".format(",".join("%s" for x in sequence)) -def mysql_connections() -> Generator[BaseDatabaseWrapper]: +def mysql_connections() -> Generator[tuple[str, BaseDatabaseWrapper]]: conn_names = [DEFAULT_DB_ALIAS] + list(set(connections) - {DEFAULT_DB_ALIAS}) for alias in conn_names: connection = connections[alias] diff --git a/tests/testapp/test_bit1_field.py b/tests/testapp/test_bit1_field.py index 47df7a12..92f8ae95 100644 --- a/tests/testapp/test_bit1_field.py +++ b/tests/testapp/test_bit1_field.py @@ -95,6 +95,7 @@ def test_loading(self): objs = list(serializers.deserialize("json", test_data)) assert len(objs) == 1 instance = objs[0].object + assert isinstance(instance, Bit1Model) assert not instance.flag_a assert instance.flag_b @@ -180,6 +181,7 @@ def test_loading(self): objs = list(serializers.deserialize("json", test_data)) assert len(objs) == 1 instance = objs[0].object + assert isinstance(instance, NullBit1Model) assert instance.flag is None else: diff --git a/tests/testapp/test_cache.py b/tests/testapp/test_cache.py index cc19bd45..60cb33a3 100644 --- a/tests/testapp/test_cache.py +++ b/tests/testapp/test_cache.py @@ -9,6 +9,7 @@ from typing import Any import pytest +from django.core.cache import BaseCache from django.core.cache import CacheKeyWarning from django.core.cache import cache from django.core.cache import caches @@ -100,7 +101,10 @@ def reverse_custom_key_func(full_key): } -def caches_setting_for_tests(options=None, **params): +def caches_setting_for_tests( + options: dict[str, Any] | None = None, + **params: Any, +) -> dict[str, Any]: # `params` are test specific overrides and `_caches_settings_base` is the # base config for the tests. # This results in the following search order: @@ -117,8 +121,10 @@ def caches_setting_for_tests(options=None, **params): # Spaces are used in the table name to ensure quoting/escaping is working def override_cache_settings( - BACKEND="django_mysql.cache.MySQLCache", LOCATION="test cache table", **kwargs -): + BACKEND: str = "django_mysql.cache.MySQLCache", + LOCATION: str = "test cache table", + **kwargs: Any, +) -> override_settings: return override_settings( CACHES=caches_setting_for_tests(BACKEND=BACKEND, LOCATION=LOCATION, **kwargs) ) @@ -128,13 +134,13 @@ class MySQLCacheTableMixin(TransactionTestCase): table_name = "test cache table" @classmethod - def create_table(self): + def create_table(self) -> None: sql = MySQLCache.create_table_sql.format(table_name=self.table_name) with connection.cursor() as cursor: cursor.execute(sql) @classmethod - def drop_table(self): + def drop_table(self) -> None: with connection.cursor() as cursor: cursor.execute("DROP TABLE `%s`" % self.table_name) @@ -153,10 +159,11 @@ def tearDownClass(cls): super().tearDownClass() cls.drop_table() - def table_count(self): + def table_count(self) -> int: with connection.cursor() as cursor: cursor.execute("SELECT COUNT(*) FROM `%s`" % self.table_name) - return cursor.fetchone()[0] + count: int = cursor.fetchone()[0] + return count # These tests were copied from django's tests/cache/tests.py file @@ -726,7 +733,7 @@ def test_cache_write_unpicklable_object(self): fetch_middleware = FetchFromCacheMiddleware(empty_response) request = self.factory.get("/cache/test") - request._cache_update_cache = True + request._cache_update_cache = True # type: ignore [attr-defined] get_cache_data = FetchFromCacheMiddleware(empty_response).process_request( request ) @@ -779,10 +786,10 @@ def test_get_or_set_version(self): cache.get_or_set("brian", 1979, version=2) with pytest.raises(TypeError, match=msg_re): - cache.get_or_set("brian") + cache.get_or_set("brian") # type: ignore [call-arg] with pytest.raises(TypeError, match=msg_re): - cache.get_or_set("brian", version=1) + cache.get_or_set("brian", version=1) # type: ignore [call-arg] assert cache.get("brian", version=1) is None assert cache.get_or_set("brian", 42, version=1) == 42 @@ -915,6 +922,7 @@ def func(key, *args): # Original tests def test_base_set_bad_value(self): + assert isinstance(cache, MySQLCache) with pytest.raises(ValueError) as excinfo: cache._base_set("foo", "key", "value") assert "'mode' should be" in str(excinfo.value) @@ -997,7 +1005,9 @@ def test_cull_deletes_expired_first(self): self._perform_cull_test(cull_cache, 30, 30) assert cull_cache.get("key") is None - def _perform_cull_test(self, cull_cache, initial_count, final_count): + def _perform_cull_test( + self, cull_cache: BaseCache, initial_count: int, final_count: int + ) -> None: # Create initial cache key entries. This will overflow the cache, # causing a cull. for i in range(1, initial_count + 1): @@ -1137,6 +1147,7 @@ def test_keys_with_prefix_version(self, cache_name): @override_cache_settings(KEY_FUNCTION=custom_key_func) def test_keys_with_prefix_with_bad_cache(self): + assert isinstance(cache, MySQLCache) with pytest.raises(ValueError) as excinfo: cache.keys_with_prefix("") assert str(excinfo.value).startswith("To use the _with_prefix commands") @@ -1176,6 +1187,7 @@ def test_get_with_prefix_version(self, cache_name): @override_cache_settings(KEY_FUNCTION=custom_key_func) def test_get_with_prefix_with_bad_cache(self): + assert isinstance(cache, MySQLCache) with pytest.raises(ValueError) as excinfo: cache.get_with_prefix("") assert str(excinfo.value).startswith("To use the _with_prefix commands") @@ -1233,6 +1245,7 @@ def test_delete_with_prefix_version(self, cache_name): @override_cache_settings(KEY_FUNCTION=custom_key_func) def test_delete_with_prefix_with_no_reverse_works(self): + assert isinstance(cache, MySQLCache) cache.set_many({"K1": "value", "K2": "value2", "B2": "Anothervalue"}) assert cache.delete_with_prefix("K") == 2 assert cache.get_many(["K1", "K2", "B2"]) == {"B2": "Anothervalue"} @@ -1262,6 +1275,7 @@ def test_mysql_cache_migration_no_mysql_caches(self): def test_cull_max_entries_minus_one(self): # cull with MAX_ENTRIES = -1 should never clear anything that is not # expired + assert isinstance(cache, MySQLCache) # one expired key cache.set("key", "value", 0.1) @@ -1301,7 +1315,7 @@ def test_cull_mysql_caches_bad_cache_name(self): @override_cache_settings() class MySQLCacheMigrationTests(MySQLCacheTableMixin, TransactionTestCase): @pytest.fixture(autouse=True) - def flake8_path(self, flake8_path): + def set_flake8_path(self, flake8_path): self.flake8_path = flake8_path def test_mysql_cache_migration(self): @@ -1341,7 +1355,7 @@ def test_mysql_cache_migration(self): operation.database_backwards("testapp", editor, new_state, state) assert not self.table_exists(self.table_name) - def table_exists(self, table_name): + def table_exists(self, table_name: str) -> bool: with connection.cursor() as cursor: cursor.execute( """SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES diff --git a/tests/testapp/test_dynamicfield.py b/tests/testapp/test_dynamicfield.py index 7beaecf2..c07cab4a 100644 --- a/tests/testapp/test_dynamicfield.py +++ b/tests/testapp/test_dynamicfield.py @@ -2,6 +2,7 @@ import datetime as dt import json +from typing import cast from unittest import SkipTest from unittest import mock @@ -12,6 +13,7 @@ from django.db import connection from django.db import connections from django.db import models +from django.db.backends.mysql.base import DatabaseWrapper from django.db.migrations.writer import MigrationWriter from django.db.models import CharField from django.db.models import Transform @@ -27,7 +29,7 @@ class DynColTestCase(TestCase): @classmethod def setUpClass(cls): - if not connection.mysql_is_mariadb: + if not cast(DatabaseWrapper, connection).mysql_is_mariadb: raise SkipTest("Dynamic Columns require MariaDB") super().setUpClass() @@ -493,6 +495,7 @@ def test_dumping(self): def test_loading(self): deserialized = list(serializers.deserialize("json", self.test_data)) instance = deserialized[0].object + assert isinstance(instance, DynamicModel) assert instance.attrs == {"a": "b"} diff --git a/tests/testapp/test_listcharfield.py b/tests/testapp/test_listcharfield.py index 28f8a850..8e1afa98 100644 --- a/tests/testapp/test_listcharfield.py +++ b/tests/testapp/test_listcharfield.py @@ -567,6 +567,7 @@ def test_loading(self): """ objs = list(serializers.deserialize("json", test_data)) instance = objs[0].object + assert isinstance(instance, CharListModel) assert instance.field == ["big", "leather", "comfy"] diff --git a/tests/testapp/test_listtextfield.py b/tests/testapp/test_listtextfield.py index eb0f8523..6476f291 100644 --- a/tests/testapp/test_listtextfield.py +++ b/tests/testapp/test_listtextfield.py @@ -358,6 +358,7 @@ def test_loading(self): """ objs = list(serializers.deserialize("json", test_data)) instance = objs[0].object + assert isinstance(instance, BigCharListModel) assert instance.field == ["big", "leather", "comfy"] def test_dumping_loading_empty(self): diff --git a/tests/testapp/test_locks.py b/tests/testapp/test_locks.py index 2269675d..c3ca46e6 100644 --- a/tests/testapp/test_locks.py +++ b/tests/testapp/test_locks.py @@ -3,11 +3,13 @@ import queue from threading import Thread from typing import TYPE_CHECKING +from typing import cast import pytest from django.db import OperationalError from django.db import connection from django.db import connections +from django.db.backends.mysql.base import DatabaseWrapper from django.db.transaction import TransactionManagementError from django.db.transaction import atomic from django.test import TestCase @@ -34,7 +36,7 @@ class LockTests(TestCase): def setUpClass(cls): super().setUpClass() - cls.supports_lock_info = connection.mysql_is_mariadb + cls.supports_lock_info = cast(DatabaseWrapper, connection).mysql_is_mariadb if cls.supports_lock_info: with connection.cursor() as cursor: cursor.execute( @@ -229,7 +231,7 @@ def tearDown(self): Customer.objects.using("other").all().delete() super().tearDown() - def is_locked(self, connection_name, table_name): + def is_locked(self, connection_name: str, table_name: str) -> bool: conn = connections[connection_name] with conn.cursor() as cursor: cursor.execute( @@ -239,7 +241,9 @@ def is_locked(self, connection_name, table_name): rows = cursor.fetchall() if rows: assert len(rows) == 1 - return rows[0][2] > 0 + value = rows[0][2] + assert isinstance(value, int) + return value > 0 else: # pragma: no cover # MySQL 8+ closes the table really quickly. If it's closed, # it's not locked. diff --git a/tests/testapp/test_models.py b/tests/testapp/test_models.py index f9251d27..ac8eb161 100644 --- a/tests/testapp/test_models.py +++ b/tests/testapp/test_models.py @@ -330,7 +330,7 @@ def test_force_index_at_least_one(self): def test_force_index_invalid_for(self): with pytest.raises(ValueError) as excinfo: - Author.objects.force_index("a", for_="INVALID") + Author.objects.force_index("a", for_="INVALID") # type: ignore [arg-type] assert "for_ must be one of" in str(excinfo.value) def test_index_hint_force_order_by(self): @@ -527,7 +527,7 @@ def test_objects_pk_range_reversed(self): def test_objects_pk_range_bad(self): with pytest.raises(ValueError) as excinfo: - list(Author.objects.iter_smart(pk_range="My Bad Value")) + list(Author.objects.iter_smart(pk_range="My Bad Value")) # type: ignore [arg-type] assert "Unrecognized value for pk_range" in str(excinfo.value) def test_pk_range_race_condition(self): diff --git a/tests/testapp/test_setcharfield.py b/tests/testapp/test_setcharfield.py index 004a41d6..25e9066e 100644 --- a/tests/testapp/test_setcharfield.py +++ b/tests/testapp/test_setcharfield.py @@ -543,6 +543,7 @@ def test_loading(self): """ objs = list(serializers.deserialize("json", test_data)) instance = objs[0].object + assert isinstance(instance, CharSetModel) assert instance.field == {"big", "leather", "comfy"} diff --git a/tests/testapp/test_settextfield.py b/tests/testapp/test_settextfield.py index fe19f16c..ed8501ab 100644 --- a/tests/testapp/test_settextfield.py +++ b/tests/testapp/test_settextfield.py @@ -336,6 +336,7 @@ def test_loading(self): """ objs = list(serializers.deserialize("json", test_data)) instance = objs[0].object + assert isinstance(instance, BigCharSetModel) assert instance.field == {"big", "leather", "comfy"} def test_empty(self): diff --git a/tests/testapp/test_size_fields.py b/tests/testapp/test_size_fields.py index b480ab1a..399597f4 100644 --- a/tests/testapp/test_size_fields.py +++ b/tests/testapp/test_size_fields.py @@ -22,7 +22,7 @@ forceDataError = override_mysql_variables(SQL_MODE="STRICT_TRANS_TABLES") -def migrate(name): +def migrate(name: str) -> None: call_command( "migrate", "testapp", name, verbosity=0, skip_checks=True, interactive=False ) @@ -50,7 +50,7 @@ def test_binaryfield_default_length(self): # By default, SizedBinaryField should act like BinaryField field = SizedBinaryField() assert field.size_class == 4 - assert field.db_type(None) == "longblob" + assert field.db_type(connection) == "longblob" @atomic def test_binary_1_max_length(self): @@ -153,7 +153,7 @@ def test_textfield_default_length(self): # By default, SizedTextField should act like TextField field = SizedTextField() assert field.size_class == 4 - assert field.db_type(None) == "longtext" + assert field.db_type(connection) == "longtext" def test_tinytext_max_length(self): # Okay