Skip to content

Commit

Permalink
Support for streaming response (#35)
Browse files Browse the repository at this point in the history
* Support for streaming response.
* Flushing response body right away.
* To avoid a bug where Python was writing data in the same buffer and the response was corrupted before sending it to the user we now copy the full response body to a char* before passing it to Go.
* Freeing request body after the body was processed and adding a new AsgiEvent_cleanup called after all operations have been processed.
  • Loading branch information
mliezun authored Aug 9, 2024
1 parent 3c1af53 commit 1ed00f0
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 32 deletions.
43 changes: 31 additions & 12 deletions caddysnake.c
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ struct AsgiEvent {
PyObject *event_ts;
PyObject *future;
PyObject *request_body;
uint8_t more_body;
};

static PyObject *AsgiEvent_new(PyTypeObject *type, PyObject *args,
Expand All @@ -516,6 +517,7 @@ static PyObject *AsgiEvent_new(PyTypeObject *type, PyObject *args,
self->event_ts = NULL;
self->future = NULL;
self->request_body = NULL;
self->more_body = 0;
}
return (PyObject *)self;
}
Expand All @@ -524,16 +526,26 @@ static void AsgiEvent_dealloc(AsgiEvent *self) {
Py_XDECREF(self->event_ts);
// Future is freed in AsgiEvent_result
// Py_XDECREF(self->future);
// Request body is freed in AsgiEvent_receive_end
// Py_XDECREF(self->request_body);
// Request body is also freed in AsgiEvent_set
Py_XDECREF(self->request_body);
Py_TYPE(self)->tp_free((PyObject *)self);
}

void AsgiEvent_set(AsgiEvent *self, const char *body) {
void AsgiEvent_cleanup(AsgiEvent *event) {
PyGILState_STATE gstate = PyGILState_Ensure();
Py_DECREF(event);
PyGILState_Release(gstate);
}

void AsgiEvent_set(AsgiEvent *self, const char *body, uint8_t more_body) {
PyGILState_STATE gstate = PyGILState_Ensure();
if (body) {
if (self->request_body) {
Py_DECREF(self->request_body);
}
self->request_body = PyBytes_FromString(body);
}
self->more_body = more_body;
PyObject *set_fn = PyObject_GetAttrString((PyObject *)self->event_ts, "set");
PyObject_CallNoArgs(set_fn);
Py_DECREF(set_fn);
Expand All @@ -557,18 +569,27 @@ static PyObject *AsgiEvent_clear(AsgiEvent *self, PyObject *args) {
}

static PyObject *AsgiEvent_receive_start(AsgiEvent *self, PyObject *args) {
asgi_receive_start(self->request_id, self);
Py_RETURN_NONE;
PyObject *result = Py_False;
if (asgi_receive_start(self->request_id, self) == 1) {
result = Py_True;
}
#if PY_MINOR_VERSION < 12
Py_INCREF(result);
#endif
return result;
}

static PyObject *AsgiEvent_receive_end(AsgiEvent *self, PyObject *args) {
PyObject *data = PyDict_New();
PyObject *data_type = PyUnicode_FromString("http.request");
PyDict_SetItemString(data, "type", data_type);
PyDict_SetItemString(data, "body", self->request_body);
PyDict_SetItemString(data, "more_body", Py_False);
PyObject *more_body = Py_False;
if (self->more_body) {
more_body = Py_True;
}
PyDict_SetItemString(data, "more_body", more_body);
Py_DECREF(data_type);
Py_DECREF(self->request_body);
return data;
}

Expand Down Expand Up @@ -647,9 +668,9 @@ static PyObject *AsgiEvent_send(AsgiEvent *self, PyObject *args) {
PyObject_RichCompareBool(more_body, Py_False, Py_EQ) == 1) {
send_more_body = 0;
}
PyObject *body = PyDict_GetItemString(data, "body");
asgi_send_response(self->request_id, PyBytes_AsString(body), send_more_body,
self);
PyObject *pybody = PyDict_GetItemString(data, "body");
char *body = copy_pybytes(pybody);
asgi_send_response(self->request_id, body, send_more_body, self);
}
Py_RETURN_NONE;
}
Expand Down Expand Up @@ -774,8 +795,6 @@ void AsgiApp_handle_request(AsgiApp *app, uint64_t request_id, MapKeyVal *scope,
Py_DECREF(add_done_callback);
Py_DECREF(asgi_event_result);

Py_DECREF(asgi_event);

PyGILState_Release(gstate);
}

Expand Down
67 changes: 53 additions & 14 deletions caddysnake.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,13 @@ func (m *Asgi) Cleanup() (err error) {

// AsgiRequestHandler stores pointers to the request and the response writer
type AsgiRequestHandler struct {
w http.ResponseWriter
r *http.Request
done chan error
event *C.AsgiEvent
w http.ResponseWriter
r *http.Request
completed_body bool
completed_response bool
accumulated_response_size int
done chan error

operations chan AsgiOperations

Expand All @@ -546,6 +550,12 @@ func (h *AsgiRequestHandler) consume() {
o.op()
}
if o.stop {
if h.event != nil {
runtime.LockOSThread()
C.AsgiEvent_cleanup(h.event)
runtime.UnlockOSThread()
}
close(h.operations)
break
}
}
Expand All @@ -559,7 +569,7 @@ func NewAsgiRequestHandler(w http.ResponseWriter, r *http.Request) *AsgiRequestH
r: r,
done: make(chan error, 2),

operations: make(chan AsgiOperations, 4),
operations: make(chan AsgiOperations, 16),
}
go h.consume()
return h
Expand Down Expand Up @@ -671,6 +681,7 @@ func (m *Asgi) HandleRequest(w http.ResponseWriter, r *http.Request) error {
asgi_handlers[request_id] = arh
asgi_lock.Unlock()
defer func() {
arh.completed_response = true
arh.operations <- AsgiOperations{stop: true}
asgi_lock.Lock()
delete(asgi_handlers, request_id)
Expand Down Expand Up @@ -698,24 +709,43 @@ func (m *Asgi) HandleRequest(w http.ResponseWriter, r *http.Request) error {
}

//export asgi_receive_start
func asgi_receive_start(request_id C.uint64_t, event *C.AsgiEvent) {
func asgi_receive_start(request_id C.uint64_t, event *C.AsgiEvent) C.uint8_t {
asgi_lock.Lock()
defer asgi_lock.Unlock()
arh := asgi_handlers[uint64(request_id)]
if arh == nil || arh.completed_response {
return C.uint8_t(0)
}

arh.event = event

arh.operations <- AsgiOperations{op: func() {
body, err := io.ReadAll(arh.r.Body)
if err != nil {
arh.done <- err
return
var body_str *C.char
var more_body C.uint8_t
if !arh.completed_body {
buffer := make([]byte, 4096)
_, err := arh.r.Body.Read(buffer)
if err != nil && err != io.EOF {
arh.done <- err
return
}
arh.completed_body = (err == io.EOF)
body_str = C.CString(string(buffer))
defer C.free(unsafe.Pointer(body_str))
}

if arh.completed_body {
more_body = C.uint8_t(0)
} else {
more_body = C.uint8_t(1)
}
body_str := C.CString(string(body))
defer C.free(unsafe.Pointer(body_str))

runtime.LockOSThread()
C.AsgiEvent_set(event, body_str)
C.AsgiEvent_set(event, body_str, more_body)
runtime.UnlockOSThread()
}}

return C.uint8_t(1)
}

//export asgi_set_headers
Expand All @@ -724,6 +754,8 @@ func asgi_set_headers(request_id C.uint64_t, status_code C.int, headers *C.MapKe
defer asgi_lock.Unlock()
arh := asgi_handlers[uint64(request_id)]

arh.event = event

arh.operations <- AsgiOperations{op: func() {
if headers != nil {
size_of_pointer := unsafe.Sizeof(headers.keys)
Expand All @@ -745,7 +777,7 @@ func asgi_set_headers(request_id C.uint64_t, status_code C.int, headers *C.MapKe
arh.w.WriteHeader(int(status_code))

runtime.LockOSThread()
C.AsgiEvent_set(event, nil)
C.AsgiEvent_set(event, nil, C.uint8_t(0))
runtime.UnlockOSThread()
}}
}
Expand All @@ -756,17 +788,24 @@ func asgi_send_response(request_id C.uint64_t, body *C.char, more_body C.uint8_t
defer asgi_lock.Unlock()
arh := asgi_handlers[uint64(request_id)]

arh.event = event

arh.operations <- AsgiOperations{op: func() {
defer C.free(unsafe.Pointer(body))
body_bytes := []byte(C.GoString(body))
arh.accumulated_response_size += len(body_bytes)
_, err := arh.w.Write(body_bytes)
if f, ok := arh.w.(http.Flusher); ok {
f.Flush()
}
if err != nil {
arh.done <- err
} else if int(more_body) == 0 {
arh.done <- nil
}

runtime.LockOSThread()
C.AsgiEvent_set(event, nil)
C.AsgiEvent_set(event, nil, C.uint8_t(0))
runtime.UnlockOSThread()
}}
}
Expand Down
5 changes: 3 additions & 2 deletions caddysnake.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ uint8_t AsgiApp_lifespan_startup(AsgiApp *);
uint8_t AsgiApp_lifespan_shutdown(AsgiApp *);
void AsgiApp_handle_request(AsgiApp *, uint64_t, MapKeyVal *, MapKeyVal *,
const char *, int, const char *, int);
void AsgiEvent_set(AsgiEvent *, const char *);
void AsgiEvent_set(AsgiEvent *, const char *, uint8_t);
void AsgiEvent_cleanup(AsgiEvent *);
void AsgiApp_cleanup(AsgiApp *);

extern void asgi_receive_start(uint64_t, AsgiEvent *);
extern uint8_t asgi_receive_start(uint64_t, AsgiEvent *);
extern void asgi_send_response(uint64_t, char *, uint8_t, AsgiEvent *);
extern void asgi_set_headers(uint64_t, int, MapKeyVal *, AsgiEvent *);
extern void asgi_cancel_request(uint64_t);
Expand Down
11 changes: 7 additions & 4 deletions caddysnake.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ def set(self):

def build_receive(asgi_event):
async def receive():
asgi_event.receive_start()
await asgi_event.wait()
asgi_event.clear()
return asgi_event.receive_end()
if asgi_event.receive_start():
await asgi_event.wait()
asgi_event.clear()
result = asgi_event.receive_end()
return result
else:
return {"type": "http.disconnect"}

return receive

Expand Down
8 changes: 8 additions & 0 deletions tests/fastapi/Caddyfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ localhost:9080 {
}
}

route /stream-item/* {
python {
module_asgi "main:app"
lifespan on
venv "./venv"
}
}

route / {
respond 404
}
Expand Down
11 changes: 11 additions & 0 deletions tests/fastapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import asynccontextmanager

from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel


Expand Down Expand Up @@ -39,3 +40,13 @@ async def store_item(id: str, item: Item):
async def delete_item(id: str):
del db[id]
return "Deleted"

def chunked_blob(blob: str):
chunk_size = 2**20
for i in range(0, len(blob), chunk_size):
chunk = blob[i:i+chunk_size]
yield chunk

@app.get("/stream-item/{id}")
async def item_stream(id: str) -> StreamingResponse:
return StreamingResponse(chunked_blob(db[id].blob), media_type='text/event-stream')
8 changes: 8 additions & 0 deletions tests/fastapi/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,20 @@ def delete_item(id: str):
response = requests.delete(f"{BASE_URL}/item/{id}")
return response.status_code == 200 and b"Deleted" in response.content

def stream_content(id: str, item: dict):
response = requests.get(f"{BASE_URL}/stream-item/{id}", stream=True)
if not response.ok:
return False
blob = b"".join(response.iter_content(chunk_size=2**20))
return blob.decode() == item["blob"]

def item_lifecycle():
id = str(uuid.uuid4())
item = get_dummy_item()
assert store_item(id, item), "Store item failed"
assert get_item(id, item), "Get item failed"
if item["blob"]:
assert stream_content(id, item), "Failed to stream content"
assert delete_item(id), "Delete item failed"
assert not delete_item(id), "Delete item should fail"

Expand Down

0 comments on commit 1ed00f0

Please sign in to comment.