From b07eb175d8cc7e64c321f9a5673783d9149968a7 Mon Sep 17 00:00:00 2001 From: Mathieu Carbou Date: Tue, 7 Jan 2025 11:00:50 +0100 Subject: [PATCH] feat(webserver): Middleware with default middleware for cors, authc, curl-like logging (#10750) * feat(webserver): Middleware with default middleware for cors, authc, curl-like logging * ci(pre-commit): Apply automatic fixes --------- Co-authored-by: Rodrigo Garcia Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> --- CMakeLists.txt | 6 +- .../examples/Middleware/Middleware.ino | 186 ++++++++++++++ .../WebServer/examples/Middleware/ci.json | 5 + libraries/WebServer/src/Middlewares.h | 66 +++++ libraries/WebServer/src/Parsing.cpp | 39 ++- libraries/WebServer/src/WebServer.cpp | 242 +++++++++++++----- libraries/WebServer/src/WebServer.h | 79 +++--- .../WebServer/src/detail/RequestHandler.h | 10 +- .../src/detail/RequestHandlersImpl.h | 33 +++ .../middleware/AuthenticationMiddleware.cpp | 82 ++++++ .../src/middleware/CorsMiddleware.cpp | 47 ++++ .../src/middleware/LoggingMiddleware.cpp | 74 ++++++ .../WebServer/src/middleware/Middleware.h | 54 ++++ .../src/middleware/MiddlewareChain.cpp | 73 ++++++ 14 files changed, 895 insertions(+), 101 deletions(-) create mode 100644 libraries/WebServer/examples/Middleware/Middleware.ino create mode 100644 libraries/WebServer/examples/Middleware/ci.json create mode 100644 libraries/WebServer/src/Middlewares.h create mode 100644 libraries/WebServer/src/middleware/AuthenticationMiddleware.cpp create mode 100644 libraries/WebServer/src/middleware/CorsMiddleware.cpp create mode 100644 libraries/WebServer/src/middleware/LoggingMiddleware.cpp create mode 100644 libraries/WebServer/src/middleware/Middleware.h create mode 100644 libraries/WebServer/src/middleware/MiddlewareChain.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index dd15e06dac8..3718b965386 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -242,7 +242,11 @@ set(ARDUINO_LIBRARY_USB_SRCS set(ARDUINO_LIBRARY_WebServer_SRCS libraries/WebServer/src/WebServer.cpp libraries/WebServer/src/Parsing.cpp - libraries/WebServer/src/detail/mimetable.cpp) + libraries/WebServer/src/detail/mimetable.cpp + libraries/WebServer/src/middleware/MiddlewareChain.cpp + libraries/WebServer/src/middleware/AuthenticationMiddleware.cpp + libraries/WebServer/src/middleware/CorsMiddleware.cpp + libraries/WebServer/src/middleware/LoggingMiddleware.cpp) set(ARDUINO_LIBRARY_NetworkClientSecure_SRCS libraries/NetworkClientSecure/src/ssl_client.cpp diff --git a/libraries/WebServer/examples/Middleware/Middleware.ino b/libraries/WebServer/examples/Middleware/Middleware.ino new file mode 100644 index 00000000000..9d957341c2d --- /dev/null +++ b/libraries/WebServer/examples/Middleware/Middleware.ino @@ -0,0 +1,186 @@ +/** + * Basic example of using Middlewares with WebServer + * + * Middleware are common request/response processing functions that can be applied globally to all incoming requests or to specific handlers. + * They allow for a common processing thus saving memory and space to avoid duplicating code or states on multiple handlers. + * + * Once the example is flashed (with the correct WiFi credentials), you can test the following scenarios with the listed curl commands: + * - CORS Middleware: answers to OPTIONS requests with the specified CORS headers and also add CORS headers to the response when the request has the Origin header + * - Logging Middleware: logs the request and response to an output in a curl-like format + * - Authentication Middleware: test the authentication with Digest Auth + * + * You can also add your own Middleware by extending the Middleware class and implementing the run method. + * When implementing a Middleware, you can decide when to call the next Middleware in the chain by calling next(). + * + * Middleware are execute in order of addition, the ones attached to the server will be executed first. + */ +#include +#include +#include + +// Your AP WiFi Credentials +// ( This is the AP your ESP will broadcast ) +const char *ap_ssid = "ESP32_Demo"; +const char *ap_password = ""; + +WebServer server(80); + +LoggingMiddleware logger; +CorsMiddleware cors; +AuthenticationMiddleware auth; + +void setup(void) { + Serial.begin(115200); + WiFi.softAP(ap_ssid, ap_password); + + Serial.print("IP address: "); + Serial.println(WiFi.AP.localIP()); + + // curl-like output example: + // + // > curl -v -X OPTIONS -H "origin: http://192.168.4.1" http://192.168.4.1/ + // + // Connection from 192.168.4.2:51683 + // > OPTIONS / HTTP/1.1 + // > Host: 192.168.4.1 + // > User-Agent: curl/8.10.0 + // > Accept: */* + // > origin: http://192.168.4.1 + // > + // * Processed in 5 ms + // < HTTP/1.HTTP/1.1 200 OK + // < Content-Type: text/html + // < Access-Control-Allow-Origin: http://192.168.4.1 + // < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE + // < Access-Control-Allow-Headers: X-Custom-Header + // < Access-Control-Allow-Credentials: false + // < Access-Control-Max-Age: 600 + // < Content-Length: 0 + // < Connection: close + // < + logger.setOutput(Serial); + + cors.setOrigin("http://192.168.4.1"); + cors.setMethods("POST,GET,OPTIONS,DELETE"); + cors.setHeaders("X-Custom-Header"); + cors.setAllowCredentials(false); + cors.setMaxAge(600); + + auth.setUsername("admin"); + auth.setPassword("admin"); + auth.setRealm("My Super App"); + auth.setAuthMethod(DIGEST_AUTH); + auth.setAuthFailureMessage("Authentication Failed"); + + server.addMiddleware(&logger); + server.addMiddleware(&cors); + + // Not authenticated + // + // Test CORS preflight request with: + // > curl -v -X OPTIONS -H "origin: http://192.168.4.1" http://192.168.4.1/ + // + // Test cross-domain request with: + // > curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/ + // + server.on("/", []() { + server.send(200, "text/plain", "Home"); + }); + + // Authenticated + // + // > curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/protected + // + // Outputs: + // + // * Connection from 192.168.4.2:51750 + // > GET /protected HTTP/1.1 + // > Host: 192.168.4.1 + // > User-Agent: curl/8.10.0 + // > Accept: */* + // > origin: http://192.168.4.1 + // > + // * Processed in 7 ms + // < HTTP/1.HTTP/1.1 401 Unauthorized + // < Content-Type: text/html + // < Access-Control-Allow-Origin: http://192.168.4.1 + // < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE + // < Access-Control-Allow-Headers: X-Custom-Header + // < Access-Control-Allow-Credentials: false + // < Access-Control-Max-Age: 600 + // < WWW-Authenticate: Digest realm="My Super App", qop="auth", nonce="ac388a64184e3e102aae6fff1c9e8d76", opaque="e7d158f2b54d25328142d118ff0f932d" + // < Content-Length: 21 + // < Connection: close + // < + // + // > curl -v -X GET -H "origin: http://192.168.4.1" --digest -u admin:admin http://192.168.4.1/protected + // + // Outputs: + // + // * Connection from 192.168.4.2:53662 + // > GET /protected HTTP/1.1 + // > Authorization: Digest username="admin", realm="My Super App", nonce="db9e6824eb2a13bc7b2bf8f3c43db896", uri="/protected", cnonce="NTliZDZiNTcwODM2MzAyY2JjMDBmZGJmNzFiY2ZmNzk=", nc=00000001, qop=auth, response="6ebd145ba0d3496a4a73f5ae79ff5264", opaque="23d739c22810282ff820538cba98bda4" + // > Host: 192.168.4.1 + // > User-Agent: curl/8.10.0 + // > Accept: */* + // > origin: http://192.168.4.1 + // > + // Request handling... + // * Processed in 7 ms + // < HTTP/1.HTTP/1.1 200 OK + // < Content-Type: text/plain + // < Access-Control-Allow-Origin: http://192.168.4.1 + // < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE + // < Access-Control-Allow-Headers: X-Custom-Header + // < Access-Control-Allow-Credentials: false + // < Access-Control-Max-Age: 600 + // < Content-Length: 9 + // < Connection: close + // < + server + .on( + "/protected", + []() { + Serial.println("Request handling..."); + server.send(200, "text/plain", "Protected"); + } + ) + .addMiddleware(&auth); + + // Not found is also handled by global middleware + // + // curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/inexsting + // + // Outputs: + // + // * Connection from 192.168.4.2:53683 + // > GET /inexsting HTTP/1.1 + // > Host: 192.168.4.1 + // > User-Agent: curl/8.10.0 + // > Accept: */* + // > origin: http://192.168.4.1 + // > + // * Processed in 16 ms + // < HTTP/1.HTTP/1.1 404 Not Found + // < Content-Type: text/plain + // < Access-Control-Allow-Origin: http://192.168.4.1 + // < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE + // < Access-Control-Allow-Headers: X-Custom-Header + // < Access-Control-Allow-Credentials: false + // < Access-Control-Max-Age: 600 + // < Content-Length: 14 + // < Connection: close + // < + server.onNotFound([]() { + server.send(404, "text/plain", "Page not found"); + }); + + server.collectAllHeaders(); + server.begin(); + Serial.println("HTTP server started"); +} + +void loop(void) { + server.handleClient(); + delay(2); //allow the cpu to switch to other tasks +} diff --git a/libraries/WebServer/examples/Middleware/ci.json b/libraries/WebServer/examples/Middleware/ci.json new file mode 100644 index 00000000000..36babb82730 --- /dev/null +++ b/libraries/WebServer/examples/Middleware/ci.json @@ -0,0 +1,5 @@ +{ + "requires": [ + "CONFIG_SOC_WIFI_SUPPORTED=y" + ] +} diff --git a/libraries/WebServer/src/Middlewares.h b/libraries/WebServer/src/Middlewares.h new file mode 100644 index 00000000000..04fab52790b --- /dev/null +++ b/libraries/WebServer/src/Middlewares.h @@ -0,0 +1,66 @@ +#ifndef MIDDLEWARES_H +#define MIDDLEWARES_H + +#include +#include + +#include + +// curl-like logging middleware +class LoggingMiddleware : public Middleware { +public: + void setOutput(Print &output); + + bool run(WebServer &server, Middleware::Callback next) override; + +private: + Print *_out = nullptr; +}; + +class CorsMiddleware : public Middleware { +public: + CorsMiddleware &setOrigin(const char *origin); + CorsMiddleware &setMethods(const char *methods); + CorsMiddleware &setHeaders(const char *headers); + CorsMiddleware &setAllowCredentials(bool credentials); + CorsMiddleware &setMaxAge(uint32_t seconds); + + void addCORSHeaders(WebServer &server); + + bool run(WebServer &server, Middleware::Callback next) override; + +private: + String _origin = F("*"); + String _methods = F("*"); + String _headers = F("*"); + bool _credentials = true; + uint32_t _maxAge = 86400; +}; + +class AuthenticationMiddleware : public Middleware { +public: + AuthenticationMiddleware &setUsername(const char *username); + AuthenticationMiddleware &setPassword(const char *password); + AuthenticationMiddleware &setPasswordHash(const char *sha1AsBase64orHex); + AuthenticationMiddleware &setCallback(WebServer::THandlerFunctionAuthCheck fn); + + AuthenticationMiddleware &setRealm(const char *realm); + AuthenticationMiddleware &setAuthMethod(HTTPAuthMethod method); + AuthenticationMiddleware &setAuthFailureMessage(const char *message); + + bool isAllowed(WebServer &server) const; + + bool run(WebServer &server, Middleware::Callback next) override; + +private: + String _username; + String _password; + bool _hash = false; + WebServer::THandlerFunctionAuthCheck _callback; + + const char *_realm = nullptr; + HTTPAuthMethod _method = BASIC_AUTH; + String _authFailMsg; +}; + +#endif diff --git a/libraries/WebServer/src/Parsing.cpp b/libraries/WebServer/src/Parsing.cpp index 040338bb749..3030317eeea 100644 --- a/libraries/WebServer/src/Parsing.cpp +++ b/libraries/WebServer/src/Parsing.cpp @@ -78,8 +78,14 @@ bool WebServer::_parseRequest(NetworkClient &client) { String req = client.readStringUntil('\r'); client.readStringUntil('\n'); //reset header value - for (int i = 0; i < _headerKeysCount; ++i) { - _currentHeaders[i].value = String(); + if (_collectAllHeaders) { + // clear previous headers + collectAllHeaders(); + } else { + // clear previous headers + for (RequestArgument *header = _currentHeaders; header; header = header->next) { + header->value = String(); + } } // First line of HTTP request looks like "GET /path HTTP/1.1" @@ -154,9 +160,6 @@ bool WebServer::_parseRequest(NetworkClient &client) { headerValue.trim(); _collectHeader(headerName.c_str(), headerValue.c_str()); - log_v("headerName: %s", headerName.c_str()); - log_v("headerValue: %s", headerValue.c_str()); - if (headerName.equalsIgnoreCase(FPSTR(Content_Type))) { using namespace mime; if (headerValue.startsWith(FPSTR(mimeTable[txt].mimeType))) { @@ -254,9 +257,6 @@ bool WebServer::_parseRequest(NetworkClient &client) { headerValue = req.substring(headerDiv + 2); _collectHeader(headerName.c_str(), headerValue.c_str()); - log_v("headerName: %s", headerName.c_str()); - log_v("headerValue: %s", headerValue.c_str()); - if (headerName.equalsIgnoreCase("Host")) { _hostHeader = headerValue; } @@ -272,12 +272,29 @@ bool WebServer::_parseRequest(NetworkClient &client) { } bool WebServer::_collectHeader(const char *headerName, const char *headerValue) { - for (int i = 0; i < _headerKeysCount; i++) { - if (_currentHeaders[i].key.equalsIgnoreCase(headerName)) { - _currentHeaders[i].value = headerValue; + RequestArgument *last = nullptr; + for (RequestArgument *header = _currentHeaders; header; header = header->next) { + if (header->next == nullptr) { + last = header; + } + if (header->key.equalsIgnoreCase(headerName)) { + header->value = headerValue; + log_v("header collected: %s: %s", headerName, headerValue); return true; } } + assert(last); + if (_collectAllHeaders) { + last->next = new RequestArgument(); + last->next->key = headerName; + last->next->value = headerValue; + _headerKeysCount++; + log_v("header collected: %s: %s", headerName, headerValue); + return true; + } + + log_v("header skipped: %s: %s", headerName, headerValue); + return false; } diff --git a/libraries/WebServer/src/WebServer.cpp b/libraries/WebServer/src/WebServer.cpp index 53c575d2c56..652a86f587f 100644 --- a/libraries/WebServer/src/WebServer.cpp +++ b/libraries/WebServer/src/WebServer.cpp @@ -41,31 +41,28 @@ static const char WWW_Authenticate[] = "WWW-Authenticate"; static const char Content_Length[] = "Content-Length"; static const char ETAG_HEADER[] = "If-None-Match"; -WebServer::WebServer(IPAddress addr, int port) - : _corsEnabled(false), _server(addr, port), _currentMethod(HTTP_ANY), _currentVersion(0), _currentStatus(HC_NONE), _statusChange(0), _nullDelay(true), - _currentHandler(nullptr), _firstHandler(nullptr), _lastHandler(nullptr), _currentArgCount(0), _currentArgs(nullptr), _postArgsLen(0), _postArgs(nullptr), - _headerKeysCount(0), _currentHeaders(nullptr), _contentLength(0), _clientContentLength(0), _chunked(false) { +WebServer::WebServer(IPAddress addr, int port) : _server(addr, port) { log_v("WebServer::Webserver(addr=%s, port=%d)", addr.toString().c_str(), port); } -WebServer::WebServer(int port) - : _corsEnabled(false), _server(port), _currentMethod(HTTP_ANY), _currentVersion(0), _currentStatus(HC_NONE), _statusChange(0), _nullDelay(true), - _currentHandler(nullptr), _firstHandler(nullptr), _lastHandler(nullptr), _currentArgCount(0), _currentArgs(nullptr), _postArgsLen(0), _postArgs(nullptr), - _headerKeysCount(0), _currentHeaders(nullptr), _contentLength(0), _clientContentLength(0), _chunked(false) { +WebServer::WebServer(int port) : _server(port) { log_v("WebServer::Webserver(port=%d)", port); } WebServer::~WebServer() { _server.close(); - if (_currentHeaders) { - delete[] _currentHeaders; - } + + _clearRequestHeaders(); + _clearResponseHeaders(); + delete _chain; + RequestHandler *handler = _firstHandler; while (handler) { RequestHandler *next = handler->next(); delete handler; handler = next; } + _firstHandler = nullptr; } void WebServer::begin() { @@ -436,7 +433,17 @@ void WebServer::handleClient() { _currentClient.setTimeout(HTTP_MAX_SEND_WAIT); /* / 1000 removed, WifiClient setTimeout changed to ms */ if (_parseRequest(_currentClient)) { _contentLength = CONTENT_LENGTH_NOT_SET; - _handleRequest(); + _responseCode = 0; + _clearResponseHeaders(); + + // Run server-level middlewares + if (_chain) { + _chain->runChain(*this, [this]() { + return _handleRequest(); + }); + } else { + _handleRequest(); + } if (_currentClient.isSSE()) { _currentStatus = HC_WAIT_CLOSE; @@ -495,16 +502,22 @@ void WebServer::stop() { } void WebServer::sendHeader(const String &name, const String &value, bool first) { - String headerLine = name; - headerLine += F(": "); - headerLine += value; - headerLine += "\r\n"; + RequestArgument *header = new RequestArgument(); + header->key = name; + header->value = value; - if (first) { - _responseHeaders = headerLine + _responseHeaders; + if (!_responseHeaders || first) { + header->next = _responseHeaders; + _responseHeaders = header; } else { - _responseHeaders += headerLine; + RequestArgument *last = _responseHeaders; + while (last->next) { + last = last->next; + } + last->next = header; } + + _responseHeaderCount++; } void WebServer::setContentLength(const size_t contentLength) { @@ -529,11 +542,14 @@ void WebServer::enableETag(bool enable, ETagFunction fn) { } void WebServer::_prepareHeader(String &response, int code, const char *content_type, size_t contentLength) { - response = String(F("HTTP/1.")) + String(_currentVersion) + ' '; - response += String(code); - response += ' '; - response += _responseCodeToString(code); - response += "\r\n"; + _responseCode = code; + + response.concat(version()); + response.concat(' '); + response.concat(String(code)); + response.concat(' '); + response.concat(responseCodeToString(code)); + response.concat(F("\r\n")); using namespace mime; if (!content_type) { @@ -558,9 +574,14 @@ void WebServer::_prepareHeader(String &response, int code, const char *content_t } sendHeader(String(F("Connection")), String(F("close"))); - response += _responseHeaders; - response += "\r\n"; - _responseHeaders = ""; + for (RequestArgument *header = _responseHeaders; header; header = header->next) { + response.concat(header->key); + response.concat(F(": ")); + response.concat(header->value); + response.concat(F("\r\n")); + } + + response.concat(F("\r\n")); } void WebServer::send(int code, const char *content_type, const String &content) { @@ -568,9 +589,6 @@ void WebServer::send(int code, const char *content_type, const String &content) // Can we assume the following? //if(code == 200 && content.length() == 0 && _contentLength == CONTENT_LENGTH_NOT_SET) // _contentLength = CONTENT_LENGTH_UNKNOWN; - if (content.length() == 0) { - log_w("content length is zero"); - } _prepareHeader(header, code, content_type, content.length()); _currentClientWrite(header.c_str(), header.length()); if (content.length()) { @@ -728,39 +746,43 @@ bool WebServer::hasArg(const String &name) const { } String WebServer::header(const String &name) const { - for (int i = 0; i < _headerKeysCount; ++i) { - if (_currentHeaders[i].key.equalsIgnoreCase(name)) { - return _currentHeaders[i].value; + for (RequestArgument *current = _currentHeaders; current; current = current->next) { + if (current->key.equalsIgnoreCase(name)) { + return current->value; } } return ""; } void WebServer::collectHeaders(const char *headerKeys[], const size_t headerKeysCount) { - _headerKeysCount = headerKeysCount + 2; - if (_currentHeaders) { - delete[] _currentHeaders; - } - _currentHeaders = new RequestArgument[_headerKeysCount]; - _currentHeaders[0].key = FPSTR(AUTHORIZATION_HEADER); - _currentHeaders[1].key = FPSTR(ETAG_HEADER); + collectAllHeaders(); + _collectAllHeaders = false; + + _headerKeysCount += headerKeysCount; + + RequestArgument *last = _currentHeaders->next; + for (int i = 2; i < _headerKeysCount; i++) { - _currentHeaders[i].key = headerKeys[i - 2]; + last->next = new RequestArgument(); + last->next->key = headerKeys[i - 2]; + last = last->next; } } String WebServer::header(int i) const { - if (i < _headerKeysCount) { - return _currentHeaders[i].value; + RequestArgument *current = _currentHeaders; + while (current && i--) { + current = current->next; } - return ""; + return current ? current->value : emptyString; } String WebServer::headerName(int i) const { - if (i < _headerKeysCount) { - return _currentHeaders[i].key; + RequestArgument *current = _currentHeaders; + while (current && i--) { + current = current->next; } - return ""; + return current ? current->key : emptyString; } int WebServer::headers() const { @@ -768,12 +790,7 @@ int WebServer::headers() const { } bool WebServer::hasHeader(const String &name) const { - for (int i = 0; i < _headerKeysCount; ++i) { - if ((_currentHeaders[i].key.equalsIgnoreCase(name)) && (_currentHeaders[i].value.length() > 0)) { - return true; - } - } - return false; + return header(name).length() > 0; } String WebServer::hostHeader() const { @@ -788,16 +805,17 @@ void WebServer::onNotFound(THandlerFunction fn) { _notFoundHandler = fn; } -void WebServer::_handleRequest() { +bool WebServer::_handleRequest() { bool handled = false; - if (!_currentHandler) { - log_e("request handler not found"); - } else { - handled = _currentHandler->handle(*this, _currentMethod, _currentUri); + if (_currentHandler) { + handled = _currentHandler->process(*this, _currentMethod, _currentUri); if (!handled) { log_e("request handler failed to handle request"); } } + // DO NOT LOG if _currentHandler == null !! + // This is is valid use case to handle any other requests + // Also, this is just causing log flooding if (!handled && _notFoundHandler) { _notFoundHandler(); handled = true; @@ -811,6 +829,7 @@ void WebServer::_handleRequest() { _finalizeResponse(); } _currentUri = ""; + return handled; } void WebServer::_finalizeResponse() { @@ -819,7 +838,7 @@ void WebServer::_finalizeResponse() { } } -String WebServer::_responseCodeToString(int code) { +String WebServer::responseCodeToString(int code) { switch (code) { case 100: return F("Continue"); case 101: return F("Switching Protocols"); @@ -864,3 +883,108 @@ String WebServer::_responseCodeToString(int code) { default: return F(""); } } + +void WebServer::_clearResponseHeaders() { + _responseHeaderCount = 0; + RequestArgument *current = _responseHeaders; + while (current) { + RequestArgument *next = current->next; + delete current; + current = next; + } + _responseHeaders = nullptr; +} + +void WebServer::_clearRequestHeaders() { + _headerKeysCount = 0; + RequestArgument *current = _currentHeaders; + while (current) { + RequestArgument *next = current->next; + delete current; + current = next; + } + _currentHeaders = nullptr; +} + +void WebServer::collectAllHeaders() { + _clearRequestHeaders(); + + _currentHeaders = new RequestArgument(); + _currentHeaders->key = FPSTR(AUTHORIZATION_HEADER); + + _currentHeaders->next = new RequestArgument(); + _currentHeaders->next->key = FPSTR(ETAG_HEADER); + + _headerKeysCount = 2; + _collectAllHeaders = true; +} + +const String &WebServer::responseHeader(String name) const { + for (RequestArgument *current = _responseHeaders; current; current = current->next) { + if (current->key.equalsIgnoreCase(name)) { + return current->value; + } + } + return emptyString; +} + +const String &WebServer::responseHeader(int i) const { + RequestArgument *current = _responseHeaders; + while (current && i--) { + current = current->next; + } + return current ? current->value : emptyString; +} + +const String &WebServer::responseHeaderName(int i) const { + RequestArgument *current = _responseHeaders; + while (current && i--) { + current = current->next; + } + return current ? current->key : emptyString; +} + +bool WebServer::hasResponseHeader(const String &name) const { + return header(name).length() > 0; +} + +int WebServer::clientContentLength() const { + return _clientContentLength; +} + +const String WebServer::version() const { + String v; + v.reserve(8); + v.concat(F("HTTP/1.")); + v.concat(_currentVersion); + return v; +} +int WebServer::responseCode() const { + return _responseCode; +} +int WebServer::responseHeaders() const { + return _responseHeaderCount; +} + +WebServer &WebServer::addMiddleware(Middleware *middleware) { + if (!_chain) { + _chain = new MiddlewareChain(); + } + _chain->addMiddleware(middleware); + return *this; +} + +WebServer &WebServer::addMiddleware(Middleware::Function fn) { + if (!_chain) { + _chain = new MiddlewareChain(); + } + _chain->addMiddleware(fn); + return *this; +} + +WebServer &WebServer::removeMiddleware(Middleware *middleware) { + if (_chain) { + _chain->removeMiddleware(middleware); + } + return *this; +} diff --git a/libraries/WebServer/src/WebServer.h b/libraries/WebServer/src/WebServer.h index 0f3405430a7..8daf12c5c30 100644 --- a/libraries/WebServer/src/WebServer.h +++ b/libraries/WebServer/src/WebServer.h @@ -92,6 +92,7 @@ typedef struct { void *data; // additional data } HTTPRaw; +#include "middleware/Middleware.h" #include "detail/RequestHandler.h" namespace fs { @@ -158,6 +159,10 @@ class WebServer { void onNotFound(THandlerFunction fn); //called when handler is not assigned void onFileUpload(THandlerFunction ufn); //handle file uploads + WebServer &addMiddleware(Middleware *middleware); + WebServer &addMiddleware(Middleware::Function fn); + WebServer &removeMiddleware(Middleware *middleware); + String uri() const { return _currentUri; } @@ -181,17 +186,23 @@ class WebServer { int args() const; // get arguments count bool hasArg(const String &name) const; // check if argument exists void collectHeaders(const char *headerKeys[], const size_t headerKeysCount); // set the request headers to collect + void collectAllHeaders(); // collect all request headers String header(const String &name) const; // get request header value by name String header(int i) const; // get request header value by number String headerName(int i) const; // get request header name by number int headers() const; // get header count bool hasHeader(const String &name) const; // check if header exists - int clientContentLength() const { - return _clientContentLength; - } // return "content-length" of incoming HTTP header from "_currentClient" + int clientContentLength() const; // return "content-length" of incoming HTTP header from "_currentClient" + const String version() const; // get the HTTP version string + String hostHeader() const; // get request host header if available or empty String if not - String hostHeader() const; // get request host header if available or empty String if not + int responseCode() const; // get the HTTP response code set + int responseHeaders() const; // get the HTTP response headers count + const String &responseHeader(String name) const; // get the HTTP response header value by name + const String &responseHeader(int i) const; // get the HTTP response header value by number + const String &responseHeaderName(int i) const; // get the HTTP response header name by number + bool hasResponseHeader(const String &name) const; // check if response header exists // send response to the client // code - HTTP response code, can be 200 or 404 @@ -228,6 +239,8 @@ class WebServer { bool _eTagEnabled = false; ETagFunction _eTagFunction = nullptr; + static String responseCodeToString(int code); + protected: virtual size_t _currentClientWrite(const char *b, size_t l) { return _currentClient.write(b, l); @@ -237,11 +250,10 @@ class WebServer { } void _addRequestHandler(RequestHandler *handler); bool _removeRequestHandler(RequestHandler *handler); - void _handleRequest(); + bool _handleRequest(); void _finalizeResponse(); bool _parseRequest(NetworkClient &client); void _parseArguments(const String &data); - static String _responseCodeToString(int code); bool _parseForm(NetworkClient &client, const String &boundary, uint32_t len); bool _parseFormUploadAborted(); void _uploadWriteByte(uint8_t b); @@ -255,48 +267,57 @@ class WebServer { // for extracting Auth parameters String _extractParam(String &authReq, const String ¶m, const char delimit = '"'); + void _clearResponseHeaders(); + void _clearRequestHeaders(); + struct RequestArgument { String key; String value; + RequestArgument *next; }; - boolean _corsEnabled; + boolean _corsEnabled = false; NetworkServer _server; NetworkClient _currentClient; - HTTPMethod _currentMethod; + HTTPMethod _currentMethod = HTTP_ANY; String _currentUri; - uint8_t _currentVersion; - HTTPClientStatus _currentStatus; - unsigned long _statusChange; - boolean _nullDelay; - - RequestHandler *_currentHandler; - RequestHandler *_firstHandler; - RequestHandler *_lastHandler; - THandlerFunction _notFoundHandler; - THandlerFunction _fileUploadHandler; - - int _currentArgCount; - RequestArgument *_currentArgs; - int _postArgsLen; - RequestArgument *_postArgs; + uint8_t _currentVersion = 0; + HTTPClientStatus _currentStatus = HC_NONE; + unsigned long _statusChange = 0; + boolean _nullDelay = true; + + RequestHandler *_currentHandler = nullptr; + RequestHandler *_firstHandler = nullptr; + RequestHandler *_lastHandler = nullptr; + THandlerFunction _notFoundHandler = nullptr; + THandlerFunction _fileUploadHandler = nullptr; + + int _currentArgCount = 0; + RequestArgument *_currentArgs = nullptr; + int _postArgsLen = 0; + RequestArgument *_postArgs = nullptr; std::unique_ptr _currentUpload; std::unique_ptr _currentRaw; - int _headerKeysCount; - RequestArgument *_currentHeaders; - size_t _contentLength; - int _clientContentLength; // "Content-Length" from header of incoming POST or GET request - String _responseHeaders; + int _headerKeysCount = 0; + RequestArgument *_currentHeaders = nullptr; + size_t _contentLength = 0; + int _clientContentLength = 0; // "Content-Length" from header of incoming POST or GET request + RequestArgument *_responseHeaders = nullptr; String _hostHeader; - bool _chunked; + bool _chunked = false; String _snonce; // Store noance and opaque for future comparison String _sopaque; String _srealm; // Store the Auth realm between Calls + + int _responseHeaderCount = 0; + int _responseCode = 0; + bool _collectAllHeaders = false; + MiddlewareChain *_chain = nullptr; }; #endif //ESP8266WEBSERVER_H diff --git a/libraries/WebServer/src/detail/RequestHandler.h b/libraries/WebServer/src/detail/RequestHandler.h index c730ce25bcb..75e11c94ba8 100644 --- a/libraries/WebServer/src/detail/RequestHandler.h +++ b/libraries/WebServer/src/detail/RequestHandler.h @@ -6,7 +6,9 @@ class RequestHandler { public: - virtual ~RequestHandler() {} + virtual ~RequestHandler() { + delete _chain; + } /* note: old handler API for backward compatibility @@ -75,8 +77,14 @@ class RequestHandler { _next = r; } + RequestHandler &addMiddleware(Middleware *middleware); + RequestHandler &addMiddleware(Middleware::Function fn); + RequestHandler &removeMiddleware(Middleware *middleware); + bool process(WebServer &server, HTTPMethod requestMethod, String requestUri); + private: RequestHandler *_next = nullptr; + MiddlewareChain *_chain = nullptr; protected: std::vector pathArgs; diff --git a/libraries/WebServer/src/detail/RequestHandlersImpl.h b/libraries/WebServer/src/detail/RequestHandlersImpl.h index c66c294dd33..3750b594ab2 100644 --- a/libraries/WebServer/src/detail/RequestHandlersImpl.h +++ b/libraries/WebServer/src/detail/RequestHandlersImpl.h @@ -10,6 +10,39 @@ using namespace mime; +RequestHandler &RequestHandler::addMiddleware(Middleware *middleware) { + if (!_chain) { + _chain = new MiddlewareChain(); + } + _chain->addMiddleware(middleware); + return *this; +} + +RequestHandler &RequestHandler::addMiddleware(Middleware::Function fn) { + if (!_chain) { + _chain = new MiddlewareChain(); + } + _chain->addMiddleware(fn); + return *this; +} + +RequestHandler &RequestHandler::removeMiddleware(Middleware *middleware) { + if (_chain) { + _chain->removeMiddleware(middleware); + } + return *this; +} + +bool RequestHandler::process(WebServer &server, HTTPMethod requestMethod, String requestUri) { + if (_chain) { + return _chain->runChain(server, [this, &server, &requestMethod, &requestUri]() { + return handle(server, requestMethod, requestUri); + }); + } else { + return handle(server, requestMethod, requestUri); + } +} + class FunctionRequestHandler : public RequestHandler { public: FunctionRequestHandler(WebServer::THandlerFunction fn, WebServer::THandlerFunction ufn, const Uri &uri, HTTPMethod method) diff --git a/libraries/WebServer/src/middleware/AuthenticationMiddleware.cpp b/libraries/WebServer/src/middleware/AuthenticationMiddleware.cpp new file mode 100644 index 00000000000..cab25ba4e50 --- /dev/null +++ b/libraries/WebServer/src/middleware/AuthenticationMiddleware.cpp @@ -0,0 +1,82 @@ +#include "Middlewares.h" + +AuthenticationMiddleware &AuthenticationMiddleware::setUsername(const char *username) { + _username = username; + _callback = nullptr; + return *this; +} + +AuthenticationMiddleware &AuthenticationMiddleware::setPassword(const char *password) { + _password = password; + _hash = false; + _callback = nullptr; + return *this; +} + +AuthenticationMiddleware &AuthenticationMiddleware::setPasswordHash(const char *sha1AsBase64orHex) { + _password = sha1AsBase64orHex; + _hash = true; + _callback = nullptr; + return *this; +} + +AuthenticationMiddleware &AuthenticationMiddleware::setCallback(WebServer::THandlerFunctionAuthCheck fn) { + assert(fn); + _callback = fn; + _hash = false; + _username = emptyString; + _password = emptyString; + return *this; +} + +AuthenticationMiddleware &AuthenticationMiddleware::setRealm(const char *realm) { + _realm = realm; + return *this; +} + +AuthenticationMiddleware &AuthenticationMiddleware::setAuthMethod(HTTPAuthMethod method) { + _method = method; + return *this; +} + +AuthenticationMiddleware &AuthenticationMiddleware::setAuthFailureMessage(const char *message) { + _authFailMsg = message; + return *this; +} + +bool AuthenticationMiddleware::isAllowed(WebServer &server) const { + if (_callback) { + return server.authenticate(_callback); + } + + if (!_username.isEmpty() && !_password.isEmpty()) { + if (_hash) { + return server.authenticateBasicSHA1(_username.c_str(), _password.c_str()); + } else { + return server.authenticate(_username.c_str(), _password.c_str()); + } + } + + return true; +} + +bool AuthenticationMiddleware::run(WebServer &server, Middleware::Callback next) { + bool authenticationRequired = false; + + if (_callback) { + authenticationRequired = !server.authenticate(_callback); + } else if (!_username.isEmpty() && !_password.isEmpty()) { + if (_hash) { + authenticationRequired = !server.authenticateBasicSHA1(_username.c_str(), _password.c_str()); + } else { + authenticationRequired = !server.authenticate(_username.c_str(), _password.c_str()); + } + } + + if (authenticationRequired) { + server.requestAuthentication(_method, _realm, _authFailMsg); + return true; + } else { + return next(); + } +} diff --git a/libraries/WebServer/src/middleware/CorsMiddleware.cpp b/libraries/WebServer/src/middleware/CorsMiddleware.cpp new file mode 100644 index 00000000000..a52ccd59f23 --- /dev/null +++ b/libraries/WebServer/src/middleware/CorsMiddleware.cpp @@ -0,0 +1,47 @@ +#include "Middlewares.h" + +CorsMiddleware &CorsMiddleware::setOrigin(const char *origin) { + _origin = origin; + return *this; +} + +CorsMiddleware &CorsMiddleware::setMethods(const char *methods) { + _methods = methods; + return *this; +} + +CorsMiddleware &CorsMiddleware::setHeaders(const char *headers) { + _headers = headers; + return *this; +} + +CorsMiddleware &CorsMiddleware::setAllowCredentials(bool credentials) { + _credentials = credentials; + return *this; +} + +CorsMiddleware &CorsMiddleware::setMaxAge(uint32_t seconds) { + _maxAge = seconds; + return *this; +} + +void CorsMiddleware::addCORSHeaders(WebServer &server) { + server.sendHeader(F("Access-Control-Allow-Origin"), _origin.c_str()); + server.sendHeader(F("Access-Control-Allow-Methods"), _methods.c_str()); + server.sendHeader(F("Access-Control-Allow-Headers"), _headers.c_str()); + server.sendHeader(F("Access-Control-Allow-Credentials"), _credentials ? F("true") : F("false")); + server.sendHeader(F("Access-Control-Max-Age"), String(_maxAge).c_str()); +} + +bool CorsMiddleware::run(WebServer &server, Middleware::Callback next) { + // Origin header ? => CORS handling + if (server.hasHeader(F("Origin"))) { + addCORSHeaders(server); + // check if this is a preflight request => handle it and return + if (server.method() == HTTP_OPTIONS) { + server.send(200); + return true; + } + } + return next(); +} diff --git a/libraries/WebServer/src/middleware/LoggingMiddleware.cpp b/libraries/WebServer/src/middleware/LoggingMiddleware.cpp new file mode 100644 index 00000000000..e1f6d708c2e --- /dev/null +++ b/libraries/WebServer/src/middleware/LoggingMiddleware.cpp @@ -0,0 +1,74 @@ +#include "Middlewares.h" + +void LoggingMiddleware::setOutput(Print &output) { + _out = &output; +} + +bool LoggingMiddleware::run(WebServer &server, Middleware::Callback next) { + if (_out == nullptr) { + return next(); + } + + _out->print(F("* Connection from ")); + _out->print(server.client().remoteIP().toString()); + _out->print(F(":")); + _out->println(server.client().remotePort()); + + _out->print(F("> ")); + const HTTPMethod method = server.method(); + if (method == HTTP_ANY) { + _out->print(F("HTTP_ANY")); + } else { + _out->print(http_method_str(method)); + } + _out->print(F(" ")); + _out->print(server.uri()); + _out->print(F(" ")); + _out->println(server.version()); + + int n = server.headers(); + for (int i = 0; i < n; i++) { + String v = server.header(i); + if (!v.isEmpty()) { + // because these 2 are always there, eventually empty: "Authorization", "If-None-Match" + _out->print(F("> ")); + _out->print(server.headerName(i)); + _out->print(F(": ")); + _out->println(server.header(i)); + } + } + + _out->println(F(">")); + + uint32_t elapsed = millis(); + const bool ret = next(); + elapsed = millis() - elapsed; + + if (ret) { + _out->print(F("* Processed in ")); + _out->print(elapsed); + _out->println(F(" ms")); + _out->print(F("< ")); + _out->print(F("HTTP/1.")); + _out->print(server.version()); + _out->print(F(" ")); + _out->print(server.responseCode()); + _out->print(F(" ")); + _out->println(WebServer::responseCodeToString(server.responseCode())); + + n = server.responseHeaders(); + for (int i = 0; i < n; i++) { + _out->print(F("< ")); + _out->print(server.responseHeaderName(i)); + _out->print(F(": ")); + _out->println(server.responseHeader(i)); + } + + _out->println(F("<")); + + } else { + _out->println(F("* Not processed!")); + } + + return ret; +} diff --git a/libraries/WebServer/src/middleware/Middleware.h b/libraries/WebServer/src/middleware/Middleware.h new file mode 100644 index 00000000000..080f5be0aba --- /dev/null +++ b/libraries/WebServer/src/middleware/Middleware.h @@ -0,0 +1,54 @@ +#ifndef MIDDLEWARE_H +#define MIDDLEWARE_H + +#include +#include + +class MiddlewareChain; +class WebServer; + +class Middleware { +public: + typedef std::function Callback; + typedef std::function Function; + + virtual ~Middleware() {} + + virtual bool run(WebServer &server, Callback next) { + return next(); + }; + +private: + friend MiddlewareChain; + Middleware *_next = nullptr; + bool _freeOnRemoval = false; +}; + +class MiddlewareFunction : public Middleware { +public: + MiddlewareFunction(Middleware::Function fn) : _fn(fn) {} + + bool run(WebServer &server, Middleware::Callback next) override { + return _fn(server, next); + } + +private: + Middleware::Function _fn; +}; + +class MiddlewareChain { +public: + ~MiddlewareChain(); + + void addMiddleware(Middleware::Function fn); + void addMiddleware(Middleware *middleware); + bool removeMiddleware(Middleware *middleware); + + bool runChain(WebServer &server, Middleware::Callback finalizer); + +private: + Middleware *_root = nullptr; + Middleware *_current = nullptr; +}; + +#endif diff --git a/libraries/WebServer/src/middleware/MiddlewareChain.cpp b/libraries/WebServer/src/middleware/MiddlewareChain.cpp new file mode 100644 index 00000000000..56b3066caea --- /dev/null +++ b/libraries/WebServer/src/middleware/MiddlewareChain.cpp @@ -0,0 +1,73 @@ +#include "Middleware.h" + +MiddlewareChain::~MiddlewareChain() { + Middleware *current = _root; + while (current) { + Middleware *next = current->_next; + if (current->_freeOnRemoval) { + delete current; + } + current = next; + } + _root = nullptr; +} + +void MiddlewareChain::addMiddleware(Middleware::Function fn) { + MiddlewareFunction *middleware = new MiddlewareFunction(fn); + middleware->_freeOnRemoval = true; + addMiddleware(middleware); +} + +void MiddlewareChain::addMiddleware(Middleware *middleware) { + if (!_root) { + _root = middleware; + return; + } + Middleware *current = _root; + while (current->_next) { + current = current->_next; + } + current->_next = middleware; +} + +bool MiddlewareChain::removeMiddleware(Middleware *middleware) { + if (!_root) { + return false; + } + if (_root == middleware) { + _root = _root->_next; + if (middleware->_freeOnRemoval) { + delete middleware; + } + return true; + } + Middleware *current = _root; + while (current->_next) { + if (current->_next == middleware) { + current->_next = current->_next->_next; + if (middleware->_freeOnRemoval) { + delete middleware; + } + return true; + } + current = current->_next; + } + return false; +} + +bool MiddlewareChain::runChain(WebServer &server, Middleware::Callback finalizer) { + if (!_root) { + return finalizer(); + } + _current = _root; + Middleware::Callback next; + next = [this, &server, &next, finalizer]() { + if (!_current) { + return finalizer(); + } + Middleware *that = _current; + _current = _current->_next; + return that->run(server, next); + }; + return next(); +}