Skip to content

Commit

Permalink
feat: FastAPI support
Browse files Browse the repository at this point in the history
Add support for the FastAPI framework. Load it automatically when
uvicorn is run with a FastAPI app.
  • Loading branch information
apotterri committed Mar 11, 2024
1 parent 597afe0 commit 27dec47
Show file tree
Hide file tree
Showing 21 changed files with 581 additions and 132 deletions.
39 changes: 28 additions & 11 deletions _appmap/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from .recorder import Recorder
from .utils import (
FnType,
FqFnName,
appmap_tls,
compact_dict,
fqname,
get_function_location,
split_function_name,
)

logger = Env.current.getLogger(__name__)
Expand Down Expand Up @@ -173,7 +173,7 @@ def to_dict(self, value):


class CallEvent(Event):
__slots__ = ["_fn", "static", "receiver", "parameters", "labels"]
__slots__ = ["_fn", "_fqfn", "static", "receiver", "parameters", "labels"]

@staticmethod
def make(fn, fntype):
Expand Down Expand Up @@ -209,10 +209,9 @@ def make_params(filterable):
# going to log a message about a mismatch.
wrapped_sig = inspect.signature(fn, follow_wrapped=True)
if sig != wrapped_sig:
logger.debug(
"signature of wrapper %s.%s doesn't match wrapped",
*split_function_name(fn)
)
logger.debug("signature of wrapper %r doesn't match wrapped", fn)
logger.debug("sig: %r", sig)
logger.debug("wrapped_sig: %r", wrapped_sig)

return [Param(p) for p in sig.parameters.values()]

Expand Down Expand Up @@ -270,17 +269,17 @@ def set_params(params, instance, args, kwargs):
@property
@lru_cache(maxsize=None)
def function_name(self):
return split_function_name(self._fn)
return self._fqfn.fqfn

@property
@lru_cache(maxsize=None)
def defined_class(self):
return self.function_name[0]
return self._fqfn.fqclass

@property
@lru_cache(maxsize=None)
def method_id(self):
return self.function_name[1]
return self._fqfn.fqfn[1]

@property
@lru_cache(maxsize=None)
Expand Down Expand Up @@ -308,6 +307,7 @@ def comment(self):
def __init__(self, fn, fntype, parameters, labels):
super().__init__("call")
self._fn = fn
self._fqfn = FqFnName(fn)
self.static = fntype in FnType.STATIC | FnType.CLASS | FnType.MODULE
self.receiver = None
if fntype in FnType.CLASS | FnType.INSTANCE:
Expand Down Expand Up @@ -351,7 +351,15 @@ class MessageEvent(Event): # pylint: disable=too-few-public-methods
def __init__(self, message_parameters):
super().__init__("call")
self.message = []
for name, value in message_parameters.items():
self.message_parameters = message_parameters

@property
def message_parameters(self):
return self.message

@message_parameters.setter
def message_parameters(self, params):
for name, value in params.items():
message_object = describe_value(name, value)
self.message.append(message_object)

Expand Down Expand Up @@ -386,6 +394,7 @@ def __init__(self, request_method, url, message_parameters, headers=None):


# pylint: disable=too-few-public-methods
_NORMALIZED_PATH_INFO_ATTR = "normalized_path_info"
class HttpServerRequestEvent(MessageEvent):
"""A call AppMap event representing an HTTP server request."""

Expand All @@ -406,7 +415,7 @@ def __init__(
"request_method": request_method,
"protocol": protocol,
"path_info": path_info,
"normalized_path_info": normalized_path_info,
_NORMALIZED_PATH_INFO_ATTR: normalized_path_info,
}

if headers is not None:
Expand All @@ -420,6 +429,14 @@ def __init__(

self.http_server_request = compact_dict(request)

@property
def normalized_path_info(self):
return self.http_server_request.get(_NORMALIZED_PATH_INFO_ATTR, None)

@normalized_path_info.setter
def normalized_path_info(self, npi):
self.http_server_request[_NORMALIZED_PATH_INFO_ATTR] = npi


class ReturnEvent(Event):
__slots__ = ["parent_id", "elapsed"]
Expand Down
49 changes: 24 additions & 25 deletions _appmap/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,56 +10,49 @@
from _appmap import wrapt

from .env import Env
from .utils import FnType
from .utils import FnType, Scope

logger = Env.current.getLogger(__name__)


Filterable = namedtuple("Filterable", "fqname obj")
Filterable = namedtuple("Filterable", "scope fqname obj")


class FilterableMod(Filterable):
__slots__ = ()

def __new__(cls, mod):
fqname = mod.__name__
return super(FilterableMod, cls).__new__(cls, fqname, mod)

def classify_fn(self, _):
return FnType.MODULE
return super(FilterableMod, cls).__new__(cls, Scope.MODULE, fqname, mod)


class FilterableCls(Filterable):
__slots__ = ()

def __new__(cls, clazz):
fqname = "%s.%s" % (clazz.__module__, clazz.__qualname__)
return super(FilterableCls, cls).__new__(cls, fqname, clazz)

def classify_fn(self, static_fn):
return FnType.classify(static_fn)
return super(FilterableCls, cls).__new__(cls, Scope.CLASS, fqname, clazz)


class FilterableFn(
namedtuple(
"FilterableFn",
Filterable._fields
+ (
"scope",
"static_fn",
),
Filterable._fields + ("static_fn",),
)
):
__slots__ = ()

def __new__(cls, scope, fn, static_fn):
fqname = "%s.%s" % (scope.fqname, fn.__name__)
self = super(FilterableFn, cls).__new__(cls, fqname, fn, scope, static_fn)
self = super(FilterableFn, cls).__new__(cls, scope.scope, fqname, fn, static_fn)
return self

@property
def fntype(self):
return self.scope.classify_fn(self.static_fn)
if self.scope == Scope.MODULE:
return FnType.MODULE

return FnType.classify(self.static_fn)


class Filter(ABC): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -161,6 +154,17 @@ def initialize(cls):
def use_filter(cls, filter_class):
cls.filter_stack.append(filter_class)

@classmethod
def instrument_function(cls, fn_name, filterableFn: FilterableFn, selected_functions=None):
# Only instrument the function if it was specifically called out for the package
# (e.g. because it should be labeled), or it's included by the filters
matched = cls.filter_chain.filter(filterableFn)
selected = selected_functions and fn_name in selected_functions
if selected or matched:
return cls.filter_chain.wrap(filterableFn)

return filterableFn.obj

@classmethod
def do_import(cls, *args, **kwargs):
mod = args[0]
Expand All @@ -177,15 +181,10 @@ def instrument_functions(filterable, selected_functions=None):
logger.trace(" functions %s", functions)

for fn_name, static_fn, fn in functions:
# Only instrument the function if it was specifically called out for the package
# (e.g. because it should be labeled), or it's included by the filters
filterableFn = FilterableFn(filterable, fn, static_fn)
matched = cls.filter_chain.filter(filterableFn)
selected = selected_functions and fn_name in selected_functions
if selected or matched:
new_fn = cls.filter_chain.wrap(filterableFn)
if fn != new_fn:
wrapt.wrap_function_wrapper(filterable.obj, fn_name, new_fn)
new_fn = cls.instrument_function(fn_name, filterableFn, selected_functions)
if new_fn != fn:
wrapt.wrap_function_wrapper(filterable.obj, fn_name, new_fn)

# Import Config here, to avoid circular top-level imports.
from .configuration import Config # pylint: disable=import-outside-toplevel
Expand Down
3 changes: 3 additions & 0 deletions _appmap/test/data/fastapi/appmap.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
name: FastAPITest
packages:
- path: fastapiapp
Empty file.
67 changes: 67 additions & 0 deletions _appmap/test/data/fastapi/fastapiapp/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Rudimentary FastAPI application for testing.
NB: This should not explicitly reference the `appmap` module in any way. Doing so invalidates
testing of record-by-default.
"""
# pylint: disable=missing-function-docstring

from typing import List

from fastapi import FastAPI, Query, Request, Response

app = FastAPI()


@app.get("/")
def hello_world():
return {"Hello": "World!"}


@app.post("/echo")
async def echo(request: Request):
body = await request.body()
return Response(content=body, media_type="application/json")


@app.get("/test")
async def get_test(my_params: List[str] = Query(None)):
response = Response(content="testing", media_type="text/html; charset=utf-8")
response.headers["ETag"] = "W/01"
return response


@app.post("/test")
async def post_test(request: Request):
await request.json()
response = Response(content='{"test":true}', media_type="application/json")
response.headers["ETag"] = "W/01"
return response


@app.get("/user/{username}")
def get_user_profile(username):
# show the user profile for that user
return {"user": username}


@app.get("/post/{post_id:int}")
def get_post(post_id):
# show the post with the given id, the id is an integer
return {"post": post_id}


@app.get("/post/{username}/{post_id:int}/summary")
def get_user_post(username, post_id):
# Show the summary of a user's post
return {"user": username, "post": post_id}


@app.get("/{org:int}/posts/{username}")
def get_org_user_posts(org, username):
return {"org": org, "username": username}


@app.route("/exception")
def raise_exception():
raise Exception("An exception")
1 change: 1 addition & 0 deletions _appmap/test/data/fastapi/init/sitecustomize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import appmap
14 changes: 14 additions & 0 deletions _appmap/test/data/fastapi/test_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
from fastapi.testclient import TestClient
from fastapiapp import app


@pytest.fixture
def client():
yield TestClient(app)


def test_request(client):
response = client.get("/")

assert response.status_code == 200
16 changes: 4 additions & 12 deletions _appmap/test/data/remote.appmap.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,15 @@
"protocol": "HTTP/1.1",
"request_method": "GET"
},
"id": 1,
"thread_id": 1
"id": 1
},
{
"event": "return",
"http_server_response": {
"headers": { "Content-Type": "text/html; charset=utf-8" },
"mime_type": "text/html; charset=utf-8",
"status_code": 200
},
"id": 2,
"parent_id": 1,
"thread_id": 1
"parent_id": 1
},
{
"event": "call",
Expand All @@ -31,19 +27,15 @@
"protocol": "HTTP/1.1",
"request_method": "GET"
},
"id": 3,
"thread_id": 1
"id": 3
},
{
"event": "return",
"http_server_response": {
"headers": { "Content-Type": "text/html; charset=utf-8" },
"mime_type": "text/html; charset=utf-8",
"status_code": 200
},
"id": 4,
"parent_id": 3,
"thread_id": 1
"parent_id": 3
}
],
"metadata": {
Expand Down
11 changes: 8 additions & 3 deletions _appmap/test/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,13 @@ def normalize_headers(dct):
"""Remove some headers which are variable between implementations.
This allows sharing tests between web frameworks, for example.
"""
for hdr in ["User-Agent", "Content-Length", "ETag", "Cookie", "Host"]:
value = dct.pop(hdr, None)
assert value is None or isinstance(value, str)
for key in list(dct.keys()):
value = dct.pop(key, None)
key = key.lower()
if key in ["user-agent", "content-length", "content-type", "etag", "cookie", "host"]:
assert value is None or isinstance(value, str)
else:
dct[key] = value


def normalize_appmap(generated_appmap):
Expand Down Expand Up @@ -82,6 +86,7 @@ def normalize(dct):
if len(dct["headers"]) == 0:
del dct["headers"]
if "http_server_request" in dct:
dct["http_server_request"].pop("headers", None)
normalize(dct["http_server_request"])
if "message" in dct:
del dct["message"]
Expand Down
Loading

0 comments on commit 27dec47

Please sign in to comment.