diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9a57056..766101e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,42 +1,30 @@ -ci: - autoupdate_branch: "2.1.x" - autoupdate_schedule: monthly repos: - repo: https://github.com/asottile/pyupgrade - rev: v2.34.0 + rev: v3.9.0 hooks: - id: pyupgrade - args: ["--py37-plus"] - - repo: https://github.com/asottile/reorder_python_imports - rev: v3.1.0 + args: ["--py38-plus"] + - repo: https://github.com/pycqa/isort + rev: 5.12.0 hooks: - - id: reorder-python-imports - name: Reorder Python imports (src, tests) - files: "^(?!examples/)" - args: ["--application-directories", ".:src"] - additional_dependencies: ["setuptools>60.9"] - - id: reorder-python-imports - name: Reorder Python imports (examples) - files: "^examples/" - args: ["--application-directories", "examples"] - additional_dependencies: ["setuptools>60.9"] + - id: isort - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 23.7.0 hooks: - id: black - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 6.0.0 hooks: - id: flake8 additional_dependencies: - flake8-bugbear - flake8-implicit-str-concat - repo: https://github.com/peterdemin/pip-compile-multi - rev: v2.4.5 + rev: v2.6.3 hooks: - id: pip-compile-multi-verify - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: fix-byte-order-marker - id: trailing-whitespace diff --git a/.readthedocs.yaml b/.readthedocs.yaml index d1f4efa..e20d0c2 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,7 +3,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.10" + python: "3.12" python: install: diff --git a/CHANGES.rst b/CHANGES.rst index 5d71eb8..c8ac40f 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,86 @@ +0.19.6 2024-05-19 +----------------- + +* Bugfix use ContentRange in the right way. See issue #331. +* Bugfix hold a strong reference to background tasks. +* Bugfix avoid ResourceWarning in DataBody.__aiter__. + +0.19.5 2024-04-01 +----------------- + +* Bugfix DeprecationWarning from datetime.utcnow(). +* Bugfix ensure request files are closed. +* Bugfix development restarting when commands are passed. +* Restore teardown_websocket methods. +* Correct the config_class type. +* Allow kwargs to be passed to the test client (matches Flask API). + +0.19.4 2023-11-19 +----------------- + +* Bugfix program not closing on Ctrl+C in Windows. +* Bugfix the typing for AfterWebsocket functions. +* Improve the typing of the ensure_async method. +* Add a shutdown event to the app. + +0.19.3 2023-10-04 +----------------- + +* Bugfix update the default config to better match Flask. + +0.19.2 2023-10-01 +----------------- + +* Bugfix restore the app {after, before}_websocket methods. +* Bugfix correctly set the cli Group in Quart. + +0.19.1 2023-09-30 +----------------- + +* Bugfix remove QUART_ENV and env usage. + +0.19.0 2023-09-30 +----------------- + +* Remove Flask-Patch. It has been replaced with the Quart-Flask-Patch + extension. +* Remove references to first request, as per Flask. +* Await the background tasks before calling the after serving funcs. +* Don't copy the app context into the background task. +* Allow background tasks a grace period to complete during shutdown. +* Base Quart on Flask, utilising Flask code where possible. This + introduces a dependency on Flask. +* Bugfix trailing slash issue in URL concatenation for empty 'path' +* Bugfix Issue #219. Use only CR in SSE documentation. +* Bugfix typing for websocket to accept auth data. +* Bugfix ensure subdomains apply to nested blueprints. +* Bugfix ensure make_response errors if the value is incorrect. +* Bugfix propagated exception handling. +* Bugfix ensure exceptions propagate before logging. +* Bugfix cope with scope extension value being None. +* Bugfix ensure the conditional 304 response is empty. +* Bugfix handle empty path in URL concatenation. +* Bugfix corrected typing hint for abort method at helpers.py. +* Bugfix root_path usage. +* Fix Werkzeug deprecation warnings. +* Add svg's to jinja's autoescaping. +* Improve the WebsocketResponse error, by including the response. +* Add a file mode parameter to the config.from_file method. +* Show the subdomain or host in the routes command output. +* Upgrade to blinker 1.6. +* Require Werkzeug 3.0.0 and Flask 3.0.0. +* Use tomllib rather than toml. + +0.18.4 2023-04-09 +----------------- + +* Restrict blinker to < 1.6 for 0.18.x versions to ensure it works + with Quart's implementation. + 0.18.3 2022-10-08 ----------------- +* Fixed Issue #206. Corrected quart.json.loads type annotation. * Bugfix signal handling on Windows. * Bugfix add missing globals to Flask-Patch. diff --git a/README.rst b/README.rst index 1296e2c..7174398 100644 --- a/README.rst +++ b/README.rst @@ -24,7 +24,7 @@ Quart can be installed via `pip $ pip install quart -and requires Python 3.7.0 or higher (see `python version support +and requires Python 3.8.0 or higher (see `python version support `_ for reasoning). @@ -102,10 +102,10 @@ Relationship with Flask ----------------------- Quart is an asyncio reimplementation of the popular `Flask -`_ microframework API. This means that if you +`_ microframework API. This means that if you understand Flask you understand Quart. -Like Flask Quart has an ecosystem of extensions for more specific +Like Flask, Quart has an ecosystem of extensions for more specific needs. In addition a number of the Flask extensions work with Quart. Migrating from Flask diff --git a/artwork/logo.png b/artwork/logo.png index 91f0cc9..50814a1 100644 Binary files a/artwork/logo.png and b/artwork/logo.png differ diff --git a/debian/changelog b/debian/changelog index 13b1d50..53e8d3a 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,8 +1,57 @@ -quart (0.18.3-2deepin1) unstable; urgency=medium +quart (0.19.6-1) unstable; urgency=medium - * fix No module named 'py._path'; 'py' is not a package. + * Team upload + * New upstream version 0.19.6 + * d/control: Increase Standards-Version to 4.7.0 + No further modifications needed. + * d/copyright: Update upstream data + + -- Carsten Schoenert Sun, 26 May 2024 07:24:39 +0200 + +quart (0.19.5-1) unstable; urgency=medium + + * Team upload + * New upstream version 0.19.5 + * Rebuild patch queue from patch-queu branch + Dropped patch (fixed upstream): + Fix-issues-with-the-latest-black-mypy-and-pytest.patch + + -- Carsten Schoenert Tue, 02 Apr 2024 22:05:20 +0200 + +quart (0.19.4-2) unstable; urgency=medium - -- LiChengGang Wed, 22 Nov 2023 09:37:34 +0800 + * Team upload + * Rebuild patch queue from patch-queu branch + Added patche: + Fix-issues-with-the-latest-black-mypy-and-pytest.patch + (Closes: #1066777) + + -- Carsten Schoenert Sun, 17 Mar 2024 12:40:00 +0100 + +quart (0.19.4-1) unstable; urgency=medium + + * Team upload + + [ Andreas Tille ] + * New upstream version 0.19.4 + (Closes: #1042259) + * Standards-Version: 4.6.2 (routine-update) + * Add salsa-ci file (routine-update) + * Build-Depends: python3-flask + + [ Carsten Schoenert ] + * d/gbp.conf: Don't use numbers in patch names + * Add patch queue from patch-queue branch + Added patch: + docs-conf.py-Use-sphinx_rtd_theme-instead.patch + * d/control: Keep B-D entries alphabetical + * d/rules: Add override for dh_clean + * d/rules: Drop direct modification of docs/conf.py + Done now through patch queue. + * d/rules: Undo modification of pyproject.toml after build + * d/python-quart-doc.lintian-overrides: Update data content + + -- Carsten Schoenert Mon, 12 Feb 2024 20:48:14 +0100 quart (0.18.3-2) unstable; urgency=medium diff --git a/debian/control b/debian/control index 912d349..3247b5d 100644 --- a/debian/control +++ b/debian/control @@ -1,45 +1,44 @@ Source: quart Maintainer: Debian Python Team -Uploaders: - Andrej Shadura , +Uploaders: Andrej Shadura Section: python -Priority: optional -Build-Depends: - debhelper-compat (= 13), - dh-sequence-python3, - pybuild-plugin-pyproject, - python3-aiofiles, - python3-all (>= 3.7), - python3-blinker, - python3-click, - python3-dotenv, - python3-hypercorn (>= 0.11.2~), - python3-hypothesis, - python3-itsdangerous, - python3-jinja2, - python3-markupsafe, - python3-poetry, - python3-pytest , - python3-pytest-asyncio , - python3-pytest-cov , - python3-sphinx , - python3-sphinx-rtd-theme , - python3-toml, - python3-werkzeug (>= 2.2.0~), -Rules-Requires-Root: no -Standards-Version: 4.6.1 Testsuite: autopkgtest-pkg-python -Homepage: https://github.com/pallets/quart +Priority: optional +Build-Depends: debhelper-compat (= 13), + dh-sequence-python3, + pybuild-plugin-pyproject, + python3-aiofiles, + python3-all, + python3-blinker, + python3-click, + python3-dotenv, + python3-flask , + python3-hypercorn, + python3-hypothesis, + python3-itsdangerous, + python3-jinja2, + python3-markupsafe, + python3-poetry, + python3-poetry-core, + python3-pytest , + python3-pytest-asyncio , + python3-pytest-cov , + python3-sphinx , + python3-sphinx-rtd-theme , + python3-toml, + python3-werkzeug (>= 2.2.0~), +Standards-Version: 4.7.0 Vcs-Browser: https://salsa.debian.org/python-team/packages/quart Vcs-Git: https://salsa.debian.org/python-team/packages/quart.git +Homepage: https://github.com/pallets/quart +Rules-Requires-Root: no Package: python-quart-doc -Section: doc Architecture: all -Depends: - ${misc:Depends}, - ${sphinxdoc:Depends}, Multi-Arch: foreign +Section: doc +Depends: ${misc:Depends}, + ${sphinxdoc:Depends} Description: Python ASGI web microframework with the same API as Flask (Documentation) Quart is a Python ASGI web microframework. It is intended to provide the easiest way to use asyncio functionality in a web context, especially @@ -62,9 +61,8 @@ Description: Python ASGI web microframework with the same API as Flask (Document Package: python3-quart Architecture: all -Depends: - ${misc:Depends}, - ${python3:Depends}, +Depends: ${misc:Depends}, + ${python3:Depends} Description: Python ASGI web microframework with the same API as Flask Quart is a Python ASGI web microframework. It is intended to provide the easiest way to use asyncio functionality in a web context, especially diff --git a/debian/copyright b/debian/copyright index 02bee0c..a7a4b98 100644 --- a/debian/copyright +++ b/debian/copyright @@ -3,11 +3,12 @@ Source: https://github.com/pallets/quart Upstream-Name: Quart Files: * -Copyright: 2017-2022 Philip G Jones +Copyright: 2017-2024 Philip G Jones License: Expat Files: artwork/* Copyright: 2017 Vic Shóstak + 2024 Philip G Jones License: CC0 Files: debian/* diff --git a/debian/gbp.conf b/debian/gbp.conf index 47e7402..ddcd107 100644 --- a/debian/gbp.conf +++ b/debian/gbp.conf @@ -2,3 +2,6 @@ debian-branch = debian/master upstream-branch = upstream/latest pristine-tar = True + +[pq] +patch-numbers = False diff --git a/debian/patches/docs-conf.py-Use-sphinx_rtd_theme-instead.patch b/debian/patches/docs-conf.py-Use-sphinx_rtd_theme-instead.patch new file mode 100644 index 0000000..7a0d153 --- /dev/null +++ b/debian/patches/docs-conf.py-Use-sphinx_rtd_theme-instead.patch @@ -0,0 +1,25 @@ +From: Carsten Schoenert +Date: Mon, 12 Feb 2024 09:39:42 +0100 +Subject: docs/conf.py: Use sphinx_rtd_theme instead + +The used theme by upstream (pydata_sphinx_theme) isn't packaged in +Debian, falling back than to the classical RTD theme. + +Forwarded: not-needed +--- + docs/conf.py | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/docs/conf.py b/docs/conf.py +index fa5647a..337135c 100644 +--- a/docs/conf.py ++++ b/docs/conf.py +@@ -84,7 +84,7 @@ todo_include_todos = False + # a list of builtin themes. + # + +-html_theme = "pydata_sphinx_theme" ++html_theme = "sphinx_rtd_theme" + html_logo = "_static/logo_short.png" + + # Theme options are theme-specific and customize the look and feel of a theme diff --git a/debian/patches/fix-LocalPath b/debian/patches/fix-LocalPath deleted file mode 100644 index a559b04..0000000 --- a/debian/patches/fix-LocalPath +++ /dev/null @@ -1,23 +0,0 @@ - ---- quart-0.18.3.orig/tests/test_helpers.py -+++ quart-0.18.3/tests/test_helpers.py -@@ -6,7 +6,7 @@ from pathlib import Path - from typing import AsyncGenerator - - import pytest --from py._path.local import LocalPath -+#from py._path.local import LocalPath - from werkzeug.exceptions import NotFound - - from quart import Blueprint, Quart, request ---- quart-0.18.3.orig/tests/wrappers/test_response.py -+++ quart-0.18.3/tests/wrappers/test_response.py -@@ -8,7 +8,7 @@ from typing import Any, AsyncGenerator - - import pytest - from hypothesis import given, strategies as strategies --from py._path.local import LocalPath -+#from py._path.local import LocalPath - from werkzeug.datastructures import Headers - from werkzeug.exceptions import RequestedRangeNotSatisfiable - diff --git a/debian/patches/series b/debian/patches/series index 2741bb4..607b951 100644 --- a/debian/patches/series +++ b/debian/patches/series @@ -1 +1 @@ -fix-LocalPath +docs-conf.py-Use-sphinx_rtd_theme-instead.patch diff --git a/debian/python-quart-doc.lintian-overrides b/debian/python-quart-doc.lintian-overrides index c76ab1b..698b441 100644 --- a/debian/python-quart-doc.lintian-overrides +++ b/debian/python-quart-doc.lintian-overrides @@ -5,7 +5,6 @@ python-quart-doc: repeated-path-segment chat [usr/share/doc/python3-quart/exampl python-quart-doc: repeated-path-segment video [usr/share/doc/python3-quart/examples/video/src/video/] # Intended by sphinx design, we ship the source of the HTML data. -python-quart-doc: file-references-package-build-path [usr/share/doc/python-quart-doc/html/reference/source/quart.flask_patch.html] python-quart-doc: file-references-package-build-path [usr/share/doc/python-quart-doc/html/reference/source/quart.wrappers.base.html] python-quart-doc: file-references-package-build-path [usr/share/doc/python-quart-doc/html/reference/source/quart.wrappers.html] python-quart-doc: file-references-package-build-path [usr/share/doc/python-quart-doc/html/reference/source/quart.wrappers.response.html] diff --git a/debian/rules b/debian/rules index e836415..52cba5d 100755 --- a/debian/rules +++ b/debian/rules @@ -14,7 +14,6 @@ SPHINXOPTS := -N -D html_last_updated_fmt="$(BUILD_DATE)" override_dh_sphinxdoc: ifeq (,$(findstring nodoc, $(DEB_BUILD_OPTIONS))) - sed -i 's,pydata_sphinx_theme,sphinx_rtd_theme,g' docs/conf.py PYTHONPATH=`dirname $$(find .pybuild/ -type d -name "quart*dist-info" | head -n1)` \ python3 -m sphinx -b html $(SPHINXOPTS) docs $(CURDIR)/debian/python-quart-doc/usr/share/doc/python-quart-doc/html dh_sphinxdoc @@ -27,9 +26,13 @@ override_dh_installdocs: override_dh_auto_build: # The name of the project need to be written in lowercase letters, # otherwise pybuild wont find the WHEEL file (as it's created from the - # project name starting with an capital letter. + # project name starting with an capital letter). sed -i 's,name = "Quart",name = "quart",g' pyproject.toml dh_auto_build + # Now undo the previous replacement again, we need to do this hack as + # dpkg will otherwise complain about modified source code of course in + # case the build is getting started twice in a row. + sed -i 's,name = "quart",name = "Quart",g' pyproject.toml override_dh_installexamples: dh_installexamples -ppython-$(PYBUILD_NAME)-doc --doc-main-package=python3-$(PYBUILD_NAME) @@ -41,4 +44,10 @@ override_dh_installchangelogs: override_dh_compress: dh_compress -X.py +override_dh_clean: + rm -rf .pybuild \ + .mypy_cache \ + docs/reference/source/* + dh_clean + .PHONY: override_dh_auto_build override_dh_auto_clean override_dh_installdocs override_dh_installexamples override_dh_installchangelogs override_dh_compress diff --git a/debian/salsa-ci.yml b/debian/salsa-ci.yml new file mode 100644 index 0000000..33c3a64 --- /dev/null +++ b/debian/salsa-ci.yml @@ -0,0 +1,4 @@ +--- +include: + - https://salsa.debian.org/salsa-ci-team/pipeline/raw/master/salsa-ci.yml + - https://salsa.debian.org/salsa-ci-team/pipeline/raw/master/pipeline-jobs.yml diff --git a/docs/Makefile b/docs/Makefile index d46376e..41ce675 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -17,4 +17,4 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/logo.png b/docs/_static/logo.png index 72ecc27..50814a1 100644 Binary files a/docs/_static/logo.png and b/docs/_static/logo.png differ diff --git a/docs/_static/logo_short.png b/docs/_static/logo_short.png index 459b150..29ca487 100644 Binary files a/docs/_static/logo_short.png and b/docs/_static/logo_short.png differ diff --git a/docs/discussion/async_compatibility.rst b/docs/discussion/async_compatibility.rst index 20652b2..f2e1b88 100644 --- a/docs/discussion/async_compatibility.rst +++ b/docs/discussion/async_compatibility.rst @@ -57,10 +57,8 @@ whilst the route function can be wrapped with the way to insert the ``await`` before the ``request.form`` and ``render_template`` calls. -It is for this reason that a proxy object, -:class:`~quart.flask_patch.globals.FlaskRequestProxy`, and render, -:func:`~quart.flask_patch.templating.render_template` functions are -created for the Flask extensions. The former adding synchronous +It is for this reason that Quart-Flask-Patch creates sync wrapped +versions for the Flask extensions. The former adding synchronous request methods and the other providing synchronous functions. Quart monkey patches a ``sync_wait`` method onto the base event loop diff --git a/docs/discussion/python_versions.rst b/docs/discussion/python_versions.rst index 13cb988..cdf5e4c 100644 --- a/docs/discussion/python_versions.rst +++ b/docs/discussion/python_versions.rst @@ -3,11 +3,8 @@ Python version support ====================== -The main branch and releases >= 0.7.0 onwards only support Python -3.7.0 or greater. - -Python 3.7.0 is required to utilise ContextVars and doing so -considerably simplifies context management with tasks. +The main branch and releases >= 0.19.0 onwards only support Python +3.8.0 or greater. The 0.6-maintenance branch supported Python3.6. The final 0.6.X release, 0.6.15, was released in October 2019 after the release of diff --git a/docs/how_to_guides/background_tasks.rst b/docs/how_to_guides/background_tasks.rst index 8fdc888..a2cac24 100644 --- a/docs/how_to_guides/background_tasks.rst +++ b/docs/how_to_guides/background_tasks.rst @@ -32,8 +32,11 @@ method: The background tasks will have access to the app context. The tasks will be awaited during shutdown to ensure they complete before the app -shuts down. If your task does not complete it will eventually be -cancelled as the app is forceably shut down by the server. +shuts down. If your task does not complete within the config +``BACKGROUND_TASK_SHUTDOWN_TIMEOUT`` it will be cancelled. + +Note ``BACKGROUND_TASK_SHUTDOWN_TIMEOUT`` should ideally be less than +any server shutdown timeout. Synchronous background tasks are supported and will run in a separate thread. diff --git a/docs/how_to_guides/event_loop.rst b/docs/how_to_guides/event_loop.rst index f3a3457..38a974c 100644 --- a/docs/how_to_guides/event_loop.rst +++ b/docs/how_to_guides/event_loop.rst @@ -52,7 +52,7 @@ or to use the ``app.run_task`` method, loop.run_until_complete(app.run_task()) the Hypercorn (production) solution is to utilise the `Hypercorn API -`_ to do the +`_ to do the following, .. code-block:: python diff --git a/docs/how_to_guides/flask_extensions.rst b/docs/how_to_guides/flask_extensions.rst index cfb4859..9258a64 100644 --- a/docs/how_to_guides/flask_extensions.rst +++ b/docs/how_to_guides/flask_extensions.rst @@ -3,81 +3,13 @@ Using Flask Extensions ====================== -Flask extensions can be used with Quart, with some caveats. To do so -the very first import in your code must be ``import quart.flask_patch`` -as this will add modules purporting to be Flask modules for later use -by the extension. For example, +Some Flask extensions can be used with Quart by patching Quart to act +as Flask, to patch Quart see the `Quart-Flask-Patch +`_ extension. This was +part of Quart until release 0.19.0. -.. code-block:: python +Reference +--------- - import quart.flask_patch - - from quart import Quart - import flask_login - - app = Quart(__name__) - login_manager = flask_login.LoginManager() - login_manager.init_app(app) - - ... - - -Caveats -------- - -Flask extensions must use the global request proxy variable to access -the request, any other access e.g. via -:meth:`~quart.local.LocalProxy._get_current_object` will require -asynchronous access. To enable this the request body must be fully -received before any part of the request is handled, which is a -limitation not present in vanilla flask. - -Trying to use Flask alongside Quart in the same runtime will likely not -work, and lead to surprising errors. - -The flask extension must be limited to creating routes, using the -request and rendering templates. Any other more advanced functionality -may not work. - -Synchronous functions will not run in a separate thread (unlike Quart -normally) and hence may block the event loop. - -Finally the flask_patching system also relies on patching asyncio, and -hence other implementations or event loop policies are unlikely to -work. - -Supported extensions --------------------- - -A list of officially supported flask extensions exist `here -`_ of those a few have been tested -against Quart (the extensions tested are still supported and don't -require external services). The following flask extensions are tested -and known to work with quart, - -- `Flask-BCrypt `_ -- `Flask-Caching `_ -- `Flask-KVSession `_ -- `Flask-Limiter `_ -- `Flask-Login `_ See - also `Quart-Login `_ -- `Flask-Mail `_ -- `Flask-Mako `_ -- `Flask-Seasurf `_ -- `Flask-SQLAlchemy `_ -- `Flask-WTF `_ - -Broken extensions ------------------ - -The following flask extensions have been tested are known not to work -with quart, - -- `Flask-CORS `_, as it - uses ``app.make_response`` which must be awaited. Try `Quart-CORS - `_ instead. -- `Flask-Restful `_ - as it subclasses the Quart (app) class with synchronous methods - overriding asynchronous methods. Try `Quart-OpenApi - `_ or `Quart-Schema - `_ instead. +More information about Flask extensions can be found +`here `_. diff --git a/docs/how_to_guides/flask_migration.rst b/docs/how_to_guides/flask_migration.rst index 777099b..b2bb7f2 100644 --- a/docs/how_to_guides/flask_migration.rst +++ b/docs/how_to_guides/flask_migration.rst @@ -68,6 +68,8 @@ function/method is a syntax error. .. code-block:: python await request.data + await request.get_data() + await request.json await request.get_json() await request.form await request.files diff --git a/docs/how_to_guides/logging.rst b/docs/how_to_guides/logging.rst index 0833a88..f56f4f5 100644 --- a/docs/how_to_guides/logging.rst +++ b/docs/how_to_guides/logging.rst @@ -3,28 +3,22 @@ Logging ======= -By default Quart has two loggers, named ``quart.app`` and -``quart.serving``, both are standard Python Loggers. The former is -usually kept for app logging whilst the latter serving. To use the -former, simply make use of :attr:`~quart.app.Quart.logger`, for -example: +Quart has a standard Python logger sharing the same name as the +``app.name``. To use it, simply make use of +:attr:`~quart.app.Quart.logger`, for example: .. code-block:: python app.logger.info('Interesting') app.logger.warning('Easy Now') -The serving logger is typically reserved for the serving code, but can -be used if required via :func:`logging.getLogger` i.e. -``getLogger('quart.serving')``. - Configuration ------------- -The Quart loggers are not created until their first usage, which may -occur as the app is created. These loggers on creation respect any -existing configuration. This allows the loggers to be configured like -any other python logger, for example +The Quart logger is not created until its first usage, which may occur +as the app is created. These loggers on creation respect any existing +configuration. This allows the loggers to be configured like any other +python logger, for example .. code-block:: python @@ -42,14 +36,12 @@ any other python logger, for example Disabling/removing handlers --------------------------- -The handlers attached to the quart loggers can be removed, the -handlers are :attr:`~quart.logging.default_handler` and -:attr:`~quart.logging.default_serving_handler` and can be removed like -so, +The handler :attr:`~quart.logging.default_handler` attached to the +quart logger can be removed like so, .. code-block:: python from logging import getLogger from quart.logging import default_handler - getLogger('quart.app').removeHandler(default_handler) + getLogger(app.name).removeHandler(default_handler) diff --git a/docs/how_to_guides/middleware.rst b/docs/how_to_guides/middleware.rst index 971ad3f..6040a97 100644 --- a/docs/how_to_guides/middleware.rst +++ b/docs/how_to_guides/middleware.rst @@ -19,7 +19,7 @@ the presence of a header, return await self.app(scope, receive, send) for header, value in scope['headers']: - if header == 'X-Secret' and value == 'very-secret': + if header.lower() == b'x-secret' and value == b'very-secret': return await self.app(scope, receive, send) return await self.error_response(receive, send) diff --git a/docs/how_to_guides/quart_extensions.rst b/docs/how_to_guides/quart_extensions.rst index eafab0c..2148a83 100644 --- a/docs/how_to_guides/quart_extensions.rst +++ b/docs/how_to_guides/quart_extensions.rst @@ -8,7 +8,7 @@ here, - `Quart-Auth `_ Secure cookie sessions, allows login, authentication and logout. -- `Quart-Babel `_ Implements i18n and l10n support for Quart. +- `Quart-Babel `_ Implements i18n and l10n support for Quart. - `Quart-Bcrypt `_ Provides bcrypt hashing utilities for your application. - `Quart-compress `_ compress your application's responses with gzip. @@ -25,24 +25,30 @@ here, port of Flask-Login to work natively with Quart. - `Quart-minify `_ minify quart response for HTML, JS, CSS and less. +- `Quart-Mongo `_ Bridges Quart, Motor, and Odmantic to create a powerful MongoDB + extension. - `Quart-Motor `_ Motor (MongoDB) support for Quart applications. - `Quart-OpenApi `_ RESTful API building. -- `Quart-Session-OpenID `_ - Support for OAuth2 OpenID Connect (OIDC). +- `Quart-Keycloak `_ + Support for Keycloak's OAuth2 OpenID Connect (OIDC). - `Quart-Rapidoc `_ API documentation from OpenAPI Specification. - `Quart-Rate-Limiter `_ Rate limiting support. +- `Quart-Redis + `_ Redis connection handling - `Webargs-Quart `_ Webargs parsing for Quart. +- `Quart-WTF `_ Simple integration of Quart + and WTForms. Including CSRF and file uploading. - `Quart-Schema `_ Schema validation and auto-generated API documentation. - `Quart-session `_ server side session support. -- `Quart-Uploads `_ File upload handling for Quart. +- `Quart-Uploads `_ File upload handling for Quart. Supporting sync code in a Quart Extension ----------------------------------------- diff --git a/docs/how_to_guides/server_sent_events.rst b/docs/how_to_guides/server_sent_events.rst index 39da005..9c5b708 100644 --- a/docs/how_to_guides/server_sent_events.rst +++ b/docs/how_to_guides/server_sent_events.rst @@ -30,7 +30,7 @@ helper class: message = f"{message}\nid: {self.id}" if self.retry is not None: message = f"{message}\nretry: {self.retry}" - message = f"{message}\r\n\r\n" + message = f"{message}\n\n" return message.encode('utf-8') To use a GET route that returns a streaming generator is diff --git a/docs/how_to_guides/startup_shutdown.rst b/docs/how_to_guides/startup_shutdown.rst index 8d44fff..c1ca1bf 100644 --- a/docs/how_to_guides/startup_shutdown.rst +++ b/docs/how_to_guides/startup_shutdown.rst @@ -8,9 +8,8 @@ coroutines before the first byte is received and after the final byte is sent, through the ``startup`` and ``shutdown`` lifespan events. This is particularly useful for creating and destroying connection pools. Quart supports this via the decorators -:func:`~quart.app.Quart.before_serving` and -:func:`~quart.app.Quart.after_serving`, which function like -:func:`~quart.app.Quart.before_first_request`, and +:func:`~quart.app.Quart.before_serving`, +:func:`~quart.app.Quart.after_serving`, and :func:`~quart.app.Quart.while_serving` which expects a function that returns a generator. diff --git a/docs/how_to_guides/templating.rst b/docs/how_to_guides/templating.rst index c9ee590..988a1b7 100644 --- a/docs/how_to_guides/templating.rst +++ b/docs/how_to_guides/templating.rst @@ -3,10 +3,10 @@ Templates ========= -Quart uses the `Jinja2 `_ templating engine, +Quart uses the `Jinja `_ templating engine, which is well `documented -`_. Quart adds a standard -context, and some standard filters to the Jinja2 defaults. Quart also +`_. Quart adds a standard +context, and some standard filters to the Jinja defaults. Quart also adds the ability to define custom filters, tests and contexts at an app and blueprint level. diff --git a/docs/how_to_guides/websockets.rst b/docs/how_to_guides/websockets.rst index 233063a..3092aa0 100644 --- a/docs/how_to_guides/websockets.rst +++ b/docs/how_to_guides/websockets.rst @@ -151,7 +151,7 @@ type of request (WebSocket upgrade or not). As so, async def http(): return "A HTTP request" - @app.route("/ws") + @app.websocket("/ws") async def ws(): ... # Use the WebSocket diff --git a/docs/index.rst b/docs/index.rst index 37cd570..9dc6528 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -31,7 +31,7 @@ If you are, * looking for a cheatsheet then look :ref:`here`. Quart is an asyncio reimplementation of the popular `Flask -`_ microframework API. This means that if you +`_ microframework API. This means that if you understand Flask you understand Quart. See :ref:`flask_evolution` to learn more about how Quart builds on Flask. @@ -52,10 +52,10 @@ ask for help try `on discord `_. If you can't find documentation for what you are looking for here, remember that Quart is an implementation of the Flask API and - hence the `Flask documentation `_ is + hence the `Flask documentation `_ is a great source of help. Quart is also built on the `Jinja - `_ template engine and the `Wekzeug - `_ toolkit. + `_ template engine and the `Werkzeug + `_ toolkit. The Flask documentation is so good that you may be better placed consulting it first then returning here to check how Quart diff --git a/docs/reference/cheatsheet.rst b/docs/reference/cheatsheet.rst index c6433ee..0d86aaf 100644 --- a/docs/reference/cheatsheet.rst +++ b/docs/reference/cheatsheet.rst @@ -63,11 +63,11 @@ Configuration .. code-block:: python import json - import toml + import tomllib app.config["VALUE"] = "something" - app.config.from_file("filename.toml", toml.load) + app.config.from_file("filename.toml", tomllib.load) app.config.from_file("filename.json", json.load) Request diff --git a/docs/tutorials/api_tutorial.rst b/docs/tutorials/api_tutorial.rst index a403e0e..3f9875c 100644 --- a/docs/tutorials/api_tutorial.rst +++ b/docs/tutorials/api_tutorial.rst @@ -9,7 +9,7 @@ response data. This tutorial is meant to serve as an introduction to building APIs in Quart. If you want to skip to the end the code is on `Github -`_. +`_. 1: Creating the project ----------------------- @@ -143,8 +143,11 @@ We can then add schemas for a Todo object by adding the following to from dataclasses import dataclass from datetime import datetime + from quart import Quart from quart_schema import QuartSchema, validate_request, validate_response + app = Quart(__name__) + QuartSchema(app) @dataclass @@ -159,7 +162,7 @@ We can then add schemas for a Todo object by adding the following to @app.post("/todos/") @validate_request(TodoIn) @validate_response(Todo) - async def create_todo(data: Todo) -> Todo: + async def create_todo(data: TodoIn) -> Todo: return Todo(id=1, task=data.task, due=data.due) The OpenAPI schema is then available at diff --git a/docs/tutorials/blog_tutorial.rst b/docs/tutorials/blog_tutorial.rst index cc03811..5f16ea6 100644 --- a/docs/tutorials/blog_tutorial.rst +++ b/docs/tutorials/blog_tutorial.rst @@ -9,7 +9,7 @@ HTML directly to the user. This tutorial is meant to serve as an introduction to building server rendered websites in Quart. If you want to skip to the end the code is -on `Github `_. +on `Github `_. 1: Creating the project ----------------------- @@ -165,7 +165,7 @@ and should be added to *src/blog/templates/posts.html*: {% endfor %} -Now we need a route to to query the database, retrieve the messages, +Now we need a route to query the database, retrieve the messages, and render the template. As done with the following code which should be added to *src/blog/__init__.py*: @@ -303,7 +303,7 @@ If you are running this in the Quart example folder you'll need to add a ``-c pyproject.toml`` option to prevent pytest from using the Quart pytest configuration. -7: Summary +8: Summary ---------- We've built a simple database backed blog server. This should be a diff --git a/docs/tutorials/chat_tutorial.rst b/docs/tutorials/chat_tutorial.rst index 0e80a8d..3d67754 100644 --- a/docs/tutorials/chat_tutorial.rst +++ b/docs/tutorials/chat_tutorial.rst @@ -9,7 +9,7 @@ server. This tutorial is meant to serve as an introduction to WebSockets in Quart. If you want to skip to the end the code is on `Github -`_. +`_. 1: Creating the project ----------------------- @@ -84,7 +84,7 @@ Which allows the following command to start the app: When users visit our chat website we will need to show them a UI which they can use to enter and receive messages. The following HTML -template should be added to *src/chat/templatest/index.html*: +template should be added to *src/chat/templates/index.html*: .. code-block:: html :caption: src/chat/templates/index.html diff --git a/docs/tutorials/deployment.rst b/docs/tutorials/deployment.rst index e6d7fa3..9344c79 100644 --- a/docs/tutorials/deployment.rst +++ b/docs/tutorials/deployment.rst @@ -6,8 +6,8 @@ Deploying Quart It is not recommended to run Quart directly (via :meth:`~quart.app.Quart.run`) in production. Instead it is recommended that Quart be run using `Hypercorn -`_ or an alternative ASGI -server. This is becuase the :meth:`~quart.app.Quart.run` enables +`_ or an alternative ASGI +server. This is because the :meth:`~quart.app.Quart.run` enables features that help development yet slow production performance. Hypercorn is installed with Quart and will be used to serve requests in development mode by default (e.g. with @@ -33,7 +33,7 @@ you can run with Hypercorn using, hypercorn example:app -See the `Hypercorn docs `_. +See the `Hypercorn docs `_. Alternative ASGI Servers ------------------------ @@ -41,7 +41,7 @@ Alternative ASGI Servers ==================================================== ====== ====== =========== ================== Server name HTTP/2 HTTP/3 Server Push Websocket Response ==================================================== ====== ====== =========== ================== -`Hypercorn `_ ✓ ✓ ✓ ✓ +`Hypercorn `_ ✓ ✓ ✓ ✓ `Daphne `_ ✓ ✗ ✗ ✗ `Uvicorn `_ ✗ ✗ ✗ ✗ ==================================================== ====== ====== =========== ================== diff --git a/docs/tutorials/installation.rst b/docs/tutorials/installation.rst index a6c4063..6677461 100644 --- a/docs/tutorials/installation.rst +++ b/docs/tutorials/installation.rst @@ -3,16 +3,13 @@ Installation ============ -Quart is only compatible with Python 3.7 or higher and can be installed -using pip or your favorite python package manager:: +Quart is only compatible with Python 3.8 or higher and can be installed +using pip or your favorite python package manager: .. code-block:: console pip install quart -If you do not have Python 3.7 or better an error message ``Python 3.7 -is the minimum required version`` will be displayed. - Dependencies ------------ @@ -23,11 +20,11 @@ be installed with Quart: - blinker, to manager signals, - click, to manage command line arguments - hypercorn, an ASGI server for development, -- importlib_metadata only for Python 3.7, +- importlib_metadata only for Python 3.8, - itsdangerous, for signing secure cookies, - jinja2, for template rendering, - markupsafe, for markup rendering, -- typing_extensions only for Python 3.7, +- typing_extensions only for Python 3.8, - werkzeug, as the basis of many Quart classes. You can choose to install with the dotenv extra: @@ -36,7 +33,7 @@ You can choose to install with the dotenv extra: pip install quart[dotenv] -Whcih will install the ``python-dotenv`` package which enables support +Which will install the ``python-dotenv`` package which enables support for automatically loading environment variables when running ``quart`` commands. diff --git a/docs/tutorials/video_tutorial.rst b/docs/tutorials/video_tutorial.rst index cfa4467..f6653c3 100644 --- a/docs/tutorials/video_tutorial.rst +++ b/docs/tutorials/video_tutorial.rst @@ -9,7 +9,7 @@ serve a video directly. This tutorial is meant to serve as an introduction to serving large files with conditional responses in Quart. If you want to skip to the end the code is on `Github -`_. +`_. 1: Creating the project ----------------------- @@ -126,7 +126,7 @@ method. The former is shown below, which should be added to @app.route("/video.mp4") async def auto_video(): - return await send_file("video.mp4", conditional=True) + return await send_file(app.static_folder / "video.mp4", conditional=True) 6: Testing ---------- diff --git a/examples/api/src/api/__init__.py b/examples/api/src/api/__init__.py index 198dbbd..905cc25 100644 --- a/examples/api/src/api/__init__.py +++ b/examples/api/src/api/__init__.py @@ -1,32 +1,38 @@ from dataclasses import dataclass from datetime import datetime -from quart import Quart, request from quart_schema import QuartSchema, validate_request, validate_response +from quart import Quart, request + app = Quart(__name__) QuartSchema(app) + @app.post("/echo") async def echo(): print(request.is_json, request.mimetype) data = await request.get_json() return {"input": data, "extra": True} + @dataclass class TodoIn: task: str due: datetime | None + @dataclass class Todo(TodoIn): id: int + @app.post("/todos/") @validate_request(TodoIn) @validate_response(Todo) async def create_todo(data: Todo) -> Todo: return Todo(id=1, task=data.task, due=data.due) + def run() -> None: app.run() diff --git a/examples/api/tests/test_api.py b/examples/api/tests/test_api.py index f055ab3..3ab1c2c 100644 --- a/examples/api/tests/test_api.py +++ b/examples/api/tests/test_api.py @@ -1,10 +1,12 @@ from api import app, TodoIn + async def test_echo() -> None: test_client = app.test_client() response = await test_client.post("/echo", json={"a": "b"}) data = await response.get_json() - assert data == {"extra":True,"input":{"a":"b"}} + assert data == {"extra": True, "input": {"a": "b"}} + async def test_create_todo() -> None: test_client = app.test_client() diff --git a/examples/blog/src/blog/__init__.py b/examples/blog/src/blog/__init__.py index 81e7139..9e5ae90 100644 --- a/examples/blog/src/blog/__init__.py +++ b/examples/blog/src/blog/__init__.py @@ -1,23 +1,28 @@ from sqlite3 import dbapi2 as sqlite3 -from quart import g, Quart, redirect, request, render_template, url_for +from quart import g, Quart, redirect, render_template, request, url_for app = Quart(__name__) -app.config.update({ - "DATABASE": app.root_path / "blog.db", -}) +app.config.update( + { + "DATABASE": app.root_path / "blog.db", + } +) + def _connect_db(): engine = sqlite3.connect(app.config["DATABASE"]) engine.row_factory = sqlite3.Row return engine + def _get_db(): if not hasattr(g, "sqlite_db"): g.sqlite_db = _connect_db() return g.sqlite_db + @app.get("/") async def posts(): db = _get_db() @@ -29,6 +34,7 @@ async def posts(): posts = cur.fetchall() return await render_template("posts.html", posts=posts) + @app.route("/create/", methods=["GET", "POST"]) async def create(): if request.method == "POST": @@ -43,11 +49,13 @@ async def create(): else: return await render_template("create.html") + def init_db(): db = _connect_db() - with open(app.root_path / "schema.sql", mode="r") as file_: + with open(app.root_path / "schema.sql") as file_: db.cursor().executescript(file_.read()) db.commit() + def run() -> None: app.run() diff --git a/examples/blog/tests/conftest.py b/examples/blog/tests/conftest.py index 6656a66..dfe3db5 100644 --- a/examples/blog/tests/conftest.py +++ b/examples/blog/tests/conftest.py @@ -1,8 +1,8 @@ import pytest - from blog import app, init_db + @pytest.fixture(autouse=True) def configure_db(tmpdir): - app.config['DATABASE'] = str(tmpdir.join('blog.db')) + app.config["DATABASE"] = str(tmpdir.join("blog.db")) init_db() diff --git a/examples/blog/tests/test_blog.py b/examples/blog/tests/test_blog.py index e297c49..5917057 100644 --- a/examples/blog/tests/test_blog.py +++ b/examples/blog/tests/test_blog.py @@ -1,5 +1,6 @@ from blog import app + async def test_create_post(): test_client = app.test_client() response = await test_client.post("/create/", form={"title": "Post", "text": "Text"}) diff --git a/examples/chat/src/chat/__init__.py b/examples/chat/src/chat/__init__.py index e72348b..22df47b 100644 --- a/examples/chat/src/chat/__init__.py +++ b/examples/chat/src/chat/__init__.py @@ -1,21 +1,24 @@ import asyncio -from quart import Quart, render_template, websocket - from chat.broker import Broker +from quart import Quart, render_template, websocket + app = Quart(__name__) broker = Broker() + @app.get("/") async def index(): return await render_template("index.html") + async def _receive() -> None: while True: message = await websocket.receive() await broker.publish(message) + @app.websocket("/ws") async def ws() -> None: try: @@ -26,5 +29,6 @@ async def ws() -> None: task.cancel() await task + def run(): app.run(debug=True) diff --git a/examples/chat/src/chat/broker.py b/examples/chat/src/chat/broker.py index 7599df6..c217bbc 100644 --- a/examples/chat/src/chat/broker.py +++ b/examples/chat/src/chat/broker.py @@ -1,7 +1,6 @@ import asyncio from typing import AsyncGenerator -from quart import Quart class Broker: def __init__(self) -> None: diff --git a/examples/chat/tests/test_chat.py b/examples/chat/tests/test_chat.py index f460361..7124a3a 100644 --- a/examples/chat/tests/test_chat.py +++ b/examples/chat/tests/test_chat.py @@ -1,12 +1,14 @@ import asyncio +from chat import app + from quart.testing.connections import TestWebsocketConnection as _TestWebsocketConnection -from chat import app async def _receive(test_websocket: _TestWebsocketConnection) -> str: return await test_websocket.receive() + async def test_websocket() -> None: test_client = app.test_client() async with test_client.websocket("/ws") as test_websocket: diff --git a/examples/video/src/video/__init__.py b/examples/video/src/video/__init__.py index 36d1ebf..0528409 100644 --- a/examples/video/src/video/__init__.py +++ b/examples/video/src/video/__init__.py @@ -1,14 +1,17 @@ -from quart import Quart, render_template, request, send_file +from quart import Quart, render_template, send_file app = Quart(__name__) + @app.get("/") async def index(): return await render_template("index.html") + @app.route("/video.mp4") async def auto_video(): return await send_file(app.static_folder / "video.mp4", conditional=True) + def run() -> None: app.run() diff --git a/pyproject.toml b/pyproject.toml index 6ae1c36..1e22852 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "Quart" -version = "0.18.3" +version = "0.19.6" description = "A Python ASGI web microframework with the same API as Flask" authors = ["pgjones "] classifiers = [ @@ -11,10 +11,11 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Internet :: WWW/HTTP :: Dynamic Content", "Topic :: Software Development :: Libraries :: Python Modules", ] @@ -25,10 +26,11 @@ repository = "https://github.com/pallets/quart/" documentation = "https://quart.palletsprojects.com" [tool.poetry.dependencies] -python = ">=3.7" +python = ">=3.8" aiofiles = "*" -blinker = "*" +blinker = ">=1.6" click = ">=8.0.0" +flask = ">=3.0.0" hypercorn = ">=0.11.2" importlib_metadata = { version = "*", python = "<3.10" } itsdangerous = "*" @@ -36,12 +38,11 @@ jinja2 = "*" markupsafe = "*" pydata_sphinx_theme = { version = "*", optional = true } python-dotenv = { version = "*", optional = true } -typing_extensions = { version = "*", python = "<3.8" } -werkzeug = ">=2.2.0" +typing_extensions = { version = "*", python = "<3.10" } +werkzeug = ">=3.0.0" [tool.poetry.dev-dependencies] hypothesis = "*" -mock = "*" pytest = "*" pytest-asyncio = "*" @@ -54,7 +55,7 @@ dotenv = ["python-dotenv"] [tool.black] line-length = 100 -target-version = ["py37"] +target-version = ["py38"] [tool.isort] combine_as_imports = true diff --git a/setup.cfg b/setup.cfg index 69d5a0b..a2e13d5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [flake8] -ignore = E203, E252, FI58, W503, W504 +ignore = E203, E252, E704, FI58, W503, W504 max_line_length = 100 -min_version = 3.7 +min_version = 3.8 per-file-ignores = src/quart/__init__.py:F401 require_code = True diff --git a/src/quart/app.py b/src/quart/app.py index 179dae0..da23382 100644 --- a/src/quart/app.py +++ b/src/quart/app.py @@ -2,15 +2,12 @@ import asyncio import os -import platform import signal import sys import warnings -from collections import OrderedDict +from collections import defaultdict from datetime import timedelta from inspect import isasyncgen, isgenerator -from logging import Logger -from pathlib import Path from types import TracebackType from typing import ( Any, @@ -20,40 +17,34 @@ Callable, cast, Coroutine, - Dict, - Iterable, - List, NoReturn, Optional, + overload, Set, - Tuple, - Type, TypeVar, Union, - ValuesView, ) -from weakref import WeakSet +from urllib.parse import quote from aiofiles import open as async_open from aiofiles.base import AiofilesContextManager from aiofiles.threadpool.binary import AsyncBufferedReader +from flask.sansio.app import App +from flask.sansio.scaffold import setupmethod from hypercorn.asyncio import serve from hypercorn.config import Config as HyperConfig from hypercorn.typing import ASGIReceiveCallable, ASGISendCallable, Scope -from werkzeug.datastructures import Authorization, Headers -from werkzeug.exceptions import Aborter, HTTPException, InternalServerError +from werkzeug.datastructures import Authorization, Headers, ImmutableDict +from werkzeug.exceptions import Aborter, BadRequestKeyError, HTTPException, InternalServerError from werkzeug.routing import BuildError, MapAdapter, RoutingException -from werkzeug.urls import url_quote -from werkzeug.utils import redirect as werkzeug_redirect from werkzeug.wrappers import Response as WerkzeugResponse from .asgi import ASGIHTTPConnection, ASGILifespan, ASGIWebsocketConnection -from .blueprints import Blueprint -from .config import Config, ConfigAttribute, DEFAULT_CONFIG +from .cli import AppGroup +from .config import Config from .ctx import ( _AppCtxGlobals, AppContext, - copy_current_app_context, has_request_context, has_websocket_context, RequestContext, @@ -70,17 +61,8 @@ websocket, websocket_ctx, ) -from .helpers import ( - _split_blueprint_path, - find_package, - get_debug_flag, - get_env, - get_flashed_messages, -) -from .json.provider import DefaultJSONProvider, JSONProvider -from .logging import create_logger +from .helpers import get_debug_flag, get_flashed_messages, send_from_directory from .routing import QuartMap, QuartRule -from .scaffold import _endpoint_from_view_func, Scaffold, setupmethod from .sessions import SecureCookieSessionInterface from .signals import ( appcontext_tearing_down, @@ -95,7 +77,7 @@ websocket_started, websocket_tearing_down, ) -from .templating import _default_template_ctx_processor, DispatchingJinjaLoader, Environment +from .templating import _default_template_ctx_processor, Environment from .testing import ( make_test_body_with_headers, make_test_headers_path_and_query_string, @@ -108,15 +90,17 @@ ) from .typing import ( AfterServingCallable, + AfterWebsocketCallable, ASGIHTTPProtocol, ASGILifespanProtocol, ASGIWebsocketProtocol, - BeforeFirstRequestCallable, BeforeServingCallable, - ErrorHandlerCallable, + BeforeWebsocketCallable, + Event, FilePath, HeadersValue, ResponseReturnValue, + ResponseTypes, ShellContextProcessorCallable, StatusCode, TeardownCallable, @@ -125,11 +109,12 @@ TemplateTestCallable, TestAppProtocol, TestClientProtocol, + WebsocketCallable, WhileServingCallable, ) from .utils import ( + cancel_tasks, file_path_to_path, - is_coroutine_function, MustReloadError, observe_changes, restart, @@ -137,10 +122,16 @@ ) from .wrappers import BaseRequestWebsocket, Request, Response, Websocket +try: + from typing import ParamSpec +except ImportError: + from typing_extensions import ParamSpec # type: ignore + AppOrBlueprintKey = Optional[str] # The App key is None, whereas blueprints are named T_after_serving = TypeVar("T_after_serving", bound=AfterServingCallable) -T_before_first_request = TypeVar("T_before_first_request", bound=BeforeFirstRequestCallable) +T_after_websocket = TypeVar("T_after_websocket", bound=AfterWebsocketCallable) T_before_serving = TypeVar("T_before_serving", bound=BeforeServingCallable) +T_before_websocket = TypeVar("T_before_websocket", bound=BeforeWebsocketCallable) T_shell_context_processor = TypeVar( "T_shell_context_processor", bound=ShellContextProcessorCallable ) @@ -148,16 +139,21 @@ T_template_filter = TypeVar("T_template_filter", bound=TemplateFilterCallable) T_template_global = TypeVar("T_template_global", bound=TemplateGlobalCallable) T_template_test = TypeVar("T_template_test", bound=TemplateTestCallable) +T_websocket = TypeVar("T_websocket", bound=WebsocketCallable) T_while_serving = TypeVar("T_while_serving", bound=WhileServingCallable) +T = TypeVar("T") +P = ParamSpec("P") + + +def _make_timedelta(value: timedelta | int | None) -> timedelta | None: + if value is None or isinstance(value, timedelta): + return value -def _convert_timedelta(value: Union[float, timedelta]) -> timedelta: - if not isinstance(value, timedelta): - return timedelta(seconds=value) - return value + return timedelta(seconds=value) -class Quart(Scaffold): +class Quart(App): """The web framework class, handles requests and returns responses. The primary method from a serving viewpoint is @@ -180,6 +176,8 @@ class Quart(Scaffold): websocket protocol. config_class: The class to use for the configuration. env: The name of the environment the app is running on. + event_class: The class to use to signal an event in an async + manner. debug: Wrapper around configuration DEBUG value, in many places this will result in more output if True. If unset, debug mode will be activated if environ is set to 'development'. @@ -193,21 +191,21 @@ class Quart(Scaffold): response_class: The class to user for responses. secret_key: Warpper around configuration SECRET_KEY value. The app secret for signing sessions. - session_cookie_name: Wrapper around configuration - SESSION_COOKIE_NAME, use to specify the cookie name for session - data. session_interface: The class to use as the session interface. + shutdown_event: This event is set when the app starts to + shutdown allowing waiting tasks to know when to stop. url_map_class: The class to map rules to endpoints. url_rule_class: The class to use for URL rules. websocket_class: The class to use for websockets. """ - asgi_http_class: Type[ASGIHTTPProtocol] - asgi_lifespan_class: Type[ASGILifespanProtocol] - asgi_websocket_class: Type[ASGIWebsocketProtocol] - test_app_class: Type[TestAppProtocol] - test_client_class: Type[TestClientProtocol] + asgi_http_class: type[ASGIHTTPProtocol] + asgi_lifespan_class: type[ASGILifespanProtocol] + asgi_websocket_class: type[ASGIWebsocketProtocol] + shutdown_event: Event + test_app_class: type[TestAppProtocol] + test_client_class: type[TestClientProtocol] # type: ignore[assignment] aborter_class = Aborter app_ctx_globals_class = _AppCtxGlobals @@ -215,42 +213,64 @@ class Quart(Scaffold): asgi_lifespan_class = ASGILifespan asgi_websocket_class = ASGIWebsocketConnection config_class = Config - env = ConfigAttribute("ENV") - jinja_environment = Environment - jinja_options: dict = {} - json_provider_class: Type[JSONProvider] = DefaultJSONProvider + event_class = asyncio.Event + jinja_environment = Environment # type: ignore[assignment] lock_class = asyncio.Lock - permanent_session_lifetime = ConfigAttribute( - "PERMANENT_SESSION_LIFETIME", converter=_convert_timedelta - ) request_class = Request response_class = Response - secret_key = ConfigAttribute("SECRET_KEY") - send_file_max_age_default = ConfigAttribute( - "SEND_FILE_MAX_AGE_DEFAULT", converter=_convert_timedelta - ) - session_cookie_name = ConfigAttribute("SESSION_COOKIE_NAME") session_interface = SecureCookieSessionInterface() test_app_class = TestApp - test_client_class = QuartClient - test_cli_runner_class = QuartCliRunner - testing = ConfigAttribute("TESTING") + test_client_class = QuartClient # type: ignore[assignment] + test_cli_runner_class = QuartCliRunner # type: ignore url_map_class = QuartMap - url_rule_class = QuartRule + url_rule_class = QuartRule # type: ignore[assignment] websocket_class = Websocket + default_config = ImmutableDict( + { + "APPLICATION_ROOT": "/", + "BACKGROUND_TASK_SHUTDOWN_TIMEOUT": 5, # Second + "BODY_TIMEOUT": 60, # Second + "DEBUG": None, + "ENV": None, + "EXPLAIN_TEMPLATE_LOADING": False, + "MAX_CONTENT_LENGTH": 16 * 1024 * 1024, # 16 MB Limit + "MAX_COOKIE_SIZE": 4093, + "PERMANENT_SESSION_LIFETIME": timedelta(days=31), + # Replaces PREFERRED_URL_SCHEME to allow for WebSocket scheme + "PREFER_SECURE_URLS": False, + "PRESERVE_CONTEXT_ON_EXCEPTION": None, + "PROPAGATE_EXCEPTIONS": None, + "RESPONSE_TIMEOUT": 60, # Second + "SECRET_KEY": None, + "SEND_FILE_MAX_AGE_DEFAULT": timedelta(hours=12), + "SERVER_NAME": None, + "SESSION_COOKIE_DOMAIN": None, + "SESSION_COOKIE_HTTPONLY": True, + "SESSION_COOKIE_NAME": "session", + "SESSION_COOKIE_PATH": None, + "SESSION_COOKIE_SAMESITE": None, + "SESSION_COOKIE_SECURE": False, + "SESSION_REFRESH_EACH_REQUEST": True, + "TEMPLATES_AUTO_RELOAD": None, + "TESTING": False, + "TRAP_BAD_REQUEST_ERRORS": None, + "TRAP_HTTP_EXCEPTIONS": False, + } + ) + def __init__( self, import_name: str, - static_url_path: Optional[str] = None, - static_folder: Optional[str] = "static", - static_host: Optional[str] = None, + static_url_path: str | None = None, + static_folder: str | None = "static", + static_host: str | None = None, host_matching: bool = False, subdomain_matching: bool = False, - template_folder: Optional[str] = "templates", - instance_path: Optional[str] = None, + template_folder: str | None = "templates", + instance_path: str | None = None, instance_relative_config: bool = False, - root_path: Optional[str] = None, + root_path: str | None = None, ) -> None: """Construct a Quart web application. @@ -273,48 +293,48 @@ def __init__( request has been handled. after_websocket_funcs: The functions to execute after a websocket has been handled. - before_first_request_func: Functions to execute before the - first request only. before_request_funcs: The functions to execute before handling a request. before_websocket_funcs: The functions to execute before handling a websocket. """ - super().__init__(import_name, static_folder, static_url_path, template_folder, root_path) - - instance_path = Path(instance_path) if instance_path else self.auto_find_instance_path() - if not instance_path.is_absolute(): - raise ValueError("The instance_path must be an absolute path.") - self.instance_path = instance_path - - self.aborter = self.make_aborter() - self.config = self.make_config(instance_relative_config) - - self.after_serving_funcs: List[Callable[[], Awaitable[None]]] = [] - self.background_tasks: WeakSet[asyncio.Task] = WeakSet() - self.before_first_request_funcs: List[BeforeFirstRequestCallable] = [] - self.before_serving_funcs: List[Callable[[], Awaitable[None]]] = [] - self.blueprints: Dict[str, Blueprint] = OrderedDict() - self.extensions: Dict[str, Any] = {} - self.json: JSONProvider = self.json_provider_class(self) - self.shell_context_processors: List[Callable[[], Dict[str, Any]]] = [] - self.teardown_appcontext_funcs: List[TeardownCallable] = [] - self.url_build_error_handlers: List[Callable[[Exception, str, dict], str]] = [] - self.url_map = self.url_map_class(host_matching=host_matching) - self.subdomain_matching = subdomain_matching - self.while_serving_gens: List[AsyncGenerator[None, None]] = [] - - self._got_first_request = False - self._first_request_lock = self.lock_class() - self._jinja_env: Optional[Environment] = None - self._logger: Optional[Logger] = None + super().__init__( + import_name, + static_url_path, + static_folder, + static_host, + host_matching, + subdomain_matching, + template_folder, + instance_path, + instance_relative_config, + root_path, + ) + + self.after_serving_funcs: list[Callable[[], Awaitable[None]]] = [] + self.after_websocket_funcs: dict[AppOrBlueprintKey, list[AfterWebsocketCallable]] = ( + defaultdict(list) + ) + self.background_tasks: Set[asyncio.Task] = set() + self.before_serving_funcs: list[Callable[[], Awaitable[None]]] = [] + self.before_websocket_funcs: dict[AppOrBlueprintKey, list[BeforeWebsocketCallable]] = ( + defaultdict(list) + ) + self.teardown_websocket_funcs: dict[AppOrBlueprintKey, list[TeardownCallable]] = ( + defaultdict(list) + ) + self.while_serving_gens: list[AsyncGenerator[None, None]] = [] + + self.template_context_processors[None] = [_default_template_ctx_processor] + + self.cli = AppGroup() + self.cli.name = self.name if self.has_static_folder: - if bool(static_host) != host_matching: - raise ValueError( - "static_host must be set if there is a static folder and host_matching is " - "enabled" - ) + assert ( + bool(static_host) == host_matching + ), "Invalid static_host/host_matching combination" + self.add_url_rule( f"{self.static_url_path}/", "static", @@ -322,101 +342,53 @@ def __init__( host=static_host, ) - self.template_context_processors[None] = [_default_template_ctx_processor] + def get_send_file_max_age(self, filename: str | None) -> int | None: + """Used by :func:`send_file` to determine the ``max_age`` cache + value for a given file path if it wasn't passed. - def _check_setup_finished(self, f_name: str) -> None: - if self._got_first_request: - raise AssertionError( - f"The setup method '{f_name}' can no longer be called" - " on the application. It has already handled its first" - " request, any changes will not be applied" - " consistently.\n" - "Make sure all imports, decorators, functions, etc." - " needed to set up the application are done before" - " running it." - ) + By default, this returns :data:`SEND_FILE_MAX_AGE_DEFAULT` from + the configuration of :data:`~flask.current_app`. This defaults + to ``None``, which tells the browser to use conditional requests + instead of a timed cache, which is usually preferable. - @property - def name(self) -> str: # type: ignore - """The name of this application. + Note this is a duplicate of the same method in the Quart + class. - This is taken from the :attr:`import_name` and is used for - debugging purposes. """ - if self.import_name == "__main__": - path = Path(getattr(sys.modules["__main__"], "__file__", "__main__.py")) - return path.stem - return self.import_name + value = self.config["SEND_FILE_MAX_AGE_DEFAULT"] - @property - def propagate_exceptions(self) -> bool: - """Return true if exceptions should be propagated into debug pages. + if value is None: + return None - If false the exception will be handled. See the - ``PROPAGATE_EXCEPTIONS`` config setting. - """ - propagate = self.config["PROPAGATE_EXCEPTIONS"] - if propagate is not None: - return propagate - else: - return self.debug or self.testing + if isinstance(value, timedelta): + return int(value.total_seconds()) - @property - def preserve_context_on_exception(self) -> bool: - preserve = self.config["PRESERVE_CONTEXT_ON_EXCEPTION"] - if preserve is not None: - return preserve - else: - return self.debug + return value + return None - @property - def logger(self) -> Logger: - """A :class:`logging.Logger` logger for the app. + async def send_static_file(self, filename: str) -> Response: + if not self.has_static_folder: + raise RuntimeError("No static folder for this object") + return await send_from_directory(self.static_folder, filename) - This can be used to log messages in a format as defined in the - app configuration, for example, + async def open_resource( + self, + path: FilePath, + mode: str = "rb", + ) -> AiofilesContextManager[None, None, AsyncBufferedReader]: + """Open a file for reading. - .. code-block:: python + Use as - app.logger.debug("Request method %s", request.method) - app.logger.error("Error, of some kind") + .. code-block:: python + async with await app.open_resource(path) as file_: + await file_.read() """ - if self._logger is None: - self._logger = create_logger(self) - return self._logger - - @property - def jinja_env(self) -> Environment: - """The jinja environment used to load templates.""" - if self._jinja_env is None: - self._jinja_env = self.create_jinja_environment() - return self._jinja_env - - @property - def got_first_request(self) -> bool: - """Return if the app has received a request.""" - return self._got_first_request - - def make_aborter(self) -> Aborter: - """Create and return the aborter instance.""" - return self.aborter_class() - - def make_config(self, instance_relative: bool = False) -> Config: - """Create and return the configuration with appropriate defaults.""" - config = self.config_class( - self.instance_path if instance_relative else self.root_path, DEFAULT_CONFIG - ) - config["ENV"] = get_env() - config["DEBUG"] = get_debug_flag() - return config + if mode not in {"r", "rb", "rt"}: + raise ValueError("Files can only be opened for reading") - def auto_find_instance_path(self) -> Path: - """Locates the instance_path if it was not provided""" - prefix, package_path = find_package(self.import_name) - if prefix is None: - return package_path / "instance" - return prefix / "var" / f"{self.name}-instance" + return async_open(os.path.join(self.root_path, path), mode) # type: ignore async def open_instance_resource( self, path: FilePath, mode: str = "rb" @@ -432,20 +404,7 @@ async def open_instance_resource( """ return async_open(self.instance_path / file_path_to_path(path), mode) # type: ignore - @property - def templates_auto_reload(self) -> bool: - """Returns True if templates should auto reload.""" - result = self.config["TEMPLATES_AUTO_RELOAD"] - if result is None: - return self.debug - else: - return result - - @templates_auto_reload.setter - def templates_auto_reload(self, value: Optional[bool]) -> None: - self.config["TEMPLATES_AUTO_RELOAD"] = value - - def create_jinja_environment(self) -> Environment: + def create_jinja_environment(self) -> Environment: # type: ignore """Create and return the jinja environment. This will create the environment based on the @@ -456,8 +415,8 @@ def create_jinja_environment(self) -> Environment: if "autoescape" not in options: options["autoescape"] = self.select_jinja_autoescape if "auto_reload" not in options: - options["auto_reload"] = self.templates_auto_reload - jinja_env = self.jinja_environment(self, **options) + options["auto_reload"] = self.config["TEMPLATES_AUTO_RELOAD"] + jinja_env = self.jinja_environment(self, **options) # type: ignore jinja_env.globals.update( { "config": self.config, @@ -471,16 +430,6 @@ def create_jinja_environment(self) -> Environment: jinja_env.policies["json.dumps_function"] = self.json.dumps return jinja_env - def create_global_jinja_loader(self) -> DispatchingJinjaLoader: - """Create and return a global (not blueprint specific) Jinja loader.""" - return DispatchingJinjaLoader(self) - - def select_jinja_autoescape(self, filename: str) -> bool: - """Returns True if the filename indicates that it should be escaped.""" - if filename is None: - return True - return Path(filename).suffix in {".htm", ".html", ".xhtml", ".xml"} - async def update_template_context(self, context: dict) -> None: """Update the provided template context. @@ -499,319 +448,12 @@ async def update_template_context(self, context: dict) -> None: extra_context: dict = {} for name in names: for processor in self.template_context_processors[name]: - extra_context.update(await self.ensure_async(processor)()) + extra_context.update(await self.ensure_async(processor)()) # type: ignore original = context.copy() context.update(extra_context) context.update(original) - def make_shell_context(self) -> dict: - """Create a context for interactive shell usage. - - The :attr:`shell_context_processors` can be used to add - additional context. - """ - context = {"app": self, "g": g} - for processor in self.shell_context_processors: - context.update(processor()) - return context - - @property - def debug(self) -> bool: - """Activate debug mode (extra checks, logging and reloading). - - Should/must be False in production. - """ - return self.config["DEBUG"] - - @debug.setter - def debug(self, value: bool) -> None: - self.config["DEBUG"] = value - self.jinja_env.auto_reload = self.templates_auto_reload - - def test_client(self, use_cookies: bool = True) -> TestClientProtocol: - """Creates and returns a test client.""" - return self.test_client_class(self, use_cookies=use_cookies) - - def test_cli_runner(self, **kwargs: Any) -> QuartCliRunner: - """Creates and returns a CLI test runner.""" - return self.test_cli_runner_class(self, **kwargs) - - @setupmethod - def register_blueprint( - self, - blueprint: Blueprint, - **options: Any, - ) -> None: - """Register a blueprint on the app. - - This results in the blueprint's routes, error handlers - etc... being added to the app. - - Arguments: - blueprint: The blueprint to register. - url_prefix: Optional prefix to apply to all paths. - url_defaults: Blueprint routes will use these default values for view arguments. - subdomain: Blueprint routes will match on this subdomain. - """ - blueprint.register(self, options) - - def iter_blueprints(self) -> ValuesView[Blueprint]: - """Return a iterator over the blueprints.""" - return self.blueprints.values() - - @setupmethod - def add_url_rule( - self, - rule: str, - endpoint: Optional[str] = None, - view_func: Optional[Callable] = None, - provide_automatic_options: Optional[bool] = None, - *, - methods: Optional[Iterable[str]] = None, - defaults: Optional[dict] = None, - host: Optional[str] = None, - subdomain: Optional[str] = None, - is_websocket: bool = False, - strict_slashes: Optional[bool] = None, - merge_slashes: Optional[bool] = None, - ) -> None: - """Add a route/url rule to the application. - - This is designed to be used on the application directly. An - example usage, - - .. code-block:: python - - def route(): - ... - - app.add_url_rule('/', route) - - Arguments: - rule: The path to route on, should start with a ``/``. - endpoint: Optional endpoint name, if not present the - function name is used. - view_func: Callable that returns a response. - provide_automatic_options: Optionally False to prevent - OPTION handling. - methods: List of HTTP verbs the function routes. - defaults: A dictionary of variables to provide automatically, use - to provide a simpler default path for a route, e.g. to allow - for ``/book`` rather than ``/book/0``, - - .. code-block:: python - - @app.route('/book', defaults={'page': 0}) - @app.route('/book/') - def book(page): - ... - - host: The full host name for this route (should include subdomain - if needed) - cannot be used with subdomain. - subdomain: A subdomain for this specific route. - strict_slashes: Strictly match the trailing slash present in the - path. Will redirect a leaf (no slash) to a branch (with slash). - is_websocket: Whether or not the view_func is a websocket. - merge_slashes: Merge consecutive slashes to a single slash (unless - as part of the path variable). - """ - endpoint = endpoint or _endpoint_from_view_func(view_func) - if methods is None: - methods = getattr(view_func, "methods", ["GET"]) - - methods = cast(Set[str], set(methods)) - required_methods = set(getattr(view_func, "required_methods", set())) - - if provide_automatic_options is None: - automatic_options = getattr(view_func, "provide_automatic_options", None) - if automatic_options is None: - automatic_options = "OPTIONS" not in methods - else: - automatic_options = provide_automatic_options - - if automatic_options: - required_methods.add("OPTIONS") - - methods.update(required_methods) - - rule = self.url_rule_class( - rule, - methods=methods, - endpoint=endpoint, - host=host, - subdomain=subdomain, - defaults=defaults, - websocket=is_websocket, - strict_slashes=strict_slashes, - merge_slashes=merge_slashes, - provide_automatic_options=automatic_options, - ) - self.url_map.add(rule) - - if view_func is not None: - old_view_func = self.view_functions.get(endpoint) - if old_view_func is not None and old_view_func != view_func: - raise AssertionError(f"Handler is overwriting existing for endpoint {endpoint}") - - self.view_functions[endpoint] = view_func - - @setupmethod - def template_filter( - self, name: Optional[str] = None - ) -> Callable[[T_template_filter], T_template_filter]: - """Add a template filter. - - This is designed to be used as a decorator. An example usage, - - .. code-block:: python - - @app.template_filter('name') - def to_upper(value): - return value.upper() - - Arguments: - name: The filter name (defaults to function name). - """ - - def decorator(func: T_template_filter) -> T_template_filter: - self.add_template_filter(func, name=name) - return func - - return decorator - - @setupmethod - def add_template_filter(self, func: TemplateFilterCallable, name: Optional[str] = None) -> None: - """Add a template filter. - - This is designed to be used on the application directly. An - example usage, - - .. code-block:: python - - def to_upper(value): - return value.upper() - - app.add_template_filter(to_upper) - - Arguments: - func: The function that is the filter. - name: The filter name (defaults to function name). - """ - self.jinja_env.filters[name or func.__name__] = func - - @setupmethod - def template_test( - self, name: Optional[str] = None - ) -> Callable[[T_template_test], T_template_test]: - """Add a template test. - - This is designed to be used as a decorator. An example usage, - - .. code-block:: python - - @app.template_test('name') - def is_upper(value): - return value.isupper() - - Arguments: - name: The test name (defaults to function name). - """ - - def decorator(func: T_template_test) -> T_template_test: - self.add_template_test(func, name=name) - return func - - return decorator - - @setupmethod - def add_template_test(self, func: TemplateTestCallable, name: Optional[str] = None) -> None: - """Add a template test. - - This is designed to be used on the application directly. An - example usage, - - .. code-block:: python - - def is_upper(value): - return value.isupper() - - app.add_template_test(is_upper) - - Arguments: - func: The function that is the test. - name: The test name (defaults to function name). - """ - self.jinja_env.tests[name or func.__name__] = func - - @setupmethod - def template_global( - self, name: Optional[str] = None - ) -> Callable[[T_template_global], T_template_global]: - """Add a template global. - - This is designed to be used as a decorator. An example usage, - - .. code-block:: python - - @app.template_global('name') - def five(): - return 5 - - Arguments: - name: The global name (defaults to function name). - """ - - def decorator(func: T_template_global) -> T_template_global: - self.add_template_global(func, name=name) - return func - - return decorator - - @setupmethod - def add_template_global(self, func: TemplateGlobalCallable, name: Optional[str] = None) -> None: - """Add a template global. - - This is designed to be used on the application directly. An - example usage, - - .. code-block:: python - - def five(): - return 5 - - app.add_template_global(five) - - Arguments: - func: The function that is the global. - name: The global name (defaults to function name). - """ - self.jinja_env.globals[name or func.__name__] = func - - @setupmethod - def before_first_request( - self, - func: T_before_first_request, - ) -> T_before_first_request: - """Add a before **first** request function. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.before_first_request - async def func(): - ... - - Arguments: - func: The before first request function itself. - """ - self.before_first_request_funcs.append(func) - return func - @setupmethod def before_serving( self, @@ -893,7 +535,7 @@ async def func(): self.after_serving_funcs.append(func) return func - def create_url_adapter(self, request: Optional[BaseRequestWebsocket]) -> Optional[MapAdapter]: + def create_url_adapter(self, request: BaseRequestWebsocket | None) -> MapAdapter | None: """Create and return a URL adapter. This will create the adapter based on the request if present @@ -904,56 +546,131 @@ def create_url_adapter(self, request: Optional[BaseRequestWebsocket]) -> Optiona (self.url_map.default_subdomain or None) if not self.subdomain_matching else None ) - return self.url_map.bind_to_request(request, subdomain, self.config["SERVER_NAME"]) + return self.url_map.bind_to_request( # type: ignore[attr-defined] + request, subdomain, self.config["SERVER_NAME"] + ) if self.config["SERVER_NAME"] is not None: scheme = "https" if self.config["PREFER_SECURE_URLS"] else "http" return self.url_map.bind(self.config["SERVER_NAME"], url_scheme=scheme) return None - @setupmethod - def shell_context_processor(self, func: T_shell_context_processor) -> T_shell_context_processor: - """Add a shell context processor. + def websocket( + self, + rule: str, + **options: Any, + ) -> Callable[[T_websocket], T_websocket]: + """Add a websocket to the application. - This is designed to be used as a decorator. An example usage, + This is designed to be used as a decorator, if used to + decorate a synchronous function, the function will be wrapped + in :func:`~quart.utils.run_sync` and run in a thread executor + (with the wrapped function returned). An example usage, .. code-block:: python - @app.shell_context_processor - def additional_context(): - return context + @app.websocket('/') + async def websocket_route(): + ... - """ - self.shell_context_processors.append(func) - return func + Arguments: + rule: The path to route on, should start with a ``/``. + endpoint: Optional endpoint name, if not present the + function name is used. + defaults: A dictionary of variables to provide automatically, use + to provide a simpler default path for a route, e.g. to allow + for ``/book`` rather than ``/book/0``, + + .. code-block:: python - def inject_url_defaults(self, endpoint: str, values: dict) -> None: - """Injects default URL values into the passed values dict. + @app.websocket('/book', defaults={'page': 0}) + @app.websocket('/book/') + def book(page): + ... - This is used to assist when building urls, see `url_for`. + host: The full host name for this route (should include subdomain + if needed) - cannot be used with subdomain. + subdomain: A subdomain for this specific route. + strict_slashes: Strictly match the trailing slash present in the + path. Will redirect a leaf (no slash) to a branch (with slash). """ - names: List[Optional[str]] = [None] - if "." in endpoint: - names.extend(reversed(_split_blueprint_path(endpoint.rsplit(".", 1)[0]))) - for name in names: - for function in self.url_default_functions[name]: - function(endpoint, values) + def decorator(func: T_websocket) -> T_websocket: + endpoint = options.pop("endpoint", None) + self.add_websocket( + rule, + endpoint, + func, + **options, + ) + return func - def url_for( - self, - endpoint: str, - *, - _anchor: Optional[str] = None, - _external: Optional[bool] = None, - _method: Optional[str] = None, - _scheme: Optional[str] = None, - **values: Any, - ) -> str: - """Return the url for a specific endpoint. + return decorator - This is most useful in templates and redirects to create a URL - that can be used in the browser. + def add_websocket( + self, + rule: str, + endpoint: str | None = None, + view_func: WebsocketCallable | None = None, + **options: Any, + ) -> None: + """Add a websocket url rule to the application. + + This is designed to be used on the application directly. An + example usage, + + .. code-block:: python + + def websocket_route(): + ... + + app.add_websocket('/', websocket_route) + + Arguments: + rule: The path to route on, should start with a ``/``. + endpoint: Optional endpoint name, if not present the + function name is used. + view_func: Callable that returns a response. + defaults: A dictionary of variables to provide automatically, use + to provide a simpler default path for a route, e.g. to allow + for ``/book`` rather than ``/book/0``, + + .. code-block:: python + + @app.websocket('/book', defaults={'page': 0}) + @app.websocket('/book/') + def book(page): + ... + + host: The full host name for this route (should include subdomain + if needed) - cannot be used with subdomain. + subdomain: A subdomain for this specific route. + strict_slashes: Strictly match the trailing slash present in the + path. Will redirect a leaf (no slash) to a branch (with slash). + """ + return self.add_url_rule( + rule, + endpoint, + view_func, + methods={"GET"}, + websocket=True, + **options, + ) + + def url_for( + self, + endpoint: str, + *, + _anchor: str | None = None, + _external: bool | None = None, + _method: str | None = None, + _scheme: str | None = None, + **values: Any, + ) -> str: + """Return the url for a specific endpoint. + + This is most useful in templates and redirects to create a URL + that can be used in the browser. Arguments: endpoint: The endpoint to build a url for, if prefixed with @@ -1020,49 +737,249 @@ def url_for( url_adapter.url_scheme = old_scheme if _anchor is not None: - quoted_anchor = url_quote(_anchor) + quoted_anchor = quote(_anchor, safe="%!#$&'()*+,/:;=?@") url = f"{url}#{quoted_anchor}" return url - def handle_url_build_error(self, error: Exception, endpoint: str, values: dict) -> str: - """Handle a build error. + def make_shell_context(self) -> dict: + """Create a context for interactive shell usage. - Ideally this will return a valid url given the error endpoint - and values. + The :attr:`shell_context_processors` can be used to add + additional context. """ - for handler in self.url_build_error_handlers: - result = handler(error, endpoint, values) - if result is not None: - return result - raise error + context = {"app": self, "g": g} + for processor in self.shell_context_processors: + context.update(processor()) + return context - def _find_error_handler(self, error: Exception) -> Optional[ErrorHandlerCallable]: - error_type, error_code = self._get_error_type_and_code(type(error)) + def run( + self, + host: str | None = None, + port: int | None = None, + debug: bool | None = None, + use_reloader: bool = True, + loop: asyncio.AbstractEventLoop | None = None, + ca_certs: str | None = None, + certfile: str | None = None, + keyfile: str | None = None, + **kwargs: Any, + ) -> None: + """Run this application. - names = [] - if has_request_context(): - names.extend(request_ctx.request.blueprints) - elif has_websocket_context(): - names.extend(websocket_ctx.websocket.blueprints) - names.append(None) + This is best used for development only, see Hypercorn for + production servers. - for code in [error_code, None]: - for name in names: - handlers = self.error_handler_spec[name].get(code) + Arguments: + host: Hostname to listen on. By default this is loopback + only, use 0.0.0.0 to have the server listen externally. + port: Port number to listen on. + debug: If set enable (or disable) debug mode and debug output. + use_reloader: Automatically reload on code changes. + loop: Asyncio loop to create the server in, if None, take default one. + If specified it is the caller's responsibility to close and cleanup the + loop. + ca_certs: Path to the SSL CA certificate file. + certfile: Path to the SSL certificate file. + keyfile: Path to the SSL key file. + """ + if kwargs: + warnings.warn( + f"Additional arguments, {','.join(kwargs.keys())}, are not supported.\n" + "They may be supported by Hypercorn, which is the ASGI server Quart " + "uses by default. This method is meant for development and debugging.", + stacklevel=2, + ) + + if loop is None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - if handlers is None: - continue + if "QUART_DEBUG" in os.environ: + self.debug = get_debug_flag() - for cls in error_type.__mro__: - handler = handlers.get(cls) + if debug is not None: + self.debug = debug - if handler is not None: - return handler - return None + loop.set_debug(self.debug) + + shutdown_event = asyncio.Event() + + def _signal_handler(*_: Any) -> None: + shutdown_event.set() + + for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}: + if hasattr(signal, signal_name): + try: + loop.add_signal_handler(getattr(signal, signal_name), _signal_handler) + except NotImplementedError: + # Add signal handler may not be implemented on Windows + signal.signal(getattr(signal, signal_name), _signal_handler) + + server_name = self.config.get("SERVER_NAME") + sn_host = None + sn_port = None + if server_name is not None: + sn_host, _, sn_port = server_name.partition(":") + + if host is None: + host = sn_host or "127.0.0.1" + + if port is None: + port = int(sn_port or "5000") + + task = self.run_task( + host, + port, + debug, + ca_certs, + certfile, + keyfile, + shutdown_trigger=shutdown_event.wait, # type: ignore + ) + print(f" * Serving Quart app '{self.name}'") # noqa: T201 + print(f" * Debug mode: {self.debug or False}") # noqa: T201 + print(" * Please use an ASGI server (e.g. Hypercorn) directly in production") # noqa: T201 + scheme = "https" if certfile is not None and keyfile is not None else "http" + print(f" * Running on {scheme}://{host}:{port} (CTRL + C to quit)") # noqa: T201 + + tasks = [loop.create_task(task)] + + if use_reloader: + tasks.append(loop.create_task(observe_changes(asyncio.sleep, shutdown_event))) + + reload_ = False + try: + loop.run_until_complete(asyncio.gather(*tasks)) + except MustReloadError: + reload_ = True + finally: + try: + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + loop.close() + + if reload_: + restart() + + def run_task( + self, + host: str = "127.0.0.1", + port: int = 5000, + debug: bool | None = None, + ca_certs: str | None = None, + certfile: str | None = None, + keyfile: str | None = None, + shutdown_trigger: Callable[..., Awaitable[None]] | None = None, + ) -> Coroutine[None, None, None]: + """Return a task that when awaited runs this application. + + This is best used for development only, see Hypercorn for + production servers. + + Arguments: + host: Hostname to listen on. By default this is loopback + only, use 0.0.0.0 to have the server listen externally. + port: Port number to listen on. + debug: If set enable (or disable) debug mode and debug output. + ca_certs: Path to the SSL CA certificate file. + certfile: Path to the SSL certificate file. + keyfile: Path to the SSL key file. + + """ + config = HyperConfig() + config.access_log_format = "%(h)s %(r)s %(s)s %(b)s %(D)s" + config.accesslog = "-" + config.bind = [f"{host}:{port}"] + config.ca_certs = ca_certs + config.certfile = certfile + if debug is not None: + self.debug = debug + config.errorlog = config.accesslog + config.keyfile = keyfile + + return serve(self, config, shutdown_trigger=shutdown_trigger) + + def test_client(self, use_cookies: bool = True, **kwargs: Any) -> TestClientProtocol: + """Creates and returns a test client.""" + return self.test_client_class(self, use_cookies=use_cookies, **kwargs) + + def test_cli_runner(self, **kwargs: Any) -> QuartCliRunner: + """Creates and returns a CLI test runner.""" + return self.test_cli_runner_class(self, **kwargs) # type: ignore + + @setupmethod + def before_websocket( + self, + func: T_before_websocket, + ) -> T_before_websocket: + """Add a before websocket function. + + This is designed to be used as a decorator, if used to + decorate a synchronous function, the function will be wrapped + in :func:`~quart.utils.run_sync` and run in a thread executor + (with the wrapped function returned). An example usage, + + .. code-block:: python + + @app.before_websocket + async def func(): + ... + + Arguments: + func: The before websocket function itself. + """ + self.before_websocket_funcs[None].append(func) + return func + + @setupmethod + def after_websocket( + self, + func: T_after_websocket, + ) -> T_after_websocket: + """Add an after websocket function. + + This is designed to be used as a decorator, if used to + decorate a synchronous function, the function will be wrapped + in :func:`~quart.utils.run_sync` and run in a thread executor + (with the wrapped function returned). An example usage, + + .. code-block:: python + + @app.after_websocket + async def func(response): + return response + + Arguments: + func: The after websocket function itself. + """ + self.after_websocket_funcs[None].append(func) + return func + + @setupmethod + def teardown_websocket( + self, + func: T_teardown, + ) -> T_teardown: + """Add a teardown websocket function. + This is designed to be used as a decorator, if used to + decorate a synchronous function, the function will be wrapped + in :func:`~quart.utils.run_sync` and run in a thread executor + (with the wrapped function returned). An example usage, + .. code-block:: python + @app.teardown_websocket + async def func(): + ... + Arguments: + func: The teardown websocket function itself. + """ + self.teardown_websocket_funcs[None].append(func) + return func async def handle_http_exception( self, error: HTTPException - ) -> Union[HTTPException, ResponseReturnValue]: + ) -> HTTPException | ResponseReturnValue: """Handle a HTTPException subclass error. This will attempt to find a handler for the error and if fails @@ -1074,91 +991,112 @@ async def handle_http_exception( if isinstance(error, RoutingException): return error - handler = self._find_error_handler(error) + blueprints = [] + if has_request_context(): + blueprints = request.blueprints + elif has_websocket_context(): + blueprints = websocket.blueprints + + handler = self._find_error_handler(error, blueprints) if handler is None: - return error.get_response() + return error else: - return await self.ensure_async(handler)(error) - - def trap_http_exception(self, error: Exception) -> bool: - """Check it error is http and should be trapped. + return await self.ensure_async(handler)(error) # type: ignore - Trapped errors are not handled by the - :meth:`handle_http_exception`, but instead trapped by the - outer most (or user handlers). This can be useful when - debugging to allow tracebacks to be viewed by the debug page. - """ - return self.config["TRAP_HTTP_EXCEPTIONS"] - - async def handle_user_exception( - self, error: Exception - ) -> Union[HTTPException, ResponseReturnValue]: + async def handle_user_exception(self, error: Exception) -> HTTPException | ResponseReturnValue: """Handle an exception that has been raised. This should forward :class:`~quart.exception.HTTPException` to :meth:`handle_http_exception`, then attempt to handle the error. If it cannot it should reraise the error. """ + if isinstance(error, BadRequestKeyError) and ( + self.debug or self.config["TRAP_BAD_REQUEST_ERRORS"] + ): + error.show_exception = True + if isinstance(error, HTTPException) and not self.trap_http_exception(error): return await self.handle_http_exception(error) - handler = self._find_error_handler(error) + blueprints = [] + if has_request_context(): + blueprints = request.blueprints + elif has_websocket_context(): + blueprints = websocket.blueprints + + handler = self._find_error_handler(error, blueprints) if handler is None: raise error - return await self.ensure_async(handler)(error) + return await self.ensure_async(handler)(error) # type: ignore - async def handle_exception(self, error: Exception) -> Union[Response, WerkzeugResponse]: + async def handle_exception(self, error: Exception) -> ResponseTypes: """Handle an uncaught exception. By default this switches the error response to a 500 internal server error. """ - await got_request_exception.send(self, exception=error) + exc_info = sys.exc_info() + await got_request_exception.send_async( + self, _sync_wrapper=self.ensure_async, exception=error # type: ignore + ) + propagate = self.config["PROPAGATE_EXCEPTIONS"] - self.log_exception(sys.exc_info()) + if propagate is None: + propagate = self.testing or self.debug + + if propagate: + # Re-raise if called with an active exception, otherwise + # raise the passed in exception. + if exc_info[1] is error: + raise - if self.propagate_exceptions: raise error - internal_server_error = InternalServerError(original_exception=error) - handler = self._find_error_handler(internal_server_error) + self.log_exception(exc_info) + server_error: InternalServerError | ResponseReturnValue + server_error = InternalServerError(original_exception=error) + handler = self._find_error_handler(server_error, request.blueprints) - response: Union[Response, WerkzeugResponse, InternalServerError] if handler is not None: - response = await self.ensure_async(handler)(internal_server_error) - else: - response = internal_server_error + server_error = await self.ensure_async(handler)(server_error) # type: ignore - return await self.finalize_request(response, from_error_handler=True) + return await self.finalize_request(server_error, from_error_handler=True) - async def handle_websocket_exception( - self, error: Exception - ) -> Optional[Union[Response, WerkzeugResponse]]: + async def handle_websocket_exception(self, error: Exception) -> ResponseTypes | None: """Handle an uncaught exception. By default this logs the exception and then re-raises it. """ - await got_websocket_exception.send(self, exception=error) + exc_info = sys.exc_info() + await got_websocket_exception.send_async( + self, _sync_wrapper=self.ensure_async, exception=error # type: ignore + ) + propagate = self.config["PROPAGATE_EXCEPTIONS"] - self.log_exception(sys.exc_info()) + if propagate is None: + propagate = self.testing or self.debug + + if propagate: + # Re-raise if called with an active exception, otherwise + # raise the passed in exception. + if exc_info[1] is error: + raise - if self.propagate_exceptions: raise error - internal_server_error = InternalServerError(original_exception=error) - handler = self._find_error_handler(internal_server_error) + self.log_exception(exc_info) + server_error: InternalServerError | ResponseReturnValue + server_error = InternalServerError(original_exception=error) + handler = self._find_error_handler(server_error, websocket.blueprints) - response: Union[Response, WerkzeugResponse, InternalServerError] if handler is not None: - response = await self.ensure_async(handler)(internal_server_error) - else: - response = internal_server_error + server_error = await self.ensure_async(handler)(server_error) # type: ignore - return await self.finalize_websocket(response, from_error_handler=True) + return await self.finalize_websocket(server_error, from_error_handler=True) def log_exception( self, - exception_info: Union[Tuple[type, BaseException, TracebackType], Tuple[None, None, None]], + exception_info: tuple[type, BaseException, TracebackType] | tuple[None, None, None], ) -> None: """Log a exception to the :attr:`logger`. @@ -1175,35 +1113,15 @@ def log_exception( else: self.logger.error("Exception", exc_info=exception_info) - def raise_routing_exception(self, request: BaseRequestWebsocket) -> NoReturn: - raise request.routing_exception + @overload + def ensure_async(self, func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: ... - @setupmethod - def teardown_appcontext( - self, - func: T_teardown, - ) -> T_teardown: - """Add a teardown app (context) function. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.teardown_appcontext - async def func(): - ... + @overload + def ensure_async(self, func: Callable[P, T]) -> Callable[P, Awaitable[T]]: ... - Arguments: - func: The teardown function itself. - name: Optional blueprint key name. - """ - self.teardown_appcontext_funcs.append(func) - return func - - def ensure_async(self, func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: + def ensure_async( + self, func: Union[Callable[P, Awaitable[T]], Callable[P, T]] + ) -> Callable[P, Awaitable[T]]: """Ensure that the returned func is async and calls the func. .. versionadded:: 0.11 @@ -1212,12 +1130,12 @@ def ensure_async(self, func: Callable[..., Any]) -> Callable[..., Awaitable[Any] run. Before Quart 0.11 this did not run the synchronous code in an executor. """ - if is_coroutine_function(func): + if asyncio.iscoroutinefunction(func): return func else: - return self.sync_to_async(func) + return self.sync_to_async(cast(Callable[P, T], func)) - def sync_to_async(self, func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: + def sync_to_async(self, func: Callable[P, T]) -> Callable[P, Awaitable[T]]: """Return a async function that will run the synchronous function *func*. This can be used as so,:: @@ -1230,7 +1148,7 @@ def sync_to_async(self, func: Callable[..., Any]) -> Callable[..., Awaitable[Any return run_sync(func) async def do_teardown_request( - self, exc: Optional[BaseException], request_context: Optional[RequestContext] = None + self, exc: BaseException | None, request_context: RequestContext | None = None ) -> None: """Teardown the request, calling the teardown functions. @@ -1242,13 +1160,15 @@ async def do_teardown_request( """ names = [*(request_context or request_ctx).request.blueprints, None] for name in names: - for function in self.teardown_request_funcs[name]: + for function in reversed(self.teardown_request_funcs[name]): await self.ensure_async(function)(exc) - await request_tearing_down.send(self, exc=exc) + await request_tearing_down.send_async( + self, _sync_wrapper=self.ensure_async, exc=exc # type: ignore + ) async def do_teardown_websocket( - self, exc: Optional[BaseException], websocket_context: Optional[WebsocketContext] = None + self, exc: BaseException | None, websocket_context: WebsocketContext | None = None ) -> None: """Teardown the websocket, calling the teardown functions. @@ -1260,16 +1180,20 @@ async def do_teardown_websocket( """ names = [*(websocket_context or websocket_ctx).websocket.blueprints, None] for name in names: - for function in self.teardown_websocket_funcs[name]: + for function in reversed(self.teardown_websocket_funcs[name]): await self.ensure_async(function)(exc) - await websocket_tearing_down.send(self, exc=exc) + await websocket_tearing_down.send_async( + self, _sync_wrapper=self.ensure_async, exc=exc # type: ignore + ) - async def do_teardown_appcontext(self, exc: Optional[BaseException]) -> None: + async def do_teardown_appcontext(self, exc: BaseException | None) -> None: """Teardown the app (context), calling the teardown functions.""" for function in self.teardown_appcontext_funcs: await self.ensure_async(function)(exc) - await appcontext_tearing_down.send(self, exc=exc) + await appcontext_tearing_down.send_async( + self, _sync_wrapper=self.ensure_async, exc=exc # type: ignore + ) def app_context(self) -> AppContext: """Create and return an app context. @@ -1280,197 +1204,40 @@ def app_context(self) -> AppContext: async with app.app_context(): ... - """ - return AppContext(self) - - def request_context(self, request: Request) -> RequestContext: - """Create and return a request context. - - Use the :meth:`test_request_context` whilst testing. This is - best used within a context, i.e. - - .. code-block:: python - - async with app.request_context(request): - ... - - Arguments: - request: A request to build a context around. - """ - return RequestContext(self, request) - - def websocket_context(self, websocket: Websocket) -> WebsocketContext: - """Create and return a websocket context. - - Use the :meth:`test_websocket_context` whilst testing. This is - best used within a context, i.e. - - .. code-block:: python - - async with app.websocket_context(websocket): - ... - - Arguments: - websocket: A websocket to build a context around. - """ - return WebsocketContext(self, websocket) - - def run( - self, - host: Optional[str] = None, - port: Optional[int] = None, - debug: Optional[bool] = None, - use_reloader: bool = True, - loop: Optional[asyncio.AbstractEventLoop] = None, - ca_certs: Optional[str] = None, - certfile: Optional[str] = None, - keyfile: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Run this application. - - This is best used for development only, see Hypercorn for - production servers. - - Arguments: - host: Hostname to listen on. By default this is loopback - only, use 0.0.0.0 to have the server listen externally. - port: Port number to listen on. - debug: If set enable (or disable) debug mode and debug output. - use_reloader: Automatically reload on code changes. - loop: Asyncio loop to create the server in, if None, take default one. - If specified it is the caller's responsibility to close and cleanup the - loop. - ca_certs: Path to the SSL CA certificate file. - certfile: Path to the SSL certificate file. - keyfile: Path to the SSL key file. - """ - if kwargs: - warnings.warn( - f"Additional arguments, {','.join(kwargs.keys())}, are not supported.\n" - "They may be supported by Hypercorn, which is the ASGI server Quart " - "uses by default. This method is meant for development and debugging." - ) - - if loop is None: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - if "QUART_ENV" in os.environ: - self.env = get_env() - self.debug = get_debug_flag() - elif "QUART_DEBUG" in os.environ: - self.debug = get_debug_flag() - - if debug is not None: - self.debug = debug - - loop.set_debug(self.debug) - - shutdown_event = asyncio.Event() - - def _signal_handler(*_: Any) -> None: - shutdown_event.set() - - for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}: - if hasattr(signal, signal_name): - try: - loop.add_signal_handler(getattr(signal, signal_name), _signal_handler) - except NotImplementedError: - # Add signal handler may not be implemented on Windows - signal.signal(getattr(signal, signal_name), _signal_handler) - - server_name = self.config.get("SERVER_NAME") - sn_host = None - sn_port = None - if server_name is not None: - sn_host, _, sn_port = server_name.partition(":") - - if host is None: - host = sn_host or "127.0.0.1" - - if port is None: - port = int(sn_port or "5000") - - task = self.run_task( - host, - port, - debug, - ca_certs, - certfile, - keyfile, - shutdown_trigger=shutdown_event.wait, # type: ignore - ) - print(f" * Serving Quart app '{self.name}'") # noqa: T201 - print(f" * Environment: {self.env}") # noqa: T201 - if self.env == "production": - print( # noqa: T201 - " * Please use an ASGI server (e.g. Hypercorn) directly in production" - ) - print(f" * Debug mode: {self.debug or False}") # noqa: T201 - scheme = "https" if certfile is not None and keyfile is not None else "http" - print(f" * Running on {scheme}://{host}:{port} (CTRL + C to quit)") # noqa: T201 + """ + return AppContext(self) - tasks = [loop.create_task(task)] - if platform.system() == "Windows": - tasks.append(loop.create_task(_windows_signal_support())) + def request_context(self, request: Request) -> RequestContext: + """Create and return a request context. - if use_reloader: - tasks.append(loop.create_task(observe_changes(asyncio.sleep, shutdown_event))) + Use the :meth:`test_request_context` whilst testing. This is + best used within a context, i.e. - reload_ = False - try: - loop.run_until_complete(asyncio.gather(*tasks)) - except MustReloadError: - reload_ = True - finally: - try: - _cancel_all_tasks(loop) - loop.run_until_complete(loop.shutdown_asyncgens()) - finally: - asyncio.set_event_loop(None) - loop.close() + .. code-block:: python - if reload_: - restart() + async with app.request_context(request): + ... - def run_task( - self, - host: str = "127.0.0.1", - port: int = 5000, - debug: Optional[bool] = None, - ca_certs: Optional[str] = None, - certfile: Optional[str] = None, - keyfile: Optional[str] = None, - shutdown_trigger: Optional[Callable[..., Awaitable[None]]] = None, - ) -> Coroutine[None, None, None]: - """Return a task that when awaited runs this application. + Arguments: + request: A request to build a context around. + """ + return RequestContext(self, request) - This is best used for development only, see Hypercorn for - production servers. + def websocket_context(self, websocket: Websocket) -> WebsocketContext: + """Create and return a websocket context. - Arguments: - host: Hostname to listen on. By default this is loopback - only, use 0.0.0.0 to have the server listen externally. - port: Port number to listen on. - debug: If set enable (or disable) debug mode and debug output. - ca_certs: Path to the SSL CA certificate file. - certfile: Path to the SSL certificate file. - keyfile: Path to the SSL key file. + Use the :meth:`test_websocket_context` whilst testing. This is + best used within a context, i.e. - """ - config = HyperConfig() - config.access_log_format = "%(h)s %(r)s %(s)s %(b)s %(D)s" - config.accesslog = "-" - config.bind = [f"{host}:{port}"] - config.ca_certs = ca_certs - config.certfile = certfile - if debug is not None: - self.debug = debug - config.errorlog = config.accesslog - config.keyfile = keyfile + .. code-block:: python - return serve(self, config, shutdown_trigger=shutdown_trigger) + async with app.websocket_context(websocket): + ... + + Arguments: + websocket: A websocket to build a context around. + """ + return WebsocketContext(self, websocket) def test_app(self) -> TestAppProtocol: return self.test_app_class(self) @@ -1480,18 +1247,18 @@ def test_request_context( path: str, *, method: str = "GET", - headers: Optional[Union[dict, Headers]] = None, - query_string: Optional[dict] = None, + headers: dict | Headers | None = None, + query_string: dict | None = None, scheme: str = "http", send_push_promise: Callable[[str, Headers], Awaitable[None]] = no_op_push, - data: Optional[AnyStr] = None, - form: Optional[dict] = None, + data: AnyStr | None = None, + form: dict | None = None, json: Any = sentinel, root_path: str = "", http_version: str = "1.1", - scope_base: Optional[dict] = None, - auth: Optional[Union[Authorization, Tuple[str, str]]] = None, - subdomain: Optional[str] = None, + scope_base: dict | None = None, + auth: Authorization | tuple[str, str] | None = None, + subdomain: str | None = None, ) -> RequestContext: """Create a request context for testing purposes. @@ -1550,49 +1317,28 @@ def test_request_context( def add_background_task(self, func: Callable, *args: Any, **kwargs: Any) -> None: async def _wrapper() -> None: try: - await copy_current_app_context(self.ensure_async(func))(*args, **kwargs) + async with self.app_context(): + await self.ensure_async(func)(*args, **kwargs) except Exception as error: await self.handle_background_exception(error) task = asyncio.get_event_loop().create_task(_wrapper()) self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) async def handle_background_exception(self, error: Exception) -> None: - await got_background_exception.send(self, exception=error) + await got_background_exception.send_async( + self, _sync_wrapper=self.ensure_async, exception=error # type: ignore + ) self.log_exception(sys.exc_info()) - async def try_trigger_before_first_request_functions(self) -> None: - """Trigger the before first request methods.""" - if self._got_first_request: - return - - # Reverse the teardown functions, so as to match the expected usage - self.teardown_appcontext_funcs = list(reversed(self.teardown_appcontext_funcs)) - for key, value in self.teardown_request_funcs.items(): - self.teardown_request_funcs[key] = list(reversed(value)) - for key, value in self.teardown_websocket_funcs.items(): - self.teardown_websocket_funcs[key] = list(reversed(value)) - - async with self._first_request_lock: - if self._got_first_request: - return - for function in self.before_first_request_funcs: - await self.ensure_async(function)() - self._got_first_request = True - - def redirect(self, location: str, code: int = 302) -> WerkzeugResponse: - """Create a redirect response object.""" - return werkzeug_redirect(location, code=code, Response=self.response_class) # type: ignore - async def make_default_options_response(self) -> Response: """This is the default route function for OPTIONS requests.""" methods = request_ctx.url_adapter.allowed_methods() return self.response_class("", headers={"Allow": ", ".join(methods)}) - async def make_response( - self, result: Union[ResponseReturnValue, HTTPException] - ) -> Union[Response, WerkzeugResponse]: + async def make_response(self, result: ResponseReturnValue | HTTPException) -> ResponseTypes: """Make a Response from the result of the route handler. The result itself can either be: @@ -1602,24 +1348,31 @@ async def make_response( A ResponseValue is either a Response object (or subclass) or a str. """ - status_or_headers: Optional[Union[StatusCode, HeadersValue]] = None - headers: Optional[HeadersValue] = None - status: Optional[StatusCode] = None + headers: HeadersValue | None = None + status: StatusCode | None = None if isinstance(result, tuple): - value, status_or_headers, headers = result + (None,) * (3 - len(result)) + if len(result) == 3: + value, status, headers = result + elif len(result) == 2: + value, status_or_headers = result + + if isinstance(status_or_headers, (Headers, dict, list)): + headers = status_or_headers + status = None + elif status_or_headers is not None: + status = status_or_headers # type: ignore[assignment] + else: + raise TypeError( + """The response value returned must be either (body, status), (body, + headers), or (body, status, headers)""" + ) else: - value = result + value = result # type: ignore[assignment] if value is None: raise TypeError("The response value returned by the view function cannot be None") - if isinstance(status_or_headers, (Headers, dict, list)): - headers = status_or_headers - status = None - elif status_or_headers is not None: - status = status_or_headers - - response: Union[Response, WerkzeugResponse] + response: ResponseTypes if isinstance(value, HTTPException): response = value.get_response() # type: ignore elif not isinstance(value, (Response, WerkzeugResponse)): @@ -1628,9 +1381,9 @@ async def make_response( or isgenerator(value) or isasyncgen(value) ): - response = self.response_class(value) # type: ignore + response = self.response_class(value) elif isinstance(value, (list, dict)): - response = self.json.response(value) + response = self.json.response(value) # type: ignore[assignment] else: raise TypeError(f"The response value type ({type(value).__name__}) is not valid") else: @@ -1640,11 +1393,11 @@ async def make_response( response.status_code = int(status) if headers is not None: - response.headers.update(headers) # type: ignore + response.headers.update(headers) # type: ignore[arg-type] return response - async def handle_request(self, request: Request) -> Union[Response, WerkzeugResponse]: + async def handle_request(self, request: Request) -> ResponseTypes: async with self.request_context(request) as request_context: try: return await self.full_dispatch_request(request_context) @@ -1656,18 +1409,31 @@ async def handle_request(self, request: Request) -> Union[Response, WerkzeugResp if request.scope.get("_quart._preserve_context", False): self._preserved_context = request_context.copy() + async def handle_websocket(self, websocket: Websocket) -> ResponseTypes | None: + async with self.websocket_context(websocket) as websocket_context: + try: + return await self.full_dispatch_websocket(websocket_context) + except asyncio.CancelledError: + raise # CancelledErrors should be handled by serving code. + except Exception as error: + return await self.handle_websocket_exception(error) + finally: + if websocket.scope.get("_quart._preserve_context", False): + self._preserved_context = websocket_context.copy() + async def full_dispatch_request( - self, request_context: Optional[RequestContext] = None - ) -> Union[Response, WerkzeugResponse]: + self, request_context: RequestContext | None = None + ) -> ResponseTypes: """Adds pre and post processing to the request dispatching. Arguments: request_context: The request context, optional as Flask omits this argument. """ - await self.try_trigger_before_first_request_functions() - await request_started.send(self) try: + await request_started.send_async(self, _sync_wrapper=self.ensure_async) # type: ignore + + result: ResponseReturnValue | HTTPException | None result = await self.preprocess_request(request_context) if result is None: result = await self.dispatch_request(request_context) @@ -1675,9 +1441,31 @@ async def full_dispatch_request( result = await self.handle_user_exception(error) return await self.finalize_request(result, request_context) + async def full_dispatch_websocket( + self, websocket_context: WebsocketContext | None = None + ) -> ResponseTypes | None: + """Adds pre and post processing to the websocket dispatching. + + Arguments: + websocket_context: The websocket context, optional to match + the Flask convention. + """ + try: + await websocket_started.send_async( + self, _sync_wrapper=self.ensure_async # type: ignore + ) + + result: ResponseReturnValue | HTTPException | None + result = await self.preprocess_websocket(websocket_context) + if result is None: + result = await self.dispatch_websocket(websocket_context) + except Exception as error: + result = await self.handle_user_exception(error) + return await self.finalize_websocket(result, websocket_context) + async def preprocess_request( - self, request_context: Optional[RequestContext] = None - ) -> Optional[ResponseReturnValue]: + self, request_context: RequestContext | None = None + ) -> ResponseReturnValue | None: """Preprocess the request i.e. call before_request functions. Arguments: @@ -1694,111 +1482,13 @@ async def preprocess_request( for function in self.before_request_funcs[name]: result = await self.ensure_async(function)() if result is not None: - return result + return result # type: ignore return None - async def dispatch_request( - self, request_context: Optional[RequestContext] = None - ) -> ResponseReturnValue: - """Dispatch the request to the view function. - - Arguments: - request_context: The request context, optional as Flask - omits this argument. - """ - request_ = (request_context or request_ctx).request - if request_.routing_exception is not None: - self.raise_routing_exception(request_) - - if request_.method == "OPTIONS" and request_.url_rule.provide_automatic_options: - return await self.make_default_options_response() - - handler = self.view_functions[request_.url_rule.endpoint] - return await self.ensure_async(handler)(**request_.view_args) - - async def finalize_request( - self, - result: Union[ResponseReturnValue, HTTPException], - request_context: Optional[RequestContext] = None, - from_error_handler: bool = False, - ) -> Union[Response, WerkzeugResponse]: - """Turns the view response return value into a response. - - Arguments: - result: The result of the request to finalize into a response. - request_context: The request context, optional as Flask - omits this argument. - """ - response = await self.make_response(result) - try: - response = await self.process_response(response, request_context) - await request_finished.send(self, response=response) - except Exception: - if not from_error_handler: - raise - self.logger.exception("Request finalizing errored") - return response - - async def process_response( - self, - response: Union[Response, WerkzeugResponse], - request_context: Optional[RequestContext] = None, - ) -> Union[Response, WerkzeugResponse]: - """Postprocess the request acting on the response. - - Arguments: - response: The response after the request is finalized. - request_context: The request context, optional as Flask - omits this argument. - """ - names = [*(request_context or request_ctx).request.blueprints, None] - - for function in (request_context or request_ctx)._after_request_functions: - response = await self.ensure_async(function)(response) - - for name in names: - for function in reversed(self.after_request_funcs[name]): - response = await self.ensure_async(function)(response) - - session_ = (request_context or request_ctx).session - if not self.session_interface.is_null_session(session_): - await self.ensure_async(self.session_interface.save_session)(self, session_, response) - return response - - async def handle_websocket( - self, websocket: Websocket - ) -> Optional[Union[Response, WerkzeugResponse]]: - async with self.websocket_context(websocket) as websocket_context: - try: - return await self.full_dispatch_websocket(websocket_context) - except asyncio.CancelledError: - raise # CancelledErrors should be handled by serving code. - except Exception as error: - return await self.handle_websocket_exception(error) - - async def full_dispatch_websocket( - self, websocket_context: Optional[WebsocketContext] = None - ) -> Optional[Union[Response, WerkzeugResponse]]: - """Adds pre and post processing to the websocket dispatching. - - Arguments: - websocket_context: The websocket context, optional to match - the Flask convention. - """ - await self.try_trigger_before_first_request_functions() - await websocket_started.send(self) - try: - result = await self.preprocess_websocket(websocket_context) - if result is None: - result = await self.dispatch_websocket(websocket_context) - except Exception as error: - result = await self.handle_user_exception(error) - return await self.finalize_websocket(result, websocket_context) - async def preprocess_websocket( - self, websocket_context: Optional[WebsocketContext] = None - ) -> Optional[ResponseReturnValue]: + self, websocket_context: WebsocketContext | None = None + ) -> ResponseReturnValue | None: """Preprocess the websocket i.e. call before_websocket functions. Arguments: @@ -1818,13 +1508,35 @@ async def preprocess_websocket( for function in self.before_websocket_funcs[name]: result = await self.ensure_async(function)() if result is not None: - return result + return result # type: ignore return None + def raise_routing_exception(self, request: BaseRequestWebsocket) -> NoReturn: + raise request.routing_exception + + async def dispatch_request( + self, request_context: RequestContext | None = None + ) -> ResponseReturnValue: + """Dispatch the request to the view function. + + Arguments: + request_context: The request context, optional as Flask + omits this argument. + """ + request_ = (request_context or request_ctx).request + if request_.routing_exception is not None: + self.raise_routing_exception(request_) + + if request_.method == "OPTIONS" and request_.url_rule.provide_automatic_options: + return await self.make_default_options_response() + + handler = self.view_functions[request_.url_rule.endpoint] + return await self.ensure_async(handler)(**request_.view_args) # type: ignore + async def dispatch_websocket( - self, websocket_context: Optional[WebsocketContext] = None - ) -> Optional[ResponseReturnValue]: + self, websocket_context: WebsocketContext | None = None + ) -> ResponseReturnValue | None: """Dispatch the websocket to the view function. Arguments: @@ -1836,14 +1548,39 @@ async def dispatch_websocket( self.raise_routing_exception(websocket_) handler = self.view_functions[websocket_.url_rule.endpoint] - return await self.ensure_async(handler)(**websocket_.view_args) + return await self.ensure_async(handler)(**websocket_.view_args) # type: ignore + + async def finalize_request( + self, + result: ResponseReturnValue | HTTPException, + request_context: RequestContext | None = None, + from_error_handler: bool = False, + ) -> ResponseTypes: + """Turns the view response return value into a response. + + Arguments: + result: The result of the request to finalize into a response. + request_context: The request context, optional as Flask + omits this argument. + """ + response = await self.make_response(result) + try: + response = await self.process_response(response, request_context) + await request_finished.send_async( + self, _sync_wrapper=self.ensure_async, response=response # type: ignore + ) + except Exception: + if not from_error_handler: + raise + self.logger.exception("Request finalizing errored") + return response async def finalize_websocket( self, - result: ResponseReturnValue, - websocket_context: Optional[WebsocketContext] = None, + result: ResponseReturnValue | HTTPException, + websocket_context: WebsocketContext | None = None, from_error_handler: bool = False, - ) -> Optional[Union[Response, WerkzeugResponse]]: + ) -> ResponseTypes | None: """Turns the view response return value into a response. Arguments: @@ -1857,18 +1594,46 @@ async def finalize_websocket( response = None try: response = await self.postprocess_websocket(response, websocket_context) - await websocket_finished.send(self, response=response) + await websocket_finished.send_async( + self, _sync_wrapper=self.ensure_async, response=response # type: ignore + ) except Exception: if not from_error_handler: raise self.logger.exception("Request finalizing errored") return response + async def process_response( + self, + response: ResponseTypes, + request_context: RequestContext | None = None, + ) -> ResponseTypes: + """Postprocess the request acting on the response. + + Arguments: + response: The response after the request is finalized. + request_context: The request context, optional as Flask + omits this argument. + """ + names = [*(request_context or request_ctx).request.blueprints, None] + + for function in (request_context or request_ctx)._after_request_functions: + response = await self.ensure_async(function)(response) # type: ignore + + for name in names: + for function in reversed(self.after_request_funcs[name]): + response = await self.ensure_async(function)(response) + + session_ = (request_context or request_ctx).session + if not self.session_interface.is_null_session(session_): + await self.ensure_async(self.session_interface.save_session)(self, session_, response) + return response + async def postprocess_websocket( self, - response: Optional[Union[Response, WerkzeugResponse]], - websocket_context: Optional[WebsocketContext] = None, - ) -> Union[Response, WerkzeugResponse]: + response: ResponseTypes | None, + websocket_context: WebsocketContext | None = None, + ) -> ResponseTypes: """Postprocess the websocket acting on the response. Arguments: @@ -1879,11 +1644,11 @@ async def postprocess_websocket( names = [*(websocket_context or websocket_ctx).websocket.blueprints, None] for function in (websocket_context or websocket_ctx)._after_websocket_functions: - response = await self.ensure_async(function)(response) + response = await self.ensure_async(function)(response) # type: ignore for name in names: for function in reversed(self.after_websocket_funcs[name]): - response = await self.ensure_async(function)(response) + response = await self.ensure_async(function)(response) # type: ignore session_ = (websocket_context or websocket_ctx).session if not self.session_interface.is_null_session(session_): @@ -1916,7 +1681,7 @@ async def asgi_app( app.asgi_app = middleware(app.asgi_app) """ - asgi_handler: Union[ASGIHTTPProtocol, ASGILifespanProtocol, ASGIWebsocketProtocol] + asgi_handler: ASGIHTTPProtocol | ASGILifespanProtocol | ASGIWebsocketProtocol if scope["type"] == "http": asgi_handler = self.asgi_http_class(self, scope) elif scope["type"] == "websocket": @@ -1928,8 +1693,7 @@ async def asgi_app( await asgi_handler(receive, send) async def startup(self) -> None: - self._got_first_request = False - + self.shutdown_event = self.event_class() try: async with self.app_context(): for func in self.before_serving_funcs: @@ -1937,11 +1701,22 @@ async def startup(self) -> None: for gen in self.while_serving_gens: await gen.__anext__() except Exception as error: - await got_serving_exception.send(self, exception=error) + await got_serving_exception.send_async( + self, _sync_wrapper=self.ensure_async, exception=error # type: ignore + ) self.log_exception(sys.exc_info()) raise async def shutdown(self) -> None: + self.shutdown_event.set() + try: + await asyncio.wait_for( + asyncio.gather(*self.background_tasks), + timeout=self.config["BACKGROUND_TASK_SHUTDOWN_TIMEOUT"], + ) + except asyncio.TimeoutError: + await cancel_tasks(self.background_tasks) + try: async with self.app_context(): for func in self.after_serving_funcs: @@ -1954,12 +1729,12 @@ async def shutdown(self) -> None: else: raise RuntimeError("While serving generator didn't terminate") except Exception as error: - await got_serving_exception.send(self, exception=error) + await got_serving_exception.send_async( + self, _sync_wrapper=self.ensure_async, exception=error # type: ignore + ) self.log_exception(sys.exc_info()) raise - await asyncio.gather(*self.background_tasks) - def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None: tasks = [task for task in asyncio.all_tasks(loop) if not task.done()] @@ -1979,11 +1754,3 @@ def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None: "task": task, } ) - - -async def _windows_signal_support() -> None: - # See https://bugs.python.org/issue23057, to catch signals on - # Windows it is necessary for an IO event to happen periodically. - # Fixed by Python 3.8 - while True: - await asyncio.sleep(1) diff --git a/src/quart/asgi.py b/src/quart/asgi.py index d711296..856b55a 100644 --- a/src/quart/asgi.py +++ b/src/quart/asgi.py @@ -3,7 +3,7 @@ import asyncio import warnings from functools import partial -from typing import AnyStr, cast, List, Optional, Set, TYPE_CHECKING, Union +from typing import AnyStr, cast, Optional, TYPE_CHECKING from urllib.parse import urlparse from hypercorn.typing import ( @@ -28,7 +28,8 @@ from .debug import traceback_response from .signals import websocket_received, websocket_sent -from .utils import encode_headers +from .typing import ResponseTypes +from .utils import cancel_tasks, encode_headers, raise_task_exceptions from .wrappers import Request, Response, Websocket # noqa: F401 if TYPE_CHECKING: @@ -36,7 +37,7 @@ class ASGIHTTPConnection: - def __init__(self, app: "Quart", scope: HTTPScope) -> None: + def __init__(self, app: Quart, scope: HTTPScope) -> None: self.app = app self.scope = scope @@ -47,8 +48,8 @@ async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) - done, pending = await asyncio.wait( [handler_task, receiver_task], return_when=asyncio.FIRST_COMPLETED ) - await _cancel_tasks(pending) - _raise_exceptions(done) + await cancel_tasks(pending) + raise_task_exceptions(done) async def handle_messages(self, request: Request, receive: ASGIReceiveCallable) -> None: while True: @@ -70,6 +71,13 @@ def _create_request_from_scope(self, send: ASGISendCallable) -> Request: path = self.scope["path"] path = path if path[0] == "/" else urlparse(path).path + root_path = self.scope.get("root_path", "") + if root_path != "": + try: + path = path.split(root_path, 1)[1] + path = " " if path == "" else path + except IndexError: + path = " " # Invalid in paths, hence will result in 404 return self.app.request_class( self.scope["method"], @@ -88,11 +96,8 @@ def _create_request_from_scope(self, send: ASGISendCallable) -> Request: async def handle_request(self, request: Request, send: ASGISendCallable) -> None: try: response = await self.app.handle_request(request) - except Exception: - if self.app.propagate_exceptions: - response = await traceback_response() - else: - raise + except Exception as error: + response = await _handle_exception(self.app, error) if isinstance(response, Response) and response.timeout != Ellipsis: timeout = cast(Optional[float], response.timeout) @@ -103,9 +108,7 @@ async def handle_request(self, request: Request, send: ASGISendCallable) -> None except asyncio.TimeoutError: pass - async def _send_response( - self, send: ASGISendCallable, response: Union[Response, WerkzeugResponse] - ) -> None: + async def _send_response(self, send: ASGISendCallable, response: ResponseTypes) -> None: await send( cast( HTTPResponseStartEvent, @@ -119,7 +122,7 @@ async def _send_response( if isinstance(response, WerkzeugResponse): for data in response.response: - body = data.encode(response.charset) if isinstance(data, str) else data + body = data.encode() if isinstance(data, str) else data await send( cast( HTTPResponseBodyEvent, @@ -129,7 +132,7 @@ async def _send_response( else: async with response.response as response_body: async for data in response_body: - body = data.encode(response.charset) if isinstance(data, str) else data + body = data.encode() if isinstance(data, str) else data await send( cast( HTTPResponseBodyEvent, @@ -144,14 +147,15 @@ async def _send_response( ) async def _send_push_promise(self, send: ASGISendCallable, path: str, headers: Headers) -> None: - if "http.response.push" in self.scope.get("extensions", {}): + extensions = self.scope.get("extensions", {}) or {} + if "http.response.push" in extensions: await send( {"type": "http.response.push", "path": path, "headers": encode_headers(headers)} ) class ASGIWebsocketConnection: - def __init__(self, app: "Quart", scope: WebsocketScope) -> None: + def __init__(self, app: Quart, scope: WebsocketScope) -> None: self.app = app self.scope = scope self.queue: asyncio.Queue = asyncio.Queue() @@ -165,15 +169,15 @@ async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) - done, pending = await asyncio.wait( [handler_task, receiver_task], return_when=asyncio.FIRST_COMPLETED ) - await _cancel_tasks(pending) - _raise_exceptions(done) + await cancel_tasks(pending) + raise_task_exceptions(done) async def handle_messages(self, receive: ASGIReceiveCallable) -> None: while True: event = await receive() if event["type"] == "websocket.receive": message = event.get("bytes") or event["text"] - await websocket_received.send(message) + await websocket_received.send_async(message) await self.queue.put(message) elif event["type"] == "websocket.disconnect": return @@ -186,6 +190,13 @@ def _create_websocket_from_scope(self, send: ASGISendCallable) -> Websocket: path = self.scope["path"] path = path if path[0] == "/" else urlparse(path).path + root_path = self.scope.get("root_path", "") + if root_path != "": + try: + path = path.split(root_path, 1)[1] + path = " " if path == "" else path + except IndexError: + path = " " # Invalid in paths, hence will result in 404 return self.app.websocket_class( path, @@ -205,14 +216,12 @@ def _create_websocket_from_scope(self, send: ASGISendCallable) -> Websocket: async def handle_websocket(self, websocket: Websocket, send: ASGISendCallable) -> None: try: response = await self.app.handle_websocket(websocket) - except Exception: - if self.app.propagate_exceptions: - raise - else: - response = await traceback_response() + except Exception as error: + response = await _handle_exception(self.app, error) if response is not None and not self._accepted: - if "websocket.http.response" in self.scope.get("extensions", {}): + extensions = self.scope.get("extensions", {}) or {} + if "websocket.http.response" in extensions: headers = [ (key.lower().encode(), value.encode()) for key, value in response.headers.items() @@ -268,10 +277,10 @@ async def send_data(self, send: ASGISendCallable, data: AnyStr) -> None: await send({"type": "websocket.send", "bytes": None, "text": data}) else: await send({"type": "websocket.send", "bytes": data, "text": None}) - await websocket_sent.send(data) + await websocket_sent.send_async(data) async def accept_connection( - self, send: ASGISendCallable, headers: Headers, subprotocol: Optional[str] + self, send: ASGISendCallable, headers: Headers, subprotocol: str | None ) -> None: if not self._accepted: message: WebsocketAcceptEvent = { @@ -300,7 +309,7 @@ async def close_connection(self, send: ASGISendCallable, code: int, reason: str) class ASGILifespan: - def __init__(self, app: "Quart", scope: LifespanScope) -> None: + def __init__(self, app: Quart, scope: LifespanScope) -> None: self.app = app async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: @@ -337,21 +346,12 @@ async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) - break -async def _cancel_tasks(tasks: Set[asyncio.Task]) -> None: - # Cancel any pending, and wait for the cancellation to - # complete i.e. finish any remaining work. - for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - _raise_exceptions(tasks) - - -def _raise_exceptions(tasks: Set[asyncio.Task]) -> None: - # Raise any unexpected exceptions - for task in tasks: - if not task.cancelled() and task.exception() is not None: - raise task.exception() +def _convert_version(raw: str) -> list[int]: + return list(map(int, raw.split("."))) -def _convert_version(raw: str) -> List[int]: - return list(map(int, raw.split("."))) +async def _handle_exception(app: Quart, error: Exception) -> Response: + if not app.testing and app.config["PROPAGATE_EXCEPTIONS"]: + return await traceback_response(error) + else: + raise error diff --git a/src/quart/blueprints.py b/src/quart/blueprints.py index a30f37e..acaa047 100644 --- a/src/quart/blueprints.py +++ b/src/quart/blueprints.py @@ -1,67 +1,48 @@ from __future__ import annotations +import os +import typing as t from collections import defaultdict -from functools import update_wrapper -from typing import ( - Any, - Callable, - Iterable, - List, - Optional, - Tuple, - Type, - TYPE_CHECKING, - TypeVar, - Union, +from datetime import timedelta + +from aiofiles import open as async_open +from aiofiles.base import AiofilesContextManager +from aiofiles.threadpool.binary import AsyncBufferedReader +from flask.sansio.app import App +from flask.sansio.blueprints import ( # noqa + Blueprint as SansioBlueprint, + BlueprintSetupState as BlueprintSetupState, ) +from flask.sansio.scaffold import setupmethod -from .scaffold import _endpoint_from_view_func, Scaffold, setupmethod +from .cli import AppGroup +from .globals import current_app +from .helpers import send_from_directory from .typing import ( - AfterRequestCallable, AfterServingCallable, AfterWebsocketCallable, - BeforeFirstRequestCallable, - BeforeRequestCallable, + AppOrBlueprintKey, BeforeServingCallable, BeforeWebsocketCallable, - ErrorHandlerCallable, - RouteCallable, + FilePath, TeardownCallable, - TemplateContextProcessorCallable, - TemplateFilterCallable, - TemplateGlobalCallable, - TemplateTestCallable, - URLDefaultCallable, - URLValuePreprocessorCallable, WebsocketCallable, WhileServingCallable, ) -if TYPE_CHECKING: - from .app import Quart # noqa - -DeferredSetupFunction = Callable[["BlueprintSetupState"], Callable] -T_after_request = TypeVar("T_after_request", bound=AfterRequestCallable) -T_after_websocket = TypeVar("T_after_websocket", bound=AfterWebsocketCallable) -T_after_serving = TypeVar("T_after_serving", bound=AfterServingCallable) -T_before_first_request = TypeVar("T_before_first_request", bound=BeforeFirstRequestCallable) -T_before_request = TypeVar("T_before_request", bound=BeforeRequestCallable) -T_before_websocket = TypeVar("T_before_websocket", bound=BeforeWebsocketCallable) -T_before_serving = TypeVar("T_before_serving", bound=BeforeServingCallable) -T_error_handler = TypeVar("T_error_handler", bound=ErrorHandlerCallable) -T_teardown = TypeVar("T_teardown", bound=TeardownCallable) -T_template_context_processor = TypeVar( - "T_template_context_processor", bound=TemplateContextProcessorCallable -) -T_template_filter = TypeVar("T_template_filter", bound=TemplateFilterCallable) -T_template_global = TypeVar("T_template_global", bound=TemplateGlobalCallable) -T_template_test = TypeVar("T_template_test", bound=TemplateTestCallable) -T_url_defaults = TypeVar("T_url_defaults", bound=URLDefaultCallable) -T_url_value_preprocessor = TypeVar("T_url_value_preprocessor", bound=URLValuePreprocessorCallable) -T_while_serving = TypeVar("T_while_serving", bound=WhileServingCallable) +if t.TYPE_CHECKING: + from .wrappers import Response + +T_after_serving = t.TypeVar("T_after_serving", bound=AfterServingCallable) +T_after_websocket = t.TypeVar("T_after_websocket", bound=AfterWebsocketCallable) +T_before_serving = t.TypeVar("T_before_serving", bound=BeforeServingCallable) +T_before_websocket = t.TypeVar("T_before_websocket", bound=BeforeWebsocketCallable) +T_teardown = t.TypeVar("T_teardown", bound=TeardownCallable) +T_websocket = t.TypeVar("T_websocket", bound=WebsocketCallable) +T_while_serving = t.TypeVar("T_while_serving", bound=WhileServingCallable) -class Blueprint(Scaffold): +class Blueprint(SansioBlueprint): """A blueprint is a collection of application properties. The application properties include routes, error handlers, and @@ -69,247 +50,241 @@ class Blueprint(Scaffold): modular code as it allows the properties to be defined in a blueprint thereby deferring the addition of these properties to the app. - - Attributes: - url_prefix: An additional prefix to every route rule in the - blueprint. """ - warn_on_modifications = False - _got_registered_once = False + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: + super().__init__(*args, **kwargs) - def __init__( - self, - name: str, - import_name: str, - static_folder: Optional[str] = None, - static_url_path: Optional[str] = None, - template_folder: Optional[str] = None, - url_prefix: Optional[str] = None, - subdomain: Optional[str] = None, - url_defaults: Optional[dict] = None, - root_path: Optional[str] = None, - cli_group: Optional[str] = Ellipsis, # type: ignore - ) -> None: - super().__init__(import_name, static_folder, static_url_path, template_folder, root_path) - - if "." in name: - raise ValueError("Blueprint names may not contain dot '.' characters.") - - self.name = name - self.url_prefix = url_prefix - self.deferred_functions: List[DeferredSetupFunction] = [] - self.subdomain = subdomain - if url_defaults is None: - url_defaults = {} - self.url_values_defaults = url_defaults - self.cli_group = cli_group - self._blueprints: List[Tuple["Blueprint", dict]] = [] - - def _check_setup_finished(self, f_name: str) -> None: - if self._got_registered_once: - raise AssertionError( - f"The setup method '{f_name}' can no longer be called on" - f" the blueprint '{self.name}'. It has already been" - " registered at least once, any changes will not be" - " applied consistently.\n" - "Make sure all imports, decorators, functions, etc." - " needed to set up the blueprint are done before" - " registering it.\n" - ) + self.cli = AppGroup() + self.cli.name = self.name - @setupmethod - def add_url_rule( - self, - rule: str, - endpoint: Optional[str] = None, - view_func: Optional[Union[RouteCallable, WebsocketCallable]] = None, - provide_automatic_options: Optional[bool] = None, - *, - methods: Optional[Iterable[str]] = None, - defaults: Optional[dict] = None, - host: Optional[str] = None, - subdomain: Optional[str] = None, - is_websocket: bool = False, - strict_slashes: Optional[bool] = None, - merge_slashes: Optional[bool] = None, - ) -> None: - """Add a route/url rule to the blueprint. + self.after_websocket_funcs: t.Dict[AppOrBlueprintKey, t.List[AfterWebsocketCallable]] = ( + defaultdict(list) + ) + self.before_websocket_funcs: t.Dict[AppOrBlueprintKey, t.List[BeforeWebsocketCallable]] = ( + defaultdict(list) + ) + self.teardown_websocket_funcs: dict[AppOrBlueprintKey, list[TeardownCallable]] = ( + defaultdict(list) + ) - This is designed to be used on the blueprint directly, and - has the same arguments as - :meth:`~quart.Quart.add_url_rule`. An example usage, + def get_send_file_max_age(self, filename: str | None) -> int | None: + """Used by :func:`send_file` to determine the ``max_age`` cache + value for a given file path if it wasn't passed. - .. code-block:: python + By default, this returns :data:`SEND_FILE_MAX_AGE_DEFAULT` from + the configuration of :data:`~flask.current_app`. This defaults + to ``None``, which tells the browser to use conditional requests + instead of a timed cache, which is usually preferable. - def route(): - ... + Note this is a duplicate of the same method in the Quart + class. - blueprint = Blueprint(__name__) - blueprint.add_url_rule('/', route) """ - endpoint = endpoint or _endpoint_from_view_func(view_func) - if "." in endpoint: - raise ValueError("Blueprint endpoints should not contain periods") - self.record( - lambda state: state.add_url_rule( - rule, - endpoint, - view_func, - provide_automatic_options=provide_automatic_options, - methods=methods, - defaults=defaults, - host=host, - subdomain=subdomain or self.subdomain, - is_websocket=is_websocket, - strict_slashes=strict_slashes, - merge_slashes=merge_slashes, - ) - ) - - @setupmethod - def app_template_filter( - self, name: Optional[str] = None - ) -> Callable[[T_template_filter], T_template_filter]: - """Add an application wide template filter. + value = current_app.config["SEND_FILE_MAX_AGE_DEFAULT"] - This is designed to be used as a decorator, and has the same arguments - as :meth:`~quart.Quart.template_filter`. An example usage, + if value is None: + return None - .. code-block:: python + if isinstance(value, timedelta): + return int(value.total_seconds()) - blueprint = Blueprint(__name__) - @blueprint.app_template_filter() - def filter(value): - ... - """ + return value + return None - def decorator(func: T_template_filter) -> T_template_filter: - self.add_app_template_filter(func, name=name) - return func + async def send_static_file(self, filename: str) -> Response: + if not self.has_static_folder: + raise RuntimeError("No static folder for this object") + return await send_from_directory(self.static_folder, filename) - return decorator - - @setupmethod - def add_app_template_filter( - self, func: TemplateFilterCallable, name: Optional[str] = None - ) -> None: - """Add an application wide template filter. + async def open_resource( + self, + path: FilePath, + mode: str = "rb", + ) -> AiofilesContextManager[None, None, AsyncBufferedReader]: + """Open a file for reading. - This is designed to be used on the blueprint directly, and - has the same arguments as - :meth:`~quart.Quart.add_template_filter`. An example usage, + Use as .. code-block:: python - def filter(): - ... - - blueprint = Blueprint(__name__) - blueprint.add_app_template_filter(filter) + async with await app.open_resource(path) as file_: + await file_.read() """ - self.record_once(lambda state: state.register_template_filter(func, name)) + if mode not in {"r", "rb", "rt"}: + raise ValueError("Files can only be opened for reading") - @setupmethod - def app_template_test( - self, name: Optional[str] = None - ) -> Callable[[T_template_test], T_template_test]: - """Add an application wide template test. + return async_open(os.path.join(self.root_path, path), mode) # type: ignore - This is designed to be used as a decorator, and has the same arguments - as :meth:`~quart.Quart.template_test`. An example usage, + def websocket( + self, + rule: str, + **options: t.Any, + ) -> t.Callable[[T_websocket], T_websocket]: + """Add a websocket to the application. + + This is designed to be used as a decorator, if used to + decorate a synchronous function, the function will be wrapped + in :func:`~quart.utils.run_sync` and run in a thread executor + (with the wrapped function returned). An example usage, .. code-block:: python - blueprint = Blueprint(__name__) - @blueprint.app_template_test() - def test(value): + @app.websocket('/') + async def websocket_route(): ... + + Arguments: + rule: The path to route on, should start with a ``/``. + endpoint: Optional endpoint name, if not present the + function name is used. + defaults: A dictionary of variables to provide automatically, use + to provide a simpler default path for a route, e.g. to allow + for ``/book`` rather than ``/book/0``, + + .. code-block:: python + + @app.websocket('/book', defaults={'page': 0}) + @app.websocket('/book/') + def book(page): + ... + + host: The full host name for this route (should include subdomain + if needed) - cannot be used with subdomain. + subdomain: A subdomain for this specific route. + strict_slashes: Strictly match the trailing slash present in the + path. Will redirect a leaf (no slash) to a branch (with slash). """ - def decorator(func: T_template_test) -> T_template_test: - self.add_app_template_test(func, name=name) + def decorator(func: T_websocket) -> T_websocket: + endpoint = options.pop("endpoint", None) + self.add_websocket( + rule, + endpoint, + func, + **options, + ) return func return decorator - @setupmethod - def add_app_template_test(self, func: TemplateTestCallable, name: Optional[str] = None) -> None: - """Add an application wide template test. + def add_websocket( + self, + rule: str, + endpoint: str | None = None, + view_func: WebsocketCallable | None = None, + **options: t.Any, + ) -> None: + """Add a websocket url rule to the application. - This is designed to be used on the blueprint directly, and - has the same arguments as - :meth:`~quart.Quart.add_template_test`. An example usage, + This is designed to be used on the application directly. An + example usage, .. code-block:: python - def test(): + def websocket_route(): ... - blueprint = Blueprint(__name__) - blueprint.add_app_template_test(test) + app.add_websocket('/', websocket_route) + + Arguments: + rule: The path to route on, should start with a ``/``. + endpoint: Optional endpoint name, if not present the + function name is used. + view_func: Callable that returns a response. + defaults: A dictionary of variables to provide automatically, use + to provide a simpler default path for a route, e.g. to allow + for ``/book`` rather than ``/book/0``, + + .. code-block:: python + + @app.websocket('/book', defaults={'page': 0}) + @app.websocket('/book/') + def book(page): + ... + + host: The full host name for this route (should include subdomain + if needed) - cannot be used with subdomain. + subdomain: A subdomain for this specific route. + strict_slashes: Strictly match the trailing slash present in the + path. Will redirect a leaf (no slash) to a branch (with slash). """ - self.record_once(lambda state: state.register_template_test(func, name)) + return self.add_url_rule( + rule, + endpoint, + view_func, + methods={"GET"}, + websocket=True, + **options, + ) @setupmethod - def app_template_global( - self, name: Optional[str] = None - ) -> Callable[[T_template_global], T_template_global]: - """Add an application wide template global. + def before_websocket( + self, + func: T_before_websocket, + ) -> T_before_websocket: + """Add a before websocket function. - This is designed to be used as a decorator, and has the same arguments - as :meth:`~quart.Quart.template_global`. An example usage, + This is designed to be used as a decorator, if used to + decorate a synchronous function, the function will be wrapped + in :func:`~quart.utils.run_sync` and run in a thread executor + (with the wrapped function returned). An example usage, .. code-block:: python - blueprint = Blueprint(__name__) - @blueprint.app_template_global() - def global(value): + @app.before_websocket + async def func(): ... - """ - - def decorator(func: T_template_global) -> T_template_global: - self.add_app_template_global(func, name=name) - return func - return decorator + Arguments: + func: The before websocket function itself. + """ + self.before_websocket_funcs[None].append(func) + return func @setupmethod - def add_app_template_global( - self, func: TemplateGlobalCallable, name: Optional[str] = None - ) -> None: - """Add an application wide template global. + def after_websocket( + self, + func: T_after_websocket, + ) -> T_after_websocket: + """Add an after websocket function. - This is designed to be used on the blueprint directly, and - has the same arguments as - :meth:`~quart.Quart.add_template_global`. An example usage, + This is designed to be used as a decorator, if used to + decorate a synchronous function, the function will be wrapped + in :func:`~quart.utils.run_sync` and run in a thread executor + (with the wrapped function returned). An example usage, .. code-block:: python - def global(): - ... + @app.after_websocket + async def func(response): + return response - blueprint = Blueprint(__name__) - blueprint.add_app_template_global(global) + Arguments: + func: The after websocket function itself. """ - self.record_once(lambda state: state.register_template_global(func, name)) + self.after_websocket_funcs[None].append(func) + return func @setupmethod - def before_app_request(self, func: T_before_request) -> T_before_request: - """Add a before request function to the app. - - This is designed to be used as a decorator, and has the same arguments - as :meth:`~quart.Quart.before_request`. It applies to all requests to the - app this blueprint is registered on. An example usage, - + def teardown_websocket( + self, + func: T_teardown, + ) -> T_teardown: + """Add a teardown websocket function. + This is designed to be used as a decorator, if used to + decorate a synchronous function, the function will be wrapped + in :func:`~quart.utils.run_sync` and run in a thread executor + (with the wrapped function returned). An example usage, .. code-block:: python - - blueprint = Blueprint(__name__) - @blueprint.before_app_request - def before(): + @app.teardown_websocket + async def func(): ... + Arguments: + func: The teardown websocket function itself. + name: Optional blueprint key name. """ - self.record_once(lambda state: state.app.before_request(func)) + self.teardown_websocket_funcs[None].append(func) return func @setupmethod @@ -328,7 +303,7 @@ def before(): ... """ - self.record_once(lambda state: state.app.before_websocket(func)) + self.record_once(lambda state: state.app.before_websocket(func)) # type: ignore return func @setupmethod @@ -346,45 +321,7 @@ def before(): ... """ - self.record_once(lambda state: state.app.before_serving(func)) - return func - - @setupmethod - def before_app_first_request(self, func: T_before_first_request) -> T_before_first_request: - """Add a before request first function to the app. - - This is designed to be used as a decorator, and has the same - arguments as :meth:`~quart.Quart.before_first_request`. It is - triggered before the first request to the app this blueprint - is registered on. An example usage, - - .. code-block:: python - - blueprint = Blueprint(__name__) - @blueprint.before_app_first_request - def before_first(): - ... - - """ - self.record_once(lambda state: state.app.before_first_request(func)) - return func - - @setupmethod - def after_app_request(self, func: T_after_request) -> T_after_request: - """Add a after request function to the app. - - This is designed to be used as a decorator, and has the same arguments - as :meth:`~quart.Quart.after_request`. It applies to all requests to the - app this blueprint is registered on. An example usage, - - .. code-block:: python - - blueprint = Blueprint(__name__) - @blueprint.after_app_request - def after(): - ... - """ - self.record_once(lambda state: state.app.after_request(func)) + self.record_once(lambda state: state.app.before_serving(func)) # type: ignore return func @setupmethod @@ -402,7 +339,7 @@ def after_app_websocket(self, func: T_after_websocket) -> T_after_websocket: def after(): ... """ - self.record_once(lambda state: state.app.after_websocket(func)) + self.record_once(lambda state: state.app.after_websocket(func)) # type: ignore return func @setupmethod @@ -419,7 +356,7 @@ def after_app_serving(self, func: T_after_serving) -> T_after_serving: def after(): ... """ - self.record_once(lambda state: state.app.after_serving(func)) + self.record_once(lambda state: state.app.after_serving(func)) # type: ignore[attr-defined] return func @setupmethod @@ -438,26 +375,7 @@ async def func(): ... # Shutdown """ - self.record_once(lambda state: state.app.while_serving(func)) - return func - - @setupmethod - def teardown_app_request(self, func: T_teardown) -> T_teardown: - """Add a teardown request function to the app. - - This is designed to be used as a decorator, and has the same - arguments as :meth:`~quart.Quart.teardown_request`. It applies - to all requests to the app this blueprint is registered on. An - example usage, - - .. code-block:: python - - blueprint = Blueprint(__name__) - @blueprint.teardown_app_request - def teardown(): - ... - """ - self.record_once(lambda state: state.app.teardown_request(func)) + self.record_once(lambda state: state.app.while_serving(func)) # type: ignore[attr-defined] return func @setupmethod @@ -476,323 +394,27 @@ def teardown_app_websocket(self, func: T_teardown) -> T_teardown: def teardown(): ... """ - self.record_once(lambda state: state.app.teardown_websocket(func)) + self.record_once(lambda state: state.app.teardown_websocket(func)) # type: ignore return func - @setupmethod - def app_errorhandler( - self, error: Union[Type[Exception], int] - ) -> Callable[[T_error_handler], T_error_handler]: - """Add an error handler function to the App. - - This is designed to be used as a decorator, and has the same - arguments as :meth:`~quart.Quart.errorhandler`. It applies - only to all errors. An example usage, - - .. code-block:: python - - blueprint = Blueprint(__name__) - @blueprint.app_errorhandler(404) - def not_found(): - ... - """ - - def decorator(func: T_error_handler) -> T_error_handler: - self.record_once(lambda state: state.app.register_error_handler(error, func)) - return func - - return decorator - - @setupmethod - def app_context_processor( - self, - func: T_template_context_processor, - ) -> T_template_context_processor: - """Add a context processor function to the app. - - This is designed to be used as a decorator, and has the same - arguments as :meth:`~quart.Quart.context_processor`. This will - add context to all templates rendered. An example usage, - - .. code-block:: python - - blueprint = Blueprint(__name__) - @blueprint.app_context_processor - def processor(): - ... - """ - self.record_once(lambda state: state.app.context_processor(func)) - return func - - @setupmethod - def app_url_value_preprocessor( - self, func: T_url_value_preprocessor - ) -> T_url_value_preprocessor: - """Add a url value preprocessor. - - This is designed to be used as a decorator, and has the same - arguments as - :meth:`~quart.Quart.app_url_value_preprocessor`. This will - apply to all URLs. An example usage, - - .. code-block:: python - - blueprint = Blueprint(__name__) - @blueprint.app_url_value_preprocessor - def processor(endpoint, view_args): - ... - - """ - self.record_once(lambda state: state.app.url_value_preprocessor(func)) - return func - - @setupmethod - def app_url_defaults(self, func: T_url_defaults) -> T_url_defaults: - """Add a url default preprocessor. - - This is designed to be used as a decorator, and has the same - arguments as :meth:`~quart.Quart.url_defaults`. This will - apply to all urls. An example usage, + def _merge_blueprint_funcs(self, app: App, name: str) -> None: + super()._merge_blueprint_funcs(app, name) - .. code-block:: python - - blueprint = Blueprint(__name__) - @blueprint.app_url_defaults - def default(endpoint, values): - ... - - """ - self.record_once(lambda state: state.app.url_defaults(func)) - return func - - @setupmethod - def record(self, func: DeferredSetupFunction) -> None: - """Used to register a deferred action.""" - self.deferred_functions.append(func) - - @setupmethod - def record_once(self, func: DeferredSetupFunction) -> None: - """Used to register a deferred action that happens only once.""" - - def wrapper(state: "BlueprintSetupState") -> None: - if state.first_registration: - func(state) - - self.record(update_wrapper(wrapper, func)) - - @setupmethod - def register_blueprint(self, blueprint: "Blueprint", **options: Any) -> None: - """Register a :class:`~quart.Blueprint` on this blueprint. - - Keyword arguments passed to this method will override the - defaults set on the blueprint. - """ - if blueprint is self: - raise ValueError("Cannot register a blueprint on itself") - self._blueprints.append((blueprint, options)) - - def register(self, app: "Quart", options: dict) -> None: - """Register this blueprint on the app given. - - Arguments: - app: The application this blueprint is being registered with. - options: Keyword arguments forwarded from - :meth:`~quart.Quart.register_blueprint`. - first_registration: Whether this is the first time this - blueprint has been registered on the application. - """ - - name = f"{options.get('name_prefix', '')}.{options.get('name', self.name)}".lstrip(".") - if name in app.blueprints and app.blueprints[name] is not self: - raise ValueError( - f"Blueprint name '{self.name}' " - f"is already registered by {app.blueprints[self.name]}. " - "Blueprints must have unique names" - ) - - first_blueprint_registration = not any( - blueprint is self for blueprint in app.blueprints.values() - ) - first_name_registration = name not in app.blueprints - - app.blueprints[name] = self - self._got_registered_once = True - - state = self.make_setup_state(app, options, first_blueprint_registration) - - if self.has_static_folder: - state.add_url_rule( - self.static_url_path + "/", - view_func=self.send_static_file, - endpoint="static", - ) - - if first_blueprint_registration or first_name_registration: - - def extend(bp_dict: dict, parent_dict: dict) -> None: - for key, values in bp_dict.items(): - key = name if key is None else f"{name}.{key}" - parent_dict[key].extend(values) - - for key, value in self.error_handler_spec.items(): + def extend(bp_dict: dict, parent_dict: dict) -> None: + for key, values in bp_dict.items(): key = name if key is None else f"{name}.{key}" - value = defaultdict( - dict, - { - code: {exc_class: func for exc_class, func in code_values.items()} - for code, code_values in value.items() - }, - ) - app.error_handler_spec[key] = value - - for endpoint, func in self.view_functions.items(): - app.view_functions[endpoint] = func - - extend(self.before_request_funcs, app.before_request_funcs) - extend(self.before_websocket_funcs, app.before_websocket_funcs) - extend(self.after_request_funcs, app.after_request_funcs) - extend(self.after_websocket_funcs, app.after_websocket_funcs) - extend( - self.teardown_request_funcs, - app.teardown_request_funcs, - ) - extend( - self.teardown_websocket_funcs, - app.teardown_websocket_funcs, + parent_dict[key].extend(values) + + for key, value in self.error_handler_spec.items(): + key = name if key is None else f"{name}.{key}" + value = defaultdict( + dict, + { + code: {exc_class: func for exc_class, func in code_values.items()} + for code, code_values in value.items() + }, ) - extend(self.url_default_functions, app.url_default_functions) - extend(self.url_value_preprocessors, app.url_value_preprocessors) - extend(self.template_context_processors, app.template_context_processors) - - for func in self.deferred_functions: - func(state) - - cli_resolved_group = options.get("cli_group", self.cli_group) - - if self.cli.commands: - if cli_resolved_group is None: - app.cli.commands.update(self.cli.commands) - elif cli_resolved_group is Ellipsis: - self.cli.name = name - app.cli.add_command(self.cli) - else: - self.cli.name = cli_resolved_group - app.cli.add_command(self.cli) - - for blueprint, bp_options in self._blueprints: - bp_options = bp_options.copy() - bp_url_prefix = bp_options.get("url_prefix") - if bp_url_prefix is None: - bp_url_prefix = blueprint.url_prefix - - if state.url_prefix is not None and bp_url_prefix is not None: - bp_options["url_prefix"] = ( - state.url_prefix.rstrip("/") + "/" + bp_url_prefix.lstrip("/") - ) - elif bp_url_prefix is not None: - bp_options["url_prefix"] = bp_url_prefix - elif state.url_prefix is not None: - bp_options["url_prefix"] = state.url_prefix - - bp_options["name_prefix"] = name - blueprint.register(app, bp_options) - - def make_setup_state( - self, app: "Quart", options: dict, first_registration: bool = False - ) -> "BlueprintSetupState": - """Return a blueprint setup state instance. - - Arguments: - first_registration: True if this is the first registration - of this blueprint on the app. - options: Keyword arguments forwarded from - :meth:`~quart.Quart.register_blueprint`. - first_registration: Whether this is the first time this - blueprint has been registered on the application. - """ - return BlueprintSetupState(self, app, options, first_registration) - - -class BlueprintSetupState: - """This setups the blueprint on the app. - - When used it can apply the deferred functions on the Blueprint to - the app. Override if you wish for blueprints to have be registered - in different ways. - - Attributes: - first_registration: True if this is the first registration - of this blueprint on the app. - """ - - def __init__( - self, blueprint: Blueprint, app: "Quart", options: dict, first_registration: bool - ) -> None: - self.blueprint = blueprint - self.app = app - self.options = options - self.url_prefix = options.get("url_prefix") or blueprint.url_prefix - self.first_registration = first_registration - self.subdomain = options.get("subdomain") or blueprint.subdomain - self.url_defaults = dict(self.blueprint.url_values_defaults) - self.url_defaults.update(options.get("url_defaults", {}) or {}) - self.name = self.options.get("name", blueprint.name) - self.name_prefix = self.options.get("name_prefix", "") - - def add_url_rule( - self, - path: str, - endpoint: Optional[str] = None, - view_func: Optional[Callable] = None, - *, - methods: Optional[Iterable[str]] = None, - defaults: Optional[dict] = None, - host: Optional[str] = None, - subdomain: Optional[str] = None, - provide_automatic_options: Optional[bool] = None, - is_websocket: bool = False, - strict_slashes: Optional[bool] = None, - merge_slashes: Optional[bool] = None, - ) -> None: - if self.url_prefix is not None: - path = f"{self.url_prefix.rstrip('/')}/{path.lstrip('/')}" - if subdomain is None: - subdomain = self.subdomain - endpoint = f"{self.name_prefix}.{self.name}.{endpoint}".lstrip(".") - url_defaults = self.url_defaults - if defaults is not None: - url_defaults = {**url_defaults, **defaults} - self.app.add_url_rule( - path, - endpoint, - view_func, - provide_automatic_options=provide_automatic_options, - methods=methods, - defaults=url_defaults, - host=host, - subdomain=subdomain, - is_websocket=is_websocket, - strict_slashes=strict_slashes, - merge_slashes=merge_slashes, - ) - - def register_template_filter(self, func: TemplateFilterCallable, name: Optional[str]) -> None: - self.app.add_template_filter(func, name) - - def register_template_test(self, func: Callable, name: Optional[str]) -> None: - self.app.add_template_test(func, name) - - def register_template_global(self, func: Callable, name: Optional[str]) -> None: - self.app.add_template_global(func, name) - - -def _merge_dict_of_lists(name: str, self_dict: dict, app_dict: dict) -> None: - for key, values in self_dict.items(): - key = name if key is None else f"{name}.{key}" - app_dict[key].extend(values) - + app.error_handler_spec[key] = value -def _merge_dict_of_dicts(name: str, self_dict: dict, app_dict: dict) -> None: - for key, value in self_dict.items(): - key = name if key is None else f"{name}.{key}" - app_dict[key] = value + extend(self.before_websocket_funcs, app.before_websocket_funcs) # type: ignore + extend(self.after_websocket_funcs, app.after_websocket_funcs) # type: ignore diff --git a/src/quart/cli.py b/src/quart/cli.py index 7e8f290..f99e234 100644 --- a/src/quart/cli.py +++ b/src/quart/cli.py @@ -13,7 +13,7 @@ from importlib import import_module from operator import attrgetter from types import ModuleType -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING import click from click.core import ParameterSource @@ -166,11 +166,11 @@ def find_app_by_string(module: ModuleType, app_name: str) -> Quart: return app raise NoAppException( - "A valid Quart application was not obtained from" f" '{module.__name__}:{app_name}'." + f"A valid Quart application was not obtained from '{module.__name__}:{app_name}'." ) -def locate_app(module_name: str, app_name: str) -> Optional[Quart]: +def locate_app(module_name: str, app_name: str) -> Quart | None: try: module = import_module(module_name) except ImportError: @@ -222,15 +222,15 @@ def prepare_import(path: str) -> str: class ScriptInfo: def __init__( self, - app_import_path: Optional[str] = None, - create_app: Optional[Callable[..., Quart]] = None, + app_import_path: str | None = None, + create_app: Callable[..., Quart] | None = None, set_debug_flag: bool = True, ) -> None: self.app_import_path = app_import_path self.create_app = create_app - self.data: Dict[Any, Any] = {} + self.data: dict[Any, Any] = {} self.set_debug_flag = set_debug_flag - self._loaded_app: Optional[Quart] = None + self._loaded_app: Quart | None = None def load_app(self) -> Quart: if self._loaded_app is not None: @@ -267,7 +267,7 @@ def load_app(self) -> Quart: pass_script_info = click.make_pass_decorator(ScriptInfo, ensure=True) -def with_appcontext(fn: Optional[Callable] = None) -> Callable: +def with_appcontext(fn: Callable | None = None) -> Callable: # decorator was used with parenthesis if fn is None: return with_appcontext @@ -372,29 +372,6 @@ def _set_app(ctx: click.Context, param: click.Option, value: str | None) -> str ) -def _set_env(ctx: click.Context, param: click.Option, value: str | None) -> str | None: - if value is None: - return None - - # Set with env var instead of ScriptInfo.load so that it can be - # accessed early during a factory function. - os.environ["QUART_ENV"] = value - return value - - -_env_option = click.Option( - ["-E", "--env"], - metavar="NAME", - help=( - "The execution environment name to set in 'app.env'. Defaults to" - " 'production'. 'development' will enable 'app.debug' and start the" - " debugger and reloader when running the server." - ), - expose_value=False, - callback=_set_env, -) - - def _set_debug(ctx: click.Context, param: click.Option, value: bool) -> bool | None: # If the flag isn't provided, it will default to False. Don't use # that, let debug be set by env in that case. @@ -468,7 +445,7 @@ def __init__( # callback. This allows users to make a custom group callback # without losing the behavior. --env-file must come first so # that it is eagerly evaluated before --app. - params.extend((_env_file_option, _app_option, _env_option, _debug_option)) + params.extend((_env_file_option, _app_option, _debug_option)) if add_version_option: params.append(version_option) @@ -526,9 +503,9 @@ def get_command(self, ctx: click.Context, name: str) -> click.Command: click.secho(f"Error: {e.format_message()}\n", err=True, fg="red") return None - return app.cli.get_command(ctx, name) # type: ignore + return app.cli.get_command(ctx, name) - def list_commands(self, ctx: click.Context) -> List[str]: + def list_commands(self, ctx: click.Context) -> list[str]: self._load_plugin_commands() rv = set(super().list_commands(ctx)) @@ -537,7 +514,7 @@ def list_commands(self, ctx: click.Context) -> List[str]: # Add commands provided by the app, showing an error and # continuing if the app couldn't be loaded. try: - rv.update(info.load_app().cli.list_commands(ctx)) # type: ignore + rv.update(info.load_app().cli.list_commands(ctx)) except NoAppException as e: # When an app couldn't be loaded, show the error message # without the traceback. @@ -675,7 +652,7 @@ def shell_command() -> None: """ banner = ( f"Python {sys.version} on {sys.platform}\n" - f"App: {current_app.import_name} [{current_app.env}]\n" + f"App: {current_app.import_name}\n" f"Instance: {current_app.instance_path}" ) ctx: dict = {} @@ -714,7 +691,7 @@ def shell_command() -> None: @click.option( "--sort", "-s", - type=click.Choice(("endpoint", "methods", "rule", "match")), + type=click.Choice(("endpoint", "methods", "domain", "rule", "match")), default="endpoint", help=( 'Method to sort routes by. "match" is the order that Quart will match ' @@ -731,29 +708,35 @@ def routes_command(sort: str, all_methods: bool) -> None: click.echo("No routes were registered.") return - ignored_methods = set(() if all_methods else ("HEAD", "OPTIONS")) + ignored_methods = set() if all_methods else {"HEAD", "OPTIONS"} + host_matching = current_app.url_map.host_matching + has_domain = any(rule.host if host_matching else rule.subdomain for rule in rules) - if sort in ("endpoint", "rule"): + if sort in ("endpoint", "rule", "domain"): rules = sorted(rules, key=attrgetter(sort)) elif sort == "methods": rules = sorted(rules, key=lambda rule: sorted(rule.methods)) - rule_methods = [", ".join(sorted(rule.methods - ignored_methods)) for rule in rules] - - headers = ("Endpoint", "Methods", "Rule") - widths = ( - max(len(rule.endpoint) for rule in rules), - max(len(methods) for methods in rule_methods), - max(len(rule.rule) for rule in rules), - ) - widths = [max(len(h), w) for h, w in zip(headers, widths)] - row = "{{0:<{0}}} {{1:<{1}}} {{2:<{2}}}".format(*widths) - - click.echo(row.format(*headers).strip()) - click.echo(row.format(*("-" * width for width in widths))) - - for rule, methods in zip(rules, rule_methods): - click.echo(row.format(rule.endpoint, methods, rule.rule).rstrip()) + headers = ["Endpoint", "Methods"] + if has_domain: + headers.append("Host" if host_matching else "Subdomain") + headers.append("Rule") + + rows = [] + for rule in rules: + row = [rule.endpoint, ", ".join(sorted(rule.methods - ignored_methods))] + if has_domain: + row.append((rule.host if host_matching else rule.subdomain) or "") + row.append(rule.rule) + rows.append(row) + + rows.insert(0, headers) + widths = [max(len(row[i]) for row in rows) for i in range(len(headers))] + rows.insert(1, ["-" * w for w in widths]) + template = " ".join(f"{{{i}:<{w}}}" for i, w in enumerate(widths)) + + for row in rows: + click.echo(template.format(*row)) cli = QuartGroup( diff --git a/src/quart/config.py b/src/quart/config.py index ecb4dbb..28190f3 100644 --- a/src/quart/config.py +++ b/src/quart/config.py @@ -1,115 +1,12 @@ from __future__ import annotations -import importlib -import importlib.util import json -import os -from configparser import ConfigParser -from datetime import timedelta -from typing import Any, Callable, Dict, IO, Mapping, Optional, Union +from typing import Any, Callable -from .typing import FilePath -from .utils import file_path_to_path +from flask.config import Config as FlaskConfig, ConfigAttribute as ConfigAttribute # noqa: F401 -DEFAULT_CONFIG = { - "APPLICATION_ROOT": None, - "BODY_TIMEOUT": 60, # Second - "DEBUG": None, - "ENV": None, - "MAX_CONTENT_LENGTH": 16 * 1024 * 1024, # 16 MB Limit - "MAX_COOKIE_SIZE": 4093, - "PERMANENT_SESSION_LIFETIME": timedelta(days=31), - "PREFER_SECURE_URLS": False, # Replaces PREFERRED_URL_SCHEME to allow for WebSocket scheme - "PRESERVE_CONTEXT_ON_EXCEPTION": None, - "PROPAGATE_EXCEPTIONS": None, - "RESPONSE_TIMEOUT": 60, # Second - "SECRET_KEY": None, - "SEND_FILE_MAX_AGE_DEFAULT": timedelta(hours=12), - "SERVER_NAME": None, - "SESSION_COOKIE_DOMAIN": None, - "SESSION_COOKIE_HTTPONLY": True, - "SESSION_COOKIE_NAME": "session", - "SESSION_COOKIE_PATH": None, - "SESSION_COOKIE_SAMESITE": None, - "SESSION_COOKIE_SECURE": False, - "SESSION_REFRESH_EACH_REQUEST": True, - "TEMPLATES_AUTO_RELOAD": None, - "TESTING": False, - "TRAP_HTTP_EXCEPTIONS": False, -} - - -class ConfigAttribute: - """Implements a property descriptor for objects with a config attribute. - - When used as a class instance it will look up the key on the class - config object, for example: - - .. code-block:: python - - class Object: - config = {} - foo = ConfigAttribute('foo') - - obj = Object() - obj.foo = 'bob' - assert obj.foo == obj.config['foo'] - """ - - def __init__(self, key: str, converter: Optional[Callable] = None) -> None: - self.key = key - self.converter = converter - - def __get__(self, instance: Any, owner: Any = None) -> Any: - if instance is None: - return self - result = instance.config[self.key] - if self.converter is not None: - return self.converter(result) - else: - return result - - def __set__(self, instance: Any, value: Any) -> None: - instance.config[self.key] = value - - -class Config(dict): - """Extends a standard Python dictionary with additional load (from) methods. - - Note that the convention (as enforced when loading) is that - configuration keys are upper case. Whilst you can set lower case - keys it is not recommended. - """ - - def __init__(self, root_path: FilePath, defaults: Optional[dict] = None) -> None: - super().__init__(defaults or {}) - self.root_path = file_path_to_path(root_path) - - def from_envvar(self, variable_name: str, silent: bool = False) -> bool: - """Load the configuration from a location specified in the environment. - - This will load a cfg file using :meth:`from_pyfile` from the - location specified in the environment, for example the two blocks - below are equivalent. - - .. code-block:: python - - app.config.from_envvar('CONFIG') - - .. code-block:: python - - filename = os.environ['CONFIG'] - app.config.from_pyfile(filename) - """ - value = os.environ.get(variable_name) - if value is None: - if silent: - return False - raise RuntimeError( - f"Environment variable {variable_name} is not present, cannot load config" - ) - return self.from_pyfile(value) +class Config(FlaskConfig): def from_prefixed_env( self, prefix: str = "QUART", *, loads: Callable[[str], Any] = json.loads ) -> bool: @@ -136,205 +33,4 @@ def from_prefixed_env( raised it is ignored and the value remains a string. The default is :func:`json.loads`. """ - prefix = f"{prefix}_" - len_prefix = len(prefix) - - for key in sorted(os.environ): - if not key.startswith(prefix): - continue - - value = os.environ[key] - - try: - value = loads(value) - except Exception: - # Keep the value as a string if loading failed. - pass - - # Change to key.removeprefix(prefix) on Python >= 3.9. - key = key[len_prefix:] - - if "__" not in key: - # A non-nested key, set directly. - self[key] = value - continue - - # Traverse nested dictionaries with keys separated by "__". - current = self - *parts, tail = key.split("__") - - for part in parts: - # If an intermediate dict does not exist, create it. - if part not in current: - current[part] = {} - - current = current[part] - - current[tail] = value - - return True - - def from_pyfile(self, filename: str, silent: bool = False) -> bool: - """Load the configuration from a Python cfg or py file. - - See Python's ConfigParser docs for details on the cfg format. - It is a common practice to load the defaults from the source - using the :meth:`from_object` and then override with a cfg or - py file, for example - - .. code-block:: python - - app.config.from_object('config_module') - app.config.from_pyfile('production.cfg') - - Arguments: - filename: The filename which when appended to - :attr:`root_path` gives the path to the file - - """ - file_path = self.root_path / filename - try: - spec = importlib.util.spec_from_file_location("module.name", file_path) - if spec is None: # Likely passed a cfg file - parser = ConfigParser() - parser.optionxform = str # type: ignore # Prevents lowercasing of keys - with open(file_path) as file_: - config_str = "[section]\n" + file_.read() - parser.read_string(config_str) - self.from_mapping(parser["section"]) - else: - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - self.from_object(module) - except (FileNotFoundError, IsADirectoryError): - if not silent: - raise - return True - - def from_object(self, instance: Union[object, str]) -> None: - """Load the configuration from a Python object. - - This can be used to reference modules or objects within - modules for example, - - .. code-block:: python - - app.config.from_object('module') - app.config.from_object('module.instance') - from module import instance - app.config.from_object(instance) - - are valid. - - Arguments: - instance: Either a str referencing a python object or the - object itself. - - """ - if isinstance(instance, str): - try: - path, config = instance.rsplit(".", 1) - except ValueError: - path = instance - instance = importlib.import_module(path) - else: - module = importlib.import_module(path) - instance = getattr(module, config) - - for key in dir(instance): - if key.isupper(): - self[key] = getattr(instance, key) - - def from_file( - self, filename: str, load: Callable[[IO[Any]], Mapping], silent: bool = False - ) -> bool: - """Load the configuration from a data file. - - This allows configuration to be loaded as so - - .. code-block:: python - - app.config.from_file('config.toml', toml.load) - app.config.from_file('config.json', json.load) - - Arguments: - filename: The filename which when appended to - :attr:`root_path` gives the path to the file. - load: Callable that takes a file descriptor and - returns a mapping loaded from the file. - silent: If True any errors will fail silently. - """ - file_path = self.root_path / filename - try: - with open(file_path) as file_: - data = load(file_) - except (FileNotFoundError, IsADirectoryError): - if not silent: - raise - else: - return False - else: - return self.from_mapping(data) - - def from_mapping(self, mapping: Optional[Mapping[str, Any]] = None, **kwargs: Any) -> bool: - """Load the configuration values from a mapping. - - This allows either a mapping to be directly passed or as - keyword arguments, for example, - - .. code-block:: python - - config = {'FOO': 'bar'} - app.config.from_mapping(config) - app.config.form_mapping(FOO='bar') - - Arguments: - mapping: Optionally a mapping object. - kwargs: Optionally a collection of keyword arguments to - form a mapping. - """ - mappings: Dict[str, Any] = {} - if mapping is not None: - mappings.update(mapping) - mappings.update(kwargs) - for key, value in mappings.items(): - if key.isupper(): - self[key] = value - return True - - def get_namespace( - self, namespace: str, lowercase: bool = True, trim_namespace: bool = True - ) -> Dict[str, Any]: - """Return a dictionary of keys within a namespace. - - A namespace is considered to be a key prefix, for example the - keys ``FOO_A, FOO_BAR, FOO_B`` are all within the ``FOO`` - namespace. This method would return a dictionary with these - keys and values present. - - .. code-block:: python - - config = {'FOO_A': 'a', 'FOO_BAR': 'bar', 'BAR': False} - app.config.from_mapping(config) - assert app.config.get_namespace('FOO_') == {'a': 'a', 'bar': 'bar'} - - Arguments: - namespace: The namespace itself (should be uppercase). - lowercase: Lowercase the keys in the returned dictionary. - trim_namespace: Remove the namespace from the returned - keys. - """ - config = {} - for key, value in self.items(): - if key.startswith(namespace): - if trim_namespace: - new_key = key[len(namespace) :] - else: - new_key = key - if lowercase: - new_key = new_key.lower() - config[new_key] = value - return config - - def __repr__(self) -> str: - return f"<{type(self).__name__} {dict.__repr__(self)}>" + return super().from_prefixed_env(prefix, loads=loads) diff --git a/src/quart/ctx.py b/src/quart/ctx.py index 806aff7..1915be1 100644 --- a/src/quart/ctx.py +++ b/src/quart/ctx.py @@ -6,6 +6,7 @@ from types import TracebackType from typing import Any, Callable, cast, Iterator, List, Optional, Tuple, TYPE_CHECKING # noqa: F401 +from flask.ctx import _AppCtxGlobals as _AppCtxGlobals # noqa: F401 from werkzeug.exceptions import HTTPException from .globals import _cv_app, _cv_request, _cv_websocket @@ -32,9 +33,9 @@ class _BaseRequestWebsocketContext: def __init__( self, - app: "Quart", + app: Quart, request_websocket: BaseRequestWebsocket, - session: Optional[SessionMixin] = None, + session: SessionMixin | None = None, ) -> None: self.app = app self.request_websocket = request_websocket @@ -43,9 +44,9 @@ def __init__( self.request_websocket.json_module = app.json self.session = session self.preserved = False - self._cv_tokens: List[Tuple[Token, Optional[AppContext]]] = [] + self._cv_tokens: list[tuple[Token, AppContext | None]] = [] - def copy(self) -> "_BaseRequestWebsocketContext": + def copy(self) -> _BaseRequestWebsocketContext: return self.__class__(self.app, self.request_websocket, self.session) def match_request(self) -> None: @@ -68,18 +69,18 @@ def match_request(self) -> None: async def push(self) -> None: raise NotImplementedError() - async def pop(self, exc: Optional[BaseException]) -> None: + async def pop(self, exc: BaseException | None) -> None: raise NotImplementedError() - async def auto_pop(self, exc: Optional[BaseException]) -> None: + async def auto_pop(self, exc: BaseException | None) -> None: if self.request_websocket.scope.get("_quart._preserve_context", False) or ( - exc is not None and self.app.preserve_context_on_exception + exc is not None and self.app.config["PRESERVE_CONTEXT_ON_EXCEPTION"] ): self.preserved = True else: await self.pop(exc) - async def __aenter__(self) -> "_BaseRequestWebsocketContext": + async def __aenter__(self) -> _BaseRequestWebsocketContext: await self.push() return self @@ -124,13 +125,13 @@ class RequestContext(_BaseRequestWebsocketContext): def __init__( self, - app: "Quart", + app: Quart, request: Request, - session: Optional[SessionMixin] = None, + session: SessionMixin | None = None, ) -> None: super().__init__(app, request, session) self.flashes = None - self._after_request_functions: List[AfterRequestCallable] = [] + self._after_request_functions: list[AfterRequestCallable] = [] @property def request(self) -> Request: @@ -140,12 +141,16 @@ async def push(self) -> None: await super()._push_appctx(_cv_request.set(self)) await super()._push() - async def pop(self, exc: Optional[BaseException] = _sentinel) -> None: # type: ignore + async def pop(self, exc: BaseException | None = _sentinel) -> None: # type: ignore try: if len(self._cv_tokens) == 1: if exc is _sentinel: exc = sys.exc_info()[1] await self.app.do_teardown_request(exc, self) + + request_close = getattr(self.request_websocket, "close", None) + if request_close is not None: + await request_close() finally: ctx = _cv_request.get() token, app_ctx = self._cv_tokens.pop() @@ -157,7 +162,7 @@ async def pop(self, exc: Optional[BaseException] = _sentinel) -> None: # type: if ctx is not self: raise AssertionError(f"Popped wrong request context. ({ctx!r} instead of {self!r})") - async def __aenter__(self) -> "RequestContext": + async def __aenter__(self) -> RequestContext: await self.push() return self @@ -176,12 +181,12 @@ class WebsocketContext(_BaseRequestWebsocketContext): def __init__( self, - app: "Quart", + app: Quart, request: Websocket, - session: Optional[SessionMixin] = None, + session: SessionMixin | None = None, ) -> None: super().__init__(app, request, session) - self._after_websocket_functions: List[AfterWebsocketCallable] = [] + self._after_websocket_functions: list[AfterWebsocketCallable] = [] @property def websocket(self) -> Websocket: @@ -191,7 +196,7 @@ async def push(self) -> None: await super()._push_appctx(_cv_websocket.set(self)) await super()._push() - async def pop(self, exc: Optional[BaseException] = _sentinel) -> None: # type: ignore + async def pop(self, exc: BaseException | None = _sentinel) -> None: # type: ignore try: if len(self._cv_tokens) == 1: if exc is _sentinel: @@ -208,13 +213,12 @@ async def pop(self, exc: Optional[BaseException] = _sentinel) -> None: # type: if ctx is not self: raise AssertionError(f"Popped wrong request context. ({ctx!r} instead of {self!r})") - async def __aenter__(self) -> "WebsocketContext": + async def __aenter__(self) -> WebsocketContext: await self.push() return self class AppContext: - """The context relating to the app bound to the current task. Do not use directly, prefer the @@ -227,22 +231,24 @@ class AppContext: g: An instance of the ctx globals class. """ - def __init__(self, app: "Quart") -> None: + def __init__(self, app: Quart) -> None: self.app = app self.url_adapter = app.create_url_adapter(None) self.g = app.app_ctx_globals_class() - self._cv_tokens: List[Token] = [] + self._cv_tokens: list[Token] = [] - def copy(self) -> "AppContext": + def copy(self) -> AppContext: app_context = self.__class__(self.app) app_context.g = self.g return app_context async def push(self) -> None: self._cv_tokens.append(_cv_app.set(self)) - await appcontext_pushed.send(self.app) + await appcontext_pushed.send_async( + self.app, _sync_wrapper=self.app.ensure_async # type: ignore + ) - async def pop(self, exc: Optional[BaseException] = _sentinel) -> None: # type: ignore + async def pop(self, exc: BaseException | None = _sentinel) -> None: # type: ignore try: if len(self._cv_tokens) == 1: if exc is _sentinel: @@ -255,9 +261,11 @@ async def pop(self, exc: Optional[BaseException] = _sentinel) -> None: # type: if ctx is not self: raise AssertionError(f"Popped wrong app context. ({ctx!r} instead of {self!r})") - await appcontext_popped.send(self.app) + await appcontext_popped.send_async( + self.app, _sync_wrapper=self.app.ensure_async # type: ignore + ) - async def __aenter__(self) -> "AppContext": + async def __aenter__(self) -> AppContext: await self.push() return self @@ -447,49 +455,3 @@ def has_websocket_context() -> bool: See also :func:`has_app_context`. """ return _cv_websocket.get(None) is not None - - -class _AppCtxGlobals: - """The g class, a plain object with some mapping methods.""" - - def get(self, name: str, default: Optional[Any] = None) -> Any: - """Get a named attribute of this instance, or return the default.""" - return self.__dict__.get(name, default) - - def pop(self, name: str, default: Any = _sentinel) -> Any: - """Pop, get and remove the named attribute of this instance.""" - if default is _sentinel: - return self.__dict__.pop(name) - else: - return self.__dict__.pop(name, default) - - def setdefault(self, name: str, default: Any = None) -> Any: - """Set an attribute with a default value.""" - return self.__dict__.setdefault(name, default) - - def __contains__(self, item: Any) -> bool: - return item in self.__dict__ - - def __iter__(self) -> Iterator: - return iter(self.__dict__) - - def __repr__(self) -> str: - ctx = _cv_app.get(None) - if ctx is not None: - return f"" - return object.__repr__(self) - - def __getattr__(self, name: str) -> Any: - try: - return self.__dict__[name] - except KeyError: - raise AttributeError(name) from None - - def __setattr__(self, name: str, value: Any) -> None: - self.__dict__[name] = value - - def __delattr__(self, name: str) -> None: - try: - del self.__dict__[name] - except KeyError: - raise AttributeError(name) from None diff --git a/src/quart/datastructures.py b/src/quart/datastructures.py index 3dbd51e..17fe27b 100644 --- a/src/quart/datastructures.py +++ b/src/quart/datastructures.py @@ -2,7 +2,7 @@ from os import PathLike from pathlib import Path -from typing import IO, Optional +from typing import IO from aiofiles import open as async_open from werkzeug.datastructures import FileStorage as WerkzeugFileStorage, Headers @@ -13,12 +13,12 @@ class FileStorage(WerkzeugFileStorage): def __init__( self, - stream: Optional[IO[bytes]] = None, - filename: Optional[str] = None, - name: Optional[str] = None, - content_type: Optional[str] = None, - content_length: Optional[int] = None, - headers: Optional[Headers] = None, + stream: IO[bytes] | None = None, + filename: str | None = None, + name: str | None = None, + content_type: str | None = None, + content_length: int | None = None, + headers: Headers | None = None, ) -> None: super().__init__(stream, filename, name, content_type, content_length, headers) diff --git a/src/quart/debug.py b/src/quart/debug.py index 561bf4a..9b0cb3e 100644 --- a/src/quart/debug.py +++ b/src/quart/debug.py @@ -1,7 +1,6 @@ from __future__ import annotations import inspect -import sys from jinja2 import Template @@ -87,8 +86,9 @@ """ -async def traceback_response() -> Response: - type_, value, tb = sys.exc_info() +async def traceback_response(error: Exception) -> Response: + type_ = type(error) + tb = error.__traceback__ frames = [] while tb: frame = tb.tb_frame @@ -109,5 +109,5 @@ async def traceback_response() -> Response: name = type_.__name__ template = Template(TEMPLATE) - html = template.render(frames=reversed(frames), name=name, value=value) + html = template.render(frames=reversed(frames), name=name, value=error) return Response(html, 500) diff --git a/src/quart/flask_patch/__init__.py b/src/quart/flask_patch/__init__.py deleted file mode 100644 index 42267a1..0000000 --- a/src/quart/flask_patch/__init__.py +++ /dev/null @@ -1,100 +0,0 @@ -from __future__ import annotations # isort:skip - -import quart.flask_patch.app # isort:skip -import quart.flask_patch.cli # isort:skip # noqa: F401 -import quart.flask_patch.globals # isort:skip # noqa: F401 -import quart.flask_patch.testing # isort:skip # noqa: F401 -import quart.views # isort:skip # noqa: F401 -from quart.flask_patch._patch import patch_all # isort:skip - -patch_all() - -from flask.app import Flask # noqa: E402, I100 -from flask.blueprints import Blueprint # noqa: E402 -from flask.config import Config # noqa: E402 -from flask.ctx import ( # noqa: E402 - after_this_request, - copy_current_request_context, - has_app_context, - has_request_context, -) -from flask.globals import current_app, g, request, session # noqa: E402 -from flask.helpers import ( # noqa: E402 - flash, - get_flashed_messages, - get_template_attribute, - make_response, - send_file, - send_from_directory, - stream_with_context, - url_for, -) -from flask.json import jsonify # noqa: E402 -from flask.signals import ( # noqa: E402 - appcontext_popped, - appcontext_pushed, - appcontext_tearing_down, - before_render_template, - got_request_exception, - message_flashed, - request_finished, - request_started, - request_tearing_down, - signals_available, - template_rendered, -) -from flask.templating import render_template, render_template_string # noqa: E402 -from flask.typing import ResponseReturnValue # noqa: E402 -from flask.wrappers import Request, Response # noqa: E402 -from markupsafe import escape, Markup # noqa: E402 -from werkzeug.exceptions import abort # noqa: E402 -from werkzeug.utils import redirect # noqa: E402 - -__all__ = ( - "abort", - "after_this_request", - "appcontext_popped", - "appcontext_pushed", - "appcontext_tearing_down", - "before_render_template", - "Blueprint", - "Config", - "copy_current_request_context", - "current_app", - "escape", - "flash", - "Flask", - "g", - "get_flashed_messages", - "get_template_attribute", - "got_request_exception", - "has_app_context", - "has_request_context", - "jsonify", - "make_response", - "Markup", - "message_flashed", - "redirect", - "render_template", - "render_template_string", - "request", - "Request", - "request_finished", - "request_started", - "request_tearing_down", - "Response", - "ResponseReturnValue", - "send_file", - "send_from_directory", - "session", - "signals_available", - "stream_with_context", - "template_rendered", - "url_for", -) - - -import sys # isort:skip # noqa: E402, I100 - -json = sys.modules["flask.json"] -sys.modules["flask"] = sys.modules[__name__] diff --git a/src/quart/flask_patch/_patch.py b/src/quart/flask_patch/_patch.py deleted file mode 100644 index be802cb..0000000 --- a/src/quart/flask_patch/_patch.py +++ /dev/null @@ -1,124 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -import sys -import types -from typing import Any, Callable - -from ._synchronise import sync_with_context - - -def _patch_asyncio() -> None: - # This patches asyncio to add a sync_wait method to the event - # loop. This method can then be called from within a task - # including a synchronous function called from a task. Sadly it - # requires the python Task and Future implementations, which - # invokes some performance cost. - asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = asyncio.tasks._PyTask # type: ignore - asyncio.Future = ( # type: ignore - asyncio.futures._CFuture # type: ignore - ) = asyncio.futures.Future = asyncio.futures._PyFuture # type: ignore # noqa - - current_policy = asyncio.get_event_loop_policy() - if hasattr(asyncio, "unix_events"): - target_policy = asyncio.unix_events._UnixDefaultEventLoopPolicy - else: - target_policy = object # type: ignore - - if not isinstance(current_policy, target_policy): - raise RuntimeError("Flask Patching only works with the default event loop policy") - - _patch_loop() - _patch_task() - - -def _patch_loop() -> None: - def _sync_wait(self, future): # type: ignore - preserved_ready = list(self._ready) - self._ready.clear() - current_task = asyncio.tasks._current_tasks.get(self) # type: ignore[attr-defined] - future = asyncio.tasks.ensure_future(future, loop=self) - while not future.done() and not future.cancelled(): - self._run_once() - if self._stopping: - break - if current_task._must_cancel: - future.cancel() - self._ready.extendleft(preserved_ready) - return future.result() - - asyncio.BaseEventLoop.sync_wait = _sync_wait # type: ignore - - -def _patch_task() -> None: - # Patch the asyncio task to allow it to be re-entered. - def enter_task(loop, task): # type: ignore - asyncio.tasks._current_tasks[loop] = task # type: ignore[attr-defined] - - asyncio.tasks._enter_task = enter_task - - def leave_task(loop, task): # type: ignore - del asyncio.tasks._current_tasks[loop] # type: ignore[attr-defined] - - asyncio.tasks._leave_task = leave_task - - def step(self, exception=None): # type: ignore - current_task = asyncio.tasks._current_tasks.get(self._loop) # type: ignore[attr-defined] - try: - self._Task__step_orig(exception) - finally: - if current_task is None: - asyncio.tasks._current_tasks.pop(self._loop, None) # type: ignore[attr-defined] - else: - asyncio.tasks._current_tasks[ # type: ignore[attr-defined] - self._loop - ] = current_task - - asyncio.Task._Task__step_orig = asyncio.Task._Task__step # type: ignore - asyncio.Task._Task__step = step # type: ignore - - -def _context_decorator(func: Callable) -> Callable: - def wrapper(*args: Any, **kwargs: Any) -> Any: - return sync_with_context(func(*args, **kwargs)) - - return wrapper - - -def _convert_module(new_name, module): # type: ignore - new_module = types.ModuleType(new_name) - for name, member in inspect.getmembers(module): - if inspect.getmodule(member) == module and inspect.iscoroutinefunction(member): - setattr(new_module, name, _context_decorator(member)) - else: - setattr(new_module, name, member) - setattr(new_module, "_QUART_PATCHED", True) - return new_module - - -def _patch_modules() -> None: - if "flask" in sys.modules: - raise ImportError("Cannot mock flask, already imported") - - # Create a set of Flask modules, prioritising those within the - # flask_patch namespace over simple references to the Quart - # versions. - flask_modules = {} - for name, module in list(sys.modules.items()): - if name.startswith("quart.flask_patch._"): - continue - elif name.startswith("quart.flask_patch"): - setattr(module, "_QUART_PATCHED", True) - flask_modules[name.replace("quart.flask_patch", "flask")] = module - elif name.startswith("quart.") and not name.startswith("quart.serving"): - flask_name = name.replace("quart.", "flask.") - if flask_name not in flask_modules: - flask_modules[flask_name] = _convert_module(flask_name, module) - - sys.modules.update(flask_modules) - - -def patch_all() -> None: - _patch_asyncio() - _patch_modules() diff --git a/src/quart/flask_patch/_synchronise.py b/src/quart/flask_patch/_synchronise.py deleted file mode 100644 index 9ca97f5..0000000 --- a/src/quart/flask_patch/_synchronise.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import Any, Awaitable - -from quart.globals import _cv_app, _cv_request, _cv_websocket - - -def sync_with_context(future: Awaitable) -> Any: - context: Any = None - if _cv_request.get(None) is not None: - context = _cv_request.get().copy() - elif _cv_websocket.get(None) is not None: - context = _cv_websocket.get().copy() - elif _cv_app.get(None) is not None: - context = _cv_app.get().copy() - - async def context_wrapper() -> Any: - if context is not None: - async with context: - return await future - else: - return await future - - return asyncio.get_event_loop().sync_wait(context_wrapper()) # type: ignore diff --git a/src/quart/flask_patch/app.py b/src/quart/flask_patch/app.py deleted file mode 100644 index 3a58207..0000000 --- a/src/quart/flask_patch/app.py +++ /dev/null @@ -1,72 +0,0 @@ -# The aim is to replace the Quart class exception handling defaults to -# allow for Werkzeug HTTPExceptions to be considered in a special way -# (like the quart HTTPException). In addition a Flask reference is -# created. -from __future__ import annotations - -from functools import wraps -from inspect import iscoroutine -from typing import Any, Awaitable, Callable, Optional, Union - -from werkzeug.wrappers import Response as WerkzeugResponse - -from quart import Response -from quart.app import Quart -from quart.ctx import RequestContext -from quart.globals import request_ctx -from quart.utils import is_coroutine_function -from ._synchronise import sync_with_context - -old_full_dispatch_request = Quart.full_dispatch_request - - -async def new_full_dispatch_request( - self: Quart, request_context: Optional[RequestContext] = None -) -> Union[Response, WerkzeugResponse]: - request_ = (request_context or request_ctx).request - await request_.get_data() - return await old_full_dispatch_request(self, request_context) - - -Quart.full_dispatch_request = new_full_dispatch_request # type: ignore - - -def new_ensure_async( # type: ignore - self, func: Callable[..., Any] -) -> Callable[..., Awaitable[Any]]: - - if is_coroutine_function(func): - return func - else: - - @wraps(func) - async def _wrapper(*args: Any, **kwargs: Any) -> Any: - result = func(*args, **kwargs) - if iscoroutine(result): - return await result - else: - return result - - return _wrapper - - -Quart.ensure_async = new_ensure_async # type: ignore - - -def ensure_sync(self, func: Callable) -> Callable: # type: ignore - if is_coroutine_function(func): - - @wraps(func) - def _wrapper(*args: Any, **kwargs: Any) -> Any: - return sync_with_context(func(*args, **kwargs)) - - return _wrapper - else: - return func - - -Quart.ensure_sync = ensure_sync # type: ignore - -Flask = Quart - -__all__ = ("Quart",) diff --git a/src/quart/flask_patch/cli.py b/src/quart/flask_patch/cli.py deleted file mode 100644 index 2a5fd62..0000000 --- a/src/quart/flask_patch/cli.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import annotations - -from quart.cli import ( # noqa: F401 - AppGroup, - QuartGroup, - run_command, - ScriptInfo, - shell_command, - with_appcontext, -) - -FlaskGroup = QuartGroup diff --git a/src/quart/flask_patch/globals.py b/src/quart/flask_patch/globals.py deleted file mode 100644 index 28eae0f..0000000 --- a/src/quart/flask_patch/globals.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from typing import Any, AnyStr - -from werkzeug.datastructures import MultiDict -from werkzeug.local import LocalProxy - -from quart.globals import ( - _cv_app, - _cv_request, - _cv_websocket, - app_ctx, - current_app, - g, - request as quart_request, - request_ctx, - session, -) -from ._synchronise import sync_with_context - - -class FlaskRequestProxy(LocalProxy): - @property - def data(self) -> bytes: - return sync_with_context(self._get_current_object().data) # type: ignore - - @property - def form(self) -> MultiDict: - return sync_with_context(self._get_current_object().form) # type: ignore - - @property - def files(self) -> MultiDict: - return sync_with_context(self._get_current_object().files) # type: ignore - - @property - def json(self) -> Any: - return sync_with_context(self._get_current_object().json) # type: ignore - - def get_json(self, *args: Any, **kwargs: Any) -> Any: - return sync_with_context(self._get_current_object().get_json(*args, **kwargs)) # type: ignore # noqa: E501 - - def get_data(self, *args: Any, **kwargs: Any) -> AnyStr: - return sync_with_context(self._get_current_object().get_data(*args, **kwargs)) # type: ignore # noqa: E501 - - -request = FlaskRequestProxy(lambda: quart_request) - - -__all__ = ( - "_cv_app", - "_cv_request", - "_cv_websocket", - "app_ctx", - "current_app", - "g", - "request", - "request_ctx", - "session", -) diff --git a/src/quart/flask_patch/testing.py b/src/quart/flask_patch/testing.py deleted file mode 100644 index 892e453..0000000 --- a/src/quart/flask_patch/testing.py +++ /dev/null @@ -1,7 +0,0 @@ -from __future__ import annotations - -from quart.testing import QuartClient - -FlaskClient = QuartClient - -__all__ = ("FlaskClient",) diff --git a/src/quart/formparser.py b/src/quart/formparser.py index 7bc498e..eab878a 100644 --- a/src/quart/formparser.py +++ b/src/quart/formparser.py @@ -7,20 +7,17 @@ cast, Dict, IO, - List, NoReturn, Optional, Tuple, - Type, TYPE_CHECKING, - Union, ) +from urllib.parse import parse_qsl from werkzeug.datastructures import Headers, MultiDict from werkzeug.formparser import default_stream_factory from werkzeug.http import parse_options_header from werkzeug.sansio.multipart import Data, Epilogue, Field, File, MultipartDecoder, NeedData -from werkzeug.urls import url_decode from .datastructures import FileStorage @@ -44,29 +41,25 @@ class FormDataParser: def __init__( self, stream_factory: StreamFactory = default_stream_factory, - charset: str = "utf-8", - errors: str = "replace", - max_form_memory_size: Optional[int] = None, - max_content_length: Optional[int] = None, - cls: Optional[Type[MultiDict]] = MultiDict, + max_form_memory_size: int | None = None, + max_content_length: int | None = None, + cls: type[MultiDict] | None = MultiDict, silent: bool = True, ) -> None: self.stream_factory = stream_factory - self.charset = charset - self.errors = errors self.cls = cls self.silent = silent - def get_parse_func(self, mimetype: str, options: Dict[str, str]) -> Optional[ParserFunc]: + def get_parse_func(self, mimetype: str, options: dict[str, str]) -> ParserFunc | None: return self.parse_functions.get(mimetype) async def parse( self, - body: "Body", + body: Body, mimetype: str, - content_length: Optional[int], - options: Optional[Dict[str, str]] = None, - ) -> Tuple[MultiDict, MultiDict]: + content_length: int | None, + options: dict[str, str] | None = None, + ) -> tuple[MultiDict, MultiDict]: if options is None: options = {} @@ -83,15 +76,13 @@ async def parse( async def _parse_multipart( self, - body: "Body", + body: Body, mimetype: str, - content_length: Optional[int], - options: Dict[str, str], - ) -> Tuple[MultiDict, MultiDict]: + content_length: int | None, + options: dict[str, str], + ) -> tuple[MultiDict, MultiDict]: parser = MultiPartParser( self.stream_factory, - self.charset, - self.errors, cls=self.cls, file_storage_cls=self.file_storage_class, ) @@ -104,15 +95,18 @@ async def _parse_multipart( async def _parse_urlencoded( self, - body: "Body", + body: Body, mimetype: str, - content_length: Optional[int], - options: Dict[str, str], - ) -> Tuple[MultiDict, MultiDict]: - form = url_decode(await body, self.charset, errors=self.errors, cls=self.cls) - return form, self.cls() + content_length: int | None, + options: dict[str, str], + ) -> tuple[MultiDict, MultiDict]: + form = parse_qsl( + (await body).decode(), + keep_blank_values=True, + ) + return self.cls(form), self.cls() - parse_functions: Dict[str, ParserFunc] = { + parse_functions: dict[str, ParserFunc] = { "multipart/form-data": _parse_multipart, "application/x-www-form-urlencoded": _parse_urlencoded, "application/x-url-encoded": _parse_urlencoded, @@ -123,15 +117,11 @@ class MultiPartParser: def __init__( self, stream_factory: StreamFactory = default_stream_factory, - charset: str = "utf-8", - errors: str = "replace", - max_form_memory_size: Optional[int] = None, - cls: Type[MultiDict] = MultiDict, + max_form_memory_size: int | None = None, + cls: type[MultiDict] = MultiDict, buffer_size: int = 64 * 1024, - file_storage_cls: Type[FileStorage] = FileStorage, + file_storage_cls: type[FileStorage] = FileStorage, ) -> None: - self.charset = charset - self.errors = errors self.max_form_memory_size = max_form_memory_size self.stream_factory = stream_factory self.cls = cls @@ -145,10 +135,15 @@ def get_part_charset(self, headers: Headers) -> str: content_type = headers.get("content-type") if content_type: - mimetype, ct_params = parse_options_header(content_type) - return ct_params.get("charset", self.charset) + parameters = parse_options_header(content_type)[1] + ct_charset = parameters.get("charset", "").lower() + + # A safe list of encodings. Modern clients should only send ASCII or UTF-8. + # This list will not be extended further. + if ct_charset in {"ascii", "us-ascii", "utf-8", "iso-8859-1"}: + return ct_charset - return self.charset + return "utf-8" def start_file_streaming(self, event: File, total_content_length: int) -> IO[bytes]: content_type = event.headers.get("content-type") @@ -167,9 +162,9 @@ def start_file_streaming(self, event: File, total_content_length: int) -> IO[byt return container async def parse( - self, body: "Body", boundary: bytes, content_length: int - ) -> Tuple[MultiDict, MultiDict]: - container: Union[IO[bytes], List[bytes]] + self, body: Body, boundary: bytes, content_length: int + ) -> tuple[MultiDict, MultiDict]: + container: IO[bytes] | list[bytes] _write: Callable[[bytes], Any] parser = MultipartDecoder(boundary, self.max_form_memory_size) @@ -177,7 +172,7 @@ async def parse( fields = [] files = [] - current_part: Union[Field, File] + current_part: Field | File async for data in body: parser.receive_data(data) event = parser.next_event() @@ -195,7 +190,7 @@ async def parse( if not event.more_data: if isinstance(current_part, Field): value = b"".join(container).decode( - self.get_part_charset(current_part.headers), self.errors + self.get_part_charset(current_part.headers), "replace" ) fields.append((current_part.name, value)) else: diff --git a/src/quart/globals.py b/src/quart/globals.py index 173cb68..af2c790 100644 --- a/src/quart/globals.py +++ b/src/quart/globals.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextvars import ContextVar -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from werkzeug.local import LocalProxy @@ -42,7 +42,7 @@ ) -def _session_lookup() -> Union[RequestContext, WebsocketContext]: +def _session_lookup() -> RequestContext | WebsocketContext: try: return _cv_request.get() except LookupError: diff --git a/src/quart/helpers.py b/src/quart/helpers.py index f61b3fa..8dd1146 100644 --- a/src/quart/helpers.py +++ b/src/quart/helpers.py @@ -4,20 +4,21 @@ import os import pkgutil import sys -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from functools import lru_cache, wraps from io import BytesIO from pathlib import Path -from typing import Any, Callable, Iterable, List, NoReturn, Optional, Tuple, Union +from typing import Any, Callable, cast, Iterable, NoReturn from zlib import adler32 +from flask.helpers import get_root_path as get_root_path # noqa: F401 from werkzeug.exceptions import abort as werkzeug_abort, NotFound from werkzeug.utils import redirect as werkzeug_redirect, safe_join from werkzeug.wrappers import Response as WerkzeugResponse from .globals import _cv_request, current_app, request, request_ctx, session from .signals import message_flashed -from .typing import FilePath +from .typing import FilePath, ResponseReturnValue, ResponseTypes from .utils import file_path_to_path from .wrappers import Response from .wrappers.response import ResponseBody @@ -33,11 +34,7 @@ def get_debug_flag() -> bool: configured, it will be enabled automatically. """ value = os.getenv("QUART_DEBUG", None) - - if value is None: - return "development" == get_env() - - return value.lower() not in {"0", "false", "no"} + return bool(value and value.lower() not in {"0", "false", "no"}) def get_load_dotenv(default: bool = True) -> bool: @@ -54,14 +51,7 @@ def get_load_dotenv(default: bool = True) -> bool: return val.lower() in ("0", "false", "no") -def get_env(default: Optional[str] = "production") -> str: - """Reads QUART_ENV environment variable to determine in which environment - the app is running on. Defaults to 'production' when unset. - """ - return os.getenv("QUART_ENV", default) - - -async def make_response(*args: Any) -> Union[Response, WerkzeugResponse]: +async def make_response(*args: Any) -> ResponseTypes: """Create a response, a simple wrapper function. This is most useful when you want to alter a Response before @@ -78,7 +68,7 @@ async def make_response(*args: Any) -> Union[Response, WerkzeugResponse]: if len(args) == 1: args = args[0] - return await current_app.make_response(args) + return await current_app.make_response(cast(ResponseReturnValue, args)) async def make_push_promise(path: str) -> None: @@ -113,14 +103,15 @@ async def login(): flashes = session.get("_flashes", []) flashes.append((category, message)) session["_flashes"] = flashes - await message_flashed.send( - current_app._get_current_object(), message=message, category=category # type: ignore + app = current_app._get_current_object() # type: ignore + await message_flashed.send_async( + app, _sync_wrapper=app.ensure_async, message=message, category=category ) def get_flashed_messages( with_categories: bool = False, category_filter: Iterable[str] = () -) -> Union[List[str], List[Tuple[str, str]]]: +) -> list[str] | list[tuple[str, str]]: """Retrieve the flashed messages stored in the session. This is mostly useful in templates where it is exposed as a global @@ -149,20 +140,6 @@ def get_flashed_messages( return flashes -def get_root_path(import_name: str) -> str: - """Find the root path of the *import_name*""" - module = sys.modules.get(import_name) - if module is not None and hasattr(module, "__file__"): - file_path = module.__file__ - else: - loader = pkgutil.get_loader(import_name) - if loader is None or import_name == "__main__": - return str(Path.cwd()) - else: - file_path = loader.get_filename(import_name) # type: ignore - return str(Path(file_path).resolve().parent) - - def get_template_attribute(template_name: str, attribute: str) -> Any: """Load a attribute from a template. @@ -179,10 +156,10 @@ def get_template_attribute(template_name: str, attribute: str) -> Any: def url_for( endpoint: str, *, - _anchor: Optional[str] = None, - _external: Optional[bool] = None, - _method: Optional[str] = None, - _scheme: Optional[str] = None, + _anchor: str | None = None, + _external: bool | None = None, + _method: str | None = None, + _scheme: str | None = None, **values: Any, ) -> str: """Return the url for a specific endpoint. @@ -240,7 +217,7 @@ async def generator(*args: Any, **kwargs: Any) -> Any: return generator -def find_package(name: str) -> Tuple[Optional[Path], Path]: +def find_package(name: str) -> tuple[Path | None, Path]: """Finds packages install prefix (or None) and it's containing Folder""" module = name.split(".")[0] loader = pkgutil.get_loader(module) @@ -248,13 +225,13 @@ def find_package(name: str) -> Tuple[Optional[Path], Path]: package_path = Path.cwd() else: if hasattr(loader, "get_filename"): - filename = loader.get_filename(module) # type: ignore + filename = loader.get_filename(module) else: __import__(name) filename = sys.modules[name].__file__ package_path = Path(filename).resolve().parent if hasattr(loader, "is_package"): - is_package = loader.is_package(module) # type: ignore + is_package = loader.is_package(module) if is_package: package_path = Path(package_path).resolve().parent sys_prefix = Path(sys.prefix).resolve() @@ -270,13 +247,13 @@ async def send_from_directory( directory: FilePath, file_name: str, *, - mimetype: Optional[str] = None, + mimetype: str | None = None, as_attachment: bool = False, - attachment_filename: Optional[str] = None, + attachment_filename: str | None = None, add_etags: bool = True, - cache_timeout: Optional[int] = None, + cache_timeout: int | None = None, conditional: bool = True, - last_modified: Optional[datetime] = None, + last_modified: datetime | None = None, ) -> Response: """Send a file from a given directory. @@ -307,14 +284,14 @@ async def send_from_directory( async def send_file( - filename_or_io: Union[FilePath, BytesIO], - mimetype: Optional[str] = None, + filename_or_io: FilePath | BytesIO, + mimetype: str | None = None, as_attachment: bool = False, - attachment_filename: Optional[str] = None, + attachment_filename: str | None = None, add_etags: bool = True, - cache_timeout: Optional[int] = None, + cache_timeout: int | None = None, conditional: bool = False, - last_modified: Optional[datetime] = None, + last_modified: datetime | None = None, ) -> Response: """Return a Response to send the filename given. @@ -333,8 +310,8 @@ async def send_file( """ file_body: ResponseBody - file_size: Optional[int] = None - etag: Optional[str] = None + file_size: int | None = None + etag: str | None = None if isinstance(filename_or_io, BytesIO): file_body = current_app.response_class.io_body_class(filename_or_io) file_size = filename_or_io.getbuffer().nbytes @@ -371,7 +348,7 @@ async def send_file( response.cache_control.public = True if cache_timeout is not None: response.cache_control.max_age = cache_timeout - response.expires = datetime.utcnow() + timedelta(seconds=cache_timeout) + response.expires = datetime.now(timezone.utc) + timedelta(seconds=cache_timeout) if add_etags and etag is not None: response.set_etag(etag) @@ -382,14 +359,14 @@ async def send_file( @lru_cache(maxsize=None) -def _split_blueprint_path(name: str) -> List[str]: +def _split_blueprint_path(name: str) -> list[str]: bps = [name] while "." in bps[-1]: bps.append(bps[-1].rpartition(".")[0]) return bps -def abort(code: int, *args: Any, **kwargs: Any) -> NoReturn: # type: ignore[misc] +def abort(code: int | Response, *args: Any, **kwargs: Any) -> NoReturn: """Raise an HTTPException for the given status code.""" if current_app: current_app.aborter(code, *args, **kwargs) diff --git a/src/quart/json/__init__.py b/src/quart/json/__init__.py index 4574506..9753dd9 100644 --- a/src/quart/json/__init__.py +++ b/src/quart/json/__init__.py @@ -3,7 +3,8 @@ import json from typing import Any, IO, TYPE_CHECKING -from .provider import _default +from flask.json.provider import _default + from ..globals import current_app if TYPE_CHECKING: @@ -11,22 +12,34 @@ def dumps(object_: Any, **kwargs: Any) -> str: - kwargs.setdefault("default", _default) - return json.dumps(object_, **kwargs) + if current_app: + return current_app.json.dumps(object_, **kwargs) + else: + kwargs.setdefault("default", _default) + return json.dumps(object_, **kwargs) def dump(object_: Any, fp: IO[str], **kwargs: Any) -> None: - kwargs.setdefault("default", _default) - json.dump(object_, fp, **kwargs) + if current_app: + current_app.json.dump(object_, fp, **kwargs) + else: + kwargs.setdefault("default", _default) + json.dump(object_, fp, **kwargs) -def loads(object_: str, **kwargs: Any) -> Any: - return json.loads(object_, **kwargs) +def loads(object_: str | bytes, **kwargs: Any) -> Any: + if current_app: + return current_app.json.loads(object_, **kwargs) + else: + return json.loads(object_, **kwargs) def load(fp: IO[str], **kwargs: Any) -> Any: - return json.load(fp, **kwargs) + if current_app: + return current_app.json.load(fp, **kwargs) + else: + return json.load(fp, **kwargs) -def jsonify(*args: Any, **kwargs: Any) -> "Response": - return current_app.json.response(*args, **kwargs) +def jsonify(*args: Any, **kwargs: Any) -> Response: + return current_app.json.response(*args, **kwargs) # type: ignore diff --git a/src/quart/json/provider.py b/src/quart/json/provider.py index ed2eaad..17fbea4 100644 --- a/src/quart/json/provider.py +++ b/src/quart/json/provider.py @@ -1,205 +1,4 @@ -from __future__ import annotations - -import json -import weakref -from dataclasses import asdict, is_dataclass -from datetime import date -from decimal import Decimal -from typing import Any, AnyStr, Callable, Dict, IO, Tuple, TYPE_CHECKING -from uuid import UUID - -from werkzeug.http import http_date - -if TYPE_CHECKING: - from ..app import Quart - from ..wrappers import Response - - -class JSONProvider: - """A standard set of JSON operations for an application. Subclasses - of this can be used to customize JSON behavior or use different - JSON libraries. - - To implement a provider for a specific library, subclass this base - class and implement at least :meth:`dumps` and :meth:`loads`. All - other methods have default implementations. - - To use a different provider, either subclass ``Quart`` and set - :attr:`~quart.Quart.json_provider_class` to a provider class, or set - :attr:`app.json ` to an instance of the class. - :param app: An application instance. This will be stored as a - :class:`weakref.proxy` on the :attr:`_app` attribute. - """ - - def __init__(self, app: Quart) -> None: - self._app = weakref.proxy(app) - - def dumps(self, object_: Any, **kwargs: Any) -> str: - """Serialize data as JSON. - - Arguments: - object_: The data to serialize. - kwargs: May be passed to the underlying JSON library. - """ - raise NotImplementedError - - def dump(self, object_: Any, fp: IO[str], **kwargs: Any) -> None: - """Serialize data as JSON and write to a file. - - Arguments: - object_: The data to serialize. - fp: A file opened for writing text. Should use the UTF-8 - encoding to be valid JSON. - kwargs: May be passed to the underlying JSON library. - """ - fp.write(self.dumps(object_, **kwargs)) - - def loads(self, object_: str | bytes, **kwargs: Any) -> Any: - """Deserialize data as JSON. - - Arguments: - s: Text or UTF-8 bytes. - kwargs: May be passed to the underlying JSON library. - """ - raise NotImplementedError - - def load(self, fp: IO[AnyStr], **kwargs: Any) -> Any: - """Deserialize data as JSON read from a file. - :param fp: A file opened for reading text or UTF-8 bytes. - :param kwargs: May be passed to the underlying JSON library. - """ - return self.loads(fp.read(), **kwargs) - - def _prepare_response_obj(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: - if args and kwargs: - raise TypeError("app.json.response() takes either args or kwargs, not both") - - if not args and not kwargs: - return None - - if len(args) == 1: - return args[0] - - return args or kwargs - - def response(self, *args: Any, **kwargs: Any) -> Response: - """Serialize the given arguments as JSON, and return a - :class:`~quart.Response` object with the ``application/json`` - mimetype. - - The :func:`~quart.json.jsonify` function calls this method for - the current application. - Either positional or keyword arguments can be given, not both. - If no arguments are given, ``None`` is serialized. - - Arguments: - args: A single value to serialize, or multiple values to - treat as a list to serialize. - kwargs: Treat as a dict to serialize. - """ - object_ = self._prepare_response_obj(args, kwargs) - return self._app.response_class(self.dumps(object_), mimetype="application/json") - - -def _default(object_: Any) -> Any: - if isinstance(object_, date): - return http_date(object_) - if isinstance(object_, (Decimal, UUID)): - return str(object_) - if is_dataclass(object_): - return asdict(object_) - if hasattr(object_, "__html__"): - return str(object_.__html__()) - - raise TypeError(f"Object of type {type(object_).__name__} is not JSON serializable") - - -class DefaultJSONProvider(JSONProvider): - """Provide JSON operations using Python's built-in :mod:`json` - library. Serializes the following additional data types: - - :class:`datetime.datetime` and :class:`datetime.date` are - serialized to :rfc:`822` strings. This is the same as the HTTP - date format. - - :class:`uuid.UUID` is serialized to a string. - - :class:`dataclasses.dataclass` is passed to - :func:`dataclasses.asdict`. - - :class:`~markupsafe.Markup` (or any object with a ``__html__`` - method) will call the ``__html__`` method to get a string. - """ - - default: Callable[[Any], Any] = staticmethod(_default) - """Apply this function to any object that :meth:`json.dumps` does - not know how to serialize. It should return a valid JSON type or - raise a ``TypeError``. - """ - - ensure_ascii = True - """Replace non-ASCII characters with escape sequences. This may be - more compatible with some clients, but can be disabled for better - performance and size. - """ - - sort_keys = True - """Sort the keys in any serialized dicts. This may be useful for - some caching situations, but can be disabled for better performance. - When enabled, keys must all be strings, they are not converted - before sorting. - """ - - compact: bool | None = None - """If ``True``, or ``None`` out of debug mode, the :meth:`response` - output will not add indentation, newlines, or spaces. If ``False``, - or ``None`` in debug mode, it will use a non-compact representation. - """ - - mimetype = "application/json" - """The mimetype set in :meth:`response`.""" - - def dumps(self, object_: Any, **kwargs: Any) -> str: - """Serialize data as JSON to a string. - Keyword arguments are passed to :func:`json.dumps`. Sets some - parameter defaults from the :attr:`default`, - :attr:`ensure_ascii`, and :attr:`sort_keys` attributes. - - Arguments: - object_: The data to serialize. - kwargs: Passed to :func:`json.dumps`. - """ - kwargs.setdefault("default", self.default) - kwargs.setdefault("ensure_ascii", self.ensure_ascii) - kwargs.setdefault("sort_keys", self.sort_keys) - return json.dumps(object_, **kwargs) - - def loads(self, object_: str | bytes, **kwargs: Any) -> Any: - """Deserialize data as JSON from a string or bytes. - - Arguments: - object_: Text or UTF-8 bytes. - kwargs: Passed to :func:`json.loads`. - """ - return json.loads(object_, **kwargs) - - def response(self, *args: Any, **kwargs: Any) -> Response: - """Serialize the given arguments as JSON, and return a - :class:`~quart.Response` object with it. The response mimetype - will be "application/json" and can be changed with - :attr:`mimetype`. - If :attr:`compact` is ``False`` or debug mode is enabled, the - output will be formatted to be easier to read. - Either positional or keyword arguments can be given, not both. - If no arguments are given, ``None`` is serialized. - - Arguments: - args: A single value to serialize, or multiple values to - treat as a list to serialize. - kwargs: Treat as a dict to serialize. - """ - object_ = self._prepare_response_obj(args, kwargs) - dump_args: Dict[str, Any] = {} - - if (self.compact is None and self._app.debug) or self.compact is False: - dump_args.setdefault("indent", 2) - else: - dump_args.setdefault("separators", (",", ":")) - - return self._app.response_class(self.dumps(object_, **dump_args), mimetype=self.mimetype) +from flask.json.provider import ( # noqa: F401 + DefaultJSONProvider as DefaultJSONProvider, + JSONProvider as JSONProvider, +) diff --git a/src/quart/json/tag.py b/src/quart/json/tag.py index 6b6e8b9..0ff177b 100644 --- a/src/quart/json/tag.py +++ b/src/quart/json/tag.py @@ -1,203 +1,12 @@ -from __future__ import annotations - -from base64 import b64decode, b64encode -from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, Type -from uuid import UUID - -from markupsafe import Markup -from werkzeug.http import parse_date - -from quart.json import dumps, loads - - -class JSONTag: - key: Optional[str] = None - - def __init__(self, serializer: "TaggedJSONSerializer") -> None: - self.serializer = serializer - - def check(self, value: Any) -> bool: - raise NotImplementedError - - def to_json(self, value: Any) -> Any: - raise NotImplementedError - - def to_python(self, value: str) -> Any: - raise NotImplementedError - - def tag(self, value: Any) -> Any: - return {self.key: self.to_json(value)} - - -class TagDict(JSONTag): - key = " di" - - def check(self, value: Any) -> bool: - return ( - isinstance(value, dict) - and len(value) == 1 - and next(iter(value)) in self.serializer.tags - ) - - def to_json(self, value: Any) -> Dict[str, Any]: - key = next(iter(value)) - return {key + "__": self.serializer.tag(value[key])} - - def to_python(self, value: str) -> Dict[str, Any]: - key, item = next(iter(value)) # type: ignore - return {key[:-2]: item} # type: ignore - - -class PassDict(JSONTag): - def check(self, value: Any) -> bool: - return isinstance(value, dict) - - def to_json(self, value: Any) -> Dict[str, Any]: - return {key: self.serializer.tag(item) for key, item in value.items()} - - tag = to_json - - -class TagTuple(JSONTag): - key = " t" - - def check(self, value: Any) -> bool: - return isinstance(value, tuple) - - def to_json(self, value: Tuple[Any]) -> List[Any]: - return [self.serializer.tag(item) for item in value] - - def to_python(self, value: Any) -> Tuple[Any, ...]: - return tuple(value) - - -class PassList(JSONTag): - def check(self, value: Any) -> bool: - return isinstance(value, list) - - def to_json(self, value: List[Any]) -> List[Any]: - return [self.serializer.tag(item) for item in value] - - tag = to_json - - -class TagBytes(JSONTag): - key = " b" - - def check(self, value: Any) -> bool: - return isinstance(value, bytes) - - def to_json(self, value: bytes) -> str: - return b64encode(value).decode("ascii") - - def to_python(self, value: str) -> bytes: - return b64decode(value) - - -class TagMarkup(JSONTag): - key = " m" - - def check(self, value: Any) -> bool: - return callable(getattr(value, "__html__", None)) - - def to_json(self, value: Any) -> str: - return str(value.__html__()) - - def to_python(self, value: str) -> Markup: - return Markup(value) - - -class TagUUID(JSONTag): - key = " u" - - def check(self, value: Any) -> bool: - return isinstance(value, UUID) - - def to_json(self, value: Any) -> str: - return value.hex - - def to_python(self, value: str) -> UUID: - return UUID(value) - - -def _parse_datetime(value: str) -> datetime: - try: - return datetime.fromisoformat(value) - except ValueError: - return parse_date(value) - - -class TagDateTime(JSONTag): - key = " d" - - def check(self, value: Any) -> bool: - return isinstance(value, datetime) - - def to_json(self, value: datetime) -> str: - return value.isoformat(timespec="microseconds") - - def to_python(self, value: str) -> datetime: - return _parse_datetime(value) - - -class TaggedJSONSerializer: - - default_tags = [ - TagDict, - PassDict, - TagTuple, - PassList, - TagBytes, - TagMarkup, - TagUUID, - TagDateTime, - ] - - def __init__(self) -> None: - self.tags: Dict[str, JSONTag] = {} - self.order: List[JSONTag] = [] - - for tag_class in self.default_tags: - self.register(tag_class) - - def register( - self, tag_class: Type[JSONTag], force: bool = False, index: Optional[int] = None - ) -> None: - tag = tag_class(self) - key = tag.key - - if key is not None: - if not force and key in self.tags: - raise KeyError(f"Tag '{key}' is already registered.") - - self.tags[key] = tag - - if index is None: - self.order.append(tag) - else: - self.order.insert(index, tag) - - def tag(self, value: Any) -> Dict[str, Any]: - for tag in self.order: - if tag.check(value): - return tag.tag(value) - - return value - - def untag(self, value: Dict[str, Any]) -> Any: - if len(value) != 1: - return value - - key = next(iter(value)) - - if key not in self.tags: - return value - - return self.tags[key].to_python(value[key]) - - def dumps(self, value: Any) -> str: - return dumps(self.tag(value), separators=(",", ":")) - - def loads(self, value: str) -> Any: - return loads(value, object_hook=self.untag) +from flask.json.tag import ( # noqa: F401 + JSONTag as JSONTag, + PassDict as PassDict, + PassList as PassList, + TagBytes as TagBytes, + TagDateTime as TagDateTime, + TagDict as TagDict, + TaggedJSONSerializer as TaggedJSONSerializer, + TagMarkup as TagMarkup, + TagTuple as TagTuple, + TagUUID as TagUUID, +) diff --git a/src/quart/logging.py b/src/quart/logging.py index de5f0bb..699a143 100644 --- a/src/quart/logging.py +++ b/src/quart/logging.py @@ -52,7 +52,7 @@ def has_level_handler(logger: Logger) -> bool: return False -def create_logger(app: "Quart") -> Logger: +def create_logger(app: Quart) -> Logger: """Create a logger for the app based on the app settings. This creates a logger named quart.app that has a log level based diff --git a/src/quart/routing.py b/src/quart/routing.py index 40ea7d6..e1b4047 100644 --- a/src/quart/routing.py +++ b/src/quart/routing.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Iterable, Optional +from typing import Iterable from werkzeug.routing import Map, MapAdapter, Rule @@ -12,13 +12,13 @@ class QuartRule(Rule): def __init__( self, string: str, - defaults: Optional[dict] = None, - subdomain: Optional[str] = None, - methods: Optional[Iterable[str]] = None, - endpoint: Optional[str] = None, - strict_slashes: Optional[bool] = None, - merge_slashes: Optional[bool] = None, - host: Optional[str] = None, + defaults: dict | None = None, + subdomain: str | None = None, + methods: Iterable[str] | None = None, + endpoint: str | None = None, + strict_slashes: bool | None = None, + merge_slashes: bool | None = None, + host: str | None = None, websocket: bool = False, provide_automatic_options: bool = False, ) -> None: @@ -38,7 +38,7 @@ def __init__( class QuartMap(Map): def bind_to_request( - self, request: BaseRequestWebsocket, subdomain: Optional[str], server_name: Optional[str] + self, request: BaseRequestWebsocket, subdomain: str | None, server_name: str | None ) -> MapAdapter: host: str if server_name is None: @@ -63,21 +63,13 @@ def bind_to_request( else: subdomain = ".".join(filter(None, request_host_parts[:offset])) - if request.root_path: - try: - path = request.path.split(request.root_path, 1)[1] - except IndexError: - path = " " # Invalid in paths, hence will result in 404 - else: - path = request.path - return super().bind( host, request.root_path, subdomain, request.scheme, request.method, - path, + request.path, request.query_string.decode(), ) diff --git a/src/quart/scaffold.py b/src/quart/scaffold.py deleted file mode 100644 index 76517a2..0000000 --- a/src/quart/scaffold.py +++ /dev/null @@ -1,820 +0,0 @@ -from __future__ import annotations - -import os -from collections import defaultdict -from functools import wraps -from pathlib import Path -from typing import ( - Any, - Callable, - cast, - Dict, - Iterable, - List, - Optional, - Tuple, - Type, - TYPE_CHECKING, - TypeVar, - Union, -) - -from aiofiles import open as async_open -from aiofiles.base import AiofilesContextManager -from aiofiles.threadpool.binary import AsyncBufferedReader -from jinja2 import FileSystemLoader -from werkzeug.exceptions import default_exceptions, HTTPException - -from .cli import AppGroup -from .globals import current_app -from .helpers import get_root_path, send_from_directory -from .templating import _default_template_ctx_processor -from .typing import ( - AfterRequestCallable, - AfterWebsocketCallable, - AppOrBlueprintKey, - BeforeRequestCallable, - BeforeWebsocketCallable, - ErrorHandlerCallable, - FilePath, - RouteCallable, - TeardownCallable, - TemplateContextProcessorCallable, - URLDefaultCallable, - URLValuePreprocessorCallable, - WebsocketCallable, -) -from .utils import file_path_to_path - -if TYPE_CHECKING: - from .wrappers import Response - - -F = TypeVar("F", bound=Callable) -T_after_request = TypeVar("T_after_request", bound=AfterRequestCallable) -T_after_websocket = TypeVar("T_after_websocket", bound=AfterWebsocketCallable) -T_before_request = TypeVar("T_before_request", bound=BeforeRequestCallable) -T_before_websocket = TypeVar("T_before_websocket", bound=BeforeWebsocketCallable) -T_error_handler = TypeVar("T_error_handler", bound=ErrorHandlerCallable) -T_teardown = TypeVar("T_teardown", bound=TeardownCallable) -T_template_context_processor = TypeVar( - "T_template_context_processor", bound=TemplateContextProcessorCallable -) -T_url_defaults = TypeVar("T_url_defaults", bound=URLDefaultCallable) -T_url_value_preprocessor = TypeVar("T_url_value_preprocessor", bound=URLValuePreprocessorCallable) -T_route = TypeVar("T_route", bound=RouteCallable) -T_websocket = TypeVar("T_websocket", bound=WebsocketCallable) - - -def setupmethod(func: F) -> F: - @wraps(func) - def wrapper(self: "Scaffold", *args: Any, **kwargs: Any) -> Any: - self._check_setup_finished(func.__name__) - return func(self, *args, **kwargs) - - return cast(F, wrapper) - - -class Scaffold: - """Base class for Quart and Blueprint classes.""" - - name: str - - def __init__( - self, - import_name: str, - static_folder: Optional[str] = None, - static_url_path: Optional[str] = None, - template_folder: Optional[str] = None, - root_path: Optional[str] = None, - ) -> None: - self.import_name = import_name - self.template_folder = Path(template_folder) if template_folder is not None else None - - if root_path is None: - self.root_path = Path(get_root_path(import_name)) - else: - self.root_path = Path(root_path) - - self._static_folder: Optional[Path] = None - self._static_url_path: Optional[str] = None - self.static_folder = static_folder # type: ignore - self.static_url_path = static_url_path - - self.cli = AppGroup() - - # Functions that are called after a HTTP view function has - # handled a request and returned a response. - self.after_request_funcs: Dict[AppOrBlueprintKey, List[AfterRequestCallable]] = defaultdict( - list - ) - - # Functions that are called after a WebSocket view function - # handled a websocket request and has returned (possibly - # returning a response). - self.after_websocket_funcs: Dict[ - AppOrBlueprintKey, List[AfterWebsocketCallable] - ] = defaultdict(list) - - # Called before a HTTP view function handles a request. - self.before_request_funcs: Dict[ - AppOrBlueprintKey, List[BeforeRequestCallable] - ] = defaultdict(list) - - # Called before a WebSocket view function handles a websocket - # request. - self.before_websocket_funcs: Dict[ - AppOrBlueprintKey, List[BeforeWebsocketCallable] - ] = defaultdict(list) - - # The registered error handlers, keyed by blueprint (None for - # app) then by Exception type. - self.error_handler_spec: Dict[ - AppOrBlueprintKey, - Dict[Optional[int], Dict[Type[Exception], ErrorHandlerCallable]], - ] = defaultdict(lambda: defaultdict(dict)) - - # Called after a HTTP request has been handled, even if the - # handling results in an exception. - self.teardown_request_funcs: Dict[AppOrBlueprintKey, List[TeardownCallable]] = defaultdict( - list - ) - - # Called after a WebSocket request has been handled, even if - # the handling results in an exception. - self.teardown_websocket_funcs: Dict[ - AppOrBlueprintKey, List[TeardownCallable] - ] = defaultdict(list) - - # Template context processors keyed by blueprint (None for - # app). - self.template_context_processors: Dict[ - AppOrBlueprintKey, List[TemplateContextProcessorCallable] - ] = defaultdict(list, {None: [_default_template_ctx_processor]}) - - # View functions keyed by endpoint. - self.view_functions: Dict[str, Callable] = {} - - # The URL value preprocessor functions keyed by blueprint - # (None for app) as used when matching - self.url_value_preprocessors: Dict[ - AppOrBlueprintKey, - List[URLValuePreprocessorCallable], - ] = defaultdict(list) - - # The URL value default injector functions keyed by blueprint - # (None for app) as used when building urls. - self.url_default_functions: Dict[AppOrBlueprintKey, List[URLDefaultCallable]] = defaultdict( - list - ) - - def __repr__(self) -> str: - return f"<{type(self).__name__} {self.name!r}>" - - @property - def static_folder(self) -> Optional[Path]: - if self._static_folder is not None: - return self.root_path / self._static_folder - else: - return None - - @static_folder.setter - def static_folder(self, static_folder: Optional[FilePath]) -> None: - if static_folder is not None: - self._static_folder = file_path_to_path(static_folder) - else: - self._static_folder = None - - @property - def static_url_path(self) -> Optional[str]: - if self._static_url_path is not None: - return self._static_url_path - if self.static_folder is not None: - return "/" + self.static_folder.name - else: - return None - - @static_url_path.setter - def static_url_path(self, static_url_path: str) -> None: - self._static_url_path = static_url_path - - @property - def has_static_folder(self) -> bool: - return self.static_folder is not None - - def get_send_file_max_age(self, filename: str) -> Optional[int]: - if current_app.send_file_max_age_default is not None: - return int(current_app.send_file_max_age_default.total_seconds()) - return None - - async def send_static_file(self, filename: str) -> Response: - if not self.has_static_folder: - raise RuntimeError("No static folder for this object") - return await send_from_directory(self.static_folder, filename) - - @property - def jinja_loader(self) -> Optional[FileSystemLoader]: - if self.template_folder is not None: - return FileSystemLoader(os.fspath(self.root_path / self.template_folder)) - else: - return None - - async def open_resource( - self, - path: FilePath, - mode: str = "rb", - ) -> AiofilesContextManager[None, None, AsyncBufferedReader]: - """Open a file for reading. - - Use as - - .. code-block:: python - - async with await app.open_resource(path) as file_: - await file_.read() - """ - if mode not in {"r", "rb"}: - raise ValueError("Files can only be opened for reading") - return async_open(self.root_path / file_path_to_path(path), mode) # type: ignore - - def _method_route(self, method: str, rule: str, options: dict) -> Callable[[T_route], T_route]: - if "methods" in options: - raise TypeError("Methods cannot be supplied, use the 'route' decorator instead.") - - return self.route(rule, methods=[method], **options) - - @setupmethod - def get(self, rule: str, **options: Any) -> Callable[[T_route], T_route]: - """Syntactic sugar for :meth:`route` with ``methods=["GET"]``.""" - return self._method_route("GET", rule, options) - - @setupmethod - def post(self, rule: str, **options: Any) -> Callable[[T_route], T_route]: - """Syntactic sugar for :meth:`route` with ``methods=["POST"]``.""" - return self._method_route("POST", rule, options) - - @setupmethod - def put(self, rule: str, **options: Any) -> Callable[[T_route], T_route]: - """Syntactic sugar for :meth:`route` with ``methods=["PUT"]``.""" - return self._method_route("PUT", rule, options) - - @setupmethod - def delete(self, rule: str, **options: Any) -> Callable[[T_route], T_route]: - """Syntactic sugar for :meth:`route` with ``methods=["DELETE"]``.""" - return self._method_route("DELETE", rule, options) - - @setupmethod - def patch(self, rule: str, **options: Any) -> Callable[[T_route], T_route]: - """Syntactic sugar for :meth:`route` with ``methods=["PATCH"]``.""" - return self._method_route("PATCH", rule, options) - - @setupmethod - def route( - self, - rule: str, - methods: Optional[List[str]] = None, - endpoint: Optional[str] = None, - defaults: Optional[dict] = None, - host: Optional[str] = None, - subdomain: Optional[str] = None, - *, - provide_automatic_options: Optional[bool] = None, - strict_slashes: Optional[bool] = None, - ) -> Callable[[T_route], T_route]: - """Add a HTTP request handling route. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.route('/') - async def route(): - ... - - Arguments: - rule: The path to route on, should start with a ``/``. - methods: List of HTTP verbs the function routes. - endpoint: Optional endpoint name, if not present the - function name is used. - defaults: A dictionary of variables to provide automatically, use - to provide a simpler default path for a route, e.g. to allow - for ``/book`` rather than ``/book/0``, - - .. code-block:: python - - @app.route('/book', defaults={'page': 0}) - @app.route('/book/') - def book(page): - ... - - host: The full host name for this route (should include subdomain - if needed) - cannot be used with subdomain. - subdomain: A subdomain for this specific route. - provide_automatic_options: Optionally False to prevent - OPTION handling. - strict_slashes: Strictly match the trailing slash present in the - path. Will redirect a leaf (no slash) to a branch (with slash). - """ - - def decorator(func: T_route) -> T_route: - self.add_url_rule( - rule, - endpoint, - func, - provide_automatic_options=provide_automatic_options, - methods=methods, - defaults=defaults, - host=host, - subdomain=subdomain, - strict_slashes=strict_slashes, - ) - return func - - return decorator - - @setupmethod - def add_url_rule( - self, - rule: str, - endpoint: Optional[str] = None, - view_func: Optional[RouteCallable] = None, - provide_automatic_options: Optional[bool] = None, - *, - methods: Optional[Iterable[str]] = None, - defaults: Optional[dict] = None, - host: Optional[str] = None, - subdomain: Optional[str] = None, - is_websocket: bool = False, - strict_slashes: Optional[bool] = None, - merge_slashes: Optional[bool] = None, - ) -> None: - """Add a route/url rule to the application. - - This is designed to be used on the application directly. An - example usage, - - .. code-block:: python - - def route(): - ... - - app.add_url_rule('/', route) - - Arguments: - rule: The path to route on, should start with a ``/``. - endpoint: Optional endpoint name, if not present the - function name is used. - view_func: Callable that returns a response. - provide_automatic_options: Optionally False to prevent - OPTION handling. - methods: List of HTTP verbs the function routes. - defaults: A dictionary of variables to provide automatically, use - to provide a simpler default path for a route, e.g. to allow - for ``/book`` rather than ``/book/0``, - - .. code-block:: python - - @app.route('/book', defaults={'page': 0}) - @app.route('/book/') - def book(page): - ... - - host: The full host name for this route (should include subdomain - if needed) - cannot be used with subdomain. - subdomain: A subdomain for this specific route. - strict_slashes: Strictly match the trailing slash present in the - path. Will redirect a leaf (no slash) to a branch (with slash). - is_websocket: Whether or not the view_func is a websocket. - merge_slashes: Merge consecutive slashes to a single slash (unless - as part of the path variable). - """ - raise NotImplementedError() - - def websocket( - self, - rule: str, - endpoint: Optional[str] = None, - defaults: Optional[dict] = None, - host: Optional[str] = None, - subdomain: Optional[str] = None, - *, - strict_slashes: Optional[bool] = None, - ) -> Callable[[T_websocket], T_websocket]: - """Add a websocket to the application. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.websocket('/') - async def websocket_route(): - ... - - Arguments: - rule: The path to route on, should start with a ``/``. - endpoint: Optional endpoint name, if not present the - function name is used. - defaults: A dictionary of variables to provide automatically, use - to provide a simpler default path for a route, e.g. to allow - for ``/book`` rather than ``/book/0``, - - .. code-block:: python - - @app.websocket('/book', defaults={'page': 0}) - @app.websocket('/book/') - def book(page): - ... - - host: The full host name for this route (should include subdomain - if needed) - cannot be used with subdomain. - subdomain: A subdomain for this specific route. - strict_slashes: Strictly match the trailing slash present in the - path. Will redirect a leaf (no slash) to a branch (with slash). - """ - - def decorator(func: T_websocket) -> T_websocket: - self.add_websocket( - rule, - endpoint, - func, - defaults=defaults, - host=host, - subdomain=subdomain, - strict_slashes=strict_slashes, - ) - return func - - return decorator - - def add_websocket( - self, - rule: str, - endpoint: Optional[str] = None, - view_func: Optional[WebsocketCallable] = None, - defaults: Optional[dict] = None, - host: Optional[str] = None, - subdomain: Optional[str] = None, - *, - strict_slashes: Optional[bool] = None, - ) -> None: - """Add a websocket url rule to the application. - - This is designed to be used on the application directly. An - example usage, - - .. code-block:: python - - def websocket_route(): - ... - - app.add_websocket('/', websocket_route) - - Arguments: - rule: The path to route on, should start with a ``/``. - endpoint: Optional endpoint name, if not present the - function name is used. - view_func: Callable that returns a response. - defaults: A dictionary of variables to provide automatically, use - to provide a simpler default path for a route, e.g. to allow - for ``/book`` rather than ``/book/0``, - - .. code-block:: python - - @app.websocket('/book', defaults={'page': 0}) - @app.websocket('/book/') - def book(page): - ... - - host: The full host name for this route (should include subdomain - if needed) - cannot be used with subdomain. - subdomain: A subdomain for this specific route. - strict_slashes: Strictly match the trailing slash present in the - path. Will redirect a leaf (no slash) to a branch (with slash). - """ - return self.add_url_rule( - rule, - endpoint, - view_func, - methods={"GET"}, - defaults=defaults, - host=host, - subdomain=subdomain, - provide_automatic_options=False, - is_websocket=True, - strict_slashes=strict_slashes, - ) - - @setupmethod - def endpoint(self, endpoint: str) -> Callable[[F], F]: - """Register a function as an endpoint. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.endpoint('name') - async def endpoint(): - ... - - Arguments: - endpoint: The endpoint name to use. - """ - - def decorator(func: F) -> F: - self.view_functions[endpoint] = func - return func - - return decorator - - @setupmethod - def before_request( - self, - func: T_before_request, - ) -> T_before_request: - """Add a before request function. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.before_request - async def func(): - ... - - Arguments: - func: The before request function itself. - """ - self.before_request_funcs[None].append(func) - return func - - @setupmethod - def after_request( - self, - func: T_after_request, - ) -> T_after_request: - """Add an after request function. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.after_request - async def func(response): - return response - - Arguments: - func: The after request function itself. - """ - self.after_request_funcs[None].append(func) - return func - - @setupmethod - def before_websocket( - self, - func: T_before_websocket, - ) -> T_before_websocket: - """Add a before websocket function. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.before_websocket - async def func(): - ... - - Arguments: - func: The before websocket function itself. - """ - self.before_websocket_funcs[None].append(func) - return func - - @setupmethod - def after_websocket( - self, - func: T_after_websocket, - ) -> T_after_websocket: - """Add an after websocket function. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.after_websocket - async def func(response): - return response - - Arguments: - func: The after websocket function itself. - """ - self.after_websocket_funcs[None].append(func) - return func - - @setupmethod - def teardown_request( - self, - func: T_teardown, - ) -> T_teardown: - """Add a teardown request function. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.teardown_request - async def func(): - ... - - Arguments: - func: The teardown request function itself. - """ - self.teardown_request_funcs[None].append(func) - return func - - @setupmethod - def teardown_websocket( - self, - func: T_teardown, - ) -> T_teardown: - """Add a teardown websocket function. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.teardown_websocket - async def func(): - ... - - Arguments: - func: The teardown websocket function itself. - name: Optional blueprint key name. - """ - self.teardown_websocket_funcs[None].append(func) - return func - - @setupmethod - def context_processor( - self, - func: T_template_context_processor, - ) -> T_template_context_processor: - """Add a template context processor. - - This is designed to be used as a decorator, if used to - decorate a synchronous function, the function will be wrapped - in :func:`~quart.utils.run_sync` and run in a thread executor - (with the wrapped function returned). An example usage, - - .. code-block:: python - - @app.context_processor - async def update_context(context): - return context - - """ - self.template_context_processors[None].append(func) - return func - - @setupmethod - def url_value_preprocessor( - self, - func: T_url_value_preprocessor, - ) -> T_url_value_preprocessor: - """Add a url value preprocessor. - - This is designed to be used as a decorator. An example usage, - - .. code-block:: python - - @app.url_value_preprocessor - def value_preprocessor(endpoint, view_args): - ... - """ - self.url_value_preprocessors[None].append(func) - return func - - @setupmethod - def url_defaults(self, func: T_url_defaults) -> T_url_defaults: - """Add a url default preprocessor. - - This is designed to be used as a decorator. An example usage, - - .. code-block:: python - - @app.url_defaults - def default(endpoint, values): - ... - """ - self.url_default_functions[None].append(func) - return func - - @setupmethod - def errorhandler( - self, error: Union[Type[Exception], int] - ) -> Callable[[T_error_handler], T_error_handler]: - """Register a function as an error handler. - - This is designed to be used as a decorator. An example usage, - - .. code-block:: python - - @app.errorhandler(500) - def error_handler(): - return "Error", 500 - - Arguments: - error: The error code or Exception to handle. - """ - - def decorator(func: T_error_handler) -> T_error_handler: - self.register_error_handler(error, func) - return func - - return decorator - - @setupmethod - def register_error_handler( - self, - error: Union[Type[Exception], int], - func: ErrorHandlerCallable, - ) -> None: - """Register a function as an error handler. - - This is designed to be used on the application directly. An - example usage, - - .. code-block:: python - - def error_handler(): - return "Error", 500 - - app.register_error_handler(500, error_handler) - - Arguments: - error: The error code or Exception to handle. - func: The function to handle the error. - """ - if isinstance(error, HTTPException): - raise ValueError( - "error must be an exception Type or int, not an instance of an exception" - ) - - try: - error_type, code = self._get_error_type_and_code(error) - except KeyError: - raise KeyError(f"{error} is not a recognised HTTP error code or HTTPException subclass") - - handlers = self.error_handler_spec[None].setdefault(code, {}) - handlers[error_type] = func - - def _get_error_type_and_code( - self, error: Union[Type[Exception], int] - ) -> Tuple[Type[Exception], Optional[int]]: - error_type: Type[Exception] - if isinstance(error, int): - error_type = default_exceptions[error] - else: - error_type = error - - if not issubclass(error_type, Exception): - raise KeyError("Custom exceptions must be subclasses of Exception.") - - if issubclass(error_type, HTTPException): - return error_type, error_type.code - else: - return error_type, None - - def _check_setup_finished(self, f_name: str) -> None: - raise NotImplementedError() - - -def _endpoint_from_view_func(view_func: Callable) -> str: - assert view_func is not None - return view_func.__name__ diff --git a/src/quart/sessions.py b/src/quart/sessions.py index cd8da31..9df230b 100644 --- a/src/quart/sessions.py +++ b/src/quart/sessions.py @@ -1,110 +1,24 @@ from __future__ import annotations import hashlib -from collections.abc import MutableMapping -from datetime import datetime -from functools import wraps -from typing import Any, Callable, Optional, TYPE_CHECKING, Union - +from datetime import datetime, timezone +from typing import TYPE_CHECKING + +from flask.sessions import ( # noqa: F401 + NullSession as NullSession, + SecureCookieSession as SecureCookieSession, + session_json_serializer as session_json_serializer, + SessionMixin as SessionMixin, +) from itsdangerous import BadSignature, URLSafeTimedSerializer from werkzeug.wrappers import Response as WerkzeugResponse -from .json.tag import TaggedJSONSerializer from .wrappers import BaseRequestWebsocket, Response if TYPE_CHECKING: from .app import Quart # noqa -class SessionMixin(MutableMapping): - """Use to extend a dict with Session attributes. - - The attributes add standard and expected Session modification flags. - - Attributes: - accessed: Indicates if the Session has been accessed during - the request, thereby allowing the Vary: Cookie header. - modified: Indicates if the Session has been modified during - the request handling. - new: Indicates if the Session is new. - """ - - accessed = True - modified = True - new = False - - @property - def permanent(self) -> bool: - return self.get("_permanent", False) - - @permanent.setter - def permanent(self, value: bool) -> None: - self["_permanent"] = value - - -def _wrap_modified(method: Callable) -> Callable: - @wraps(method) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - self.accessed = True - self.modified = True - return method(self, *args, **kwargs) - - return wrapper - - -def _wrap_accessed(method: Callable) -> Callable: - @wraps(method) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - self.accessed = True - return method(self, *args, **kwargs) - - return wrapper - - -class SecureCookieSession(dict, SessionMixin): - """A session implementation using cookies. - - Note that the intention is for this session to use cookies, this - class does not implement anything bar modification and accessed - flags. - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.accessed = False - self.modified = False - - __delitem__ = _wrap_modified(dict.__delitem__) - __getitem__ = _wrap_accessed(dict.__getitem__) - __setitem__ = _wrap_modified(dict.__setitem__) - clear = _wrap_modified(dict.clear) - get = _wrap_accessed(dict.get) - pop = _wrap_modified(dict.pop) - popitem = _wrap_modified(dict.popitem) - setdefault = _wrap_modified(dict.setdefault) - update = _wrap_modified(dict.update) - - -def _wrap_no_modification(method: Callable) -> Callable: - @wraps(method) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - raise RuntimeError("Cannot create session, ensure there is a app secret key.") - - return wrapper - - -class NullSession(SecureCookieSession): - """A session implementation for sessions without storage.""" - - __delitem__ = _wrap_no_modification(dict.__delitem__) - __setitem__ = _wrap_no_modification(dict.__setitem__) - clear = _wrap_no_modification(dict.clear) - pop = _wrap_no_modification(dict.pop) - popitem = _wrap_no_modification(dict.popitem) - setdefault = _wrap_no_modification(dict.setdefault) - update = _wrap_no_modification(dict.update) - - class SessionInterface: """Base class for session interfaces. @@ -117,7 +31,7 @@ class SessionInterface: null_session_class = NullSession pickle_based = False - async def make_null_session(self, app: "Quart") -> NullSession: + async def make_null_session(self, app: Quart) -> NullSession: """Create a Null session object. This is used in replacement of an actual session if sessions @@ -129,47 +43,43 @@ def is_null_session(self, instance: object) -> bool: """Returns True is the instance is a null session.""" return isinstance(instance, self.null_session_class) - def get_cookie_name(self, app: "Quart") -> str: + def get_cookie_name(self, app: Quart) -> str: """Helper method to return the Cookie Name for the App.""" - return app.session_cookie_name + return app.config["SESSION_COOKIE_NAME"] - def get_cookie_domain(self, app: "Quart") -> Optional[str]: + def get_cookie_domain(self, app: Quart) -> str | None: """Helper method to return the Cookie Domain for the App.""" - if app.config["SESSION_COOKIE_DOMAIN"] is not None: - return app.config["SESSION_COOKIE_DOMAIN"] - elif app.config["SERVER_NAME"] is not None: - return "." + app.config["SERVER_NAME"].rsplit(":", 1)[0] - else: - return None + rv = app.config["SESSION_COOKIE_DOMAIN"] + return rv if rv else None - def get_cookie_path(self, app: "Quart") -> str: + def get_cookie_path(self, app: Quart) -> str: """Helper method to return the Cookie path for the App.""" - return app.config["SESSION_COOKIE_PATH"] or app.config["APPLICATION_ROOT"] or "/" + return app.config["SESSION_COOKIE_PATH"] or app.config["APPLICATION_ROOT"] - def get_cookie_httponly(self, app: "Quart") -> bool: + def get_cookie_httponly(self, app: Quart) -> bool: """Helper method to return if the Cookie should be HTTPOnly for the App.""" return app.config["SESSION_COOKIE_HTTPONLY"] - def get_cookie_secure(self, app: "Quart") -> bool: + def get_cookie_secure(self, app: Quart) -> bool: """Helper method to return if the Cookie should be Secure for the App.""" return app.config["SESSION_COOKIE_SECURE"] - def get_cookie_samesite(self, app: "Quart") -> str: + def get_cookie_samesite(self, app: Quart) -> str: """Helper method to return the Cookie Samesite configuration for the App.""" return app.config["SESSION_COOKIE_SAMESITE"] - def get_expiration_time(self, app: "Quart", session: SessionMixin) -> Optional[datetime]: + def get_expiration_time(self, app: Quart, session: SessionMixin) -> datetime | None: """Helper method to return the Session expiration time. If the session is not 'permanent' it will expire as and when the browser stops accessing the app. """ if session.permanent: - return datetime.utcnow() + app.permanent_session_lifetime + return datetime.now(timezone.utc) + app.permanent_session_lifetime else: return None - def should_set_cookie(self, app: "Quart", session: SessionMixin) -> bool: + def should_set_cookie(self, app: Quart, session: SessionMixin) -> bool: """Helper method to return if the Set Cookie header should be present. This triggers if the session is marked as modified or the app @@ -180,9 +90,7 @@ def should_set_cookie(self, app: "Quart", session: SessionMixin) -> bool: save_each = app.config["SESSION_REFRESH_EACH_REQUEST"] return save_each and session.permanent - async def open_session( - self, app: "Quart", request: BaseRequestWebsocket - ) -> Optional[SessionMixin]: + async def open_session(self, app: Quart, request: BaseRequestWebsocket) -> SessionMixin | None: """Open an existing session from the request or create one. Returns: @@ -193,7 +101,7 @@ async def open_session( raise NotImplementedError() async def save_session( - self, app: "Quart", session: SessionMixin, response: Union[Response, WerkzeugResponse, None] + self, app: Quart, session: SessionMixin, response: Response | WerkzeugResponse | None ) -> None: """Save the session argument to the response. @@ -218,10 +126,10 @@ class SecureCookieSessionInterface(SessionInterface): digest_method = staticmethod(hashlib.sha1) key_derivation = "hmac" salt = "cookie-session" - serializer = TaggedJSONSerializer() + serializer = session_json_serializer session_class = SecureCookieSession - def get_signing_serializer(self, app: "Quart") -> Optional[URLSafeTimedSerializer]: + def get_signing_serializer(self, app: Quart) -> URLSafeTimedSerializer | None: """Return a serializer for the session that also signs data. This will return None if the app is not configured for secrets. @@ -235,8 +143,8 @@ def get_signing_serializer(self, app: "Quart") -> Optional[URLSafeTimedSerialize ) async def open_session( - self, app: "Quart", request: BaseRequestWebsocket - ) -> Optional[SecureCookieSession]: + self, app: Quart, request: BaseRequestWebsocket + ) -> SecureCookieSession | None: """Open a secure cookie based session. This will return None if a signing serializer is not available, @@ -249,17 +157,18 @@ async def open_session( cookie = request.cookies.get(self.get_cookie_name(app)) if cookie is None: return self.session_class() + max_age = int(app.permanent_session_lifetime.total_seconds()) try: - data = signer.loads(cookie, max_age=app.permanent_session_lifetime.total_seconds()) - return self.session_class(**data) + data = signer.loads(cookie, max_age=max_age) + return self.session_class(data) except BadSignature: return self.session_class() async def save_session( self, - app: "Quart", + app: Quart, session: SessionMixin, - response: Union[Response, WerkzeugResponse, None], + response: Response | WerkzeugResponse | None, ) -> None: """Saves the session to the response in a secure cookie.""" if response is None: @@ -273,25 +182,43 @@ async def save_session( name = self.get_cookie_name(app) domain = self.get_cookie_domain(app) path = self.get_cookie_path(app) - if not session: - if session.modified: - response.delete_cookie(name, domain=domain, path=path) - return + secure = self.get_cookie_secure(app) + samesite = self.get_cookie_samesite(app) + httponly = self.get_cookie_httponly(app) + # Add a "Vary: Cookie" header if the session was accessed at all. if session.accessed: response.vary.add("Cookie") + # If the session is modified to be empty, remove the cookie. + # If the session is empty, return without setting the cookie. + if not session: + if session.modified: + response.delete_cookie( + name, + domain=domain, + path=path, + secure=secure, + samesite=samesite, + httponly=httponly, + ) + response.vary.add("Cookie") + + return + if not self.should_set_cookie(app, session): return - data = self.get_signing_serializer(app).dumps(dict(session)) + expires = self.get_expiration_time(app, session) + val = self.get_signing_serializer(app).dumps(dict(session)) response.set_cookie( name, - data, # type: ignore - expires=self.get_expiration_time(app, session), - httponly=self.get_cookie_httponly(app), + val, + expires=expires, + httponly=httponly, domain=domain, path=path, - secure=self.get_cookie_secure(app), - samesite=self.get_cookie_samesite(app), + secure=secure, + samesite=samesite, ) + response.vary.add("Cookie") diff --git a/src/quart/signals.py b/src/quart/signals.py index a23b902..de5e9fc 100644 --- a/src/quart/signals.py +++ b/src/quart/signals.py @@ -1,54 +1,10 @@ from __future__ import annotations -from functools import wraps -from typing import Any, Callable, List, Optional, Tuple - -from blinker import NamedSignal, Namespace # type: ignore[import] - -from .utils import is_coroutine_function +from blinker import Namespace signals_available = True - -class AsyncNamedSignal(NamedSignal): # type: ignore - def __init__(self, name: str, doc: Optional[str] = None) -> None: - super().__init__(name, doc) - - async def send(self, *sender: Any, **kwargs: Any) -> List[Tuple[Callable, Any]]: - coroutines = super().send(*sender, **kwargs) - result: List[Tuple[Callable, Any]] = [] - for handler, coroutine in coroutines: - result.append((handler, await coroutine)) - return result - - def connect(self, receiver: Callable, *args: Any, **kwargs: Any) -> Callable: - if is_coroutine_function(receiver): - handler = receiver - else: - - @wraps(receiver) - async def handler(*a: Any, **k: Any) -> Any: - return receiver(*a, **k) - - if handler is not receiver and kwargs.get("weak", True): - # Blinker will take a weakref to handler, which goes out - # of scope with this method as it is a wrapper around the - # receiver. Whereas we'd want it to go out of scope when - # receiver does. Therefore we can place it on the receiver - # function. (Ideally I'll think of a better way). - receiver._quart_wrapper_func = handler # type: ignore - return super().connect(handler, *args, **kwargs) - - -class AsyncNamespace(Namespace): # type: ignore - def signal(self, name: str, doc: Optional[str] = None) -> AsyncNamedSignal: - try: - return self[name] - except KeyError: - return self.setdefault(name, AsyncNamedSignal(name, doc)) - - -_signals = AsyncNamespace() +_signals = Namespace() #: Called before a template is rendered, connection functions # should have a signature of Callable[[Quart, Template, dict], None] diff --git a/src/quart/templating.py b/src/quart/templating.py index 416c2ee..c53874b 100644 --- a/src/quart/templating.py +++ b/src/quart/templating.py @@ -1,19 +1,9 @@ from __future__ import annotations -from typing import ( - Any, - AsyncIterator, - Callable, - Dict, - Generator, - List, - Optional, - Tuple, - TYPE_CHECKING, - Union, -) - -from jinja2 import BaseLoader, Environment as BaseEnvironment, Template, TemplateNotFound +from typing import Any, AsyncIterator, TYPE_CHECKING + +from flask.templating import DispatchingJinjaLoader as DispatchingJinjaLoader # noqa: F401 +from jinja2 import Environment as BaseEnvironment, Template from .ctx import has_app_context, has_request_context from .globals import app_ctx, current_app, request_ctx @@ -25,18 +15,18 @@ class Environment(BaseEnvironment): - """Quart specific Jinja2 Environment. + """Quart specific Jinja Environment. - This changes the default Jinja2 loader to use the + This changes the default Jinja loader to use the DispatchingJinjaLoader, and enables async Jinja by default. """ - def __init__(self, app: "Quart", **options: Any) -> None: - """Create a Quart specific Jinja2 Environment. + def __init__(self, app: Quart, **options: Any) -> None: + """Create a Quart specific Jinja Environment. Arguments: app: The Quart app to bind to. - options: The standard Jinja2 Environment options. + options: The standard Jinja Environment options. """ if "loader" not in options: options["loader"] = app.create_global_jinja_loader() @@ -44,53 +34,7 @@ def __init__(self, app: "Quart", **options: Any) -> None: super().__init__(**options) -class DispatchingJinjaLoader(BaseLoader): - """Quart specific Jinja2 Loader. - - This changes the default sourcing to consider the app - and blueprints. - """ - - def __init__(self, app: "Quart") -> None: - self.app = app - - def get_source( - self, environment: BaseEnvironment, template: str - ) -> Tuple[str, Optional[str], Optional[Callable[[], bool]]]: - """Returns the template source from the environment. - - This considers the loaders on the :attr:`app` and blueprints. - """ - for loader in self._loaders(): - try: - return loader.get_source(environment, template) - except TemplateNotFound: - continue - raise TemplateNotFound(template) - - def _loaders(self) -> Generator[BaseLoader, None, None]: - loader = self.app.jinja_loader - if loader is not None: - yield loader - - for blueprint in self.app.iter_blueprints(): - loader = blueprint.jinja_loader - if loader is not None: - yield loader - - def list_templates(self) -> List[str]: - """Returns a list of all available templates in environment. - - This considers the loaders on the :attr:`app` and blueprints. - """ - result = set() - for loader in self._loaders(): - for template in loader.list_templates(): - result.add(str(template)) - return list(result) - - -async def render_template(template_name_or_list: Union[str, List[str]], **context: Any) -> str: +async def render_template(template_name_or_list: str | list[str], **context: Any) -> str: """Render the template with the context given. Arguments: @@ -115,14 +59,18 @@ async def render_template_string(source: str, **context: Any) -> str: return await _render(template, context, current_app._get_current_object()) # type: ignore -async def _render(template: Template, context: dict, app: "Quart") -> str: - await before_render_template.send(app, template=template, context=context) +async def _render(template: Template, context: dict, app: Quart) -> str: + await before_render_template.send_async( + app, _sync_wrapper=app.ensure_async, template=template, context=context # type: ignore + ) rendered_template = await template.render_async(context) - await template_rendered.send(app, template=template, context=context) + await template_rendered.send_async( + app, _sync_wrapper=app.ensure_async, template=template, context=context # type: ignore + ) return rendered_template -async def _default_template_ctx_processor() -> Dict[str, Any]: +async def _default_template_ctx_processor() -> dict[str, Any]: context = {} if has_app_context(): context["g"] = app_ctx.g @@ -133,7 +81,7 @@ async def _default_template_ctx_processor() -> Dict[str, Any]: async def stream_template( - template_name_or_list: Union[str, Template, List[Union[str, Template]]], **context: Any + template_name_or_list: str | Template | list[str | Template], **context: Any ) -> AsyncIterator[str]: """Render a template by name with the given context as a stream. @@ -165,13 +113,17 @@ async def stream_template_string(source: str, **context: Any) -> AsyncIterator[s return await _stream(current_app._get_current_object(), template, context) # type: ignore -async def _stream(app: "Quart", template: Template, context: Dict[str, Any]) -> AsyncIterator[str]: - await before_render_template.send(app, template=template, context=context) +async def _stream(app: Quart, template: Template, context: dict[str, Any]) -> AsyncIterator[str]: + await before_render_template.send_async( + app, _sync_wrapper=app.ensure_async, template=template, context=context # type: ignore + ) async def generate() -> AsyncIterator[str]: async for chunk in template.generate_async(context): yield chunk - await template_rendered.send(app, template=template, context=context) + await template_rendered.send_async( + app, _sync_wrapper=app.ensure_async, template=template, context=context # type: ignore + ) # If a request context is active, keep it while generating. if has_request_context(): diff --git a/src/quart/testing/__init__.py b/src/quart/testing/__init__.py index 0182d89..3cc18bb 100644 --- a/src/quart/testing/__init__.py +++ b/src/quart/testing/__init__.py @@ -21,13 +21,13 @@ class QuartCliRunner(CliRunner): - def __init__(self, app: "Quart", **kwargs: Any) -> None: + def __init__(self, app: Quart, **kwargs: Any) -> None: self.app = app super().__init__(**kwargs) def invoke(self, cli: Any = None, args: Any = None, **kwargs: Any) -> Any: # type: ignore if cli is None: - cli = self.app.cli # type: ignore + cli = self.app.cli if "obj" not in kwargs: kwargs["obj"] = ScriptInfo(create_app=lambda: self.app) diff --git a/src/quart/testing/app.py b/src/quart/testing/app.py index e35f756..7875590 100644 --- a/src/quart/testing/app.py +++ b/src/quart/testing/app.py @@ -21,7 +21,7 @@ class LifespanError(Exception): class TestApp: def __init__( self, - app: "Quart", + app: Quart, startup_timeout: int = DEFAULT_TIMEOUT, shutdown_timeout: int = DEFAULT_TIMEOUT, ) -> None: @@ -50,7 +50,7 @@ async def shutdown(self) -> None: await asyncio.wait_for(self._shutdown.wait(), timeout=self.shutdown_timeout) await self._task - async def __aenter__(self) -> "TestApp": + async def __aenter__(self) -> TestApp: await self.startup() return self diff --git a/src/quart/testing/client.py b/src/quart/testing/client.py index 2440273..066b15c 100644 --- a/src/quart/testing/client.py +++ b/src/quart/testing/client.py @@ -4,18 +4,7 @@ from datetime import datetime, timedelta from http.cookiejar import CookieJar from types import TracebackType -from typing import ( - Any, - AnyStr, - AsyncGenerator, - Dict, - List, - Optional, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, AnyStr, AsyncGenerator, TYPE_CHECKING from urllib.request import Request as U2Request from werkzeug.datastructures import Authorization, Headers @@ -42,7 +31,7 @@ class _TestWrapper: def __init__(self, headers: Headers) -> None: self.headers = headers - def get_all(self, name: str, default: Optional[Any] = None) -> List[str]: + def get_all(self, name: str, default: Any | None = None) -> list[str]: name = name.lower() result = [] for key, value in self.headers: @@ -60,40 +49,40 @@ def info(self) -> _TestWrapper: class QuartClient: - http_connection_class: Type[TestHTTPConnectionProtocol] - websocket_connection_class: Type[TestWebsocketConnectionProtocol] + http_connection_class: type[TestHTTPConnectionProtocol] + websocket_connection_class: type[TestWebsocketConnectionProtocol] http_connection_class = TestHTTPConnection websocket_connection_class = TestWebsocketConnection - def __init__(self, app: "Quart", use_cookies: bool = True) -> None: + def __init__(self, app: Quart, use_cookies: bool = True) -> None: self.app = app - self.cookie_jar: Optional[CookieJar] + self.cookie_jar: CookieJar | None if use_cookies: self.cookie_jar = CookieJar() else: self.cookie_jar = None self.preserve_context = False - self.push_promises: List[Tuple[str, Headers]] = [] + self.push_promises: list[tuple[str, Headers]] = [] async def open( self, path: str, *, method: str = "GET", - headers: Optional[Union[dict, Headers]] = None, - data: Optional[AnyStr] = None, - form: Optional[dict] = None, - files: Optional[Dict[str, FileStorage]] = None, - query_string: Optional[dict] = None, + headers: dict | Headers | None = None, + data: AnyStr | None = None, + form: dict | None = None, + files: dict[str, FileStorage] | None = None, + query_string: dict | None = None, json: Any = sentinel, scheme: str = "http", follow_redirects: bool = False, root_path: str = "", http_version: str = "1.1", - scope_base: Optional[dict] = None, - auth: Optional[Union[Authorization, Tuple[str, str]]] = None, - subdomain: Optional[str] = None, + scope_base: dict | None = None, + auth: Authorization | tuple[str, str] | None = None, + subdomain: str | None = None, ) -> Response: self.push_promises = [] response = await self._make_request( @@ -144,14 +133,14 @@ def request( path: str, *, method: str = "GET", - headers: Optional[Union[dict, Headers]] = None, - query_string: Optional[dict] = None, + headers: dict | Headers | None = None, + query_string: dict | None = None, scheme: str = "http", root_path: str = "", http_version: str = "1.1", - scope_base: Optional[dict] = None, - auth: Optional[Union[Authorization, Tuple[str, str]]] = None, - subdomain: Optional[str] = None, + scope_base: dict | None = None, + auth: Authorization | tuple[str, str] | None = None, + subdomain: str | None = None, ) -> TestHTTPConnectionProtocol: headers, path, query_string_bytes = make_test_headers_path_and_query_string( self.app, @@ -182,15 +171,15 @@ def websocket( self, path: str, *, - headers: Optional[Union[dict, Headers]] = None, - query_string: Optional[dict] = None, + headers: dict | Headers | None = None, + query_string: dict | None = None, scheme: str = "ws", - subprotocols: Optional[List[str]] = None, + subprotocols: list[str] | None = None, root_path: str = "", http_version: str = "1.1", - scope_base: Optional[dict] = None, - auth: Optional[Union[Authorization, Tuple[str, str]]] = None, - subdomain: Optional[str] = None, + scope_base: dict | None = None, + auth: Authorization | tuple[str, str] | None = None, + subdomain: str | None = None, ) -> TestWebsocketConnectionProtocol: headers, path, query_string_bytes = make_test_headers_path_and_query_string( self.app, @@ -286,14 +275,13 @@ def set_cookie( server_name: str, key: str, value: str = "", - max_age: Optional[Union[int, timedelta]] = None, - expires: Optional[Union[int, float, datetime]] = None, + max_age: int | timedelta | None = None, + expires: int | float | datetime | None = None, path: str = "/", - domain: Optional[str] = None, + domain: str | None = None, secure: bool = False, httponly: bool = False, samesite: str = None, - charset: str = "utf-8", ) -> None: """Set a cookie in the cookie jar. @@ -309,7 +297,6 @@ def set_cookie( domain=domain, secure=secure, httponly=httponly, - charset=charset, samesite=samesite, ) self.cookie_jar.extract_cookies( @@ -318,7 +305,7 @@ def set_cookie( ) def delete_cookie( - self, server_name: str, key: str, path: str = "/", domain: Optional[str] = None + self, server_name: str, key: str, path: str = "/", domain: str | None = None ) -> None: """Delete a cookie (set to expire immediately).""" self.set_cookie(server_name, key, expires=0, max_age=0, path=path, domain=domain) @@ -329,15 +316,15 @@ async def session_transaction( path: str = "/", *, method: str = "GET", - headers: Optional[Union[dict, Headers]] = None, - query_string: Optional[dict] = None, + headers: dict | Headers | None = None, + query_string: dict | None = None, scheme: str = "http", - data: Optional[AnyStr] = None, - form: Optional[dict] = None, + data: AnyStr | None = None, + form: dict | None = None, json: Any = sentinel, root_path: str = "", http_version: str = "1.1", - auth: Optional[Union[Authorization, Tuple[str, str]]] = None, + auth: Authorization | tuple[str, str] | None = None, ) -> AsyncGenerator[SessionMixin, None]: if self.cookie_jar is None: raise RuntimeError("Session transactions only make sense with cookies enabled.") @@ -384,7 +371,7 @@ async def session_transaction( U2Request(ctx.request.url), ) - async def __aenter__(self) -> "QuartClient": + async def __aenter__(self) -> QuartClient: if self.preserve_context: raise RuntimeError("Cannot nest client invocations") self.preserve_context = True @@ -405,18 +392,18 @@ async def _make_request( self, path: str, method: str, - headers: Optional[Union[dict, Headers]], - data: Optional[AnyStr], - form: Optional[dict], - files: Optional[Dict[str, FileStorage]], - query_string: Optional[dict], + headers: dict | Headers | None, + data: AnyStr | None, + form: dict | None, + files: dict[str, FileStorage] | None, + query_string: dict | None, json: Any, scheme: str, root_path: str, http_version: str, - scope_base: Optional[dict], - auth: Optional[Union[Authorization, Tuple[str, str]]] = None, - subdomain: Optional[str] = None, + scope_base: dict | None, + auth: Authorization | tuple[str, str] | None = None, + subdomain: str | None = None, ) -> Response: headers, path, query_string_bytes = make_test_headers_path_and_query_string( self.app, path, headers, query_string, auth, subdomain diff --git a/src/quart/testing/connections.py b/src/quart/testing/connections.py index ff58515..d2e9a93 100644 --- a/src/quart/testing/connections.py +++ b/src/quart/testing/connections.py @@ -2,7 +2,7 @@ import asyncio from types import TracebackType -from typing import Any, AnyStr, Awaitable, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, AnyStr, Awaitable, TYPE_CHECKING from hypercorn.typing import ASGIReceiveEvent, ASGISendEvent, HTTPScope, WebsocketScope from werkzeug.datastructures import Headers @@ -25,18 +25,18 @@ class WebsocketDisconnectError(Exception): class WebsocketResponseError(Exception): def __init__(self, response: Response) -> None: - super().__init__() + super().__init__(response) self.response = response class TestHTTPConnection: def __init__(self, app: Quart, scope: HTTPScope, _preserve_context: bool = False) -> None: self.app = app - self.headers: Optional[Headers] = None - self.push_promises: List[Tuple[str, Headers]] = [] + self.headers: Headers | None = None + self.push_promises: list[tuple[str, Headers]] = [] self.response_data = bytearray() self.scope = scope - self.status_code: Optional[int] = None + self.status_code: int | None = None self._preserve_context = _preserve_context self._send_queue: asyncio.Queue = asyncio.Queue() self._receive_queue: asyncio.Queue = asyncio.Queue() @@ -58,7 +58,7 @@ async def receive(self) -> bytes: async def disconnect(self) -> None: await self._send_queue.put({"type": "http.disconnect"}) - async def __aenter__(self) -> "TestHTTPConnection": + async def __aenter__(self) -> TestHTTPConnection: self._task = asyncio.ensure_future( self.app(self.scope, self._asgi_receive, self._asgi_send) ) @@ -101,15 +101,15 @@ class TestWebsocketConnection: def __init__(self, app: Quart, scope: WebsocketScope) -> None: self.accepted = False self.app = app - self.headers: Optional[Headers] = None + self.headers: Headers | None = None self.response_data = bytearray() self.scope = scope - self.status_code: Optional[int] = None + self.status_code: int | None = None self._send_queue: asyncio.Queue = asyncio.Queue() self._receive_queue: asyncio.Queue = asyncio.Queue() self._task: Awaitable[None] = None - async def __aenter__(self) -> "TestWebsocketConnection": + async def __aenter__(self) -> TestWebsocketConnection: self._task = asyncio.ensure_future( self.app(self.scope, self._asgi_receive, self._asgi_send) ) diff --git a/src/quart/testing/utils.py b/src/quart/testing/utils.py index 371a2bb..baf193a 100644 --- a/src/quart/testing/utils.py +++ b/src/quart/testing/utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, AnyStr, cast, Dict, Optional, overload, Tuple, TYPE_CHECKING, Union +from typing import Any, AnyStr, cast, overload, TYPE_CHECKING from urllib.parse import unquote, urlencode from hypercorn.typing import HTTPScope, Scope, WebsocketScope @@ -24,13 +24,13 @@ def make_test_headers_path_and_query_string( - app: "Quart", + app: Quart, path: str, - headers: Optional[Union[dict, Headers]] = None, - query_string: Optional[dict] = None, - auth: Optional[Union[Authorization, Tuple[str, str]]] = None, - subdomain: Optional[str] = None, -) -> Tuple[Headers, str, bytes]: + headers: dict | Headers | None = None, + query_string: dict | None = None, + auth: Authorization | tuple[str, str] | None = None, + subdomain: str | None = None, +) -> tuple[Headers, str, bytes]: """Make the headers and path with defaults for testing. Arguments: @@ -71,12 +71,12 @@ def make_test_headers_path_and_query_string( def make_test_body_with_headers( *, - data: Optional[AnyStr] = None, - form: Optional[dict] = None, - files: Optional[Dict[str, FileStorage]] = None, + data: AnyStr | None = None, + form: dict | None = None, + files: dict[str, FileStorage] | None = None, json: Any = sentinel, - app: Optional["Quart"] = None, -) -> Tuple[bytes, Headers]: + app: Quart | None = None, +) -> tuple[bytes, Headers]: """Make the body bytes with associated headers. Arguments: @@ -143,11 +143,10 @@ def make_test_scope( scheme: str, root_path: str, http_version: str, - scope_base: Optional[dict], + scope_base: dict | None, *, _preserve_context: bool = False, -) -> HTTPScope: - ... +) -> HTTPScope: ... @overload @@ -160,11 +159,10 @@ def make_test_scope( scheme: str, root_path: str, http_version: str, - scope_base: Optional[dict], + scope_base: dict | None, *, _preserve_context: bool = False, -) -> WebsocketScope: - ... +) -> WebsocketScope: ... def make_test_scope( @@ -176,7 +174,7 @@ def make_test_scope( scheme: str, root_path: str, http_version: str, - scope_base: Optional[dict], + scope_base: dict | None, *, _preserve_context: bool = False, ) -> Scope: diff --git a/src/quart/typing.py b/src/quart/typing.py index 8437296..97b46e7 100644 --- a/src/quart/typing.py +++ b/src/quart/typing.py @@ -12,12 +12,12 @@ Awaitable, Callable, Dict, - Generator, + Iterator, List, Mapping, Optional, + Sequence, Tuple, - Type, TYPE_CHECKING, Union, ) @@ -30,13 +30,15 @@ WebsocketScope, ) +from .datastructures import FileStorage + try: from typing import Protocol except ImportError: from typing_extensions import Protocol # type: ignore if TYPE_CHECKING: - from werkzeug.datastructures import Headers # noqa: F401 + from werkzeug.datastructures import Authorization, Headers # noqa: F401 from werkzeug.wrappers import Response as WerkzeugResponse from .app import Quart @@ -49,11 +51,12 @@ ResponseValue = Union[ "Response", "WerkzeugResponse", - AnyStr, + bytes, + str, Mapping[str, Any], # any jsonify-able dict List[Any], # any jsonify-able list - AsyncGenerator[AnyStr, None], - Generator[AnyStr, None, None], + Iterator[bytes], + Iterator[str], ] StatusCode = int @@ -62,7 +65,9 @@ HeaderValue = Union[str, List[str], Tuple[str, ...]] # the possible types for HTTP headers -HeadersValue = Union["Headers", Dict[HeaderName, HeaderValue], List[Tuple[HeaderName, HeaderValue]]] +HeadersValue = Union[ + "Headers", Mapping[HeaderName, HeaderValue], Sequence[Tuple[HeaderName, HeaderValue]] +] # The possible types returned by a route function. ResponseReturnValue = Union[ @@ -72,16 +77,17 @@ Tuple[ResponseValue, StatusCode, HeadersValue], ] +ResponseTypes = Union["Response", "WerkzeugResponse"] + AppOrBlueprintKey = Optional[str] # The App key is None, whereas blueprints are named AfterRequestCallable = Union[ - Callable[["Response"], "Response"], Callable[["Response"], Awaitable["Response"]] + Callable[[ResponseTypes], ResponseTypes], Callable[[ResponseTypes], Awaitable[ResponseTypes]] ] AfterServingCallable = Union[Callable[[], None], Callable[[], Awaitable[None]]] AfterWebsocketCallable = Union[ - Callable[["Response"], Optional["Response"]], - Callable[["Response"], Awaitable[Optional["Response"]]], + Callable[[Optional[ResponseTypes]], Optional[ResponseTypes]], + Callable[[Optional[ResponseTypes]], Awaitable[Optional[ResponseTypes]]], ] -BeforeFirstRequestCallable = Union[Callable[[], None], Callable[[], Awaitable[None]]] BeforeRequestCallable = Union[ Callable[[], Optional[ResponseReturnValue]], Callable[[], Awaitable[Optional[ResponseReturnValue]]], @@ -92,8 +98,8 @@ Callable[[], Awaitable[Optional[ResponseReturnValue]]], ] ErrorHandlerCallable = Union[ - Callable[[Exception], ResponseReturnValue], - Callable[[Exception], Awaitable[ResponseReturnValue]], + Callable[[Any], ResponseReturnValue], + Callable[[Any], Awaitable[ResponseReturnValue]], ] ShellContextProcessorCallable = Callable[[], Dict[str, Any]] TeardownCallable = Union[ @@ -121,227 +127,200 @@ class ASGIHTTPProtocol(Protocol): - def __init__(self, app: Quart, scope: HTTPScope) -> None: - ... + def __init__(self, app: Quart, scope: HTTPScope) -> None: ... - async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: - ... + async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... class ASGILifespanProtocol(Protocol): - def __init__(self, app: Quart, scope: LifespanScope) -> None: - ... + def __init__(self, app: Quart, scope: LifespanScope) -> None: ... - async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: - ... + async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... class ASGIWebsocketProtocol(Protocol): - def __init__(self, app: Quart, scope: WebsocketScope) -> None: - ... + def __init__(self, app: Quart, scope: WebsocketScope) -> None: ... - async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: - ... + async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... class TestHTTPConnectionProtocol(Protocol): - push_promises: List[Tuple[str, Headers]] + push_promises: list[tuple[str, Headers]] - def __init__(self, app: Quart, scope: HTTPScope, _preserve_context: bool = False) -> None: - ... + def __init__(self, app: Quart, scope: HTTPScope, _preserve_context: bool = False) -> None: ... - async def send(self, data: bytes) -> None: - ... + async def send(self, data: bytes) -> None: ... - async def send_complete(self) -> None: - ... + async def send_complete(self) -> None: ... - async def receive(self) -> bytes: - ... + async def receive(self) -> bytes: ... - async def disconnect(self) -> None: - ... + async def disconnect(self) -> None: ... - async def __aenter__(self) -> TestHTTPConnectionProtocol: - ... + async def __aenter__(self) -> TestHTTPConnectionProtocol: ... - async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: - ... + async def __aexit__( + self, exc_type: type, exc_value: BaseException, tb: TracebackType + ) -> None: ... - async def as_response(self) -> Response: - ... + async def as_response(self) -> Response: ... class TestWebsocketConnectionProtocol(Protocol): - def __init__(self, app: Quart, scope: WebsocketScope) -> None: - ... + def __init__(self, app: Quart, scope: WebsocketScope) -> None: ... - async def __aenter__(self) -> TestWebsocketConnectionProtocol: - ... + async def __aenter__(self) -> TestWebsocketConnectionProtocol: ... - async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: - ... + async def __aexit__( + self, exc_type: type, exc_value: BaseException, tb: TracebackType + ) -> None: ... - async def receive(self) -> AnyStr: - ... + async def receive(self) -> AnyStr: ... - async def send(self, data: AnyStr) -> None: - ... + async def send(self, data: AnyStr) -> None: ... - async def receive_json(self) -> Any: - ... + async def receive_json(self) -> Any: ... - async def send_json(self, data: Any) -> None: - ... + async def send_json(self, data: Any) -> None: ... - async def close(self, code: int) -> None: - ... + async def close(self, code: int) -> None: ... - async def disconnect(self) -> None: - ... + async def disconnect(self) -> None: ... class TestClientProtocol(Protocol): app: Quart - cookie_jar: Optional[CookieJar] - http_connection_class: Type[TestHTTPConnectionProtocol] - push_promises: List[Tuple[str, Headers]] - websocket_connection_class: Type[TestWebsocketConnectionProtocol] + cookie_jar: CookieJar | None + http_connection_class: type[TestHTTPConnectionProtocol] + push_promises: list[tuple[str, Headers]] + websocket_connection_class: type[TestWebsocketConnectionProtocol] - def __init__(self, app: Quart, use_cookies: bool = True) -> None: - ... + def __init__(self, app: Quart, use_cookies: bool = True) -> None: ... async def open( self, path: str, *, method: str = "GET", - headers: Optional[Union[dict, Headers]] = None, - data: Optional[AnyStr] = None, - form: Optional[dict] = None, - query_string: Optional[dict] = None, - json: Any = None, + headers: dict | Headers | None = None, + data: AnyStr | None = None, + form: dict | None = None, + files: dict[str, FileStorage] | None = None, + query_string: dict | None = None, + json: Any, scheme: str = "http", follow_redirects: bool = False, root_path: str = "", http_version: str = "1.1", - ) -> Response: - ... + scope_base: dict | None = None, + auth: Authorization | tuple[str, str] | None = None, + subdomain: str | None = None, + ) -> Response: ... def request( self, path: str, *, method: str = "GET", - headers: Optional[Union[dict, Headers]] = None, - query_string: Optional[dict] = None, + headers: dict | Headers | None = None, + query_string: dict | None = None, scheme: str = "http", root_path: str = "", http_version: str = "1.1", - ) -> TestHTTPConnectionProtocol: - ... + scope_base: dict | None = None, + auth: Authorization | tuple[str, str] | None = None, + subdomain: str | None = None, + ) -> TestHTTPConnectionProtocol: ... def websocket( self, path: str, *, - headers: Optional[Union[dict, Headers]] = None, - query_string: Optional[dict] = None, + headers: dict | Headers | None = None, + query_string: dict | None = None, scheme: str = "ws", - subprotocols: Optional[List[str]] = None, + subprotocols: list[str] | None = None, root_path: str = "", http_version: str = "1.1", - ) -> TestWebsocketConnectionProtocol: - ... + scope_base: dict | None = None, + auth: Authorization | tuple[str, str] | None = None, + subdomain: str | None = None, + ) -> TestWebsocketConnectionProtocol: ... - async def delete(self, *args: Any, **kwargs: Any) -> Response: - ... + async def delete(self, *args: Any, **kwargs: Any) -> Response: ... - async def get(self, *args: Any, **kwargs: Any) -> Response: - ... + async def get(self, *args: Any, **kwargs: Any) -> Response: ... - async def head(self, *args: Any, **kwargs: Any) -> Response: - ... + async def head(self, *args: Any, **kwargs: Any) -> Response: ... - async def options(self, *args: Any, **kwargs: Any) -> Response: - ... + async def options(self, *args: Any, **kwargs: Any) -> Response: ... - async def patch(self, *args: Any, **kwargs: Any) -> Response: - ... + async def patch(self, *args: Any, **kwargs: Any) -> Response: ... - async def post(self, *args: Any, **kwargs: Any) -> Response: - ... + async def post(self, *args: Any, **kwargs: Any) -> Response: ... - async def put(self, *args: Any, **kwargs: Any) -> Response: - ... + async def put(self, *args: Any, **kwargs: Any) -> Response: ... - async def trace(self, *args: Any, **kwargs: Any) -> Response: - ... + async def trace(self, *args: Any, **kwargs: Any) -> Response: ... def set_cookie( self, server_name: str, key: str, value: str = "", - max_age: Optional[Union[int, timedelta]] = None, - expires: Optional[Union[int, float, datetime]] = None, + max_age: int | timedelta | None = None, + expires: int | float | datetime | None = None, path: str = "/", - domain: Optional[str] = None, + domain: str | None = None, secure: bool = False, httponly: bool = False, samesite: str = None, charset: str = "utf-8", - ) -> None: - ... + ) -> None: ... def delete_cookie( - self, server_name: str, key: str, path: str = "/", domain: Optional[str] = None - ) -> None: - ... + self, server_name: str, key: str, path: str = "/", domain: str | None = None + ) -> None: ... def session_transaction( self, path: str = "/", *, method: str = "GET", - headers: Optional[Union[dict, Headers]] = None, - query_string: Optional[dict] = None, + headers: dict | Headers | None = None, + query_string: dict | None = None, scheme: str = "http", - data: Optional[AnyStr] = None, - form: Optional[dict] = None, + data: AnyStr | None = None, + form: dict | None = None, json: Any = None, root_path: str = "", http_version: str = "1.1", - ) -> AsyncContextManager[SessionMixin]: - ... + ) -> AsyncContextManager[SessionMixin]: ... - async def __aenter__(self) -> TestClientProtocol: - ... + async def __aenter__(self) -> TestClientProtocol: ... - async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: - ... + async def __aexit__( + self, exc_type: type, exc_value: BaseException, tb: TracebackType + ) -> None: ... class TestAppProtocol(Protocol): - def __init__(self, app: Quart) -> None: - ... + def __init__(self, app: Quart) -> None: ... - def test_client(self) -> TestClientProtocol: - ... + def test_client(self) -> TestClientProtocol: ... - async def startup(self) -> None: - ... + async def startup(self) -> None: ... - async def shutdown(self) -> None: - ... + async def shutdown(self) -> None: ... - async def __aenter__(self) -> TestAppProtocol: - ... + async def __aenter__(self) -> TestAppProtocol: ... - async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: - ... + async def __aexit__( + self, exc_type: type, exc_value: BaseException, tb: TracebackType + ) -> None: ... class Event(Protocol): - def is_set(self) -> bool: - ... + def is_set(self) -> bool: ... + + def set(self) -> None: ... diff --git a/src/quart/utils.py b/src/quart/utils.py index dfa709f..7e48ace 100644 --- a/src/quart/utils.py +++ b/src/quart/utils.py @@ -14,13 +14,9 @@ Awaitable, Callable, Coroutine, - Dict, Generator, Iterable, - List, - Tuple, TYPE_CHECKING, - Union, ) from werkzeug.datastructures import Headers @@ -37,7 +33,7 @@ class MustReloadError(Exception): def file_path_to_path(*paths: FilePath) -> Path: # Flask supports bytes paths - safe_paths: List[Union[str, os.PathLike]] = [] + safe_paths: list[str | os.PathLike] = [] for path in paths: if isinstance(path, bytes): safe_paths.append(path.decode()) @@ -94,45 +90,16 @@ def _inner() -> Any: return _gen_wrapper() -def is_coroutine_function(func: Any) -> bool: - # Python < 3.8 does not correctly determine partially wrapped - # coroutine functions are coroutine functions, hence the need for - # this to exist. Code taken from CPython. - if sys.version_info >= (3, 8): - return asyncio.iscoroutinefunction(func) - else: - # Note that there is something special about the AsyncMock - # such that it isn't determined as a coroutine function - # without an explicit check. - try: - from mock import AsyncMock - - if isinstance(func, AsyncMock): - return True - except ImportError: - # Not testing, no asynctest to import - pass - - while inspect.ismethod(func): - func = func.__func__ - while isinstance(func, partial): - func = func.func - if not inspect.isfunction(func): - return False - result = bool(func.__code__.co_flags & inspect.CO_COROUTINE) - return result or getattr(func, "_is_coroutine", None) is asyncio.coroutines._is_coroutine - - -def encode_headers(headers: Headers) -> List[Tuple[bytes, bytes]]: +def encode_headers(headers: Headers) -> list[tuple[bytes, bytes]]: return [(key.lower().encode(), value.encode()) for key, value in headers.items()] -def decode_headers(headers: Iterable[Tuple[bytes, bytes]]) -> Headers: +def decode_headers(headers: Iterable[tuple[bytes, bytes]]) -> Headers: return Headers([(key.decode(), value.decode()) for key, value in headers]) async def observe_changes(sleep: Callable[[float], Awaitable[Any]], shutdown_event: Event) -> None: - last_updates: Dict[Path, float] = {} + last_updates: dict[Path, float] = {} for module in list(sys.modules.values()): filename = getattr(module, "__file__", None) if filename is None: @@ -178,14 +145,14 @@ def restart() -> None: executable = str(script_path.with_suffix(".exe")) else: # python run.py - args.append(str(script_path)) + args = [str(script_path), *args] else: if script_path.is_file() and os.access(script_path, os.X_OK): # hypercorn run:app --reload executable = str(script_path) else: # python run.py - args.append(str(script_path)) + args = [str(script_path), *args] else: # Executed as a module e.g. python -m run module = script_path.stem @@ -195,3 +162,19 @@ def restart() -> None: args[:0] = ["-m", import_name.lstrip(".")] os.execv(executable, [executable] + args) + + +async def cancel_tasks(tasks: set[asyncio.Task]) -> None: + # Cancel any pending, and wait for the cancellation to + # complete i.e. finish any remaining work. + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise_task_exceptions(tasks) + + +def raise_task_exceptions(tasks: set[asyncio.Task]) -> None: + # Raise any unexpected exceptions + for task in tasks: + if not task.cancelled() and task.exception() is not None: + raise task.exception() diff --git a/src/quart/views.py b/src/quart/views.py index cf3565f..9df77b0 100644 --- a/src/quart/views.py +++ b/src/quart/views.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable, ClassVar, Collection, List, Optional, Type +from typing import Any, Callable, ClassVar, Collection from .globals import current_app, request from .typing import ResponseReturnValue, RouteCallable @@ -40,9 +40,9 @@ async def dispatch_request(id): for every request. """ - decorators: ClassVar[List[Callable]] = [] - methods: ClassVar[Optional[Collection[str]]] = None - provide_automatic_options: ClassVar[Optional[bool]] = None + decorators: ClassVar[list[Callable]] = [] + methods: ClassVar[Collection[str] | None] = None + provide_automatic_options: ClassVar[bool | None] = None init_every_request: ClassVar[bool] = True async def dispatch_request(self, **kwargs: Any) -> ResponseReturnValue: @@ -65,7 +65,7 @@ async def view(**kwargs: Any) -> ResponseReturnValue: self = cls(*class_args, **class_kwargs) async def view(**kwargs: Any) -> ResponseReturnValue: - return current_app.ensure_async(self.dispatch_request)(**kwargs) + return await current_app.ensure_async(self.dispatch_request)(**kwargs) if cls.decorators: view.__name__ = name @@ -73,7 +73,7 @@ async def view(**kwargs: Any) -> ResponseReturnValue: for decorator in cls.decorators: view = decorator(view) - view.view_class: Type[View] = cls # type: ignore + view.view_class: type[View] = cls # type: ignore view.__name__ = name view.__doc__ = cls.__doc__ view.__module__ = cls.__module__ diff --git a/src/quart/wrappers/base.py b/src/quart/wrappers/base.py index dfdca88..7b7971e 100644 --- a/src/quart/wrappers/base.py +++ b/src/quart/wrappers/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from hypercorn.typing import WWWScope from werkzeug.datastructures import Headers @@ -27,9 +27,9 @@ class BaseRequestWebsocket(SansIORequest): """ json_module: json.provider.JSONProvider = json # type: ignore - routing_exception: Optional[Exception] = None - url_rule: Optional["QuartRule"] = None - view_args: Optional[Dict[str, Any]] = None + routing_exception: Exception | None = None + url_rule: QuartRule | None = None + view_args: dict[str, Any] | None = None def __init__( self, @@ -73,7 +73,7 @@ def __init__( self.scope = scope @property - def max_content_length(self) -> Optional[int]: + def max_content_length(self) -> int | None: """Read-only view of the ``MAX_CONTENT_LENGTH`` config key.""" if current_app: return current_app.config["MAX_CONTENT_LENGTH"] @@ -81,7 +81,7 @@ def max_content_length(self) -> Optional[int]: return None @property - def endpoint(self) -> Optional[str]: + def endpoint(self) -> str | None: """Returns the corresponding endpoint matched for this request. This can be None if the request has not been matched with a @@ -93,7 +93,7 @@ def endpoint(self) -> Optional[str]: return None @property - def blueprint(self) -> Optional[str]: + def blueprint(self) -> str | None: """Returns the blueprint the matched endpoint belongs to. This can be None if the request has not been matched or the @@ -105,7 +105,7 @@ def blueprint(self) -> Optional[str]: return None @property - def blueprints(self) -> List[str]: + def blueprints(self) -> list[str]: """Return the names of the current blueprints. The returned list is ordered from the current blueprint, upwards through parent blueprints. diff --git a/src/quart/wrappers/request.py b/src/quart/wrappers/request.py index 4bbbba5..a80fa81 100644 --- a/src/quart/wrappers/request.py +++ b/src/quart/wrappers/request.py @@ -1,21 +1,10 @@ from __future__ import annotations import asyncio -from typing import ( - Any, - AnyStr, - Awaitable, - Callable, - Dict, - Generator, - List, - NoReturn, - Optional, - overload, -) +from typing import Any, AnyStr, Awaitable, Callable, Generator, NoReturn, overload from hypercorn.typing import HTTPScope -from werkzeug.datastructures import CombinedMultiDict, Headers, MultiDict +from werkzeug.datastructures import CombinedMultiDict, Headers, iter_multi_items, MultiDict from werkzeug.exceptions import BadRequest, RequestEntityTooLarge, RequestTimeout from .base import BaseRequestWebsocket @@ -53,9 +42,7 @@ class Body: it. """ - def __init__( - self, expected_content_length: Optional[int], max_content_length: Optional[int] - ) -> None: + def __init__(self, expected_content_length: int | None, max_content_length: int | None) -> None: self._data = bytearray() self._complete: asyncio.Event = asyncio.Event() self._has_data: asyncio.Event = asyncio.Event() @@ -63,7 +50,7 @@ def __init__( # Exceptions must be raised within application (not ASGI) # calls, this is achieved by having the ASGI methods set this # to an exception on error. - self._must_raise: Optional[Exception] = None + self._must_raise: Exception | None = None if ( expected_content_length is not None and max_content_length is not None @@ -71,7 +58,7 @@ def __init__( ): self._must_raise = RequestEntityTooLarge() - def __aiter__(self) -> "Body": + def __aiter__(self) -> Body: return self async def __anext__(self) -> bytes: @@ -155,8 +142,8 @@ def __init__( http_version: str, scope: HTTPScope, *, - max_content_length: Optional[int] = None, - body_timeout: Optional[int] = None, + max_content_length: int | None = None, + body_timeout: int | None = None, send_push_promise: Callable[[str, Headers], Awaitable[None]], ) -> None: """Create a request object. @@ -185,9 +172,9 @@ def __init__( ) self.body_timeout = body_timeout self.body = self.body_class(self.content_length, max_content_length) - self._cached_json: Dict[bool, Any] = {False: Ellipsis, True: Ellipsis} - self._form: Optional[MultiDict] = None - self._files: Optional[MultiDict] = None + self._cached_json: dict[bool, Any] = {False: Ellipsis, True: Ellipsis} + self._form: MultiDict | None = None + self._files: MultiDict | None = None self._parsing_lock = self.lock_class() self._send_push_promise = send_push_promise @@ -200,18 +187,17 @@ async def data(self) -> bytes: return await self.get_data(as_text=False, parse_form_data=True) @overload - async def get_data(self, cache: bool, as_text: Literal[False], parse_form_data: bool) -> bytes: - ... + async def get_data( + self, cache: bool, as_text: Literal[False], parse_form_data: bool + ) -> bytes: ... @overload - async def get_data(self, cache: bool, as_text: Literal[True], parse_form_data: bool) -> str: - ... + async def get_data(self, cache: bool, as_text: Literal[True], parse_form_data: bool) -> str: ... @overload async def get_data( self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False - ) -> AnyStr: - ... + ) -> AnyStr: ... async def get_data( self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False @@ -239,7 +225,7 @@ async def get_data( self.body.clear() if as_text: - return raw_data.decode(self.charset, self.encoding_errors) + return raw_data.decode() else: return raw_data @@ -253,7 +239,7 @@ async def values(self) -> CombinedMultiDict: form = await self.form sources.append(form) - multidict_sources: List[MultiDict] = [] + multidict_sources: list[MultiDict] = [] for source in sources: if not isinstance(source, MultiDict): multidict_sources.append(MultiDict(source)) @@ -284,8 +270,6 @@ async def files(self) -> MultiDict: def make_form_data_parser(self) -> FormDataParser: return self.form_data_parser_class( - charset=self.charset, - errors=self.encoding_errors, max_content_length=self.max_content_length, cls=self.parameter_storage_class, ) @@ -362,3 +346,7 @@ async def send_push_promise(self, path: str) -> None: for value in self.headers.getlist(name): headers.add(name, value) await self._send_push_promise(path, headers) + + async def close(self) -> None: + for _key, value in iter_multi_items(self._files or ()): + value.close() diff --git a/src/quart/wrappers/response.py b/src/quart/wrappers/response.py index 75e916a..065eabf 100644 --- a/src/quart/wrappers/response.py +++ b/src/quart/wrappers/response.py @@ -13,10 +13,8 @@ AsyncIterable, AsyncIterator, Iterable, - Optional, overload, TYPE_CHECKING, - Union, ) from aiofiles import open as async_open @@ -73,19 +71,16 @@ def __init__(self, data: bytes) -> None: self.begin = 0 self.end = len(self.data) - async def __aenter__(self) -> "DataBody": + async def __aenter__(self) -> DataBody: return self async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: pass - def __aiter__(self) -> AsyncIterator: - async def _aiter() -> AsyncGenerator[bytes, None]: - yield self.data[self.begin : self.end] - - return _aiter() + def __aiter__(self) -> AsyncIterator[bytes]: + return _DataBodyGen(self) - async def make_conditional(self, begin: int, end: Optional[int]) -> int: + async def make_conditional(self, begin: int, end: int | None) -> int: self.begin = begin self.end = len(self.data) if end is None else end self.end = min(len(self.data), self.end) @@ -93,8 +88,21 @@ async def make_conditional(self, begin: int, end: Optional[int]) -> int: return len(self.data) +class _DataBodyGen(AsyncIterator[bytes]): + def __init__(self, data_body: DataBody): + self._data_body = data_body + self._iterated = False + + async def __anext__(self) -> bytes: + if self._iterated: + raise StopAsyncIteration + + self._iterated = True + return self._data_body.data[self._data_body.begin : self._data_body.end] + + class IterableBody(ResponseBody): - def __init__(self, iterable: Union[AsyncGenerator[bytes, None], Iterable]) -> None: + def __init__(self, iterable: AsyncGenerator[bytes, None] | Iterable) -> None: self.iter: AsyncGenerator[bytes, None] if isasyncgen(iterable): self.iter = iterable @@ -108,7 +116,7 @@ async def _aiter() -> AsyncGenerator[bytes, None]: self.iter = _aiter() - async def __aenter__(self) -> "IterableBody": + async def __aenter__(self) -> IterableBody: return self async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: @@ -132,19 +140,17 @@ class FileBody(ResponseBody): buffer_size = 8192 - def __init__( - self, file_path: Union[str, PathLike], *, buffer_size: Optional[int] = None - ) -> None: + def __init__(self, file_path: str | PathLike, *, buffer_size: int | None = None) -> None: self.file_path = file_path_to_path(file_path) self.size = self.file_path.stat().st_size self.begin = 0 self.end = self.size if buffer_size is not None: self.buffer_size = buffer_size - self.file: Optional[AsyncBufferedIOBase] = None + self.file: AsyncBufferedIOBase | None = None self.file_manager: AiofilesContextManager[None, None, AsyncBufferedReader] = None - async def __aenter__(self) -> "FileBody": + async def __aenter__(self) -> FileBody: self.file_manager = async_open(self.file_path, mode="rb") self.file = await self.file_manager.__aenter__() await self.file.seek(self.begin) @@ -153,7 +159,7 @@ async def __aenter__(self) -> "FileBody": async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: await self.file_manager.__aexit__(exc_type, exc_value, tb) - def __aiter__(self) -> "FileBody": + def __aiter__(self) -> FileBody: return self async def __anext__(self) -> bytes: @@ -168,7 +174,7 @@ async def __anext__(self) -> bytes: else: raise StopAsyncIteration() - async def make_conditional(self, begin: int, end: Optional[int]) -> int: + async def make_conditional(self, begin: int, end: int | None) -> int: self.begin = begin self.end = self.size if end is None else end self.end = min(self.size, self.end) @@ -190,7 +196,7 @@ class IOBody(ResponseBody): buffer_size = 8192 - def __init__(self, io_stream: BytesIO, *, buffer_size: Optional[int] = None) -> None: + def __init__(self, io_stream: BytesIO, *, buffer_size: int | None = None) -> None: self.io_stream = io_stream self.size = io_stream.getbuffer().nbytes self.begin = 0 @@ -198,14 +204,14 @@ def __init__(self, io_stream: BytesIO, *, buffer_size: Optional[int] = None) -> if buffer_size is not None: self.buffer_size = buffer_size - async def __aenter__(self) -> "IOBody": + async def __aenter__(self) -> IOBody: self.io_stream.seek(self.begin) return self async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: return None - def __aiter__(self) -> "IOBody": + def __aiter__(self) -> IOBody: return self async def __anext__(self) -> bytes: @@ -220,7 +226,7 @@ async def __anext__(self) -> bytes: else: raise StopAsyncIteration() - async def make_conditional(self, begin: int, end: Optional[int]) -> int: + async def make_conditional(self, begin: int, end: int | None) -> int: self.begin = begin self.end = self.size if end is None else end self.end = min(self.size, self.end) @@ -256,11 +262,11 @@ class Response(SansIOResponse): def __init__( self, - response: Union[ResponseBody, AnyStr, Iterable, None] = None, - status: Optional[int] = None, - headers: Optional[Union[dict, Headers]] = None, - mimetype: Optional[str] = None, - content_type: Optional[str] = None, + response: ResponseBody | AnyStr | Iterable | None = None, + status: int | None = None, + headers: dict | Headers | None = None, + mimetype: str | None = None, + content_type: str | None = None, ) -> None: """Create a response object. @@ -302,16 +308,13 @@ def max_cookie_size(self) -> int: # type: ignore return super().max_cookie_size @overload - async def get_data(self, as_text: Literal[True]) -> str: - ... + async def get_data(self, as_text: Literal[True]) -> str: ... @overload - async def get_data(self, as_text: Literal[False]) -> bytes: - ... + async def get_data(self, as_text: Literal[False]) -> bytes: ... @overload - async def get_data(self, as_text: bool = True) -> AnyStr: - ... + async def get_data(self, as_text: bool = True) -> AnyStr: ... async def get_data(self, as_text: bool = False) -> AnyStr: """Return the body data.""" @@ -321,7 +324,7 @@ async def get_data(self, as_text: bool = False) -> AnyStr: async with self.response as body: async for data in body: if as_text: - result += data.decode(self.charset) + result += data.decode() else: result += data return result # type: ignore @@ -332,7 +335,7 @@ def set_data(self, data: AnyStr) -> None: This will encode using the :attr:`charset`. """ if isinstance(data, str): - bytes_data = data.encode(self.charset) + bytes_data = data.encode() else: bytes_data = data self.response = self.data_body_class(bytes_data) @@ -371,7 +374,7 @@ async def get_json(self, force: bool = False, silent: bool = False) -> Any: raise return None - def _is_range_request_processable(self, request: "Request") -> bool: + def _is_range_request_processable(self, request: Request) -> bool: return ( "If-Range" not in request.headers or not is_resource_modified( @@ -389,9 +392,9 @@ def _is_range_request_processable(self, request: "Request") -> bool: async def _process_range_request( self, - request: "Request", - complete_length: Optional[int] = None, - accept_ranges: Optional[str] = None, + request: Request, + complete_length: int | None = None, + accept_ranges: str | None = None, ) -> bool: if ( accept_ranges is None @@ -421,7 +424,7 @@ async def _process_range_request( self.content_range = ContentRange( request_range.units, self.response.begin, # type: ignore - self.response.end - 1, # type: ignore + self.response.end, # type: ignore complete_length, ) self.status_code = 206 @@ -430,10 +433,10 @@ async def _process_range_request( async def make_conditional( self, - request: "Request", - accept_ranges: Union[bool, str] = False, - complete_length: Optional[int] = None, - ) -> "Response": + request: Request, + accept_ranges: bool | str = False, + complete_length: int | None = None, + ) -> Response: if request.method in {"GET", "HEAD"}: accept_ranges = _clean_accept_ranges(accept_ranges) is206 = await self._process_range_request(request, complete_length, accept_ranges) @@ -452,6 +455,8 @@ async def make_conditional( self.status_code = 412 else: self.status_code = 304 + self.set_data(b"") + del self.content_length return self @@ -463,17 +468,17 @@ async def iter_encode(self) -> AsyncGenerator[bytes, None]: async with self.response as response_body: async for item in response_body: if isinstance(item, str): - yield item.encode(self.charset) + yield item.encode() else: yield item async def freeze(self) -> None: """Freeze this object ready for pickling.""" - self.set_data((await self.get_data())) + self.set_data(await self.get_data()) async def add_etag(self, overwrite: bool = False, weak: bool = False) -> None: if overwrite or "etag" not in self.headers: - self.set_etag(md5((await self.get_data(as_text=False))).hexdigest(), weak) + self.set_etag(md5(await self.get_data(as_text=False)).hexdigest(), weak) def _set_or_pop_header(self, key: str, value: str) -> None: if value == "": @@ -482,7 +487,7 @@ def _set_or_pop_header(self, key: str, value: str) -> None: self.headers[key] = value -def _clean_accept_ranges(accept_ranges: Union[bool, str]) -> str: +def _clean_accept_ranges(accept_ranges: bool | str) -> str: if accept_ranges is True: return "bytes" elif accept_ranges is False: diff --git a/src/quart/wrappers/websocket.py b/src/quart/wrappers/websocket.py index 837e08b..ff8d72a 100644 --- a/src/quart/wrappers/websocket.py +++ b/src/quart/wrappers/websocket.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import Any, AnyStr, Callable, List, Optional, Union +from typing import Any, AnyStr, Callable from hypercorn.typing import WebsocketScope from werkzeug.datastructures import Headers @@ -18,7 +18,7 @@ def __init__( headers: Headers, root_path: str, http_version: str, - subprotocols: List[str], + subprotocols: list[str], receive: Callable, send: Callable, accept: Callable, @@ -48,7 +48,7 @@ def __init__( self._subprotocols = subprotocols @property - def requested_subprotocols(self) -> List[str]: + def requested_subprotocols(self) -> list[str]: return self._subprotocols async def receive(self) -> AnyStr: @@ -79,7 +79,7 @@ async def send_json(self, *args: Any, **kwargs: Any) -> None: await self.send(raw) async def accept( - self, headers: Optional[Union[dict, Headers]] = None, subprotocol: Optional[str] = None + self, headers: dict | Headers | None = None, subprotocol: str | None = None ) -> None: """Manually chose to accept the websocket connection. diff --git a/tests/assets/config.json b/tests/assets/config.json deleted file mode 100644 index a249fdd..0000000 --- a/tests/assets/config.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "FOO": "bar", - "BOB": "jeff" -} diff --git a/tests/assets/config.py b/tests/assets/config.py deleted file mode 100644 index 036e60a..0000000 --- a/tests/assets/config.py +++ /dev/null @@ -1,4 +0,0 @@ -from __future__ import annotations - -FOO = "bar" -BOB = "jeff" diff --git a/tests/assets/config.toml b/tests/assets/config.toml deleted file mode 100644 index 70ec155..0000000 --- a/tests/assets/config.toml +++ /dev/null @@ -1,2 +0,0 @@ -BOB = "jeff" -FOO = "bar" diff --git a/tests/test_app.py b/tests/test_app.py index 2f31ff1..02f012d 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,10 +1,10 @@ from __future__ import annotations import asyncio -from typing import AsyncGenerator, NoReturn, Optional, Set, Union +from typing import AsyncGenerator, NoReturn +from unittest.mock import AsyncMock import pytest -from _pytest.monkeypatch import MonkeyPatch from hypercorn.typing import HTTPScope, WebsocketScope from werkzeug.datastructures import Headers from werkzeug.exceptions import InternalServerError @@ -14,17 +14,11 @@ from quart.globals import session, websocket from quart.sessions import SecureCookieSession, SessionInterface from quart.testing import no_op_push, WebsocketResponseError -from quart.typing import ResponseReturnValue +from quart.typing import ResponseReturnValue, ResponseTypes from quart.wrappers import Request, Response TEST_RESPONSE = Response("") -try: - from unittest.mock import AsyncMock -except ImportError: - # Python < 3.8 - from mock import AsyncMock # type: ignore - class SimpleError(Exception): pass @@ -68,7 +62,7 @@ async def route3() -> str: ], ) def test_add_url_rule_methods( - methods: Set[str], required_methods: Set[str], automatic_options: bool + methods: set[str], required_methods: set[str], automatic_options: bool ) -> None: app = Quart(__name__) @@ -83,8 +77,6 @@ def route() -> str: "/", "end", route, methods=non_func_methods, provide_automatic_options=automatic_options ) result = {"PATCH"} if not methods else set() - if automatic_options: - result.add("OPTIONS") result.update(methods) result.update(required_methods) if "GET" in result: @@ -95,19 +87,19 @@ def route() -> str: @pytest.mark.parametrize( "methods, arg_automatic, func_automatic, expected_methods, expected_automatic", [ - ({"GET"}, True, None, {"HEAD", "GET", "OPTIONS"}, True), + ({"GET"}, True, None, {"HEAD", "GET"}, True), ({"GET"}, None, None, {"HEAD", "GET", "OPTIONS"}, True), - ({"GET"}, None, True, {"HEAD", "GET", "OPTIONS"}, True), + ({"GET"}, None, True, {"HEAD", "GET"}, True), ({"GET", "OPTIONS"}, None, None, {"HEAD", "GET", "OPTIONS"}, False), ({"GET"}, False, True, {"HEAD", "GET"}, False), ({"GET"}, None, False, {"HEAD", "GET"}, False), ], ) def test_add_url_rule_automatic_options( - methods: Set[str], - arg_automatic: Optional[bool], - func_automatic: Optional[bool], - expected_methods: Set[str], + methods: set[str], + arg_automatic: bool | None, + func_automatic: bool | None, + expected_methods: set[str], expected_automatic: bool, ) -> None: app = Quart(__name__) @@ -172,12 +164,12 @@ async def route(subdomain: str) -> str: False, ), (InternalServerError(), InternalServerError().get_response(), False), - ((val for val in "abcd"), Response((val for val in "abcd")), False), + ((val for val in "abcd"), Response(val for val in "abcd"), False), (int, None, True), ], ) async def test_make_response( - result: ResponseReturnValue, expected: Union[Response, WerkzeugResponse], raises: bool + result: ResponseReturnValue, expected: Response | WerkzeugResponse, raises: bool ) -> None: app = Quart(__name__) app.config["RESPONSE_TIMEOUT"] = None @@ -195,36 +187,6 @@ async def test_make_response( assert response.get_data() == expected.get_data() -@pytest.mark.parametrize( - "quart_env, quart_debug, expected_env, expected_debug", - [ - (None, None, "production", False), - ("development", None, "development", True), - ("development", False, "development", False), - ], -) -def test_env_and_debug_environments( - quart_env: Optional[str], - quart_debug: Optional[bool], - expected_env: bool, - expected_debug: bool, - monkeypatch: MonkeyPatch, -) -> None: - if quart_env is None: - monkeypatch.delenv("QUART_ENV", raising=False) - else: - monkeypatch.setenv("QUART_ENV", quart_env) - - if quart_debug is None: - monkeypatch.delenv("QUART_DEBUG", raising=False) - else: - monkeypatch.setenv("QUART_DEBUG", str(quart_debug)) - - app = Quart(__name__) - assert app.env == expected_env - assert app.debug is expected_debug - - @pytest.fixture(name="basic_app") def _basic_app() -> Quart: app = Quart(__name__) @@ -258,7 +220,7 @@ def before() -> None: async def test_app_after_request_exception(basic_app: Quart) -> None: @basic_app.after_request - def after(_: Response) -> None: + def after(_: ResponseTypes) -> None: raise Exception() test_client = basic_app.test_client() @@ -268,7 +230,7 @@ def after(_: Response) -> None: async def test_app_after_request_handler_exception(basic_app: Quart) -> None: @basic_app.after_request - def after(_: Response) -> None: + def after(_: ResponseTypes) -> None: raise Exception() test_client = basic_app.test_client() diff --git a/tests/test_asgi.py b/tests/test_asgi.py index f1f34e9..13f0da4 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -1,22 +1,21 @@ from __future__ import annotations import asyncio -from typing import Optional +from unittest.mock import AsyncMock, Mock import pytest from hypercorn.typing import ASGIReceiveEvent, ASGISendEvent, HTTPScope, WebsocketScope from werkzeug.datastructures import Headers from quart import Quart -from quart.asgi import _convert_version, ASGIHTTPConnection, ASGIWebsocketConnection +from quart.asgi import ( + _convert_version, + _handle_exception, + ASGIHTTPConnection, + ASGIWebsocketConnection, +) from quart.utils import encode_headers -try: - from unittest.mock import AsyncMock -except ImportError: - # Python < 3.8 - from mock import AsyncMock # type: ignore - @pytest.mark.parametrize("headers, expected", [([(b"host", b"quart")], "quart"), ([], "")]) async def test_http_1_0_host_header(headers: list, expected: str) -> None: @@ -175,6 +174,32 @@ def test_http_path_from_absolute_target() -> None: assert request.path == "/path" +@pytest.mark.parametrize( + "path, expected", + [("/app", "/ "), ("/", "/ "), ("/app/", "/"), ("/app/2", "/2")], +) +def test_http_path_with_root_path(path: str, expected: str) -> None: + app = Quart(__name__) + scope: HTTPScope = { + "type": "http", + "asgi": {}, + "http_version": "1.1", + "method": "GET", + "scheme": "https", + "path": path, + "raw_path": b"/", + "query_string": b"", + "root_path": "/app", + "headers": [(b"host", b"quart")], + "client": ("127.0.0.1", 80), + "server": None, + "extensions": {}, + } + connection = ASGIHTTPConnection(app, scope) + request = connection._create_request_from_scope(lambda: None) # type: ignore + assert request.path == expected + + def test_websocket_path_from_absolute_target() -> None: app = Quart(__name__) scope: WebsocketScope = { @@ -197,6 +222,32 @@ def test_websocket_path_from_absolute_target() -> None: assert websocket.path == "/path" +@pytest.mark.parametrize( + "path, expected", + [("/app", "/ "), ("/", "/ "), ("/app/", "/"), ("/app/2", "/2")], +) +def test_websocket_path_with_root_path(path: str, expected: str) -> None: + app = Quart(__name__) + scope: WebsocketScope = { + "type": "websocket", + "asgi": {}, + "http_version": "1.1", + "scheme": "wss", + "path": path, + "raw_path": b"/", + "query_string": b"", + "root_path": "/app", + "headers": [(b"host", b"quart")], + "client": ("127.0.0.1", 80), + "server": None, + "subprotocols": [], + "extensions": {"websocket.http.response": {}}, + } + connection = ASGIWebsocketConnection(app, scope) + websocket = connection._create_websocket_from_scope(lambda: None) # type: ignore + assert websocket.path == expected + + @pytest.mark.parametrize( "scope, headers, subprotocol, has_headers", [ @@ -207,7 +258,7 @@ def test_websocket_path_from_absolute_target() -> None: ], ) async def test_websocket_accept_connection( - scope: dict, headers: Headers, subprotocol: Optional[str], has_headers: bool + scope: dict, headers: Headers, subprotocol: str | None, has_headers: bool ) -> None: connection = ASGIWebsocketConnection(Quart(__name__), scope) # type: ignore mock_send = AsyncMock() @@ -233,7 +284,7 @@ async def test_websocket_accept_connection_warns(websocket_scope: WebsocketScope async def mock_send(message: ASGISendEvent) -> None: pass - with pytest.warns(None): + with pytest.warns(UserWarning): await connection.accept_connection(mock_send, Headers({"a": "b"}), None) @@ -255,3 +306,25 @@ def test_http_asgi_scope_from_request() -> None: connection = ASGIHTTPConnection(app, scope) # type: ignore request = connection._create_request_from_scope(lambda: None) # type: ignore assert request.scope["test_result"] == "PASSED" # type: ignore + + +@pytest.mark.parametrize( + "propagate_exceptions, testing, raises", + [ + (True, False, False), + (True, True, True), + (False, True, True), + (False, False, True), + ], +) +async def test__handle_exception(propagate_exceptions: bool, testing: bool, raises: bool) -> None: + app = Mock() + app.config = {} + app.config["PROPAGATE_EXCEPTIONS"] = propagate_exceptions + app.testing = testing + + if raises: + with pytest.raises(ValueError): + await _handle_exception(app, ValueError()) + else: + await _handle_exception(app, ValueError()) diff --git a/tests/test_basic.py b/tests/test_basic.py index b96203a..339e16e 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -44,9 +44,9 @@ async def error() -> ResponseReturnValue: abort(409) return "OK" - @app.route("/param/") - async def param() -> ResponseReturnValue: - return param + @app.route("/param/") + async def param(value: str) -> ResponseReturnValue: + return value @app.route("/stream") async def stream() -> ResponseReturnValue: @@ -54,7 +54,7 @@ async def _gen() -> AsyncGenerator[str, None]: yield "Hello " yield "World" - return _gen() + return _gen() # type: ignore @app.errorhandler(409) async def generic_http_handler(_: Exception) -> ResponseReturnValue: @@ -107,21 +107,21 @@ async def test_json(app: Quart) -> None: test_client = app.test_client() response = await test_client.post("/json/", json={"value": "json"}) assert response.status_code == 200 - assert b'{"value":"json"}' == (await response.get_data()) # type: ignore + assert b'{"value":"json"}\n' == (await response.get_data()) # type: ignore async def test_implicit_json(app: Quart) -> None: test_client = app.test_client() response = await test_client.post("/implicit_json/", json={"value": "json"}) assert response.status_code == 200 - assert b'{"value":"json"}' == (await response.get_data()) # type: ignore + assert b'{"value":"json"}\n' == (await response.get_data()) # type: ignore async def test_implicit_json_list(app: Quart) -> None: test_client = app.test_client() response = await test_client.post("/implicit_json/", json=["a", 2]) assert response.status_code == 200 - assert b'["a",2]' == (await response.get_data()) # type: ignore + assert b'["a",2]\n' == (await response.get_data()) # type: ignore async def test_werkzeug(app: Quart) -> None: @@ -141,7 +141,7 @@ async def test_generic_error(app: Quart) -> None: async def test_url_defaults(app: Quart) -> None: @app.url_defaults def defaults(_: str, values: dict) -> None: - values["param"] = "hello" + values["value"] = "hello" async with app.test_request_context("/"): assert url_for("param") == "/param/hello" @@ -159,6 +159,7 @@ async def test_make_response_str(app: Quart) -> None: assert response.status_code == 200 assert (await response.get_data()) == b"Result" # type: ignore + response = await app.make_response(("Result", 200)) response = await app.make_response(("Result", {"name": "value"})) assert response.status_code == 200 assert (await response.get_data()) == b"Result" # type: ignore @@ -186,6 +187,15 @@ async def test_make_response_response(app: Quart) -> None: assert response.headers["name"] == "value" +async def test_make_response_errors(app: Quart) -> None: + with pytest.raises(TypeError): + await app.make_response(("Result", {"name": "value"}, 200)) # type: ignore + with pytest.raises(TypeError): + await app.make_response(("Result", {"name": "value"}, 200, "a")) # type: ignore + with pytest.raises(TypeError): + await app.make_response(("Result",)) # type: ignore + + async def test_websocket(app: Quart) -> None: test_client = app.test_client() data = b"bob" diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index f69dc0b..759f118 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import cast, List, Optional +from typing import cast import click import pytest @@ -73,6 +73,22 @@ async def route() -> ResponseReturnValue: assert request.blueprint == "blueprint" +async def test_empty_path_with_url_prefix() -> None: + app = Quart(__name__) + prefix = Blueprint("prefix", __name__, url_prefix="/prefix") + + @prefix.route("") + async def empty_path_route() -> ResponseReturnValue: + return "OK" + + app.register_blueprint(prefix) + + test_client = app.test_client() + response = await test_client.get("/prefix") + assert response.status_code == 200 + assert await response.get_data() == b"OK" # type: ignore + + async def test_blueprint_template_filter() -> None: app = Quart(__name__) blueprint = Blueprint("blueprint", __name__) @@ -141,7 +157,7 @@ async def post(self) -> ResponseReturnValue: (Ellipsis, ["blueprint", "cmd"]), ], ) -def test_cli_blueprints(cli_group: Optional[str], args: List[str]) -> None: +def test_cli_blueprints(cli_group: str | None, args: list[str]) -> None: app = Quart(__name__) blueprint = Blueprint("blueprint", __name__, cli_group=cli_group) @@ -168,10 +184,10 @@ def command() -> None: ], ) async def test_nesting_url_prefixes( - parent_init: Optional[str], - child_init: Optional[str], - parent_registration: Optional[str], - child_registration: Optional[str], + parent_init: str | None, + child_init: str | None, + parent_registration: str | None, + child_registration: str | None, ) -> None: app = Quart(__name__) @@ -190,6 +206,39 @@ def index() -> ResponseReturnValue: assert response.status_code == 200 +@pytest.mark.parametrize( + "parent_subdomain, child_subdomain, expected_subdomain", + [ + (None, None, None), + ("parent", None, "parent"), + (None, "child", "child"), + ("parent", "child", "child.parent"), + ], +) +async def test_nesting_subdomains( + parent_subdomain: str | None, + child_subdomain: str | None, + expected_subdomain: str | None, +) -> None: + app = Quart(__name__) + domain_name = "domain.tld" + app.config["SERVER_NAME"] = domain_name + + parent = Blueprint("parent", __name__, subdomain=parent_subdomain) + child = Blueprint("child", __name__, subdomain=child_subdomain) + + @child.route("/") + def index() -> ResponseReturnValue: + return "index" + + parent.register_blueprint(child) + app.register_blueprint(parent) + + test_client = app.test_client() + response = await test_client.get("/", subdomain=expected_subdomain) + assert response.status_code == 200 + + async def test_nesting_and_sibling() -> None: app = Quart(__name__) @@ -217,9 +266,13 @@ def test_unique_blueprint_names() -> None: bp2 = Blueprint("bp", __name__) app.register_blueprint(bp) - app.register_blueprint(bp) # Should not error + + with pytest.raises(ValueError): + app.register_blueprint(bp) + with pytest.raises(ValueError): app.register_blueprint(bp2, url_prefix="/a") + app.register_blueprint(bp, name="alt") @@ -339,7 +392,7 @@ async def app_before1() -> None: g.setdefault("seen", []).append("app_1") @app.teardown_request - async def app_teardown1(exc: Optional[BaseException] = None) -> None: + async def app_teardown1(exc: BaseException | None = None) -> None: assert g.seen.pop() == "app_1" @app.before_request @@ -347,7 +400,7 @@ async def app_before2() -> None: g.setdefault("seen", []).append("app_2") @app.teardown_request - async def app_teardown2(exc: Optional[BaseException] = None) -> None: + async def app_teardown2(exc: BaseException | None = None) -> None: assert g.seen.pop() == "app_2" @app.context_processor @@ -359,7 +412,7 @@ async def parent_before1() -> None: g.setdefault("seen", []).append("parent_1") @parent.teardown_request - async def parent_teardown1(exc: Optional[BaseException] = None) -> None: + async def parent_teardown1(exc: BaseException | None = None) -> None: assert g.seen.pop() == "parent_1" @parent.before_request @@ -367,7 +420,7 @@ async def parent_before2() -> None: g.setdefault("seen", []).append("parent_2") @parent.teardown_request - async def parent_teardown2(exc: Optional[BaseException] = None) -> None: + async def parent_teardown2(exc: BaseException | None = None) -> None: assert g.seen.pop() == "parent_2" @parent.context_processor @@ -379,7 +432,7 @@ async def child_before1() -> None: g.setdefault("seen", []).append("child_1") @child.teardown_request - async def child_teardown1(exc: Optional[BaseException] = None) -> None: + async def child_teardown1(exc: BaseException | None = None) -> None: assert g.seen.pop() == "child_1" @child.before_request @@ -387,7 +440,7 @@ async def child_before2() -> None: g.setdefault("seen", []).append("child_2") @child.teardown_request - async def child_teardown2(exc: Optional[BaseException] = None) -> None: + async def child_teardown2(exc: BaseException | None = None) -> None: assert g.seen.pop() == "child_2" @child.context_processor diff --git a/tests/test_cli.py b/tests/test_cli.py index 9dd17da..5fe0e7a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -17,7 +17,6 @@ @pytest.fixture(scope="module") def reset_env() -> None: - os.environ.pop("QUART_ENV", None) os.environ.pop("QUART_DEBUG", None) @@ -33,16 +32,10 @@ def loadable_app(monkeypatch: MonkeyPatch) -> Mock: @pytest.fixture(name="dev_app") def loadable_dev_app(app: Mock) -> Mock: - app.env == "development" app.debug = True return app -@pytest.fixture(name="dev_env") -def dev_env_patch(monkeypatch: MonkeyPatch) -> None: - monkeypatch.setenv("QUART_ENV", "development") - - @pytest.fixture(name="debug_env") def debug_env_patch(monkeypatch: MonkeyPatch) -> None: monkeypatch.setenv("QUART_DEBUG", "true") @@ -84,17 +77,7 @@ def test_run_command(app: Mock) -> None: ) -def test_run_command_development(dev_app: Mock, dev_env: None) -> None: - runner = CliRunner() - runner.invoke(cli, ["--app", "module:app", "run"]) - dev_app.run.assert_called_once_with( - debug=True, host="127.0.0.1", port=5000, certfile=None, keyfile=None, use_reloader=True - ) - - -def test_run_command_development_debug_disabled( - dev_app: Mock, dev_env: None, no_debug_env: None -) -> None: +def test_run_command_development_debug_disabled(dev_app: Mock, no_debug_env: None) -> None: runner = CliRunner() runner.invoke(cli, ["--app", "module:app", "run"]) dev_app.run.assert_called_once_with( diff --git a/tests/test_config.py b/tests/test_config.py deleted file mode 100644 index c741ef1..0000000 --- a/tests/test_config.py +++ /dev/null @@ -1,167 +0,0 @@ -from __future__ import annotations - -import json -import os -from pathlib import Path - -import pytest -import toml -from _pytest.monkeypatch import MonkeyPatch - -from quart.config import Config, ConfigAttribute - -FOO = "bar" -BOB = "jeff" - - -class ConfigInstance: - value = ConfigAttribute("VALUE") - config: dict = {} - - -def test_config_attribute() -> None: - instance = ConfigInstance() - instance.value = "test" - assert instance.config["VALUE"] == "test" - - -def _check_standard_config(config: Config) -> None: - assert config.pop("FOO") == "bar" - assert config.pop("BOB") == "jeff" - assert len(config) == 0 - - -def test_config_from_object() -> None: - config = Config(Path(__file__).parent) - config.from_object(__name__) - _check_standard_config(config) - - -def test_from_prefixed_env(monkeypatch: MonkeyPatch) -> None: - monkeypatch.setenv("QUART_STRING", "value") - monkeypatch.setenv("QUART_BOOL", "true") - monkeypatch.setenv("QUART_INT", "1") - monkeypatch.setenv("QUART_FLOAT", "1.2") - monkeypatch.setenv("QUART_LIST", "[1, 2]") - monkeypatch.setenv("QUART_DICT", '{"k": "v"}') - monkeypatch.setenv("NOT_QUART_OTHER", "other") - - config = Config(Path(__file__).parent) - config.from_prefixed_env() - - assert config["STRING"] == "value" - assert config["BOOL"] is True - assert config["INT"] == 1 - assert config["FLOAT"] == 1.2 - assert config["LIST"] == [1, 2] - assert config["DICT"] == {"k": "v"} - assert "OTHER" not in config - - -def test_from_prefixed_env_custom_prefix(monkeypatch: MonkeyPatch) -> None: - monkeypatch.setenv("QUART_A", "a") - monkeypatch.setenv("NOT_QUART_A", "b") - - config = Config(Path(__file__).parent) - config.from_prefixed_env("NOT_QUART") - - assert config["A"] == "b" - - -def test_from_prefixed_env_nested(monkeypatch: MonkeyPatch) -> None: - monkeypatch.setenv("QUART_EXIST__ok", "other") - monkeypatch.setenv("QUART_EXIST__inner__ik", "2") - monkeypatch.setenv("QUART_EXIST__new__more", '{"k": false}') - monkeypatch.setenv("QUART_NEW__K", "v") - - config = Config(Path(__file__).parent) - config["EXIST"] = {"ok": "value", "flag": True, "inner": {"ik": 1}} - config.from_prefixed_env() - - if os.name != "nt": - assert config["EXIST"] == { - "ok": "other", - "flag": True, - "inner": {"ik": 2}, - "new": {"more": {"k": False}}, - } - else: - # Windows env var keys are always uppercase. - assert config["EXIST"] == { - "ok": "value", - "OK": "other", - "flag": True, - "inner": {"ik": 1}, - "INNER": {"IK": 2}, - "NEW": {"MORE": {"k": False}}, - } - - assert config["NEW"] == {"K": "v"} - - -def test_config_from_pyfile_this() -> None: - config = Config(Path(__file__).parent) - config.from_pyfile(__file__) - _check_standard_config(config) - - -def test_config_from_pyfile_py() -> None: - config = Config(Path(__file__).parent) - config.from_pyfile("assets/config.py") - _check_standard_config(config) - - -def test_config_from_pyfile_cfg() -> None: - config = Config(Path(__file__).parent) - config.from_pyfile("assets/config.cfg") - _check_standard_config(config) - - -def test_config_from_pyfile_no_file() -> None: - config = Config(Path(__file__).parent) - with pytest.raises(FileNotFoundError): - config.from_pyfile("assets/no_file.cfg") - - -def test_config_from_pyfile_directory() -> None: - config = Config(Path(__file__).parent) - with pytest.raises(PermissionError if os.name == "nt" else IsADirectoryError): - config.from_pyfile("assets") - - -def test_config_from_envvar() -> None: - config = Config(Path(__file__).parent) - os.environ["CONFIG"] = "assets/config.cfg" - config.from_envvar("CONFIG") - _check_standard_config(config) - - -def test_config_from_envvar_not_set_with_silent() -> None: - config = Config(Path(__file__).parent) - config.from_envvar("UNKNOWN_CONFIG", silent=True) - - -def test_config_from_envvar_not_set_without_silent() -> None: - config = Config(Path(__file__).parent) - with pytest.raises(RuntimeError): - config.from_envvar("UNKNOWN_CONFIG") - - -def test_config_from_json() -> None: - config = Config(Path(__file__).parent) - config.from_file("assets/config.json", json.load) - _check_standard_config(config) - - -def test_config_from_toml() -> None: - config = Config(Path(__file__).parent) - config.from_file("assets/config.toml", toml.load) - _check_standard_config(config) - - -def test_config_get_namespace() -> None: - config = Config(Path(__file__).parent) - config["FOO_A"] = "a" - config["FOO_BAR"] = "bar" - config["BAR"] = "bar" - assert config.get_namespace("FOO_") == {"a": "a", "bar": "bar"} diff --git a/tests/test_ctx.py b/tests/test_ctx.py index 5d5d9a0..0bccabf 100644 --- a/tests/test_ctx.py +++ b/tests/test_ctx.py @@ -11,7 +11,6 @@ from quart.app import Quart from quart.ctx import ( - _AppCtxGlobals, after_this_request, AppContext, copy_current_app_context, @@ -117,42 +116,6 @@ async def test_has_app_context() -> None: assert has_app_context() is False -def test_app_ctx_globals_get() -> None: - g = _AppCtxGlobals() - g.foo = "bar" - assert g.get("foo") == "bar" - assert g.get("bar", "something") == "something" - - -def test_app_ctx_globals_pop() -> None: - g = _AppCtxGlobals() - g.foo = "bar" - assert g.pop("foo") == "bar" - assert g.pop("foo", None) is None - with pytest.raises(KeyError): - g.pop("foo") - - -def test_app_ctx_globals_setdefault() -> None: - g = _AppCtxGlobals() - g.setdefault("foo", []).append("bar") - assert g.foo == ["bar"] - - -def test_app_ctx_globals_contains() -> None: - g = _AppCtxGlobals() - g.foo = "bar" - assert "foo" in g - assert "bar" not in g - - -def test_app_ctx_globals_iter() -> None: - g = _AppCtxGlobals() - g.foo = "bar" - g.bar = "foo" - assert sorted(iter(g)) == ["bar", "foo"] - - async def test_copy_current_app_context() -> None: app = Quart(__name__) diff --git a/tests/test_debug.py b/tests/test_debug.py index 86a2a11..5cba6a1 100644 --- a/tests/test_debug.py +++ b/tests/test_debug.py @@ -7,10 +7,7 @@ async def test_debug() -> None: app = Quart(__name__) async with app.test_request_context("/"): - try: - raise Exception("Unique error") - except Exception: - response = await traceback_response() + response = await traceback_response(Exception("Unique error")) assert response.status_code == 500 assert b"Unique error" in (await response.get_data()) # type: ignore diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 9a0ef13..4a7eda4 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,7 +1,6 @@ from __future__ import annotations from http import HTTPStatus -from typing import Union import pytest from werkzeug.exceptions import abort, HTTPException @@ -10,7 +9,7 @@ @pytest.mark.parametrize("status", [400, HTTPStatus.BAD_REQUEST]) -def test_abort(status: Union[int, HTTPStatus]) -> None: +def test_abort(status: int | HTTPStatus) -> None: with pytest.raises(HTTPException) as exc_info: abort(status) assert exc_info.value.get_response().status_code == 400 diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 7065bc9..d286b01 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -6,7 +6,6 @@ from typing import AsyncGenerator import pytest -from py._path.local import LocalPath from werkzeug.exceptions import NotFound from quart import Blueprint, Quart, request @@ -130,8 +129,8 @@ async def test_url_for_scheme(app: Quart) -> None: async def test_url_for_anchor(app: Quart) -> None: async with app.test_request_context("/"): - assert url_for("index", _anchor="&foo") == "/#%26foo" - assert url_for("resource", id=5, _anchor="&foo") == "/resource/5#%26foo" + assert url_for("index", _anchor="&foo") == "/#&foo" + assert url_for("resource", id=5, _anchor="&foo") == "/resource/5#&foo" async def test_url_for_blueprint_relative(app: Quart) -> None: @@ -158,7 +157,7 @@ async def test_url_for_root_path(app: Quart) -> None: async def test_stream_with_context() -> None: app = Quart(__name__) - @app.route("/") + @app.route("/") # type: ignore async def index() -> AsyncGenerator[bytes, None]: @stream_with_context async def generator() -> AsyncGenerator[bytes, None]: @@ -179,13 +178,13 @@ async def test_send_from_directory_raises() -> None: await send_from_directory(str(ROOT_PATH), "no_file.no") -async def test_send_file_path(tmpdir: LocalPath) -> None: +async def test_send_file_path(tmp_path: Path) -> None: app = Quart(__name__) - file_ = tmpdir.join("send.img") - file_.write("something") + file_ = tmp_path / "send.img" + file_.write_text("something") async with app.app_context(): - response = await send_file(Path(file_.realpath())) - assert (await response.get_data(as_text=False)) == file_.read_binary() + response = await send_file(Path(file_)) + assert (await response.get_data(as_text=False)) == file_.read_bytes() async def test_send_file_bytes_io() -> None: @@ -203,61 +202,59 @@ async def test_send_file_no_mimetype() -> None: await send_file(BytesIO(b"something")) -async def test_send_file_as_attachment(tmpdir: LocalPath) -> None: +async def test_send_file_as_attachment(tmp_path: Path) -> None: app = Quart(__name__) - file_ = tmpdir.join("send.img") - file_.write("something") + file_ = tmp_path / "send.img" + file_.write_text("something") async with app.app_context(): - response = await send_file(Path(file_.realpath()), as_attachment=True) + response = await send_file(Path(file_), as_attachment=True) assert response.headers["content-disposition"] == "attachment; filename=send.img" -async def test_send_file_as_attachment_name(tmpdir: LocalPath) -> None: +async def test_send_file_as_attachment_name(tmp_path: Path) -> None: app = Quart(__name__) - file_ = tmpdir.join("send.img") - file_.write("something") + file_ = tmp_path / "send.img" + file_.write_text("something") async with app.app_context(): - response = await send_file( - Path(file_.realpath()), as_attachment=True, attachment_filename="send.html" - ) + response = await send_file(Path(file_), as_attachment=True, attachment_filename="send.html") assert response.headers["content-disposition"] == "attachment; filename=send.html" -async def test_send_file_mimetype(tmpdir: LocalPath) -> None: +async def test_send_file_mimetype(tmp_path: Path) -> None: app = Quart(__name__) - file_ = tmpdir.join("send.bob") - file_.write("something") + file_ = tmp_path / "send.bob" + file_.write_text("something") async with app.app_context(): - response = await send_file(Path(file_.realpath()), mimetype="application/bob") - assert (await response.get_data(as_text=False)) == file_.read_binary() + response = await send_file(Path(file_), mimetype="application/bob") + assert (await response.get_data(as_text=False)) == file_.read_bytes() assert response.headers["Content-Type"] == "application/bob" -async def test_send_file_last_modified(tmpdir: LocalPath) -> None: +async def test_send_file_last_modified(tmp_path: Path) -> None: app = Quart(__name__) - file_ = tmpdir.join("send.img") - file_.write("something") + file_ = tmp_path / "send.img" + file_.write_text("something") async with app.app_context(): - response = await send_file(str(file_.realpath())) - mtime = datetime.fromtimestamp(file_.mtime(), tz=timezone.utc) + response = await send_file(str(file_)) + mtime = datetime.fromtimestamp(file_.stat().st_mtime, tz=timezone.utc) mtime = mtime.replace(microsecond=0) assert response.last_modified == mtime -async def test_send_file_last_modified_override(tmpdir: LocalPath) -> None: +async def test_send_file_last_modified_override(tmp_path: Path) -> None: app = Quart(__name__) - file_ = tmpdir.join("send.img") - file_.write("something") + file_ = tmp_path / "send.img" + file_.write_text("something") last_modified = datetime(2015, 10, 10, tzinfo=timezone.utc) async with app.app_context(): - response = await send_file(str(file_.realpath()), last_modified=last_modified) + response = await send_file(str(file_), last_modified=last_modified) assert response.last_modified == last_modified -async def test_send_file_max_age(tmpdir: LocalPath) -> None: +async def test_send_file_max_age(tmp_path: Path) -> None: app = Quart(__name__) - file_ = tmpdir.join("send.img") - file_.write("something") + file_ = tmp_path / "send.img" + file_.write_text("something") async with app.app_context(): - response = await send_file(str(file_.realpath())) - assert response.cache_control.max_age == app.send_file_max_age_default.total_seconds() + response = await send_file(str(file_)) + assert response.cache_control.max_age == app.config["SEND_FILE_MAX_AGE_DEFAULT"].total_seconds() diff --git a/tests/test_json.py b/tests/test_json.py deleted file mode 100644 index a879209..0000000 --- a/tests/test_json.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pytest -from hypothesis import given, strategies as strategies - -from quart.app import Quart -from quart.json.tag import TaggedJSONSerializer - - -@pytest.mark.parametrize("as_ascii, expected", [(True, '"\\ud83c\\udf8a"'), (False, '"🎊"')]) -async def test_ascii_dumps(as_ascii: bool, expected: str) -> None: - app = Quart(__name__) - async with app.app_context(): - app.json.ensure_ascii = as_ascii # type: ignore - assert app.json.dumps("🎊") == expected - - -@given( - value=strategies.one_of( - strategies.datetimes(), - strategies.uuids(), - strategies.binary(), - strategies.tuples(strategies.integers()), - ) -) -def test_jsonserializer(value: Any) -> None: - serializer = TaggedJSONSerializer() - assert serializer.loads(serializer.dumps(value)) == value diff --git a/tests/test_routing.py b/tests/test_routing.py index d9c3e41..c01ecf9 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + import pytest from hypercorn.typing import HTTPScope from werkzeug.datastructures import Headers @@ -10,10 +12,10 @@ @pytest.mark.parametrize( - "server_name, expected", - [("localhost", 0), ("quart.com", 1)], + "server_name, warns", + [("localhost", False), ("quart.com", True)], ) -async def test_bind_warning(server_name: str, expected: int, http_scope: HTTPScope) -> None: +async def test_bind_warning(server_name: str, warns: bool, http_scope: HTTPScope) -> None: map_ = QuartMap(host_matching=False) request = Request( "GET", @@ -26,7 +28,11 @@ async def test_bind_warning(server_name: str, expected: int, http_scope: HTTPSco http_scope, send_push_promise=no_op_push, ) - with pytest.warns(None) as record: - map_.bind_to_request(request, subdomain=None, server_name=server_name) - assert len(record) == expected + if warns: + with pytest.warns(UserWarning): + map_.bind_to_request(request, subdomain=None, server_name=server_name) + else: + with warnings.catch_warnings(): + warnings.simplefilter("error") + map_.bind_to_request(request, subdomain=None, server_name=server_name) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 94f94d9..1579731 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -1,66 +1,16 @@ from __future__ import annotations -from contextlib import contextmanager from http.cookies import SimpleCookie -from sys import version_info -from typing import Generator -import pytest from hypercorn.typing import HTTPScope from werkzeug.datastructures import Headers from quart.app import Quart -from quart.sessions import NullSession, SecureCookieSession, SecureCookieSessionInterface +from quart.sessions import SecureCookieSession, SecureCookieSessionInterface from quart.testing import no_op_push from quart.wrappers import Request, Response -@contextmanager -def _test_secure_cookie_session(attribute: str) -> Generator[SecureCookieSession, None, None]: - session = SecureCookieSession({"a": "b"}) - assert hasattr(session, attribute) - assert not getattr(session, attribute) - yield session - assert getattr(session, attribute) - - -def test_secure_cookie_access() -> None: - with _test_secure_cookie_session("accessed") as session: - _ = session["a"] - with _test_secure_cookie_session("accessed") as session: - _ = session.get("a") # noqa: F841 - - -def test_secure_cookie_modification() -> None: - with _test_secure_cookie_session("modified") as session: - session.clear() - with _test_secure_cookie_session("modified") as session: - session.setdefault("a", []) - with _test_secure_cookie_session("modified") as session: - session.update({"a": "b"}) - with _test_secure_cookie_session("modified") as session: - session["a"] = "b" - with _test_secure_cookie_session("modified") as session: - session.pop("a", None) - with _test_secure_cookie_session("modified") as session: - session.popitem() - with _test_secure_cookie_session("modified") as session: - del session["a"] - session = SecureCookieSession({"a": "b"}) - _ = session["a"] # noqa - assert not session.modified - - -def test_null_session_no_modification() -> None: - session = NullSession() - with pytest.raises(RuntimeError): - session.setdefault("a", []) - with pytest.raises(RuntimeError): - session.update({"a": "b"}) - with pytest.raises(RuntimeError): - session["a"] = "b" - - async def test_secure_cookie_session_interface_open_session(http_scope: HTTPScope) -> None: session = SecureCookieSession() session["something"] = "else" @@ -87,12 +37,11 @@ async def test_secure_cookie_session_interface_save_session() -> None: await interface.save_session(app, session, response) cookies: SimpleCookie = SimpleCookie() cookies.load(response.headers["Set-Cookie"]) - cookie = cookies[app.session_cookie_name] + cookie = cookies[app.config["SESSION_COOKIE_NAME"]] assert cookie["path"] == interface.get_cookie_path(app) assert cookie["httponly"] == "" if not interface.get_cookie_httponly(app) else True assert cookie["secure"] == "" if not interface.get_cookie_secure(app) else True - if version_info >= (3, 8): - assert cookie["samesite"] == (interface.get_cookie_samesite(app) or "") + assert cookie["samesite"] == (interface.get_cookie_samesite(app) or "") assert cookie["domain"] == (interface.get_cookie_domain(app) or "") assert cookie["expires"] == (interface.get_expiration_time(app, session) or "") assert response.headers["Vary"] == "Cookie" diff --git a/tests/test_signals.py b/tests/test_signals.py deleted file mode 100644 index 671d594..0000000 --- a/tests/test_signals.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pytest - -from quart.signals import AsyncNamedSignal - - -@pytest.mark.parametrize("weak", [True, False]) -async def test_sync_signal(weak: bool) -> None: - signal = AsyncNamedSignal("name") - fired = False - - def sync_fired(*_: Any) -> None: - nonlocal fired - fired = True - - signal.connect(sync_fired, weak=weak) - - await signal.send() - assert fired - - -@pytest.mark.parametrize("weak", [True, False]) -async def test_async_signal(weak: bool) -> None: - signal = AsyncNamedSignal("name") - fired = False - - async def async_fired(*_: Any) -> None: - nonlocal fired - fired = True - - signal.connect(async_fired, weak=weak) - - await signal.send() - assert fired diff --git a/tests/test_templating.py b/tests/test_templating.py index 46aa5f5..f745020 100644 --- a/tests/test_templating.py +++ b/tests/test_templating.py @@ -143,7 +143,7 @@ def app_test(value: int) -> bool: async def test_simple_stream(app: Quart) -> None: @app.get("/") async def index() -> ResponseReturnValue: - return await stream_template_string("{{ config }}", config=42) + return await stream_template_string("{{ config }}", config=42) # type: ignore test_client = app.test_client() response = await test_client.get("/") diff --git a/tests/test_testing.py b/tests/test_testing.py index 9bafc03..a77c563 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -1,7 +1,7 @@ from __future__ import annotations from io import BytesIO -from typing import Callable, Optional +from typing import Callable import pytest from werkzeug.datastructures import Headers @@ -47,8 +47,8 @@ async def echo() -> str: ) def test_build_headers_path_and_query_string( path: str, - query_string: Optional[dict], - subdomain: Optional[str], + query_string: dict | None, + subdomain: str | None, expected_path: str, expected_query_string: bytes, expected_host: str, @@ -212,6 +212,22 @@ async def echo() -> Response: assert (await response.get_json()) == {"a": "b"} +async def test_files() -> None: + app = Quart(__name__) + + @app.route("/", methods=["POST"]) + async def echo() -> Response: + files = await request.files + data = files["file"].read() + return data + + client = Client(app) + response = await client.post( + "/", files={"file": FileStorage(BytesIO(b"bar"), filename="a.txt")} + ) + assert (await response.get_data(as_text=True)) == "bar" + + async def test_data() -> None: app = Quart(__name__) diff --git a/tests/test_utils.py b/tests/test_utils.py index 311fa7d..a5d852c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,18 +1,8 @@ from __future__ import annotations -from functools import partial - from werkzeug.datastructures import Headers -from quart.utils import decode_headers, encode_headers, is_coroutine_function - - -def test_is_coroutine_function() -> None: - async def async_func() -> None: - pass - - assert is_coroutine_function(async_func) - assert is_coroutine_function(partial(async_func)) +from quart.utils import decode_headers, encode_headers def test_encode_headers() -> None: diff --git a/tests/wrappers/test_request.py b/tests/wrappers/test_request.py index 4a43eb6..8cf00ab 100644 --- a/tests/wrappers/test_request.py +++ b/tests/wrappers/test_request.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -from typing import List, Tuple from urllib.parse import urlencode import pytest @@ -58,7 +57,7 @@ async def test_body_streaming_no_data() -> None: semaphore = asyncio.Semaphore(0) asyncio.ensure_future(_fill_body(body, semaphore, 0)) async for _ in body: # noqa: F841 - assert False # Should not reach this + raise AssertionError("Should not reach this line") assert b"" == await body @@ -111,7 +110,7 @@ async def test_request_get_data_timeout(http_scope: HTTPScope) -> None: "method, expected", [("GET", ["b", "c"]), ("POST", ["b", "c", "d"])], ) -async def test_request_values(method: str, expected: List[str], http_scope: HTTPScope) -> None: +async def test_request_values(method: str, expected: list[str], http_scope: HTTPScope) -> None: request = Request( method, "http", @@ -129,7 +128,7 @@ async def test_request_values(method: str, expected: List[str], http_scope: HTTP async def test_request_send_push_promise(http_scope: HTTPScope) -> None: - push_promise: Tuple[str, Headers] = None + push_promise: tuple[str, Headers] = None async def _push(path: str, headers: Headers) -> None: nonlocal push_promise diff --git a/tests/wrappers/test_response.py b/tests/wrappers/test_response.py index e69da4b..99a77a3 100644 --- a/tests/wrappers/test_response.py +++ b/tests/wrappers/test_response.py @@ -8,7 +8,6 @@ import pytest from hypothesis import given, strategies as strategies -from py._path.local import LocalPath from werkzeug.datastructures import Headers from werkzeug.exceptions import RequestedRangeNotSatisfiable @@ -44,10 +43,10 @@ async def test_iterable_wrapper(iterable: Any) -> None: assert results == [b"abc", b"def"] -async def test_file_wrapper(tmpdir: LocalPath) -> None: - file_ = tmpdir.join("file_wrapper") - file_.write("abcdef") - wrapper = FileBody(Path(file_.realpath()), buffer_size=3) +async def test_file_wrapper(tmp_path: Path) -> None: + file_ = tmp_path / "file_wrapper" + file_.write_text("abcdef") + wrapper = FileBody(Path(file_), buffer_size=3) results = [] async with wrapper as response: async for data in response: @@ -98,7 +97,7 @@ async def test_response_make_conditional(http_scope: HTTPScope) -> None: assert response.accept_ranges == "bytes" assert response.content_range.units == "bytes" assert response.content_range.start == 0 - assert response.content_range.stop == 3 + assert response.content_range.stop == 4 assert response.content_range.length == 6 @@ -130,6 +129,26 @@ async def test_response_make_conditional_out_of_bound(http_scope: HTTPScope) -> assert response.status_code == 206 +async def test_response_make_conditional_not_modified(http_scope: HTTPScope) -> None: + response = Response(b"abcdef") + await response.add_etag() + request = Request( + "GET", + "https", + "/", + b"", + Headers([("If-None-Match", response.get_etag()[0])]), + "", + "1.1", + http_scope, + send_push_promise=no_op_push, + ) + await response.make_conditional(request) + assert response.status_code == 304 + assert b"" == (await response.get_data()) # type: ignore + assert "content-length" not in response.headers + + @pytest.mark.parametrize( "range_", ["second=0-3", "bytes=0-2,3-5", "bytes=8-16"], diff --git a/tox.ini b/tox.ini index 6db44d7..bc71c21 100644 --- a/tox.ini +++ b/tox.ini @@ -5,16 +5,16 @@ envlist = mypy package pep8 - py37 py38 py39 py310 + py311 + py312 minversion = 3.3 isolated_build = true [testenv] deps = - py37: mock hypothesis pytest pytest-asyncio @@ -24,7 +24,7 @@ deps = commands = pytest --cov=quart {posargs} [testenv:docs] -basepython = python3.10 +basepython = python3.12 deps = pydata-sphinx-theme sphinx @@ -32,7 +32,7 @@ commands = sphinx-build -b html -d {envtmpdir}/doctrees docs/ docs/_build/html/ [testenv:format] -basepython = python3.10 +basepython = python3.12 deps = black isort @@ -42,7 +42,7 @@ commands = skip_install = true [testenv:pep8] -basepython = python3.10 +basepython = python3.12 deps = flake8 pep8-naming @@ -52,11 +52,10 @@ commands = flake8 src/quart/ tests/ skip_install = true [testenv:mypy] -basepython = python3.10 +basepython = python3.12 deps = flask mypy - types-toml hypothesis pytest pytest-asyncio @@ -68,7 +67,7 @@ commands = mypy src/quart/ tests/ [testenv:package] -basepython = python3.10 +basepython = python3.12 deps = poetry twine