-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixers for INSTALLED_APPS & MIDDLEWARE
- Loading branch information
Showing
14 changed files
with
373 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import ast | ||
import pkgutil | ||
from collections import defaultdict | ||
|
||
from django_upgrade.data import FIXERS, ASTCallbackMapping, Settings, State, TokenFunc | ||
from tokenize_rt import Offset | ||
|
||
from django_production import fixers | ||
|
||
|
||
def visit( | ||
tree: ast.Module, | ||
settings: Settings, | ||
filename: str, | ||
) -> dict[Offset, list[TokenFunc]]: | ||
ast_funcs = get_ast_funcs() | ||
initial_state = State( | ||
settings=settings, | ||
filename=filename, | ||
from_imports=defaultdict(set), | ||
) | ||
|
||
nodes: list[tuple[State, ast.AST, ast.AST]] = [(initial_state, tree, tree)] | ||
parents: list[ast.AST] = [tree] | ||
ret = defaultdict(list) | ||
while nodes: | ||
state, node, parent = nodes.pop() | ||
if len(parents) > 1 and parent == parents[-2]: | ||
parents.pop() | ||
elif parent != parents[-1]: | ||
parents.append(parent) | ||
|
||
for ast_func in ast_funcs.get(type(node), [None]): | ||
if ast_func is None: | ||
continue | ||
for offset, token_func in ast_func(state, node, parents): | ||
ret[offset].append(token_func) | ||
|
||
for name in reversed(node._fields): | ||
value = getattr(node, name) | ||
next_state = state | ||
|
||
if isinstance(value, ast.AST): | ||
nodes.append((next_state, value, node)) | ||
elif isinstance(value, list): | ||
for subvalue in reversed(value): | ||
if isinstance(subvalue, ast.AST): | ||
nodes.append((next_state, subvalue, node)) | ||
return ret | ||
|
||
|
||
def _import_fixers() -> None: | ||
# https://github.com/python/mypy/issues/1422 | ||
fixers_path: str = fixers.__path__ # type: ignore | ||
mod_infos = pkgutil.walk_packages(fixers_path, f"{fixers.__name__}.") | ||
for _, name, _ in mod_infos: | ||
__import__(name, fromlist=["_trash"]) | ||
|
||
|
||
_import_fixers() | ||
|
||
|
||
def get_ast_funcs() -> ASTCallbackMapping: | ||
ast_funcs: ASTCallbackMapping = defaultdict(list) | ||
for fixer in FIXERS: | ||
if fixer.name.startswith("django_production."): | ||
for type_, type_funcs in fixer.ast_funcs.items(): | ||
ast_funcs[type_].extend(type_funcs) | ||
return ast_funcs |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import ast | ||
from functools import partial | ||
from typing import Iterable | ||
|
||
from django_upgrade.ast import ast_start_offset | ||
from django_upgrade.data import Fixer, State, TokenFunc | ||
from django_upgrade.tokens import CODE, find_last_token | ||
from tokenize_rt import Offset, Token | ||
|
||
fixer = Fixer( | ||
__name__, | ||
min_version=(0, 0), | ||
) | ||
|
||
|
||
@fixer.register(ast.Assign) | ||
def visit_Assign( | ||
state: State, | ||
node: ast.Assign, | ||
parents: list[ast.AST], | ||
) -> Iterable[tuple[Offset, TokenFunc]]: | ||
""" | ||
Ensure these are in INSTALLED_APPS | ||
[ | ||
"django_webserver", # Allow running webserver from manage.py | ||
"whitenoise.runserver_nostatic", # Use whitenoise with runserver | ||
] | ||
""" | ||
if node.targets[0].id == "INSTALLED_APPS" and isinstance(node.value, ast.List): | ||
yield ast_start_offset(node), partial(add_apps, node=node) | ||
return [] | ||
|
||
|
||
def add_apps( | ||
tokens: list[Token], | ||
i: int, | ||
*, | ||
node: ast.Assign, | ||
) -> None: | ||
j = find_last_token(tokens, i, node=node) | ||
needs_app = [] | ||
current_apps = [v.s for v in node.value.elts] | ||
for wants_app in ["django_webserver", "whitenoise.runserver_nostatic"]: | ||
if wants_app not in current_apps: | ||
needs_app.append(wants_app) | ||
if len(needs_app) == 0: | ||
return | ||
code = ["INSTALLED_APPS = ["] | ||
code.extend([f' "{a}",' for a in current_apps]) | ||
code.extend([f' "{a}",' for a in needs_app]) | ||
code.append("]") | ||
tokens[i : j + 1] = [Token(name=CODE, src="\n".join(code))] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import ast | ||
from functools import partial | ||
from typing import Iterable | ||
|
||
from django_upgrade.ast import ast_start_offset | ||
from django_upgrade.data import Fixer, State, TokenFunc | ||
from django_upgrade.tokens import CODE, find_last_token | ||
from tokenize_rt import Offset, Token | ||
|
||
fixer = Fixer( | ||
__name__, | ||
min_version=(0, 0), | ||
) | ||
|
||
|
||
@fixer.register(ast.Assign) | ||
def visit_Assign( | ||
state: State, | ||
node: ast.Assign, | ||
parents: list[ast.AST], | ||
) -> Iterable[tuple[Offset, TokenFunc]]: | ||
""" | ||
try: | ||
MIDDLEWARE.insert( | ||
MIDDLEWARE.index("django.middleware.security.SecurityMiddleware") + 1, | ||
"whitenoise.middleware.WhiteNoiseMiddleware", | ||
) | ||
except ValueError: | ||
MIDDLEWARE.insert(0, "whitenoise.middleware.WhiteNoiseMiddleware") | ||
# skip host checking for healthcheck URLs | ||
MIDDLEWARE.insert(0, "django_alive.middleware.healthcheck_bypass_host_check") | ||
""" | ||
if node.targets[0].id == "MIDDLEWARE" and isinstance(node.value, ast.List): | ||
yield ast_start_offset(node), partial(add_middleware, node=node) | ||
return [] | ||
|
||
|
||
def add_middleware( | ||
tokens: list[Token], | ||
i: int, | ||
*, | ||
node: ast.Assign, | ||
) -> None: | ||
j = find_last_token(tokens, i, node=node) | ||
middleware = [v.s for v in node.value.elts] | ||
original_middleware = middleware.copy() | ||
whitenoise_middleware_path = "whitenoise.middleware.WhiteNoiseMiddleware" | ||
if whitenoise_middleware_path not in middleware: | ||
try: | ||
security_middleware_index = middleware.index( | ||
"django.middleware.security.SecurityMiddleware" | ||
) | ||
except ValueError: | ||
security_middleware_index = -1 | ||
middleware.insert(security_middleware_index + 1, whitenoise_middleware_path) | ||
|
||
alive_middleware_path = "django_alive.middleware.healthcheck_bypass_host_check" | ||
if alive_middleware_path not in middleware: | ||
middleware.insert(0, alive_middleware_path) | ||
if middleware == original_middleware: | ||
return | ||
code = ["MIDDLEWARE = ["] | ||
code.extend([f' "{m}",' for m in middleware]) | ||
code.append("]") | ||
tokens[i : j + 1] = [Token(name=CODE, src="\n".join(code))] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import tokenize | ||
|
||
from django_upgrade.ast import ast_parse | ||
from django_upgrade.data import Settings | ||
from django_upgrade.main import fixup_dedent_tokens | ||
from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src | ||
|
||
from django_production.data import visit | ||
|
||
|
||
def apply_fixers(contents_text: str, filename: str) -> str: | ||
try: | ||
ast_obj = ast_parse(contents_text) | ||
except SyntaxError: | ||
return contents_text | ||
|
||
callbacks = visit(ast_obj, Settings(target_version=(999, 0)), filename) | ||
|
||
if not callbacks: | ||
return contents_text | ||
|
||
try: | ||
tokens = src_to_tokens(contents_text) | ||
except tokenize.TokenError: # pragma: no cover (bpo-2180) | ||
return contents_text | ||
|
||
fixup_dedent_tokens(tokens) | ||
|
||
for i, token in reversed_enumerate(tokens): | ||
if not token.src: | ||
continue | ||
# though this is a defaultdict, by using `.get()` this function's | ||
# self time is almost 50% faster | ||
for callback in callbacks.get(token.offset, ()): | ||
callback(tokens, i) | ||
|
||
# no types for tokenize-rt | ||
return tokens_to_src(tokens) # type: ignore [no-any-return] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
pytest.register_assert_rewrite(f"{__name__}.tools") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from __future__ import annotations | ||
|
||
from django_upgrade.data import Settings | ||
|
||
from tests.fixers.tools import check_noop, check_transformed | ||
|
||
settings = Settings(target_version=(999, 0)) | ||
|
||
|
||
def test_noop(): | ||
check_noop( | ||
"""\ | ||
INSTALLED_APPS = ["django_webserver", "whitenoise.runserver_nostatic"] | ||
""", | ||
settings, | ||
filename="settings.py", | ||
) | ||
|
||
|
||
def test_updated(): | ||
check_transformed( | ||
"""\ | ||
INSTALLED_APPS = ["appone", "apptwo"] | ||
""", | ||
"""\ | ||
INSTALLED_APPS = [ | ||
"appone", | ||
"apptwo", | ||
"django_webserver", | ||
"whitenoise.runserver_nostatic", | ||
] | ||
""", | ||
settings, | ||
filename="settings.py", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import annotations | ||
|
||
from django_upgrade.data import Settings | ||
|
||
from tests.fixers.tools import check_noop, check_transformed | ||
|
||
settings = Settings(target_version=(999, 0)) | ||
|
||
|
||
def test_noop(): | ||
check_noop( | ||
"""\ | ||
MIDDLEWARE = [ | ||
# not modified | ||
"django_alive.middleware.healthcheck_bypass_host_check", | ||
"django.middleware.security.SecurityMiddleware", | ||
"whitenoise.middleware.WhiteNoiseMiddleware", | ||
"django.contrib.sessions.middleware.SessionMiddleware", | ||
"django.middleware.common.CommonMiddleware", | ||
"django.middleware.csrf.CsrfViewMiddleware", | ||
"django.contrib.auth.middleware.AuthenticationMiddleware", | ||
"django.contrib.messages.middleware.MessageMiddleware", | ||
"django.middleware.clickjacking.XFrameOptionsMiddleware", | ||
] | ||
""", | ||
settings, | ||
filename="settings.py", | ||
) | ||
|
||
|
||
def test_updated(): | ||
check_transformed( | ||
"""\ | ||
MIDDLEWARE = [ | ||
"django.middleware.security.SecurityMiddleware", | ||
"django.contrib.sessions.middleware.SessionMiddleware", | ||
"django.middleware.common.CommonMiddleware", | ||
"django.middleware.csrf.CsrfViewMiddleware", | ||
"django.contrib.auth.middleware.AuthenticationMiddleware", | ||
"django.contrib.messages.middleware.MessageMiddleware", | ||
"django.middleware.clickjacking.XFrameOptionsMiddleware", | ||
] | ||
""", | ||
"""\ | ||
MIDDLEWARE = [ | ||
"django_alive.middleware.healthcheck_bypass_host_check", | ||
"django.middleware.security.SecurityMiddleware", | ||
"whitenoise.middleware.WhiteNoiseMiddleware", | ||
"django.contrib.sessions.middleware.SessionMiddleware", | ||
"django.middleware.common.CommonMiddleware", | ||
"django.middleware.csrf.CsrfViewMiddleware", | ||
"django.contrib.auth.middleware.AuthenticationMiddleware", | ||
"django.contrib.messages.middleware.MessageMiddleware", | ||
"django.middleware.clickjacking.XFrameOptionsMiddleware", | ||
] | ||
""", | ||
settings, | ||
filename="settings.py", | ||
) |
Oops, something went wrong.