From aac19d743ee482b235b48ee7c8b6044fcffc8fd2 Mon Sep 17 00:00:00 2001 From: Alan Potter Date: Mon, 11 Mar 2024 14:16:28 -0400 Subject: [PATCH] feat: FastAPI support Add support for the FastAPI framework. Load it automatically when uvicorn is run with a FastAPI app. --- _appmap/event.py | 39 +++- _appmap/importer.py | 49 +++-- _appmap/test/data/fastapi/appmap.yml | 3 + .../test/data/fastapi/fastapiapp/__init__.py | 0 _appmap/test/data/fastapi/fastapiapp/main.py | 67 ++++++ .../test/data/fastapi/init/sitecustomize.py | 1 + _appmap/test/data/fastapi/test_app.py | 14 ++ _appmap/test/data/remote.appmap.json | 16 +- _appmap/test/normalize.py | 11 +- _appmap/test/test_configuration.py | 15 +- _appmap/test/test_fastapi.py | 109 ++++++++++ _appmap/test/test_recording.py | 2 +- _appmap/test/web_framework.py | 42 ++-- _appmap/utils.py | 93 +++++---- _appmap/web_framework.py | 5 +- appmap/__init__.py | 13 +- appmap/fastapi.py | 194 ++++++++++++++++++ appmap/uvicorn.py | 27 +++ pylintrc | 2 +- pyproject.toml | 3 + requirements-dev.txt | 6 +- 21 files changed, 579 insertions(+), 132 deletions(-) create mode 100644 _appmap/test/data/fastapi/appmap.yml create mode 100644 _appmap/test/data/fastapi/fastapiapp/__init__.py create mode 100644 _appmap/test/data/fastapi/fastapiapp/main.py create mode 100644 _appmap/test/data/fastapi/init/sitecustomize.py create mode 100644 _appmap/test/data/fastapi/test_app.py create mode 100644 _appmap/test/test_fastapi.py create mode 100644 appmap/fastapi.py create mode 100644 appmap/uvicorn.py diff --git a/_appmap/event.py b/_appmap/event.py index 76385ed6..0b1964c1 100644 --- a/_appmap/event.py +++ b/_appmap/event.py @@ -10,11 +10,11 @@ from .recorder import Recorder from .utils import ( FnType, + FqFnName, appmap_tls, compact_dict, fqname, get_function_location, - split_function_name, ) logger = Env.current.getLogger(__name__) @@ -173,7 +173,7 @@ def to_dict(self, value): class CallEvent(Event): - __slots__ = ["_fn", "static", "receiver", "parameters", "labels"] + __slots__ = ["_fn", "_fqfn", "static", "receiver", "parameters", "labels"] @staticmethod def make(fn, fntype): @@ -209,10 +209,9 @@ def make_params(filterable): # going to log a message about a mismatch. wrapped_sig = inspect.signature(fn, follow_wrapped=True) if sig != wrapped_sig: - logger.debug( - "signature of wrapper %s.%s doesn't match wrapped", - *split_function_name(fn) - ) + logger.debug("signature of wrapper %r doesn't match wrapped", fn) + logger.debug("sig: %r", sig) + logger.debug("wrapped_sig: %r", wrapped_sig) return [Param(p) for p in sig.parameters.values()] @@ -270,17 +269,17 @@ def set_params(params, instance, args, kwargs): @property @lru_cache(maxsize=None) def function_name(self): - return split_function_name(self._fn) + return self._fqfn.fqfn @property @lru_cache(maxsize=None) def defined_class(self): - return self.function_name[0] + return self._fqfn.fqclass @property @lru_cache(maxsize=None) def method_id(self): - return self.function_name[1] + return self._fqfn.fqfn[1] @property @lru_cache(maxsize=None) @@ -308,6 +307,7 @@ def comment(self): def __init__(self, fn, fntype, parameters, labels): super().__init__("call") self._fn = fn + self._fqfn = FqFnName(fn) self.static = fntype in FnType.STATIC | FnType.CLASS | FnType.MODULE self.receiver = None if fntype in FnType.CLASS | FnType.INSTANCE: @@ -351,7 +351,15 @@ class MessageEvent(Event): # pylint: disable=too-few-public-methods def __init__(self, message_parameters): super().__init__("call") self.message = [] - for name, value in message_parameters.items(): + self.message_parameters = message_parameters + + @property + def message_parameters(self): + return self.message + + @message_parameters.setter + def message_parameters(self, params): + for name, value in params.items(): message_object = describe_value(name, value) self.message.append(message_object) @@ -386,6 +394,7 @@ def __init__(self, request_method, url, message_parameters, headers=None): # pylint: disable=too-few-public-methods +_NORMALIZED_PATH_INFO_ATTR = "normalized_path_info" class HttpServerRequestEvent(MessageEvent): """A call AppMap event representing an HTTP server request.""" @@ -406,7 +415,7 @@ def __init__( "request_method": request_method, "protocol": protocol, "path_info": path_info, - "normalized_path_info": normalized_path_info, + _NORMALIZED_PATH_INFO_ATTR: normalized_path_info, } if headers is not None: @@ -420,6 +429,14 @@ def __init__( self.http_server_request = compact_dict(request) + @property + def normalized_path_info(self): + return self.http_server_request.get(_NORMALIZED_PATH_INFO_ATTR, None) + + @normalized_path_info.setter + def normalized_path_info(self, npi): + self.http_server_request[_NORMALIZED_PATH_INFO_ATTR] = npi + class ReturnEvent(Event): __slots__ = ["parent_id", "elapsed"] diff --git a/_appmap/importer.py b/_appmap/importer.py index 5cbee04e..b853c384 100644 --- a/_appmap/importer.py +++ b/_appmap/importer.py @@ -10,12 +10,12 @@ from _appmap import wrapt from .env import Env -from .utils import FnType +from .utils import FnType, Scope logger = Env.current.getLogger(__name__) -Filterable = namedtuple("Filterable", "fqname obj") +Filterable = namedtuple("Filterable", "scope fqname obj") class FilterableMod(Filterable): @@ -23,10 +23,7 @@ class FilterableMod(Filterable): def __new__(cls, mod): fqname = mod.__name__ - return super(FilterableMod, cls).__new__(cls, fqname, mod) - - def classify_fn(self, _): - return FnType.MODULE + return super(FilterableMod, cls).__new__(cls, Scope.MODULE, fqname, mod) class FilterableCls(Filterable): @@ -34,32 +31,28 @@ class FilterableCls(Filterable): def __new__(cls, clazz): fqname = "%s.%s" % (clazz.__module__, clazz.__qualname__) - return super(FilterableCls, cls).__new__(cls, fqname, clazz) - - def classify_fn(self, static_fn): - return FnType.classify(static_fn) + return super(FilterableCls, cls).__new__(cls, Scope.CLASS, fqname, clazz) class FilterableFn( namedtuple( "FilterableFn", - Filterable._fields - + ( - "scope", - "static_fn", - ), + Filterable._fields + ("static_fn",), ) ): __slots__ = () def __new__(cls, scope, fn, static_fn): fqname = "%s.%s" % (scope.fqname, fn.__name__) - self = super(FilterableFn, cls).__new__(cls, fqname, fn, scope, static_fn) + self = super(FilterableFn, cls).__new__(cls, scope.scope, fqname, fn, static_fn) return self @property def fntype(self): - return self.scope.classify_fn(self.static_fn) + if self.scope == Scope.MODULE: + return FnType.MODULE + + return FnType.classify(self.static_fn) class Filter(ABC): # pylint: disable=too-few-public-methods @@ -161,6 +154,17 @@ def initialize(cls): def use_filter(cls, filter_class): cls.filter_stack.append(filter_class) + @classmethod + def instrument_function(cls, fn_name, filterableFn: FilterableFn, selected_functions=None): + # Only instrument the function if it was specifically called out for the package + # (e.g. because it should be labeled), or it's included by the filters + matched = cls.filter_chain.filter(filterableFn) + selected = selected_functions and fn_name in selected_functions + if selected or matched: + return cls.filter_chain.wrap(filterableFn) + + return filterableFn.obj + @classmethod def do_import(cls, *args, **kwargs): mod = args[0] @@ -177,15 +181,10 @@ def instrument_functions(filterable, selected_functions=None): logger.trace(" functions %s", functions) for fn_name, static_fn, fn in functions: - # Only instrument the function if it was specifically called out for the package - # (e.g. because it should be labeled), or it's included by the filters filterableFn = FilterableFn(filterable, fn, static_fn) - matched = cls.filter_chain.filter(filterableFn) - selected = selected_functions and fn_name in selected_functions - if selected or matched: - new_fn = cls.filter_chain.wrap(filterableFn) - if fn != new_fn: - wrapt.wrap_function_wrapper(filterable.obj, fn_name, new_fn) + new_fn = cls.instrument_function(fn_name, filterableFn, selected_functions) + if new_fn != fn: + wrapt.wrap_function_wrapper(filterable.obj, fn_name, new_fn) # Import Config here, to avoid circular top-level imports. from .configuration import Config # pylint: disable=import-outside-toplevel diff --git a/_appmap/test/data/fastapi/appmap.yml b/_appmap/test/data/fastapi/appmap.yml new file mode 100644 index 00000000..84e2d4da --- /dev/null +++ b/_appmap/test/data/fastapi/appmap.yml @@ -0,0 +1,3 @@ +name: FastAPITest +packages: +- path: fastapiapp diff --git a/_appmap/test/data/fastapi/fastapiapp/__init__.py b/_appmap/test/data/fastapi/fastapiapp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/_appmap/test/data/fastapi/fastapiapp/main.py b/_appmap/test/data/fastapi/fastapiapp/main.py new file mode 100644 index 00000000..267fabd5 --- /dev/null +++ b/_appmap/test/data/fastapi/fastapiapp/main.py @@ -0,0 +1,67 @@ +""" +Rudimentary FastAPI application for testing. + +NB: This should not explicitly reference the `appmap` module in any way. Doing so invalidates +testing of record-by-default. +""" +# pylint: disable=missing-function-docstring + +from typing import List + +from fastapi import FastAPI, Query, Request, Response + +app = FastAPI() + + +@app.get("/") +def hello_world(): + return {"Hello": "World!"} + + +@app.post("/echo") +async def echo(request: Request): + body = await request.body() + return Response(content=body, media_type="application/json") + + +@app.get("/test") +async def get_test(my_params: List[str] = Query(None)): + response = Response(content="testing", media_type="text/html; charset=utf-8") + response.headers["ETag"] = "W/01" + return response + + +@app.post("/test") +async def post_test(request: Request): + await request.json() + response = Response(content='{"test":true}', media_type="application/json") + response.headers["ETag"] = "W/01" + return response + + +@app.get("/user/{username}") +def get_user_profile(username): + # show the user profile for that user + return {"user": username} + + +@app.get("/post/{post_id:int}") +def get_post(post_id): + # show the post with the given id, the id is an integer + return {"post": post_id} + + +@app.get("/post/{username}/{post_id:int}/summary") +def get_user_post(username, post_id): + # Show the summary of a user's post + return {"user": username, "post": post_id} + + +@app.get("/{org:int}/posts/{username}") +def get_org_user_posts(org, username): + return {"org": org, "username": username} + + +@app.route("/exception") +def raise_exception(): + raise Exception("An exception") diff --git a/_appmap/test/data/fastapi/init/sitecustomize.py b/_appmap/test/data/fastapi/init/sitecustomize.py new file mode 100644 index 00000000..d1fe4fec --- /dev/null +++ b/_appmap/test/data/fastapi/init/sitecustomize.py @@ -0,0 +1 @@ +import appmap diff --git a/_appmap/test/data/fastapi/test_app.py b/_appmap/test/data/fastapi/test_app.py new file mode 100644 index 00000000..5355126d --- /dev/null +++ b/_appmap/test/data/fastapi/test_app.py @@ -0,0 +1,14 @@ +import pytest +from fastapi.testclient import TestClient +from fastapiapp import app + + +@pytest.fixture +def client(): + yield TestClient(app) + + +def test_request(client): + response = client.get("/") + + assert response.status_code == 200 diff --git a/_appmap/test/data/remote.appmap.json b/_appmap/test/data/remote.appmap.json index 3caeef0f..15232c50 100644 --- a/_appmap/test/data/remote.appmap.json +++ b/_appmap/test/data/remote.appmap.json @@ -9,19 +9,15 @@ "protocol": "HTTP/1.1", "request_method": "GET" }, - "id": 1, - "thread_id": 1 + "id": 1 }, { "event": "return", "http_server_response": { - "headers": { "Content-Type": "text/html; charset=utf-8" }, - "mime_type": "text/html; charset=utf-8", "status_code": 200 }, "id": 2, - "parent_id": 1, - "thread_id": 1 + "parent_id": 1 }, { "event": "call", @@ -31,19 +27,15 @@ "protocol": "HTTP/1.1", "request_method": "GET" }, - "id": 3, - "thread_id": 1 + "id": 3 }, { "event": "return", "http_server_response": { - "headers": { "Content-Type": "text/html; charset=utf-8" }, - "mime_type": "text/html; charset=utf-8", "status_code": 200 }, "id": 4, - "parent_id": 3, - "thread_id": 1 + "parent_id": 3 } ], "metadata": { diff --git a/_appmap/test/normalize.py b/_appmap/test/normalize.py index 462d4dce..9041e8ae 100644 --- a/_appmap/test/normalize.py +++ b/_appmap/test/normalize.py @@ -51,9 +51,13 @@ def normalize_headers(dct): """Remove some headers which are variable between implementations. This allows sharing tests between web frameworks, for example. """ - for hdr in ["User-Agent", "Content-Length", "ETag", "Cookie", "Host"]: - value = dct.pop(hdr, None) - assert value is None or isinstance(value, str) + for key in list(dct.keys()): + value = dct.pop(key, None) + key = key.lower() + if key in ["user-agent", "content-length", "content-type", "etag", "cookie", "host"]: + assert value is None or isinstance(value, str) + else: + dct[key] = value def normalize_appmap(generated_appmap): @@ -82,6 +86,7 @@ def normalize(dct): if len(dct["headers"]) == 0: del dct["headers"] if "http_server_request" in dct: + dct["http_server_request"].pop("headers", None) normalize(dct["http_server_request"]) if "message" in dct: del dct["message"] diff --git a/_appmap/test/test_configuration.py b/_appmap/test/test_configuration.py index fca0871f..4c8eed6e 100644 --- a/_appmap/test/test_configuration.py +++ b/_appmap/test/test_configuration.py @@ -84,39 +84,40 @@ def test_config_no_message(caplog): assert caplog.text == "" -cf = lambda: ConfigFilter(NullFilter()) +def cf(): + return ConfigFilter(NullFilter()) @pytest.mark.appmap_enabled(config="appmap-class.yml") def test_class_included(): - f = Filterable("package1.package2.Mod1Class", None) + f = Filterable(None, "package1.package2.Mod1Class", None) assert cf().filter(f) is True @pytest.mark.appmap_enabled(config="appmap-func.yml") def test_function_included(): - f = Filterable("package1.package2.Mod1Class.func", None) + f = Filterable(None, "package1.package2.Mod1Class.func", None) assert cf().filter(f) is True @pytest.mark.appmap_enabled(config="appmap-class.yml") def test_function_included_by_class(): - f = Filterable("package1.package2.Mod1Class.func", None) + f = Filterable(None, "package1.package2.Mod1Class.func", None) assert cf().filter(f) is True @pytest.mark.appmap_enabled class TestConfiguration: def test_package_included(self): - f = Filterable("package1.cls", None) + f = Filterable(None, "package1.cls", None) assert cf().filter(f) is True def test_function_included_by_package(self): - f = Filterable("package1.package2.Mod1Class.func", None) + f = Filterable(None, "package1.package2.Mod1Class.func", None) assert cf().filter(f) is True def test_class_prefix_doesnt_match(self): - f = Filterable("package1_prefix.cls", None) + f = Filterable(None, "package1_prefix.cls", None) assert cf().filter(f) is False diff --git a/_appmap/test/test_fastapi.py b/_appmap/test/test_fastapi.py new file mode 100644 index 00000000..501fe771 --- /dev/null +++ b/_appmap/test/test_fastapi.py @@ -0,0 +1,109 @@ +import importlib +import socket +import sys +from importlib.metadata import version +from pathlib import Path +from types import SimpleNamespace as NS + +import pytest +from fastapi.testclient import TestClient +from xprocess import ProcessStarter + +import appmap +from _appmap.env import Env +from _appmap.metadata import Metadata + +from .web_framework import ( + _TestRecordRequests, + _TestRemoteRecording, + _TestRequestCapture, +) + + +# Opt in to these tests. (I couldn't find a way to DRY this up that allowed +# pytest collection to find all these tests.) +class TestRecordRequests(_TestRecordRequests): + pass + + +@pytest.mark.app(remote_enabled=True) +class TestRemoteRecording(_TestRemoteRecording): + def setup_method(self): + self.expected_thread_id = 1 + self.expected_content_type = "application/json" + + +class TestRequestCapture(_TestRequestCapture): + pass + + +@pytest.fixture(name="app") +def fastapi_app(data_dir, monkeypatch, request): + monkeypatch.syspath_prepend(data_dir / "fastapi") + + Env.current.set("APPMAP_CONFIG", data_dir / "fastapi" / "appmap.yml") + + from fastapiapp import main # pyright: ignore[reportMissingImports] + + importlib.reload(main) + + # FastAPI doesn't provide a way to say what environment the app is running + # in. So, instead use a mark to indicate whether remote recording should be + # enabled. (When we're running as part of a server integration, we infer the + # environment from the server configure, e.g. "uvicorn --reload".) + mark = request.node.get_closest_marker("app") + remote_enabled = None + if mark is not None: + remote_enabled = mark.kwargs.get("remote_enabled", None) + + # Add the FastAPI middleware to the app. This now happens automatically when + # a FastAPI app is started from uvicorn, but must be done manually + # otherwise. + return appmap.fastapi.Middleware(main.app, remote_enabled).init_app() + + +@pytest.fixture(name="client") +def fastapi_client(app): + yield TestClient(app, headers={}) + + +@pytest.mark.appmap_enabled(env={"APPMAP_RECORD_REQUESTS": "false"}) +def test_framework_metadata(client, events): # pylint: disable=unused-argument + client.get("/") + assert Metadata()["frameworks"] == [{"name": "FastAPI", "version": version("fastapi")}] + + +@pytest.fixture(name="server") +def fastapi_server(xprocess, server_base): + host, port, debug, server_env = server_base + reload = "--reload" if debug else "" + + class Starter(ProcessStarter): + def startup_check(self): + try: + s = socket.socket() + s.connect((host, port)) + return True + except ConnectionRefusedError: + pass + return False + + pattern = f"Uvicorn running on http://{host}:{port}" + # Can't set popen_kwargs["cwd"] until + # https://github.com/pytest-dev/pytest-xprocess/issues/89 is fixed. + args = [ + "bash", + "-ec", + f"cd {Path(__file__).parent / 'data'/ 'fastapi'};" + + f" {sys.executable} -m uvicorn fastapiapp.main:app" + + f" {reload} --host {host} --port {port}", + ] + env = { + "PYTHONUNBUFFERED": "1", + "APPMAP_OUTPUT_DIR": "/tmp", + **server_env, + } + + xprocess.ensure("myserver", Starter) + yield NS(debug=debug, url=f"http://{host}:{port}") + xprocess.getinfo("myserver").terminate() diff --git a/_appmap/test/test_recording.py b/_appmap/test/test_recording.py index 0e53a558..1edbd153 100644 --- a/_appmap/test/test_recording.py +++ b/_appmap/test/test_recording.py @@ -208,7 +208,7 @@ def test_process_recording(data_dir, shell, tmp_path): appmap_dir = tmp / "tmp" / "appmap" / "process" appmap_files = list(appmap_dir.glob("*.appmap.json")) - assert len(appmap_files) == 1 + assert len(appmap_files) == 1, "this only fails when run from VS Code?" actual = json.loads(appmap_files[0].read_text()) assert len(actual["events"]) > 0 assert len(actual["classMap"]) > 0 diff --git a/_appmap/test/web_framework.py b/_appmap/test/web_framework.py index 00a5b61c..9e2dbbdb 100644 --- a/_appmap/test/web_framework.py +++ b/_appmap/test/web_framework.py @@ -299,6 +299,20 @@ def test_can_record(self, data_dir, client): data = res.data if hasattr(res, "data") else res.content generated_appmap = normalize_appmap(data) + for evt in generated_appmap["events"]: + # Strip out thread id. The values for these vary by framework, and + # may not even be the same within an AppMap (e.g. FastAPI). They + # should always be ints, though + if "thread_id" in evt: + value = evt.pop("thread_id") + assert isinstance(value, int) + + # Check mime_type. These also vary by framework, but will be + # consistent within an AppMap. + if "http_server_response" in evt: + actual_content_type = evt["http_server_response"].pop("mime_type") + assert actual_content_type == self.expected_content_type + expected_path = data_dir / "remote.appmap.json" with open(expected_path, encoding="utf-8") as expected: expected_appmap = json.load(expected) @@ -307,34 +321,6 @@ def test_can_record(self, data_dir, client): res = client.delete("/_appmap/record") assert res.status_code == 404 - -def port_state(address, port): - ret = None - s = socket.socket() - try: - s.connect((address, port)) - ret = "open" - except Exception: # pylint: disable=broad-except - ret = "closed" - s.close() - return ret - - -def wait_until_port_is(address, port, desired_state): - max_wait_seconds = 10 - sleep_time = 0.1 - max_count = 1 / sleep_time * max_wait_seconds - count = 0 - # don't "while True" to not lock-up the testsuite if something goes wrong - while count < max_count: - current_state = port_state(address, port) - if current_state == desired_state: - break - - time.sleep(sleep_time) - count += 1 - - class _TestRecordRequests: """Common tests for per-requests recording (record requests.)""" server_host = "127.0.0.1" diff --git a/_appmap/utils.py b/_appmap/utils.py index 69dcf600..10c4e225 100644 --- a/_appmap/utils.py +++ b/_appmap/utils.py @@ -1,17 +1,21 @@ -from contextvars import ContextVar import inspect import os import re import shlex import subprocess -import threading import types -from collections.abc import MutableMapping -from enum import IntFlag, auto +from contextlib import contextmanager +from contextvars import ContextVar +from enum import Enum, IntFlag, auto from .env import Env +class Scope(Enum): + MODULE = 0 + CLASS = 1 + + def compact_dict(dictionary): """Return a copy of dictionary with None values filtered out.""" return {k: v for k, v in dictionary.items() if v is not None} @@ -41,27 +45,6 @@ def classify(fn): return FnType.INSTANCE -class ThreadLocalDict(threading.local, MutableMapping): - def __init__(self): - super().__init__() - self.values = {} - - def __getitem__(self, k): - return self.values[k] - - def __setitem__(self, k, v): - self.values[k] = v - - def __delitem__(self, k): - del self.values[k] - - def __iter__(self): - return iter(self.values) - - def __len__(self): - return len(self.values) - - _appmap_tls = ContextVar("tls") @@ -73,23 +56,57 @@ def appmap_tls(): return _appmap_tls.get() +@contextmanager +def appmap_tls_context(): + token = _appmap_tls.set({}) + try: + yield + finally: + _appmap_tls.reset(token) + + def fqname(cls): return "%s.%s" % (cls.__module__, cls.__qualname__) -def split_function_name(fn): - """ - Given a method, return a tuple containing its fully-qualified - class name and the method name. - """ - qualname = fn.__qualname__ - if "." in qualname: - class_name, fn_name = qualname.rsplit(".", 1) - class_name = "%s.%s" % (fn.__module__, class_name) - else: - class_name = fn.__module__ - fn_name = qualname - return (class_name, fn_name) +class FqFnName: + def __init__(self, fn): + + # def split_function_name(fn): + # """ + # Given a method, return a tuple containing its fully-qualified + # class name and the method name. + # """ + self._modname = fn.__module__ + qualname = fn.__qualname__ + if "." in qualname: + self._scope = Scope.CLASS + self._class_name, self._fn_name = qualname.rsplit(".", 1) + else: + self._scope = Scope.MODULE + self._class_name = None + self._fn_name = qualname + # return (fn.__module__, class_name, fn_name) + + @property + def scope(self): + return self._scope + + @property + def fqmod(self): + return self._modname + + @property + def fqclass(self): + return self._modname if self._class_name is None else f"{self._modname}.{self._class_name}" + + @property + def fqfn(self): + return (self.fqclass, self._fn_name) + + @property + def fn_name(self): + return self._fn_name def root_relative_path(path): diff --git a/_appmap/web_framework.py b/_appmap/web_framework.py index d5978d02..7bc9f66c 100644 --- a/_appmap/web_framework.py +++ b/_appmap/web_framework.py @@ -142,6 +142,7 @@ def before_request_main(self, rec, req: Any) -> Tuple[float, int]: raise NotImplementedError def after_request_main(self, rec, status, headers, start, call_event_id) -> None: + duration = time.monotonic() - start return_event = HttpServerResponseEvent( parent_id=call_event_id, @@ -239,7 +240,7 @@ def middleware_present(self): @abstractmethod def insert_middleware(self): - """Insert the AppMap middleware.""" + """Insert the AppMap middleware. Optionally return a new instance of the app.""" @abstractmethod def remote_enabled(self): @@ -247,7 +248,7 @@ def remote_enabled(self): def run(self): if not self.middleware_present(): - self.insert_middleware() + return self.insert_middleware() if self.remote_enabled() and not self.debug: self._show_warning() diff --git a/appmap/__init__.py b/appmap/__init__.py index a2a83a47..65e38973 100644 --- a/appmap/__init__.py +++ b/appmap/__init__.py @@ -1,4 +1,5 @@ """AppMap recorder for Python""" + from _appmap import generation # noqa: F401 from _appmap.env import Env # noqa: F401 from _appmap.importer import instrument_module # noqa: F401 @@ -9,13 +10,21 @@ try: from . import django # noqa: F401 except ImportError: - # not using django pass try: from . import flask # noqa: F401 except ImportError: - # not using flask + pass + +try: + from . import fastapi # noqa: F401 +except ImportError: + pass + +try: + from . import uvicorn # noqa: F401 +except ImportError: pass # Note: pytest integration is configured as a pytest plugin, so it doesn't need to be imported here diff --git a/appmap/fastapi.py b/appmap/fastapi.py new file mode 100644 index 00000000..449c6bce --- /dev/null +++ b/appmap/fastapi.py @@ -0,0 +1,194 @@ +import sys +import time +from importlib.metadata import version +from urllib.parse import urlunparse + +import fastapi +from starlette.middleware.base import BaseHTTPMiddleware + +from _appmap import utils, wrapt +from _appmap.env import Env +from _appmap.event import HttpServerRequestEvent +from _appmap.flask import app as flask_remote +from _appmap.importer import Filterable, FilterableFn, Importer +from _appmap.metadata import Metadata +from _appmap.utils import appmap_tls_context, values_dict +from _appmap.web_framework import ( + JSON_ERRORS, + REMOTE_ENABLED_ATTR, + REQUEST_ENABLED_ATTR, + AppmapMiddleware, + MiddlewareInserter, +) + +logger = Env.current.getLogger(__name__) + + +def _add_api_route(wrapped, _, args, kwargs): + if not Env.current.enabled: + wrapped(*args, **kwargs) + return + + fn = args[1] + + fqn = utils.FqFnName(fn) + scope = Filterable(fqn.scope, fqn.fqclass, None) + + filterable_fn = FilterableFn(scope, fn, fn) + logger.debug("_add_api_route, fn: %s", filterable_fn.fqname) + instrumented_fn = Importer.instrument_function(fqn.fn_name, filterable_fn) + + if instrumented_fn != filterable_fn.obj: + instrumented_fn = wrapt.FunctionWrapper(fn, instrumented_fn) + wrapped(args[0], instrumented_fn, **kwargs) + + +if Env.current.enabled: + wrapt.wrap_function_wrapper("fastapi.routing", "APIRouter.add_api_route", _add_api_route) + + +_REQUEST_EVENT_ATTR = "_appmap_server_request_event" +_MAX_JSON_LENGTH = 2048 +class Middleware(AppmapMiddleware, BaseHTTPMiddleware): + + def __init__(self, app, remote_enabled=None): + super().__init__("FastAPI") + BaseHTTPMiddleware.__init__(self, app) + self._json = None + self._remote_enabled = remote_enabled + + def init_app(self): + # pylint: disable=import-outside-toplevel + from fastapi.middleware.wsgi import WSGIMiddleware + from starlette.routing import Mount, Router + + # pylint: enable=import-outside-toplevel + + routes = [Mount("/", Middleware(self.app))] + setattr(self.app, REQUEST_ENABLED_ATTR, True) + + if self._remote_enabled is not None: + enable_by_default = "true" if self._remote_enabled else "false" + else: + enable_by_default = "false" + + remote_enabled = Env.current.enables("remote", enable_by_default) + if remote_enabled: + routes.insert(0, Mount("/_appmap", WSGIMiddleware(flask_remote))) + setattr(self.app, REMOTE_ENABLED_ATTR, remote_enabled) + + return Router(routes=routes) + + def before_request_main(self, rec, req: fastapi.Request): + self.add_framework_metadata() + start = time.monotonic() + scope = req.scope + scope[_REQUEST_EVENT_ATTR] = call_event = HttpServerRequestEvent( + request_method=req.method, + path_info=scope["path"], + message_parameters={}, + headers=req.headers, + protocol=f"{scope['scheme'].upper()}/{scope['http_version']}", + ) + rec.add_event(call_event) + + return start, call_event.id + + async def dispatch(self, request, call_next): + with appmap_tls_context(): + response = await self._dispatch(request, call_next) + return response + + async def _dispatch(self, request, call_next): + if not self.should_record: + response = await call_next(request) + return response + + await self._parse_json(request) + + rec, start, call_event_id = self.before_request_hook(request) + + try: + response = await call_next(request) + except: + self.on_exception(rec, start, call_event_id, sys.exc_info()) + raise + + self._update_request_event(request) + + parsed = request.url.components + baseurl = urlunparse((parsed.scheme, parsed.netloc, parsed.path, "", "", "")) + self.after_request_hook( + request.url.path, + request.method, + baseurl, + response.status_code, + response.headers, + start, + call_event_id, + ) + return response + + async def _parse_json(self, request): + content_length = int(request.headers.get("Content-Length", 0)) + json_content = request.headers.get("Content-Type", "").startswith("application/json") + if not json_content or not 0 < content_length <= _MAX_JSON_LENGTH: + return + + # Calling Request.json loads and caches the entire body. The cache + # will be used when any code subsequently tries to access the body + # in any way (e.g. Request.stream, Request.body, etc) + try: + self._json = await request.json() + if not isinstance(self._json, dict): + # parseable, but not a JSON object + self._json = None + except JSON_ERRORS: + # parsing failed, igore + pass + + def _update_request_event(self, request): + # This updates the http_server_request event that was previously added + # to the recording. This is ok for now, because we haven't done anything + # with the events, e.g. streamed them to disk. + # + # If, at some point in the future, we implement some sort of + # checkpointing, we'll need to change this so it adds the event to the + # recording's `eventUpdates` instead. + scope = request.scope + if "route" not in scope: + return + + request_event = scope[_REQUEST_EVENT_ATTR] + route = scope["route"] + request_event.normalized_path_info = route.path_format + query_params = {k: request.query_params.getlist(k) for k in request.query_params.keys()} + if self._json is not None: + for k, v in self._json.items(): + if k in query_params: + query_params[k].append(v) + else: + query_params[k] = [v] + + params = values_dict(query_params.items()) + # path_params are orthogonal to query_params, so update is ok + params.update(request.path_params) + request_event.message_parameters = params + + def add_framework_metadata(self): + Metadata.add_framework("FastAPI", version("fastapi")) + + +class FastAPIInserter(MiddlewareInserter): + def __init__(self, app, remote_enabled): + super().__init__(remote_enabled) + self.app = app + + def middleware_present(self): + return hasattr(self.app, REQUEST_ENABLED_ATTR) + + def insert_middleware(self): + return Middleware(self.app, self.debug).init_app() + + def remote_enabled(self): + return getattr(self.app, REMOTE_ENABLED_ATTR, None) diff --git a/appmap/uvicorn.py b/appmap/uvicorn.py new file mode 100644 index 00000000..edb373f9 --- /dev/null +++ b/appmap/uvicorn.py @@ -0,0 +1,27 @@ +# uvicorn integration +from uvicorn.config import Config + +from _appmap import wrapt +from _appmap.env import Env + + +def install_extension(wrapped, config, args, kwargs): + wrapped(*args, **kwargs) + try: + # pylint: disable=import-outside-toplevel + from .fastapi import FastAPIInserter + + # pylint: enable=import-outside-toplevel + + app = config.loaded_app + if app: + # uvicorn doc recommends running with `--reload` in development, so use + # that to decide whether to enable remote recording + config.loaded_app = FastAPIInserter(config.loaded_app, config.reload).run() + except ImportError: + # Not FastAPI + pass + + +if Env.current.enabled: + Config.load = wrapt.wrap_function_wrapper("uvicorn.config", "Config.load", install_extension) diff --git a/pylintrc b/pylintrc index 7f95670e..06d340d4 100644 --- a/pylintrc +++ b/pylintrc @@ -1,6 +1,6 @@ [MAIN] # Specify a score threshold under which the program will exit with error. -fail-under=9.82 +fail-under=9.83 # Analyse import fallback blocks. This can be used to support both Python 2 and diff --git a/pyproject.toml b/pyproject.toml index de98128a..3c748421 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,9 @@ tox = "^3.22.0" # v2.30.0 of "requests" depends on urllib3 v2, which breaks the tests for http_client_requests. Pin # to v1 until this gets fixed. urllib3 = "^1" +uvicorn = "^0.27.1" +fastapi = "^0.110.0" +httpx = "^0.27.0" [build-system] requires = ["poetry-core>=1.1.0"] diff --git a/requirements-dev.txt b/requirements-dev.txt index cebc1a3f..c2d5913e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,7 @@ #requirements-dev.txt tox django -flask >=2, < 3 -pytest-django<4.8 \ No newline at end of file +flask >=2, <= 3 +pytest-django<4.8 +fastapi +httpx \ No newline at end of file