From 2476001c41c448e9aa45653394e6663ddca2ce25 Mon Sep 17 00:00:00 2001 From: Peter Baumgartner Date: Thu, 19 Oct 2023 11:44:40 -0400 Subject: [PATCH] WIP ast parser --- django_production/__main__.py | 32 ++++++++ django_production/modifiers.py | 17 +++++ settings.py | 0 tests/default_django_settings.py | 123 +++++++++++++++++++++++++++++++ tests/test_rewrite.py | 21 ++++++ 5 files changed, 193 insertions(+) create mode 100644 django_production/modifiers.py create mode 100644 settings.py create mode 100644 tests/default_django_settings.py create mode 100644 tests/test_rewrite.py diff --git a/django_production/__main__.py b/django_production/__main__.py index 503d781..36b7ec1 100644 --- a/django_production/__main__.py +++ b/django_production/__main__.py @@ -7,6 +7,8 @@ import django +from django_production.modifiers import add_imports + START_MARKER = "\n# BEGIN: added by django-production" END_MARKER = "# END: added by django-production\n" @@ -55,3 +57,33 @@ def do_patch(): settings = import_module(os.environ["DJANGO_SETTINGS_MODULE"]) patch_settings(settings) patch_urlconf(settings) + + +def fix_file( + filename: str, + exit_zero_even_if_changed: bool, +) -> int: + if filename == "-": + contents_bytes = sys.stdin.buffer.read() + else: + with open(filename, "rb") as fb: + contents_bytes = fb.read() + + try: + contents_text_orig = contents_text = contents_bytes.decode() + except UnicodeDecodeError: + print(f"{filename} is non-utf-8 (not supported)") + return 1 + + contents_text = add_imports(contents_text, filename) + + if filename == "-": + print(contents_text, end="") + elif contents_text != contents_text_orig: + print(f"Rewriting {filename}", file=sys.stderr) + with open(filename, "w", encoding="UTF-8", newline="") as f: + f.write(contents_text) + + if exit_zero_even_if_changed: + return 0 + return contents_text != contents_text_orig \ No newline at end of file diff --git a/django_production/modifiers.py b/django_production/modifiers.py new file mode 100644 index 0000000..e8e1618 --- /dev/null +++ b/django_production/modifiers.py @@ -0,0 +1,17 @@ +import ast + + +def add_imports(contents: str, filename: str) -> str: + tree = ast.parse(contents.encode(), filename=filename) + + # Check if the 'os' module has been imported + os_imported = any( + isinstance(node, ast.Import) and any(alias.name == 'os' for alias in node.names) for node in ast.walk(tree)) + + # If 'os' module is not imported, add the import statement + if not os_imported: + import_os = ast.Import(names=[ast.alias(name='os', asname=None)]) + tree.body.insert(0, import_os) + + # Generate the modified code + return compile(tree, filename, 'exec') diff --git a/settings.py b/settings.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/default_django_settings.py b/tests/default_django_settings.py new file mode 100644 index 0000000..803caec --- /dev/null +++ b/tests/default_django_settings.py @@ -0,0 +1,123 @@ +""" +Django settings for tmp project. + +Generated by 'django-admin startproject' using Django 4.2.6. + +For more information on this file, see +https://docs.djangoproject.com/en/4.2/topics/settings/ + +For the full list of settings and their values, see +https://docs.djangoproject.com/en/4.2/ref/settings/ +""" + +from pathlib import Path + +# Build paths inside the project like this: BASE_DIR / 'subdir'. +BASE_DIR = Path(__file__).resolve().parent.parent + + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = "django-insecure-!$!f8*cmiu&wbjtd=a6ypeufxues9z7r#sylrsuudodjo52s$(" + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = True + +ALLOWED_HOSTS = [] + + +# Application definition + +INSTALLED_APPS = [ + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", +] + +MIDDLEWARE = [ + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", +] + +ROOT_URLCONF = "tmp.urls" + +TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ], + }, + }, +] + +WSGI_APPLICATION = "tmp.wsgi.application" + + +# Database +# https://docs.djangoproject.com/en/4.2/ref/settings/#databases + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": BASE_DIR / "db.sqlite3", + } +} + + +# Password validation +# https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators + +AUTH_PASSWORD_VALIDATORS = [ + { + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", + }, +] + + +# Internationalization +# https://docs.djangoproject.com/en/4.2/topics/i18n/ + +LANGUAGE_CODE = "en-us" + +TIME_ZONE = "UTC" + +USE_I18N = True + +USE_TZ = True + + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/4.2/howto/static-files/ + +STATIC_URL = "static/" + +# Default primary key field type +# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field + +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" diff --git a/tests/test_rewrite.py b/tests/test_rewrite.py new file mode 100644 index 0000000..3828ca2 --- /dev/null +++ b/tests/test_rewrite.py @@ -0,0 +1,21 @@ +import io +from pathlib import Path + +import pytest +from django_production.modifiers import add_imports + +# Define a fixture to create an in-memory file with content +@pytest.fixture +def temporary_settings_file(tmp_path) -> [str, str]: + content = (Path(__file__).parent / "default_django_settings.py").read_text() + tmp_settings = tmp_path / "settings.py" + tmp_settings.write_text(content) + return content, str(tmp_settings.name) + + +def test_imports_added(temporary_settings_file): + contents, filename = temporary_settings_file + modded = add_imports(contents, filename) + + # Check if 'os' import has been added + assert "import os" in modded.splitlines()