Skip to content

Commit

Permalink
feat: add FastAPI integration
Browse files Browse the repository at this point in the history
  • Loading branch information
apotterri committed Mar 7, 2024
1 parent 59b355c commit fa786ce
Show file tree
Hide file tree
Showing 19 changed files with 500 additions and 151 deletions.
39 changes: 28 additions & 11 deletions _appmap/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from .recorder import Recorder
from .utils import (
FnType,
FqFnName,
appmap_tls,
compact_dict,
fqname,
get_function_location,
split_function_name,
)

logger = Env.current.getLogger(__name__)
Expand Down Expand Up @@ -173,7 +173,7 @@ def to_dict(self, value):


class CallEvent(Event):
__slots__ = ["_fn", "static", "receiver", "parameters", "labels"]
__slots__ = ["_fn", "_fqfn", "static", "receiver", "parameters", "labels"]

@staticmethod
def make(fn, fntype):
Expand Down Expand Up @@ -209,10 +209,9 @@ def make_params(filterable):
# going to log a message about a mismatch.
wrapped_sig = inspect.signature(fn, follow_wrapped=True)
if sig != wrapped_sig:
logger.debug(
"signature of wrapper %s.%s doesn't match wrapped",
*split_function_name(fn)
)
logger.debug("signature of wrapper %r doesn't match wrapped", fn)
logger.debug("sig: %r", sig)
logger.debug("wrapped_sig: %r", wrapped_sig)

return [Param(p) for p in sig.parameters.values()]

Expand Down Expand Up @@ -270,17 +269,17 @@ def set_params(params, instance, args, kwargs):
@property
@lru_cache(maxsize=None)
def function_name(self):
return split_function_name(self._fn)
return self._fqfn.fqfn

@property
@lru_cache(maxsize=None)
def defined_class(self):
return self.function_name[0]
return self._fqfn.fqclass

@property
@lru_cache(maxsize=None)
def method_id(self):
return self.function_name[1]
return self._fqfn.fqfn[1]

@property
@lru_cache(maxsize=None)
Expand Down Expand Up @@ -308,6 +307,7 @@ def comment(self):
def __init__(self, fn, fntype, parameters, labels):
super().__init__("call")
self._fn = fn
self._fqfn = FqFnName(fn)
self.static = fntype in FnType.STATIC | FnType.CLASS | FnType.MODULE
self.receiver = None
if fntype in FnType.CLASS | FnType.INSTANCE:
Expand Down Expand Up @@ -351,7 +351,15 @@ class MessageEvent(Event): # pylint: disable=too-few-public-methods
def __init__(self, message_parameters):
super().__init__("call")
self.message = []
for name, value in message_parameters.items():
self.message_parameters = message_parameters

@property
def message_parameters(self):
return self.message

@message_parameters.setter
def message_parameters(self, params):
for name, value in params.items():
message_object = describe_value(name, value)
self.message.append(message_object)

Expand Down Expand Up @@ -386,6 +394,7 @@ def __init__(self, request_method, url, message_parameters, headers=None):


# pylint: disable=too-few-public-methods
_NORMALIZED_PATH_INFO_ATTR = "normalized_path_info"
class HttpServerRequestEvent(MessageEvent):
"""A call AppMap event representing an HTTP server request."""

Expand All @@ -406,7 +415,7 @@ def __init__(
"request_method": request_method,
"protocol": protocol,
"path_info": path_info,
"normalized_path_info": normalized_path_info,
_NORMALIZED_PATH_INFO_ATTR: normalized_path_info,
}

if headers is not None:
Expand All @@ -420,6 +429,14 @@ def __init__(

self.http_server_request = compact_dict(request)

@property
def normalized_path_info(self):
return self.http_server_request.get(_NORMALIZED_PATH_INFO_ATTR, None)

@normalized_path_info.setter
def normalized_path_info(self, npi):
self.http_server_request[_NORMALIZED_PATH_INFO_ATTR] = npi


class ReturnEvent(Event):
__slots__ = ["parent_id", "elapsed"]
Expand Down
49 changes: 24 additions & 25 deletions _appmap/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,56 +10,49 @@
from _appmap import wrapt

from .env import Env
from .utils import FnType
from .utils import FnType, Scope

logger = Env.current.getLogger(__name__)


Filterable = namedtuple("Filterable", "fqname obj")
Filterable = namedtuple("Filterable", "scope fqname obj")


class FilterableMod(Filterable):
__slots__ = ()

def __new__(cls, mod):
fqname = mod.__name__
return super(FilterableMod, cls).__new__(cls, fqname, mod)

def classify_fn(self, _):
return FnType.MODULE
return super(FilterableMod, cls).__new__(cls, Scope.MODULE, fqname, mod)


class FilterableCls(Filterable):
__slots__ = ()

def __new__(cls, clazz):
fqname = "%s.%s" % (clazz.__module__, clazz.__qualname__)
return super(FilterableCls, cls).__new__(cls, fqname, clazz)

def classify_fn(self, static_fn):
return FnType.classify(static_fn)
return super(FilterableCls, cls).__new__(cls, Scope.CLASS, fqname, clazz)


class FilterableFn(
namedtuple(
"FilterableFn",
Filterable._fields
+ (
"scope",
"static_fn",
),
Filterable._fields + ("static_fn",),
)
):
__slots__ = ()

def __new__(cls, scope, fn, static_fn):
fqname = "%s.%s" % (scope.fqname, fn.__name__)
self = super(FilterableFn, cls).__new__(cls, fqname, fn, scope, static_fn)
self = super(FilterableFn, cls).__new__(cls, scope.scope, fqname, fn, static_fn)
return self

@property
def fntype(self):
return self.scope.classify_fn(self.static_fn)
if self.scope == Scope.MODULE:
return FnType.MODULE
else:
return FnType.classify(self.static_fn)


class Filter(ABC): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -161,6 +154,17 @@ def initialize(cls):
def use_filter(cls, filter_class):
cls.filter_stack.append(filter_class)

@classmethod
def instrument_function(cls, fn_name, filterableFn: FilterableFn, selected_functions=None):
# Only instrument the function if it was specifically called out for the package
# (e.g. because it should be labeled), or it's included by the filters
matched = cls.filter_chain.filter(filterableFn)
selected = selected_functions and fn_name in selected_functions
if selected or matched:
return cls.filter_chain.wrap(filterableFn)

return filterableFn.obj

@classmethod
def do_import(cls, *args, **kwargs):
mod = args[0]
Expand All @@ -177,15 +181,10 @@ def instrument_functions(filterable, selected_functions=None):
logger.trace(" functions %s", functions)

for fn_name, static_fn, fn in functions:
# Only instrument the function if it was specifically called out for the package
# (e.g. because it should be labeled), or it's included by the filters
filterableFn = FilterableFn(filterable, fn, static_fn)
matched = cls.filter_chain.filter(filterableFn)
selected = selected_functions and fn_name in selected_functions
if selected or matched:
new_fn = cls.filter_chain.wrap(filterableFn)
if fn != new_fn:
wrapt.wrap_function_wrapper(filterable.obj, fn_name, new_fn)
new_fn = cls.instrument_function(fn_name, filterableFn, selected_functions)
if new_fn != fn:
wrapt.wrap_function_wrapper(filterable.obj, fn_name, new_fn)

# Import Config here, to avoid circular top-level imports.
from .configuration import Config # pylint: disable=import-outside-toplevel
Expand Down
3 changes: 3 additions & 0 deletions _appmap/test/data/fastapi/appmap.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
name: FastAPITest
packages:
- path: fastapiapp
Empty file.
67 changes: 67 additions & 0 deletions _appmap/test/data/fastapi/fastapiapp/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Rudimentary FastAPI application for testing.
NB: This should not explicitly reference the `appmap` module in any way. Doing so invalidates
testing of record-by-default.
"""
# pylint: disable=missing-function-docstring

from typing import List

from fastapi import FastAPI, Query, Request, Response

app = FastAPI()


@app.get("/")
def hello_world():
return {"Hello": "World!"}


@app.post("/echo")
async def echo(request: Request):
body = await request.body()
return Response(content=body, media_type="application/json")


@app.get("/test")
async def get_test(my_params: List[str] = Query(None)):
response = Response(content="test", media_type="text/html; charset=utf-8")
response.headers["ETag"] = "W/01"
return response


@app.post("/test")
async def post_test(request: Request):
await request.json()
response = Response(content='{"test":true}', media_type="application/json")
response.headers["ETag"] = "W/01"
return response


@app.get("/user/{username}")
def get_user_profile(username):
# show the user profile for that user
return {"user": username}


@app.get("/post/{post_id:int}")
def get_post(post_id):
# show the post with the given id, the id is an integer
return {"post": post_id}


@app.get("/post/{username}/{post_id:int}/summary")
def get_user_post(username, post_id):
# Show the summary of a user's post
return {"user": username, "post": post_id}


@app.get("/{org:int}/posts/{username}")
def get_org_user_posts(org, username):
return {"org": org, "username": username}


@app.route("/exception")
def raise_exception():
raise Exception("An exception")
1 change: 1 addition & 0 deletions _appmap/test/data/fastapi/init/sitecustomize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import appmap
14 changes: 14 additions & 0 deletions _appmap/test/data/fastapi/test_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
from fastapi.testclient import TestClient
from fastapiapp import app


@pytest.fixture
def client():
yield TestClient(app)


def test_request(client):
response = client.get("/")

assert response.status_code == 200
11 changes: 11 additions & 0 deletions _appmap/test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,14 @@ class DictIncluding(dict):

def __eq__(self, other):
return other.items() >= self.items()


class HeadersIncluding(dict):
"""Like DictIncluding, but key comparison is case-insensitive."""

def __eq__(self, other):
for k, v in self.items():
v1 = other.get(k, other.get(k.lower(), None))
if v1 is None:
return False
return True
44 changes: 44 additions & 0 deletions _appmap/test/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import importlib.metadata

import pytest
from fastapi.testclient import TestClient

import appmap
from _appmap.env import Env
from _appmap.metadata import Metadata

# pylint: disable=unused-import
from .web_framework import TestRequestCapture

# pylint: enable=unused-import

pytestmark = pytest.mark.web

@pytest.fixture(name="app")
def fastapi_app(data_dir, monkeypatch):
monkeypatch.syspath_prepend(data_dir / "fastapi")

Env.current.set("APPMAP_CONFIG", data_dir / "fastapi" / "appmap.yml")

from fastapiapp import main # pyright: ignore[reportMissingImports]

importlib.reload(main)

# Add the FastAPI middleware to the app. This now happens automatically when a FastAPI app is
# started from the command line, but must be done manually otherwise.
main.app.add_middleware(appmap.fastapi.Middleware)

return main.app


@pytest.fixture(name="client")
def fastapi_client(app):
yield TestClient(app)


@pytest.mark.appmap_enabled(env={"APPMAP_RECORD_REQUESTS": "false"})
def test_framework_metadata(client, events): # pylint: disable=unused-argument
client.get("/")
assert Metadata()["frameworks"] == [
{"name": "FastAPI", "version": importlib.metadata.version("fastapi")}
]
Loading

0 comments on commit fa786ce

Please sign in to comment.