diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index 49046b1..33bdfcd 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - tool-name: ['flask', 'fastapi', 'simple', 'simple_async'] + tool-name: ['flask', 'fastapi', 'simple', 'simple_async', 'socketio'] python-version: ['3.9', '3.10', '3.11', '3.12'] env: GOEXPERIMENT: cgocheck2 diff --git a/README.md b/README.md index 23b528f..66962c1 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,10 @@ Supports both WSGI and ASGI, which means you can run all types of frameworks lik ## Quickstart +``` +CGO_ENABLED=1 xcaddy build --with github.com/mliezun/caddy-snake +``` + #### Requirements - Python >= 3.9 + dev files diff --git a/caddysnake.c b/caddysnake.c index 32106f4..01d1968 100644 --- a/caddysnake.c +++ b/caddysnake.c @@ -25,6 +25,7 @@ static PyObject *asyncio_run_coroutine_threadsafe; static PyObject *build_receive; static PyObject *build_send; static PyObject *build_lifespan; +static PyObject *websocket_closed; char *concatenate_strings(const char *str1, const char *str2) { size_t new_str_len = strlen(str1) + strlen(str2) + 1; @@ -49,18 +50,19 @@ char *copy_pystring(PyObject *pystr) { return result; } -char *copy_pybytes(PyObject *pybytes) { +char *copy_pybytes(PyObject *pybytes, size_t *size) { Py_ssize_t og_size = 0; + *size = 0; char *og_str; if (PyBytes_AsStringAndSize(pybytes, &og_str, &og_size) < 0) { return NULL; } - size_t new_str_len = og_size + 1; - char *result = malloc(new_str_len * sizeof(char)); + char *result = malloc(og_size * sizeof(char)); if (result == NULL) { return NULL; } - strcpy(result, og_str); + memcpy(result, og_str, og_size); + *size = (size_t)og_size; return result; } @@ -502,28 +504,37 @@ uint8_t AsgiApp_lifespan_shutdown(AsgiApp *app) { struct AsgiEvent { PyObject_HEAD AsgiApp *app; uint64_t request_id; - PyObject *event_ts; + PyObject *event_ts_send; + PyObject *event_ts_receive; PyObject *future; PyObject *request_body; uint8_t more_body; + uint8_t websockets_state; }; +#define WS_NONE 0 +#define WS_CONNECTED 1 +#define WS_DISCONNECTED 2 + static PyObject *AsgiEvent_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { AsgiEvent *self; self = (AsgiEvent *)type->tp_alloc(type, 0); if (self != NULL) { self->request_id = 0; - self->event_ts = NULL; + self->event_ts_send = NULL; + self->event_ts_receive = NULL; self->future = NULL; self->request_body = NULL; self->more_body = 0; + self->websockets_state = WS_NONE; } return (PyObject *)self; } static void AsgiEvent_dealloc(AsgiEvent *self) { - Py_XDECREF(self->event_ts); + Py_XDECREF(self->event_ts_send); + Py_XDECREF(self->event_ts_receive); // Future is freed in AsgiEvent_result // Py_XDECREF(self->future); // Request body is also freed in AsgiEvent_set @@ -537,7 +548,8 @@ void AsgiEvent_cleanup(AsgiEvent *event) { PyGILState_Release(gstate); } -void AsgiEvent_set(AsgiEvent *self, const char *body, uint8_t more_body) { +void AsgiEvent_set(AsgiEvent *self, const char *body, uint8_t more_body, + uint8_t is_send) { PyGILState_STATE gstate = PyGILState_Ensure(); if (body) { if (self->request_body) { @@ -546,53 +558,160 @@ void AsgiEvent_set(AsgiEvent *self, const char *body, uint8_t more_body) { self->request_body = PyBytes_FromString(body); } self->more_body = more_body; - PyObject *set_fn = PyObject_GetAttrString((PyObject *)self->event_ts, "set"); + PyObject *set_fn = NULL; + if (is_send) { + set_fn = PyObject_GetAttrString((PyObject *)self->event_ts_send, "set"); + } else { + set_fn = PyObject_GetAttrString((PyObject *)self->event_ts_receive, "set"); + } PyObject_CallNoArgs(set_fn); Py_DECREF(set_fn); PyGILState_Release(gstate); } -static PyObject *AsgiEvent_wait(AsgiEvent *self, PyObject *args) { - PyObject *wait_fn = - PyObject_GetAttrString((PyObject *)self->event_ts, "wait"); - PyObject *coro = PyObject_CallNoArgs(wait_fn); - Py_DECREF(wait_fn); - return coro; +void AsgiEvent_set_websocket(AsgiEvent *self, const char *body, + uint8_t message_type, uint8_t is_send) { + PyGILState_STATE gstate = PyGILState_Ensure(); + if (body) { + if (!self->request_body) { + self->request_body = PyList_New(0); + } + PyObject *tuple = PyTuple_New(2); + if (message_type == 0) { + PyTuple_SetItem(tuple, 0, PyUnicode_FromString(body)); + } else { + PyTuple_SetItem(tuple, 0, PyBytes_FromString(body)); + } + PyTuple_SetItem(tuple, 1, PyLong_FromLong(message_type)); + PyList_Append(self->request_body, tuple); + Py_DECREF(tuple); // WARNING: not sure if this should go here + } + PyObject *set_fn = NULL; + if (is_send) { + set_fn = PyObject_GetAttrString((PyObject *)self->event_ts_send, "set"); + } else { + set_fn = PyObject_GetAttrString((PyObject *)self->event_ts_receive, "set"); + } + PyObject_CallNoArgs(set_fn); + Py_DECREF(set_fn); + PyGILState_Release(gstate); } -static PyObject *AsgiEvent_clear(AsgiEvent *self, PyObject *args) { - PyObject *clear_fn = - PyObject_GetAttrString((PyObject *)self->event_ts, "clear"); - PyObject_CallNoArgs(clear_fn); - Py_DECREF(clear_fn); - Py_RETURN_NONE; +void AsgiEvent_connect_websocket(AsgiEvent *self) { + self->websockets_state = WS_CONNECTED; +} + +void AsgiEvent_disconnect_websocket(AsgiEvent *self) { + self->websockets_state = WS_DISCONNECTED; } static PyObject *AsgiEvent_receive_start(AsgiEvent *self, PyObject *args) { PyObject *result = Py_False; if (asgi_receive_start(self->request_id, self) == 1) { - result = Py_True; + // WARNING: should I incref here? + Py_INCREF(self->event_ts_receive); + result = self->event_ts_receive; } #if PY_MINOR_VERSION < 12 - Py_INCREF(result); + if (result == Py_False) + Py_INCREF(result); #endif return result; } static PyObject *AsgiEvent_receive_end(AsgiEvent *self, PyObject *args) { PyObject *data = PyDict_New(); - PyObject *data_type = PyUnicode_FromString("http.request"); - PyDict_SetItemString(data, "type", data_type); - PyDict_SetItemString(data, "body", self->request_body); - PyObject *more_body = Py_False; - if (self->more_body) { - more_body = Py_True; - } - PyDict_SetItemString(data, "more_body", more_body); - Py_DECREF(data_type); + switch (self->websockets_state) { + case WS_NONE: { + PyObject *data_type = PyUnicode_FromString("http.request"); + PyDict_SetItemString(data, "type", data_type); + PyDict_SetItemString(data, "body", self->request_body); + PyObject *more_body = Py_False; + if (self->more_body) { + more_body = Py_True; + } + PyDict_SetItemString(data, "more_body", more_body); + Py_DECREF(data_type); + break; + } + + case WS_CONNECTED: { + if (!self->request_body) { + PyObject *data_type = PyUnicode_FromString("websocket.connect"); + PyDict_SetItemString(data, "type", data_type); + Py_DECREF(data_type); + } else { + PyObject *data_type = PyUnicode_FromString("websocket.receive"); + PyDict_SetItemString(data, "type", data_type); + PyObject *pop_fn = PyObject_GetAttrString(self->request_body, "pop"); + PyObject *ix = PyLong_FromLong(0); + PyObject *message = PyObject_CallOneArg(pop_fn, ix); + PyObject *message_data = PyTuple_GetItem(message, 0); + PyObject *message_type = PyTuple_GetItem(message, 1); + if (message_type == ix) { + PyDict_SetItemString(data, "text", message_data); + } else { + PyDict_SetItemString(data, "bytes", message_data); + } + Py_DECREF(message); // WARNING: not sure if this should be here + Py_DECREF(ix); // WARNING: not sure if this should be here + Py_DECREF(pop_fn); + Py_DECREF(data_type); + } + break; + } + + case WS_DISCONNECTED: { + PyObject *data_type = PyUnicode_FromString("websocket.disconnect"); + PyDict_SetItemString(data, "type", data_type); + Py_DECREF(data_type); + PyObject *default_code = PyLong_FromLong(1005); + PyObject *close_code = default_code; + if (self->request_body && PyList_Size(self->request_body) > 0) { + PyObject *pop_fn = PyObject_GetAttrString(self->request_body, "pop"); + PyObject *ix = PyLong_FromLong(0); + PyObject *message = PyObject_CallOneArg(pop_fn, ix); + PyObject *message_data = PyTuple_GetItem(message, 0); + PyObject *message_type = PyTuple_GetItem(message, 1); + if (message_type == ix) { + close_code = PyLong_FromUnicodeObject(message_data, 10); + if (!close_code) { + if (PyErr_Occurred()) { + PyErr_Clear(); + } + close_code = default_code; + } + } + Py_DECREF(message); // WARNING: not sure if this should be here + Py_DECREF(ix); // WARNING: not sure if this should be here + Py_DECREF(pop_fn); + } + PyDict_SetItemString(data, "code", close_code); + if (close_code != default_code) { + Py_DECREF(close_code); // WARNING: not sure if this should be here + } + Py_DECREF(default_code); // WARNING: not sure if this should be here + break; + } + } return data; } +uint8_t is_weboscket_closed(PyObject *exc) { + if (PyErr_GivenExceptionMatches(exc, websocket_closed)) { + return 1; + } + PyObject *cause = PyObject_GetAttrString(exc, "__cause__"); + if (cause) { + if (PyErr_GivenExceptionMatches(cause, websocket_closed)) { + Py_DECREF(cause); + return 1; + } + Py_DECREF(cause); + } + return 0; +} + /* AsgiEvent_result is called when an execution of AsgiApp finishes. */ @@ -601,14 +720,20 @@ static PyObject *AsgiEvent_result(AsgiEvent *self, PyObject *args) { PyObject_GetAttrString(self->future, "exception"); PyObject *exc = PyObject_CallNoArgs(future_exception); if (exc != Py_None) { + if (!is_weboscket_closed(exc)) { #if PY_MINOR_VERSION >= 12 - // PyErr_DisplayException was introduced in Python 3.12 - PyErr_DisplayException(exc); + // PyErr_DisplayException was introduced in Python 3.12 + PyErr_DisplayException(exc); #else - PyErr_Display(NULL, exc, NULL); + PyErr_Display(NULL, exc, NULL); #endif + if (self->websockets_state == WS_NONE) { + asgi_cancel_request(self->request_id); + } else { + asgi_cancel_request_websocket(self->request_id, NULL, 1000); + } + } Py_DECREF(exc); - asgi_cancel_request(self->request_id); } Py_DECREF(future_exception); @@ -638,6 +763,7 @@ static PyObject *AsgiEvent_send(AsgiEvent *self, PyObject *args) { PyObject *key, *value, *item; size_t pos = 0; + size_t len = 0; while ((item = PyIter_Next(iterator))) { // if (!PyTuple_Check(item) || PyTuple_Size(item) != 2) { // PyErr_SetString(PyExc_RuntimeError, @@ -651,8 +777,8 @@ static PyObject *AsgiEvent_send(AsgiEvent *self, PyObject *args) { // } key = PyTuple_GetItem(item, 0); value = PyTuple_GetItem(item, 1); - http_headers->keys[pos] = copy_pybytes(key); - http_headers->values[pos] = copy_pybytes(value); + http_headers->keys[pos] = copy_pybytes(key, &len); + http_headers->values[pos] = copy_pybytes(value, &len); Py_DECREF(item); pos++; } @@ -669,18 +795,135 @@ static PyObject *AsgiEvent_send(AsgiEvent *self, PyObject *args) { send_more_body = 0; } PyObject *pybody = PyDict_GetItemString(data, "body"); - char *body = copy_pybytes(pybody); - asgi_send_response(self->request_id, body, send_more_body, self); + size_t body_len = 0; + char *body = copy_pybytes(pybody, &body_len); + asgi_send_response(self->request_id, body, body_len, send_more_body, self); + } else if (PyUnicode_CompareWithASCIIString(data_type, "websocket.accept") == + 0) { + if (self->websockets_state == WS_DISCONNECTED) { + goto websocket_error; + } + + PyObject *headers = PyDict_GetItemString(data, "headers"); + PyObject *subprotocol = PyDict_GetItemString(data, "subprotocol"); + + Py_ssize_t headers_count = 0; + PyObject *iterator = NULL; + if (headers) { + iterator = PyObject_GetIter(headers); + if (PyTuple_Check(headers)) { + headers_count = PyTuple_Size(headers); + } else if (PyList_Check(headers)) { + headers_count = PyList_Size(headers); + } + } + if (subprotocol && subprotocol != Py_None) { + headers_count += 1; + } + + MapKeyVal *http_headers = MapKeyVal_new(headers_count); + size_t pos = 0; + size_t len = 0; + + if (iterator) { + PyObject *key, *value, *item; + while ((item = PyIter_Next(iterator))) { + // if (!PyTuple_Check(item) || PyTuple_Size(item) != 2) { + // PyErr_SetString(PyExc_RuntimeError, + // "expected response headers to be tuples with 2 + // items"); + // PyErr_Print(); + // Py_DECREF(item); + // Py_DECREF(iterator); + // MapKeyVal_free(http_headers, pos); + // goto finalize_error; + // } + key = PyTuple_GetItem(item, 0); + value = PyTuple_GetItem(item, 1); + http_headers->keys[pos] = copy_pybytes(key, &len); + http_headers->values[pos] = copy_pybytes(value, &len); + Py_DECREF(item); + pos++; + } + Py_DECREF(iterator); + } + + if (subprotocol && subprotocol != Py_None) { + http_headers->keys[pos] = + concatenate_strings("sec-websocket-protocol", ""); + http_headers->values[pos] = copy_pybytes(subprotocol, &len); + pos++; + } + + asgi_set_headers(self->request_id, 101, http_headers, self); + + if (self->websockets_state == WS_DISCONNECTED) { + goto websocket_error; + } + } else if (PyUnicode_CompareWithASCIIString(data_type, "websocket.send") == + 0) { + + if (self->websockets_state == WS_DISCONNECTED) { + goto websocket_error; + } + + PyObject *data_text = PyDict_GetItemString(data, "text"); + char *body = NULL; + size_t body_len = 0; + uint8_t message_type = 0; + if (data_text) { + body = copy_pystring(data_text); + message_type = 0; + } else { + body = copy_pybytes(PyDict_GetItemString(data, "bytes"), &body_len); + message_type = 1; + } + asgi_send_response_websocket(self->request_id, body, body_len, message_type, + self); + + if (self->websockets_state == WS_DISCONNECTED) { + goto websocket_error; + } + } else if (PyUnicode_CompareWithASCIIString(data_type, "websocket.close") == + 0) { + + if (self->websockets_state == WS_DISCONNECTED) { + goto websocket_error; + } + + PyObject *close_code = PyDict_GetItemString(data, "code"); + PyObject *close_reason = PyDict_GetItemString(data, "reason"); + int code = 1000; + if (close_code) { + code = PyLong_AsLong(close_code); + } + char *reason = NULL; + if (close_reason) { + reason = copy_pystring(close_reason); + } + + asgi_cancel_request_websocket(self->request_id, reason, code); + + if (self->websockets_state == WS_DISCONNECTED) { + goto websocket_error; + } } + goto finalize_send; + + PyObject *exc_instance; +websocket_error: + exc_instance = PyObject_CallObject(websocket_closed, NULL); + PyErr_SetObject(websocket_closed, exc_instance); + Py_DECREF(exc_instance); Py_RETURN_NONE; + +finalize_send: + // WARNING: should I incref here? + Py_INCREF(self->event_ts_send); + return self->event_ts_send; } static PyMethodDef AsgiEvent_methods[] = { - {"wait", (PyCFunction)AsgiEvent_wait, METH_VARARGS, - "Wait until ASGI Event is set, calls the underlying asnycio.Event set() " - "method."}, - {"clear", (PyCFunction)AsgiEvent_clear, METH_VARARGS, - "Clear ASGI Event, calls the underlying asnycio.Event clear() method."}, {"receive_start", (PyCFunction)AsgiEvent_receive_start, METH_VARARGS, "Start reading receive data."}, {"receive_end", (PyCFunction)AsgiEvent_receive_end, METH_VARARGS, @@ -706,7 +949,7 @@ static PyTypeObject AsgiEventType = { void AsgiApp_handle_request(AsgiApp *app, uint64_t request_id, MapKeyVal *scope, MapKeyVal *headers, const char *client_host, int client_port, const char *server_host, - int server_port) { + int server_port, const char *subprotocols) { PyGILState_STATE gstate = PyGILState_Ensure(); PyObject *scope_dict = PyDict_New(); @@ -753,6 +996,21 @@ void AsgiApp_handle_request(AsgiApp *app, uint64_t request_id, MapKeyVal *scope, PyDict_SetItemString(scope_dict, "state", state); Py_DECREF(state); + if (subprotocols) { + PyObject *py_subprotocols = PyUnicode_FromString(subprotocols); + PyObject *split_list = + PyObject_CallMethod(py_subprotocols, "split", "s", ","); + if (!split_list) { + if (PyErr_Occurred()) { + PyErr_Clear(); + } + } else { + PyDict_SetItemString(scope_dict, "subprotocols", split_list); + Py_DECREF(split_list); + } + Py_DECREF(py_subprotocols); + } + AsgiEvent *asgi_event = (AsgiEvent *)PyObject_CallObject((PyObject *)&AsgiEventType, NULL); asgi_event->app = app; @@ -761,11 +1019,14 @@ void AsgiApp_handle_request(AsgiApp *app, uint64_t request_id, MapKeyVal *scope, PyObject *noargs = PyTuple_New(0); PyObject *kwargs = PyDict_New(); PyDict_SetItemString(kwargs, "loop", asyncio_Loop); - asgi_event->event_ts = PyObject_Call(asyncio_Event_ts, noargs, kwargs); + asgi_event->event_ts_send = PyObject_Call(asyncio_Event_ts, noargs, kwargs); + asgi_event->event_ts_receive = + PyObject_Call(asyncio_Event_ts, noargs, kwargs); Py_DECREF(kwargs); Py_DECREF(noargs); #else - asgi_event->event_ts = PyObject_CallNoArgs(asyncio_Event_ts); + asgi_event->event_ts_send = PyObject_CallNoArgs(asyncio_Event_ts); + asgi_event->event_ts_receive = PyObject_CallNoArgs(asyncio_Event_ts); #endif PyObject *receive = @@ -876,6 +1137,7 @@ void Py_init_and_release_gil(const char *setup_py) { build_receive = PyTuple_GetItem(asgi_setup_result, 1); build_send = PyTuple_GetItem(asgi_setup_result, 2); build_lifespan = PyTuple_GetItem(asgi_setup_result, 3); + websocket_closed = PyTuple_GetItem(asgi_setup_result, 4); PyRun_SimpleString("del caddysnake_setup_asgi"); // Setup ASGI version asgi_version = PyDict_New(); diff --git a/caddysnake.go b/caddysnake.go index 2a97218..b2158aa 100644 --- a/caddysnake.go +++ b/caddysnake.go @@ -20,12 +20,14 @@ import ( "strconv" "strings" "sync" + "time" "unsafe" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" "github.com/caddyserver/caddy/v2/modules/caddyhttp" + "github.com/gorilla/websocket" "go.uber.org/zap" ) @@ -122,7 +124,7 @@ func (m *CaddySnake) Validate() error { // Cleanup frees resources uses by module func (m *CaddySnake) Cleanup() error { - if m.app != nil { + if m != nil && m.app != nil { m.logger.Info("cleaning up module") return m.app.Cleanup() } @@ -500,7 +502,7 @@ func NewAsgi(asgi_pattern string, venv_path string, lifespan bool) (*Asgi, error // Cleanup deallocates CGO resources used by Asgi app func (m *Asgi) Cleanup() (err error) { - if m.app != nil { + if m != nil && m.app != nil { asgi_lock.Lock() if _, ok := asgiapp_cache[m.asgi_pattern]; !ok { asgi_lock.Unlock() @@ -522,6 +524,19 @@ func (m *Asgi) Cleanup() (err error) { return } +type WebsocketState uint8 + +const ( + WS_STARTING WebsocketState = iota + 2 + WS_CONNECTED + WS_DISCONNECTED +) + +type WsMessage struct { + mt int + message []byte +} + // AsgiRequestHandler stores pointers to the request and the response writer type AsgiRequestHandler struct { event *C.AsgiEvent @@ -534,7 +549,9 @@ type AsgiRequestHandler struct { operations chan AsgiOperations - is_websocket bool + is_websocket bool + websocket_state WebsocketState + websocket_conn *websocket.Conn } // AsgiOperations stores operations that should be executed in the background @@ -578,6 +595,7 @@ func NewAsgiRequestHandler(w http.ResponseWriter, r *http.Request) *AsgiRequestH var asgi_lock sync.RWMutex = sync.RWMutex{} var asgi_request_counter uint64 = 0 var asgi_handlers map[uint64]*AsgiRequestHandler = map[uint64]*AsgiRequestHandler{} +var upgrader = websocket.Upgrader{} // use default options // HandleRequest passes request down to Python ASGI app and writes responses and headers. func (m *Asgi) HandleRequest(w http.ResponseWriter, r *http.Request) error { @@ -597,7 +615,21 @@ func (m *Asgi) HandleRequest(w http.ResponseWriter, r *http.Request) error { client_host_str := C.CString(client_host) defer C.free(unsafe.Pointer(client_host_str)) - is_websocket := r.Header.Get("connection") == "Upgrade" && r.Header.Get("upgrade") == "websocket" && r.Method == "GET" + contains_connection_upgrade := false + for _, v := range r.Header.Values("connection") { + if strings.Contains(strings.ToLower(v), "upgrade") { + contains_connection_upgrade = true + break + } + } + contains_upgrade_websockets := false + for _, v := range r.Header.Values("upgrade") { + if strings.Contains(strings.ToLower(v), "websocket") { + contains_upgrade_websockets = true + break + } + } + is_websocket := contains_connection_upgrade && contains_upgrade_websockets && r.Method == "GET" decodedPath, err := url.PathUnescape(r.URL.Path) if err != nil { @@ -688,6 +720,12 @@ func (m *Asgi) HandleRequest(w http.ResponseWriter, r *http.Request) error { asgi_lock.Unlock() }() + var subprotocols *C.char = nil + if is_websocket { + subprotocols = C.CString(r.Header.Get("sec-websocket-protocol")) + defer C.free(unsafe.Pointer(subprotocols)) + } + runtime.LockOSThread() C.AsgiApp_handle_request( m.app, @@ -698,6 +736,7 @@ func (m *Asgi) HandleRequest(w http.ResponseWriter, r *http.Request) error { C.int(client_port), server_host_str, C.int(server_port), + subprotocols, ) runtime.UnlockOSThread() @@ -719,11 +758,66 @@ func asgi_receive_start(request_id C.uint64_t, event *C.AsgiEvent) C.uint8_t { arh.event = event + if arh.is_websocket { + switch arh.websocket_state { + case WS_STARTING: + // TODO: this shouldn't happen, what do I do here? + fmt.Println("SHOULD NOT SEE THIS - PLEASE REPORT") + case WS_CONNECTED: + go func() { + mt, message, err := arh.websocket_conn.ReadMessage() + if err != nil { + closeError, isClose := err.(*websocket.CloseError) + closeCode := 1005 + if isClose { + closeCode = closeError.Code + } + body_str := C.CString(fmt.Sprintf("%d", closeCode)) + defer C.free(unsafe.Pointer(body_str)) + arh.websocket_state = WS_DISCONNECTED + arh.websocket_conn.Close() + runtime.LockOSThread() + C.AsgiEvent_disconnect_websocket(event) + C.AsgiEvent_set_websocket(event, body_str, C.uint8_t(0), C.uint8_t(0)) + runtime.UnlockOSThread() + arh.done <- fmt.Errorf("websocket closed: %d", closeCode) + return + } + body_str := C.CString(string(message)) + defer C.free(unsafe.Pointer(body_str)) + + message_type := C.uint8_t(0) + if mt == websocket.BinaryMessage { + message_type = C.uint8_t(1) + } + + runtime.LockOSThread() + C.AsgiEvent_set_websocket(event, body_str, message_type, C.uint8_t(0)) + runtime.UnlockOSThread() + }() + case WS_DISCONNECTED: + go func() { + runtime.LockOSThread() + C.AsgiEvent_disconnect_websocket(event) + C.AsgiEvent_set(event, nil, C.uint8_t(0), C.uint8_t(0)) + runtime.UnlockOSThread() + arh.done <- errors.New("websocket closed - receive start") + }() + default: + arh.websocket_state = WS_STARTING + runtime.LockOSThread() + C.AsgiEvent_connect_websocket(event) + C.AsgiEvent_set(event, nil, C.uint8_t(0), C.uint8_t(0)) + runtime.UnlockOSThread() + } + return C.uint8_t(1) + } + arh.operations <- AsgiOperations{op: func() { var body_str *C.char var more_body C.uint8_t if !arh.completed_body { - buffer := make([]byte, 4096) + buffer := make([]byte, 1<<16) _, err := arh.r.Body.Read(buffer) if err != nil && err != io.EOF { arh.done <- err @@ -741,7 +835,7 @@ func asgi_receive_start(request_id C.uint64_t, event *C.AsgiEvent) C.uint8_t { } runtime.LockOSThread() - C.AsgiEvent_set(event, body_str, more_body) + C.AsgiEvent_set(event, body_str, more_body, C.uint8_t(0)) runtime.UnlockOSThread() }} @@ -756,6 +850,51 @@ func asgi_set_headers(request_id C.uint64_t, status_code C.int, headers *C.MapKe arh.event = event + if arh.is_websocket { + ws_headers := arh.w.Header().Clone() + if headers != nil { + size_of_pointer := unsafe.Sizeof(headers.keys) + defer C.free(unsafe.Pointer(headers)) + defer C.free(unsafe.Pointer(headers.keys)) + defer C.free(unsafe.Pointer(headers.values)) + + for i := 0; i < int(headers.count); i++ { + header_name_ptr := unsafe.Pointer(uintptr(unsafe.Pointer(headers.keys)) + uintptr(i)*size_of_pointer) + header_value_ptr := unsafe.Pointer(uintptr(unsafe.Pointer(headers.values)) + uintptr(i)*size_of_pointer) + header_name := *(**C.char)(header_name_ptr) + defer C.free(unsafe.Pointer(header_name)) + header_value := *(**C.char)(header_value_ptr) + defer C.free(unsafe.Pointer(header_value)) + ws_headers.Add(C.GoString(header_name), C.GoString(header_value)) + } + } + switch arh.websocket_state { + case WS_STARTING: + ws_conn, err := upgrader.Upgrade(arh.w, arh.r, ws_headers) + if err != nil { + arh.websocket_state = WS_DISCONNECTED + arh.websocket_conn.Close() + runtime.LockOSThread() + C.AsgiEvent_disconnect_websocket(event) + C.AsgiEvent_set(event, nil, C.uint8_t(0), C.uint8_t(1)) + runtime.UnlockOSThread() + return + } + arh.websocket_state = WS_CONNECTED + arh.websocket_conn = ws_conn + + runtime.LockOSThread() + C.AsgiEvent_set(event, nil, C.uint8_t(0), C.uint8_t(1)) + runtime.UnlockOSThread() + case WS_DISCONNECTED: + runtime.LockOSThread() + C.AsgiEvent_disconnect_websocket(event) + C.AsgiEvent_set(event, nil, C.uint8_t(0), C.uint8_t(1)) + runtime.UnlockOSThread() + } + return + } + arh.operations <- AsgiOperations{op: func() { if headers != nil { size_of_pointer := unsafe.Sizeof(headers.keys) @@ -777,13 +916,13 @@ func asgi_set_headers(request_id C.uint64_t, status_code C.int, headers *C.MapKe arh.w.WriteHeader(int(status_code)) runtime.LockOSThread() - C.AsgiEvent_set(event, nil, C.uint8_t(0)) + C.AsgiEvent_set(event, nil, C.uint8_t(0), C.uint8_t(1)) runtime.UnlockOSThread() }} } //export asgi_send_response -func asgi_send_response(request_id C.uint64_t, body *C.char, more_body C.uint8_t, event *C.AsgiEvent) { +func asgi_send_response(request_id C.uint64_t, body *C.char, body_len C.size_t, more_body C.uint8_t, event *C.AsgiEvent) { asgi_lock.Lock() defer asgi_lock.Unlock() arh := asgi_handlers[uint64(request_id)] @@ -792,7 +931,7 @@ func asgi_send_response(request_id C.uint64_t, body *C.char, more_body C.uint8_t arh.operations <- AsgiOperations{op: func() { defer C.free(unsafe.Pointer(body)) - body_bytes := []byte(C.GoString(body)) + body_bytes := C.GoBytes(unsafe.Pointer(body), C.int(body_len)) arh.accumulated_response_size += len(body_bytes) _, err := arh.w.Write(body_bytes) if f, ok := arh.w.(http.Flusher); ok { @@ -805,7 +944,43 @@ func asgi_send_response(request_id C.uint64_t, body *C.char, more_body C.uint8_t } runtime.LockOSThread() - C.AsgiEvent_set(event, nil, C.uint8_t(0)) + C.AsgiEvent_set(event, nil, C.uint8_t(0), C.uint8_t(1)) + runtime.UnlockOSThread() + }} +} + +//export asgi_send_response_websocket +func asgi_send_response_websocket(request_id C.uint64_t, body *C.char, body_len C.size_t, message_type C.uint8_t, event *C.AsgiEvent) { + asgi_lock.Lock() + defer asgi_lock.Unlock() + arh := asgi_handlers[uint64(request_id)] + + arh.event = event + + arh.operations <- AsgiOperations{op: func() { + defer C.free(unsafe.Pointer(body)) + var body_bytes []byte + var ws_message_type int + if message_type == C.uint8_t(0) { + ws_message_type = websocket.TextMessage + body_bytes = []byte(C.GoString(body)) + } else { + ws_message_type = websocket.BinaryMessage + body_bytes = C.GoBytes(unsafe.Pointer(body), C.int(body_len)) + } + err := arh.websocket_conn.WriteMessage(ws_message_type, body_bytes) + if err != nil { + arh.websocket_state = WS_DISCONNECTED + arh.websocket_conn.Close() + runtime.LockOSThread() + C.AsgiEvent_disconnect_websocket(event) + C.AsgiEvent_set(event, nil, C.uint8_t(0), C.uint8_t(1)) + runtime.UnlockOSThread() + return + } + + runtime.LockOSThread() + C.AsgiEvent_set(event, nil, C.uint8_t(0), C.uint8_t(1)) runtime.UnlockOSThread() }} } @@ -819,3 +994,32 @@ func asgi_cancel_request(request_id C.uint64_t) { arh.done <- errors.New("request cancelled") } } + +//export asgi_cancel_request_websocket +func asgi_cancel_request_websocket(request_id C.uint64_t, reason *C.char, code C.int) { + asgi_lock.Lock() + defer asgi_lock.Unlock() + arh, ok := asgi_handlers[uint64(request_id)] + if ok { + var reasonText string + if reason != nil { + defer C.free(unsafe.Pointer(reason)) + reasonText = C.GoString(reason) + } + closeCode := int(code) + if arh.websocket_state == WS_STARTING { + arh.w.WriteHeader(403) + arh.done <- fmt.Errorf("websocket closed: %d '%s'", closeCode, reasonText) + } else if arh.websocket_state == WS_CONNECTED { + arh.websocket_state = WS_DISCONNECTED + closeMessage := websocket.FormatCloseMessage(closeCode, reasonText) + go func() { + if arh.websocket_conn != nil { + arh.websocket_conn.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(5*time.Second)) + arh.websocket_conn.Close() + arh.done <- fmt.Errorf("websocket closed: %d '%s'", closeCode, reasonText) + } + }() + } + } +} diff --git a/caddysnake.h b/caddysnake.h index 86d3816..e7efda0 100644 --- a/caddysnake.h +++ b/caddysnake.h @@ -30,14 +30,20 @@ AsgiApp *AsgiApp_import(const char *, const char *, const char *); uint8_t AsgiApp_lifespan_startup(AsgiApp *); uint8_t AsgiApp_lifespan_shutdown(AsgiApp *); void AsgiApp_handle_request(AsgiApp *, uint64_t, MapKeyVal *, MapKeyVal *, - const char *, int, const char *, int); -void AsgiEvent_set(AsgiEvent *, const char *, uint8_t); + const char *, int, const char *, int, const char *); +void AsgiEvent_set(AsgiEvent *, const char *, uint8_t, uint8_t); +void AsgiEvent_set_websocket(AsgiEvent *, const char *, uint8_t, uint8_t); +void AsgiEvent_connect_websocket(AsgiEvent *); +void AsgiEvent_disconnect_websocket(AsgiEvent *); void AsgiEvent_cleanup(AsgiEvent *); void AsgiApp_cleanup(AsgiApp *); extern uint8_t asgi_receive_start(uint64_t, AsgiEvent *); -extern void asgi_send_response(uint64_t, char *, uint8_t, AsgiEvent *); +extern void asgi_send_response(uint64_t, char *, size_t, uint8_t, AsgiEvent *); +extern void asgi_send_response_websocket(uint64_t, char *, size_t, uint8_t, + AsgiEvent *); extern void asgi_set_headers(uint64_t, int, MapKeyVal *, AsgiEvent *); extern void asgi_cancel_request(uint64_t); +extern void asgi_cancel_request_websocket(uint64_t, char *, int); #endif // CADDYSNAKE_H_ diff --git a/caddysnake.py b/caddysnake.py index d3c8dd0..a449931 100644 --- a/caddysnake.py +++ b/caddysnake.py @@ -32,9 +32,10 @@ def set(self): def build_receive(asgi_event): async def receive(): - if asgi_event.receive_start(): - await asgi_event.wait() - asgi_event.clear() + ev = asgi_event.receive_start() + if ev: + await ev.wait() + ev.clear() result = asgi_event.receive_end() return result else: @@ -44,9 +45,9 @@ async def receive(): def build_send(asgi_event): async def send(data): - asgi_event.send(data) - await asgi_event.wait() - asgi_event.clear() + ev = asgi_event.send(data) + await ev.wait() + ev.clear() return send @@ -121,4 +122,13 @@ def run_lifespan(): Thread(target=loop.run_forever).start() - return Event_ts, build_receive, build_send, build_lifespan + class WebsocketClosed(IOError): + pass + + return ( + Event_ts, + build_receive, + build_send, + build_lifespan, + WebsocketClosed, + ) diff --git a/go.mod b/go.mod index e018187..6bc59f0 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21.6 require ( github.com/caddyserver/caddy/v2 v2.7.6 + github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c go.uber.org/zap v1.26.0 ) diff --git a/go.sum b/go.sum index 7b8a39d..787cecd 100644 --- a/go.sum +++ b/go.sum @@ -212,6 +212,7 @@ github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51 github.com/gorilla/mux v1.4.0/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c h1:Lh2aW+HnU2Nbe1gqD9SOJLJxW1jBMmQOktN2acDyJk8= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/groob/finalizer v0.0.0-20170707115354-4c2ed49aabda/go.mod h1:MyndkAZd5rUMdNogn35MWXBX1UiBigrU8eTj8DoAC2c= github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= diff --git a/tests/fastapi/main.py b/tests/fastapi/main.py index 7bd0eb7..cb9d12c 100644 --- a/tests/fastapi/main.py +++ b/tests/fastapi/main.py @@ -1,9 +1,10 @@ +import random import sys from typing import Optional from contextlib import asynccontextmanager -from fastapi import FastAPI -from fastapi.responses import StreamingResponse +from fastapi import FastAPI, WebSocket +from fastapi.responses import StreamingResponse, HTMLResponse from pydantic import BaseModel @@ -41,12 +42,14 @@ async def delete_item(id: str): del db[id] return "Deleted" + def chunked_blob(blob: str): chunk_size = 2**20 for i in range(0, len(blob), chunk_size): - chunk = blob[i:i+chunk_size] + chunk = blob[i : i + chunk_size] yield chunk + @app.get("/stream-item/{id}") async def item_stream(id: str) -> StreamingResponse: - return StreamingResponse(chunked_blob(db[id].blob), media_type='text/event-stream') + return StreamingResponse(chunked_blob(db[id].blob), media_type="text/event-stream") diff --git a/tests/fastapi/main_test.py b/tests/fastapi/main_test.py index 85da4de..7859d16 100644 --- a/tests/fastapi/main_test.py +++ b/tests/fastapi/main_test.py @@ -37,6 +37,7 @@ def delete_item(id: str): response = requests.delete(f"{BASE_URL}/item/{id}") return response.status_code == 200 and b"Deleted" in response.content + def stream_content(id: str, item: dict): response = requests.get(f"{BASE_URL}/stream-item/{id}", stream=True) if not response.ok: @@ -44,6 +45,7 @@ def stream_content(id: str, item: dict): blob = b"".join(response.iter_content(chunk_size=2**20)) return blob.decode() == item["blob"] + def item_lifecycle(): id = str(uuid.uuid4()) item = get_dummy_item() diff --git a/tests/socketio/Caddyfile b/tests/socketio/Caddyfile new file mode 100644 index 0000000..fb5fbad --- /dev/null +++ b/tests/socketio/Caddyfile @@ -0,0 +1,19 @@ +{ + http_port 9080 + https_port 9443 + log { + level info + } +} +localhost:9080 { + route /socket.io/* { + python { + module_asgi "main:app" + venv "./venv" + } + } + + route / { + respond 404 + } +} diff --git a/tests/socketio/main.py b/tests/socketio/main.py new file mode 100644 index 0000000..38be5b0 --- /dev/null +++ b/tests/socketio/main.py @@ -0,0 +1,28 @@ +import sys + +import socketio + +# Create a Socket.IO server +sio = socketio.AsyncServer( + async_mode='asgi', + cors_allowed_origins=['http://localhost:9080'], +) + +@sio.event +async def connect(sid, environ): + print(f"User connected: {sid}", file=sys.stderr) + +@sio.event +async def disconnect(sid): + print(f"User disconnected: {sid}", file=sys.stderr) + +@sio.event +async def start(sid, data): + return {"sid": sid, "data": data} + +@sio.event +async def ping(sid, data): + await sio.emit("pong", {"sid": sid, "data": data}) + + +app = socketio.ASGIApp(sio) diff --git a/tests/socketio/main_test.py b/tests/socketio/main_test.py new file mode 100644 index 0000000..5df308d --- /dev/null +++ b/tests/socketio/main_test.py @@ -0,0 +1,118 @@ +import os +import base64 +import socketio +import time +from concurrent.futures import ThreadPoolExecutor +import psutil + +user_count = 0 + +BASE_URL = "http://localhost:9080" + +BIG_BLOB = base64.b64encode(os.urandom(2**18)).decode("utf") + + +def get_dummy_user() -> dict: + global user_count + user_count += 1 + return { + "name": f"User {user_count}", + "description": f"User Description {user_count}", + "blob": BIG_BLOB if user_count % 4 == 0 else None, + } + +def user_lifecycle(): + sio = socketio.Client() + + connected_ok = False + disconnected_ok = False + ping_data = get_dummy_user() + pong_data = [] + + + @sio.event + def connect(): + nonlocal connected_ok + connected_ok = True + + @sio.event + def disconnect(): + nonlocal disconnected_ok + disconnected_ok = True + + @sio.event + def pong(data): + nonlocal pong_data + pong_data.append(data) + + sio.connect(BASE_URL) + start_data = sio.call("start", ping_data) + sio.emit("ping", ping_data) + + time.sleep(1 if not ping_data.get("blob") else 5) + assert connected_ok + assert start_data in pong_data + assert start_data["data"] == ping_data + + sio.disconnect() + time.sleep(1) + assert disconnected_ok + + +def make_users(max_workers: int, count: int): + start = time.time() + failed = False + + def user_done(fut): + exc = fut.exception() + if exc: + nonlocal failed + failed = True + raise SystemExit(1) from exc + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for _ in range(count): + future = executor.submit(user_lifecycle) + future.add_done_callback(user_done) + + if failed: + print("Tests failed") + exit(1) + + print(f"Created and destroyed {count} users") + print(f"Elapsed: {time.time()-start}s") + + +def find_and_terminate_process(process_name): + for proc in psutil.process_iter(["pid", "name"]): + try: + if process_name in proc.info["name"]: + pid = proc.info["pid"] + p = psutil.Process(pid) + p.terminate() + print(f"Process {process_name} with PID {pid} terminated.") + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + pass + + +def check_user_events_on_logs(logs: str): + events_count = { + "User connected": 0, + "User disconnected": 0, + } + with open(logs, "r") as fd: + for line in fd: + event = line.strip() + for event_key in events_count.keys(): + if event_key in event: + events_count[event_key] += 1 + for event, count in events_count.items(): + assert ( + count == 256 + ), f"Expected '{event}' to only be seen once, but seen {count} times" + + +if __name__ == "__main__": + make_users(max_workers=8, count=256) + find_and_terminate_process("caddy") + check_user_events_on_logs("caddy.log") diff --git a/tests/socketio/requirements.txt b/tests/socketio/requirements.txt new file mode 100644 index 0000000..c40e7f8 --- /dev/null +++ b/tests/socketio/requirements.txt @@ -0,0 +1,13 @@ +bidict==0.23.1 +certifi==2024.7.4 +charset-normalizer==3.3.2 +h11==0.14.0 +idna==3.7 +psutil==6.0.0 +python-engineio==4.9.1 +python-socketio==5.11.3 +requests==2.32.3 +simple-websocket==1.0.0 +urllib3==2.2.2 +websocket-client==1.8.0 +wsproto==1.2.0