Skip to content

Commit

Permalink
Fixers for INSTALLED_APPS & MIDDLEWARE
Browse files Browse the repository at this point in the history
  • Loading branch information
ipmb committed Oct 20, 2023
1 parent 2476001 commit a00122a
Show file tree
Hide file tree
Showing 14 changed files with 373 additions and 36 deletions.
9 changes: 5 additions & 4 deletions django_production/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import django

from django_production.modifiers import add_imports
from django_production.main import apply_fixers

START_MARKER = "\n# BEGIN: added by django-production"
END_MARKER = "# END: added by django-production\n"
Expand Down Expand Up @@ -55,8 +55,9 @@ def do_patch():
sys.path.insert(0, os.getcwd())
django.setup()
settings = import_module(os.environ["DJANGO_SETTINGS_MODULE"])
patch_settings(settings)
patch_urlconf(settings)
apply_fixers(
contents_text=Path(settings.__file__).read_text(), filename=settings.__file__
)


def fix_file(
Expand Down Expand Up @@ -86,4 +87,4 @@ def fix_file(

if exit_zero_even_if_changed:
return 0
return contents_text != contents_text_orig
return contents_text != contents_text_orig
69 changes: 69 additions & 0 deletions django_production/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import ast
import pkgutil
from collections import defaultdict

from django_upgrade.data import FIXERS, ASTCallbackMapping, Settings, State, TokenFunc
from tokenize_rt import Offset

from django_production import fixers


def visit(
tree: ast.Module,
settings: Settings,
filename: str,
) -> dict[Offset, list[TokenFunc]]:
ast_funcs = get_ast_funcs()
initial_state = State(
settings=settings,
filename=filename,
from_imports=defaultdict(set),
)

nodes: list[tuple[State, ast.AST, ast.AST]] = [(initial_state, tree, tree)]
parents: list[ast.AST] = [tree]
ret = defaultdict(list)
while nodes:
state, node, parent = nodes.pop()
if len(parents) > 1 and parent == parents[-2]:
parents.pop()
elif parent != parents[-1]:
parents.append(parent)

for ast_func in ast_funcs.get(type(node), [None]):
if ast_func is None:
continue
for offset, token_func in ast_func(state, node, parents):
ret[offset].append(token_func)

for name in reversed(node._fields):
value = getattr(node, name)
next_state = state

if isinstance(value, ast.AST):
nodes.append((next_state, value, node))
elif isinstance(value, list):
for subvalue in reversed(value):
if isinstance(subvalue, ast.AST):
nodes.append((next_state, subvalue, node))
return ret


def _import_fixers() -> None:
# https://github.com/python/mypy/issues/1422
fixers_path: str = fixers.__path__ # type: ignore
mod_infos = pkgutil.walk_packages(fixers_path, f"{fixers.__name__}.")
for _, name, _ in mod_infos:
__import__(name, fromlist=["_trash"])


_import_fixers()


def get_ast_funcs() -> ASTCallbackMapping:
ast_funcs: ASTCallbackMapping = defaultdict(list)
for fixer in FIXERS:
if fixer.name.startswith("django_production."):
for type_, type_funcs in fixer.ast_funcs.items():
ast_funcs[type_].extend(type_funcs)
return ast_funcs
Empty file.
52 changes: 52 additions & 0 deletions django_production/fixers/installed_apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import ast
from functools import partial
from typing import Iterable

from django_upgrade.ast import ast_start_offset
from django_upgrade.data import Fixer, State, TokenFunc
from django_upgrade.tokens import CODE, find_last_token
from tokenize_rt import Offset, Token

fixer = Fixer(
__name__,
min_version=(0, 0),
)


@fixer.register(ast.Assign)
def visit_Assign(
state: State,
node: ast.Assign,
parents: list[ast.AST],
) -> Iterable[tuple[Offset, TokenFunc]]:
"""
Ensure these are in INSTALLED_APPS
[
"django_webserver", # Allow running webserver from manage.py
"whitenoise.runserver_nostatic", # Use whitenoise with runserver
]
"""
if node.targets[0].id == "INSTALLED_APPS" and isinstance(node.value, ast.List):
yield ast_start_offset(node), partial(add_apps, node=node)
return []


def add_apps(
tokens: list[Token],
i: int,
*,
node: ast.Assign,
) -> None:
j = find_last_token(tokens, i, node=node)
needs_app = []
current_apps = [v.s for v in node.value.elts]
for wants_app in ["django_webserver", "whitenoise.runserver_nostatic"]:
if wants_app not in current_apps:
needs_app.append(wants_app)
if len(needs_app) == 0:
return
code = ["INSTALLED_APPS = ["]
code.extend([f' "{a}",' for a in current_apps])
code.extend([f' "{a}",' for a in needs_app])
code.append("]")
tokens[i : j + 1] = [Token(name=CODE, src="\n".join(code))]
66 changes: 66 additions & 0 deletions django_production/fixers/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import ast
from functools import partial
from typing import Iterable

from django_upgrade.ast import ast_start_offset
from django_upgrade.data import Fixer, State, TokenFunc
from django_upgrade.tokens import CODE, find_last_token
from tokenize_rt import Offset, Token

fixer = Fixer(
__name__,
min_version=(0, 0),
)


@fixer.register(ast.Assign)
def visit_Assign(
state: State,
node: ast.Assign,
parents: list[ast.AST],
) -> Iterable[tuple[Offset, TokenFunc]]:
"""
try:
MIDDLEWARE.insert(
MIDDLEWARE.index("django.middleware.security.SecurityMiddleware") + 1,
"whitenoise.middleware.WhiteNoiseMiddleware",
)
except ValueError:
MIDDLEWARE.insert(0, "whitenoise.middleware.WhiteNoiseMiddleware")
# skip host checking for healthcheck URLs
MIDDLEWARE.insert(0, "django_alive.middleware.healthcheck_bypass_host_check")
"""
if node.targets[0].id == "MIDDLEWARE" and isinstance(node.value, ast.List):
yield ast_start_offset(node), partial(add_middleware, node=node)
return []


def add_middleware(
tokens: list[Token],
i: int,
*,
node: ast.Assign,
) -> None:
j = find_last_token(tokens, i, node=node)
middleware = [v.s for v in node.value.elts]
original_middleware = middleware.copy()
whitenoise_middleware_path = "whitenoise.middleware.WhiteNoiseMiddleware"
if whitenoise_middleware_path not in middleware:
try:
security_middleware_index = middleware.index(
"django.middleware.security.SecurityMiddleware"
)
except ValueError:
security_middleware_index = -1
middleware.insert(security_middleware_index + 1, whitenoise_middleware_path)

alive_middleware_path = "django_alive.middleware.healthcheck_bypass_host_check"
if alive_middleware_path not in middleware:
middleware.insert(0, alive_middleware_path)
if middleware == original_middleware:
return
code = ["MIDDLEWARE = ["]
code.extend([f' "{m}",' for m in middleware])
code.append("]")
tokens[i : j + 1] = [Token(name=CODE, src="\n".join(code))]
38 changes: 38 additions & 0 deletions django_production/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import tokenize

from django_upgrade.ast import ast_parse
from django_upgrade.data import Settings
from django_upgrade.main import fixup_dedent_tokens
from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src

from django_production.data import visit


def apply_fixers(contents_text: str, filename: str) -> str:
try:
ast_obj = ast_parse(contents_text)
except SyntaxError:
return contents_text

callbacks = visit(ast_obj, Settings(target_version=(999, 0)), filename)

if not callbacks:
return contents_text

try:
tokens = src_to_tokens(contents_text)
except tokenize.TokenError: # pragma: no cover (bpo-2180)
return contents_text

fixup_dedent_tokens(tokens)

for i, token in reversed_enumerate(tokens):
if not token.src:
continue
# though this is a defaultdict, by using `.get()` this function's
# self time is almost 50% faster
for callback in callbacks.get(token.offset, ()):
callback(tokens, i)

# no types for tokenize-rt
return tokens_to_src(tokens) # type: ignore [no-any-return]
17 changes: 0 additions & 17 deletions django_production/modifiers.py

This file was deleted.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ classifiers = [
]
dynamic = ["version", "description"]
dependencies = [
"django-upgrade",
"django-environ",
"whitenoise",
"django-webserver[gunicorn]",
Expand Down
5 changes: 5 additions & 0 deletions tests/fixers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

import pytest

pytest.register_assert_rewrite(f"{__name__}.tools")
35 changes: 35 additions & 0 deletions tests/fixers/test_installed_apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from django_upgrade.data import Settings

from tests.fixers.tools import check_noop, check_transformed

settings = Settings(target_version=(999, 0))


def test_noop():
check_noop(
"""\
INSTALLED_APPS = ["django_webserver", "whitenoise.runserver_nostatic"]
""",
settings,
filename="settings.py",
)


def test_updated():
check_transformed(
"""\
INSTALLED_APPS = ["appone", "apptwo"]
""",
"""\
INSTALLED_APPS = [
"appone",
"apptwo",
"django_webserver",
"whitenoise.runserver_nostatic",
]
""",
settings,
filename="settings.py",
)
59 changes: 59 additions & 0 deletions tests/fixers/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from django_upgrade.data import Settings

from tests.fixers.tools import check_noop, check_transformed

settings = Settings(target_version=(999, 0))


def test_noop():
check_noop(
"""\
MIDDLEWARE = [
# not modified
"django_alive.middleware.healthcheck_bypass_host_check",
"django.middleware.security.SecurityMiddleware",
"whitenoise.middleware.WhiteNoiseMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"django.middleware.common.CommonMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
]
""",
settings,
filename="settings.py",
)


def test_updated():
check_transformed(
"""\
MIDDLEWARE = [
"django.middleware.security.SecurityMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"django.middleware.common.CommonMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
]
""",
"""\
MIDDLEWARE = [
"django_alive.middleware.healthcheck_bypass_host_check",
"django.middleware.security.SecurityMiddleware",
"whitenoise.middleware.WhiteNoiseMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"django.middleware.common.CommonMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
]
""",
settings,
filename="settings.py",
)
Loading

0 comments on commit a00122a

Please sign in to comment.