From 13b4b8a824337eba6cdab11405b4b6d112cb2820 Mon Sep 17 00:00:00 2001 From: cpq Date: Tue, 15 Aug 2023 11:52:22 +0100 Subject: [PATCH] TLS layer: use c->raw, enhance tls_builtin --- mongoose.c | 152 ++++++++++++++++++++++++++++++---------------- mongoose.h | 1 + src/net.c | 13 +++- src/net.h | 1 + src/net_builtin.c | 16 +---- src/sock.c | 45 ++++++++++---- src/tls_builtin.c | 78 +++++++++++++++--------- test/fuzz.c | 2 +- 8 files changed, 199 insertions(+), 109 deletions(-) diff --git a/mongoose.c b/mongoose.c index 9edde71d4b..01bd70e592 100644 --- a/mongoose.c +++ b/mongoose.c @@ -3692,7 +3692,7 @@ struct mg_connection *mg_alloc_conn(struct mg_mgr *mgr) { (struct mg_connection *) calloc(1, sizeof(*c) + mgr->extraconnsize); if (c != NULL) { c->mgr = mgr; - c->send.align = c->recv.align = MG_IO_SIZE; + c->send.align = c->recv.align = c->raw.align = MG_IO_SIZE; c->id = ++mgr->nextid; } return c; @@ -3711,6 +3711,7 @@ void mg_close_conn(struct mg_connection *c) { mg_tls_free(c); mg_iobuf_free(&c->recv); mg_iobuf_free(&c->send); + mg_iobuf_free(&c->raw); memset(c, 0, sizeof(*c)); free(c); } @@ -3756,7 +3757,7 @@ struct mg_connection *mg_listen(struct mg_mgr *mgr, const char *url, c->fn = fn; c->fn_data = fn_data; mg_call(c, MG_EV_OPEN, NULL); - if (mg_url_is_ssl(url)) c->is_tls = 1; // Accepted connection must + if (mg_url_is_ssl(url)) c->is_tls = 1; // Accepted connection must MG_DEBUG(("%lu %p %s", c->id, c->fd, url)); } return c; @@ -3786,6 +3787,14 @@ struct mg_timer *mg_timer_add(struct mg_mgr *mgr, uint64_t milliseconds, return t; } +long mg_io_recv(struct mg_connection *c, void *buf, size_t len) { + if (c->raw.len == 0) return MG_IO_WAIT; + if (len > c->raw.len) len = c->raw.len; + memcpy(buf, c->raw.buf, len); + mg_iobuf_del(&c->raw, 0, len); + return (long) len; +} + void mg_mgr_free(struct mg_mgr *mgr) { struct mg_connection *c; struct mg_timer *tmp, *t = mgr->timers; @@ -4403,20 +4412,10 @@ long mg_io_send(struct mg_connection *c, const void *buf, size_t len) { return (long) len; } -long mg_io_recv(struct mg_connection *c, void *buf, size_t len) { - struct connstate *s = (struct connstate *) (c + 1); - if (s->raw.len == 0) return MG_IO_WAIT; - if (len > s->raw.len) len = s->raw.len; - memcpy(buf, s->raw.buf, len); - mg_iobuf_del(&s->raw, 0, len); - return (long) len; -} - static void read_conn(struct mg_connection *c, struct pkt *pkt) { struct connstate *s = (struct connstate *) (c + 1); - struct mg_iobuf *io = c->is_tls ? &s->raw : &c->recv; + struct mg_iobuf *io = c->is_tls ? &c->raw : &c->recv; uint32_t seq = mg_ntohl(pkt->tcp->seq); - s->raw.align = c->recv.align; uint32_t rem_ip; memcpy(&rem_ip, c->rem.ip, sizeof(uint32_t)); if (pkt->tcp->flags & TH_FIN) { @@ -4456,9 +4455,9 @@ static void read_conn(struct mg_connection *c, struct pkt *pkt) { } else { // Copy TCP payload into the IO buffer. If the connection is plain text, // we copy to c->recv. If the connection is TLS, this data is encrypted, - // therefore we copy that encrypted data to the s->raw iobuffer instead, + // therefore we copy that encrypted data to the c->raw iobuffer instead, // and then call mg_tls_recv() to decrypt it. NOTE: mg_tls_recv() will - // call back mg_io_recv() which grabs raw data from s->raw + // call back mg_io_recv() which grabs raw data from c->raw memcpy(&io->buf[io->len], pkt->pay.ptr, pkt->pay.len); io->len += pkt->pay.len; @@ -5831,7 +5830,7 @@ bool mg_open_listener(struct mg_connection *c, const char *url) { return success; } -long mg_io_recv(struct mg_connection *c, void *buf, size_t len) { +static long recv_raw(struct mg_connection *c, void *buf, size_t len) { long n = 0; if (c->is_udp) { union usa usa; @@ -5847,19 +5846,42 @@ long mg_io_recv(struct mg_connection *c, void *buf, size_t len) { return n; } +static bool ioalloc(struct mg_connection *c, struct mg_iobuf *io) { + bool res = false; + if (io->len >= MG_MAX_RECV_SIZE) { + mg_error(c, "MG_MAX_RECV_SIZE"); + } else if (io->size <= io->len && + !mg_iobuf_resize(io, io->size + MG_IO_SIZE)) { + mg_error(c, "OOM"); + } else { + res = true; + } + return res; +} + // NOTE(lsm): do only one iteration of reads, cause some systems // (e.g. FreeRTOS stack) return 0 instead of -1/EWOULDBLOCK when no data static void read_conn(struct mg_connection *c) { - long n = -1; - if (c->recv.len >= MG_MAX_RECV_SIZE) { - mg_error(c, "max_recv_buf_size reached"); - } else if (c->recv.size <= c->recv.len && - !mg_iobuf_resize(&c->recv, c->recv.size + MG_IO_SIZE)) { - mg_error(c, "oom"); - } else { + if (ioalloc(c, &c->recv)) { char *buf = (char *) &c->recv.buf[c->recv.len]; size_t len = c->recv.size - c->recv.len; - n = c->is_tls ? mg_tls_recv(c, buf, len) : mg_io_recv(c, buf, len); + long n = -1; + if (c->is_tls) { + if (!ioalloc(c, &c->raw)) return; + n = recv_raw(c, (char *) &c->raw.buf[c->raw.len], + c->raw.size - c->raw.len); + MG_INFO(("%lu %ld", c->id, n)); + if (n == MG_IO_ERR) { + c->is_closing = 1; + } else if (n > 0) { + c->raw.len += (size_t) n; + if (c->is_tls_hs) mg_tls_handshake(c); + if (c->is_tls_hs) return; + n = mg_tls_recv(c, buf, len); + } + } else { + n = recv_raw(c, buf, len); + } MG_DEBUG(("%lu %p snd %ld/%ld rcv %ld/%ld n=%ld err=%d", c->id, c->fd, (long) c->send.len, (long) c->send.size, (long) c->recv.len, (long) c->recv.size, n, MG_SOCK_ERR(n))); @@ -6179,8 +6201,8 @@ void mg_mgr_poll(struct mg_mgr *mgr, int ms) { if (c->is_readable) accept_conn(mgr, c); } else if (c->is_connecting) { if (c->is_readable || c->is_writable) connect_conn(c); - } else if (c->is_tls_hs) { - if ((c->is_readable || c->is_writable)) mg_tls_handshake(c); + //} else if (c->is_tls_hs) { + // if ((c->is_readable || c->is_writable)) mg_tls_handshake(c); } else { if (c->is_readable) read_conn(c); if (c->is_writable) write_conn(c); @@ -6547,8 +6569,12 @@ void mg_timer_poll(struct mg_timer **head, uint64_t now_ms) { #if MG_TLS == MG_TLS_BUILTIN struct tls_data { - struct mg_iobuf send; - struct mg_iobuf recv; + uint8_t client_random[32]; // From client hello + uint8_t client_pub[32]; // From client hello +}; +struct tls_ctx { + struct mg_iobuf server_cert; // Decoded server certificate + struct mg_iobuf server_key; // Decoded server private key }; #define MG_LOAD_BE16(p) ((uint16_t) ((MG_U8P(p)[0] << 8U) | MG_U8P(p)[1])) @@ -6561,6 +6587,10 @@ static inline bool mg_is_big_endian(void) { static inline uint16_t mg_swap16(uint16_t v) { return (uint16_t) ((v << 8U) | (v >> 8U)); } +static inline uint16_t mg_be16(uint16_t v) { + return mg_is_big_endian() ? mg_swap16(v) : v; +} +#if 0 static inline uint32_t mg_swap32(uint32_t v) { return (v >> 24) | (v >> 8 & 0xff00) | (v << 8 & 0xff0000) | (v << 24); } @@ -6568,12 +6598,10 @@ static inline uint64_t mg_swap64(uint64_t v) { return (((uint64_t) mg_swap32((uint32_t) v)) << 32) | mg_swap32((uint32_t) (v >> 32)); } -static inline uint16_t mg_be16(uint16_t v) { - return mg_is_big_endian() ? mg_swap16(v) : v; -} static inline uint32_t mg_be32(uint32_t v) { return mg_is_big_endian() ? mg_swap32(v) : v; } +#endif static inline void add8(struct mg_iobuf *io, uint8_t data) { mg_iobuf_add(io, io->len, &data, sizeof(data)); @@ -6590,7 +6618,7 @@ static inline void add32(struct mg_iobuf *io, uint32_t data) { void mg_tls_init(struct mg_connection *c, struct mg_str hostname) { struct tls_data *tls = (struct tls_data *) calloc(1, sizeof(struct tls_data)); if (tls != NULL) { - tls->send.align = tls->recv.align = MG_IO_SIZE; + // tls->send.align = tls->recv.align = MG_IO_SIZE; c->tls = tls; c->is_tls = c->is_tls_hs = 1; } else { @@ -6601,8 +6629,8 @@ void mg_tls_init(struct mg_connection *c, struct mg_str hostname) { void mg_tls_free(struct mg_connection *c) { struct tls_data *tls = c->tls; if (tls != NULL) { - mg_iobuf_free(&tls->send); - mg_iobuf_free(&tls->recv); + // mg_iobuf_free(&tls->send); + // mg_iobuf_free(&tls->recv); } free(c->tls); c->tls = NULL; @@ -6631,36 +6659,24 @@ size_t mg_tls_pending(struct mg_connection *c) { return 0; } void mg_tls_handshake(struct mg_connection *c) { - struct tls_data *tls = c->tls; - struct mg_iobuf *rio = &tls->recv; - struct mg_iobuf *wio = &tls->send; - // Pull data from TCP - for (;;) { - mg_iobuf_resize(rio, rio->len + 1); - long n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len); - if (n > 0) { - rio->len += (size_t) n; - } else if (n == MG_IO_WAIT) { - break; - } else { - mg_error(c, "IO err"); - return; - } - } + // struct tls_data *tls = c->tls; + struct mg_iobuf *rio = &c->raw; + struct mg_iobuf *wio = &c->send; + // Look if we've pulled everything if (rio->len < TLS_HDR_SIZE) return; uint8_t record_type = rio->buf[0]; uint16_t record_len = MG_LOAD_BE16(rio->buf + 3); uint16_t record_version = MG_LOAD_BE16(rio->buf + 1); if (record_type != 22) { - mg_error(c, "no 22"); + mg_error(c, "not a handshake"); return; } if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return; // Got full hello // struct tls_hello *hello = (struct tls_hello *) (hdr + 1); MG_INFO(("CT=%d V=%hx L=%hu", record_type, record_version, record_len)); - mg_hexdump(rio->buf, rio->len); + // mg_hexdump(rio->buf, rio->len); // Send response. Server Hello size_t ofs = wio->len; @@ -6673,6 +6689,8 @@ void mg_tls_handshake(struct mg_connection *c) { add8(wio, 0); // Compression method: 0 add16(wio, 46); // Extensions length add16(wio, 43), add16(wio, 2), add16(wio, 0x304); // extension: TLS 1.3 + + // Key share: use curve x25519 (id 29) add16(wio, 51), add16(wio, 36), add16(wio, 29), add16(wio, 32); // keyshare mg_iobuf_add(wio, wio->len, NULL, 32); // 32 random mg_random(wio->buf + wio->len - 32, 32); // bytes @@ -6692,7 +6710,7 @@ void mg_tls_handshake(struct mg_connection *c) { add8(wio, 0), add16(wio, 2), add16(wio, 0); // empty 2 bytes add8(wio, 11); // certificate message add8(wio, 0), add16(wio, 4), add32(wio, 0x1020304); // len - *(uint16_t *) &wio->buf[ofs + 3] = mg_be16((uint16_t)(wio->len - ofs - 5)); + *(uint16_t *) &wio->buf[ofs + 3] = mg_be16((uint16_t) (wio->len - ofs - 5)); mg_io_send(c, wio->buf, wio->len); wio->len = 0; @@ -6702,10 +6720,36 @@ void mg_tls_handshake(struct mg_connection *c) { mg_error(c, "doh"); } void mg_tls_ctx_free(struct mg_mgr *mgr) { + free(mgr->tls_ctx); mgr->tls_ctx = NULL; } +static void pem_load(struct mg_str pem, struct mg_iobuf *io) { + if (pem.len > 0) { + if (pem.ptr[0] == '-') { + char in[4], out[4]; + size_t i = 0, j = 0; + while (i < pem.len && pem.ptr[i] != '\n') i++; // Skip first line, ----- + for (; i < pem.len && pem.ptr[i] != '-'; i++) { // Until the end ----- + if (pem.ptr[i] == '\r' || pem.ptr[i] == '\n') continue; + in[j++] = pem.ptr[i]; + if (j >= sizeof(in)) { + int n = mg_base64_decode(in, 4, out); + mg_iobuf_add(io, io->len, out, n); + j = 0; + } + } + mg_hexdump(io->buf, io->len); + } else { + mg_iobuf_add(io, io->len, pem.ptr, pem.len); // Already in DER + } + } +} void mg_tls_ctx_init(struct mg_mgr *mgr, const struct mg_tls_opts *opts) { - (void) opts, (void) mgr; + struct tls_ctx *ctx = (struct tls_ctx *) calloc(1, sizeof(*ctx)); + ctx->server_key.align = ctx->server_cert.align = 128; + pem_load(opts->server_cert, &ctx->server_cert); + pem_load(opts->server_key, &ctx->server_key); + mgr->tls_ctx = ctx; } #endif diff --git a/mongoose.h b/mongoose.h index df643a1628..6b77d9a84a 100644 --- a/mongoose.h +++ b/mongoose.h @@ -1195,6 +1195,7 @@ struct mg_connection { unsigned long id; // Auto-incrementing unique connection ID struct mg_iobuf recv; // Incoming data struct mg_iobuf send; // Outgoing data + struct mg_iobuf raw; // TLS only. Incoming encrypted data mg_event_handler_t fn; // User-specified event handler function void *fn_data; // User-specified function parameter mg_event_handler_t pfn; // Protocol-specific handler function diff --git a/src/net.c b/src/net.c index 45b6090421..baaaed5df5 100644 --- a/src/net.c +++ b/src/net.c @@ -124,7 +124,7 @@ struct mg_connection *mg_alloc_conn(struct mg_mgr *mgr) { (struct mg_connection *) calloc(1, sizeof(*c) + mgr->extraconnsize); if (c != NULL) { c->mgr = mgr; - c->send.align = c->recv.align = MG_IO_SIZE; + c->send.align = c->recv.align = c->raw.align = MG_IO_SIZE; c->id = ++mgr->nextid; } return c; @@ -143,6 +143,7 @@ void mg_close_conn(struct mg_connection *c) { mg_tls_free(c); mg_iobuf_free(&c->recv); mg_iobuf_free(&c->send); + mg_iobuf_free(&c->raw); memset(c, 0, sizeof(*c)); free(c); } @@ -188,7 +189,7 @@ struct mg_connection *mg_listen(struct mg_mgr *mgr, const char *url, c->fn = fn; c->fn_data = fn_data; mg_call(c, MG_EV_OPEN, NULL); - if (mg_url_is_ssl(url)) c->is_tls = 1; // Accepted connection must + if (mg_url_is_ssl(url)) c->is_tls = 1; // Accepted connection must MG_DEBUG(("%lu %p %s", c->id, c->fd, url)); } return c; @@ -218,6 +219,14 @@ struct mg_timer *mg_timer_add(struct mg_mgr *mgr, uint64_t milliseconds, return t; } +long mg_io_recv(struct mg_connection *c, void *buf, size_t len) { + if (c->raw.len == 0) return MG_IO_WAIT; + if (len > c->raw.len) len = c->raw.len; + memcpy(buf, c->raw.buf, len); + mg_iobuf_del(&c->raw, 0, len); + return (long) len; +} + void mg_mgr_free(struct mg_mgr *mgr) { struct mg_connection *c; struct mg_timer *tmp, *t = mgr->timers; diff --git a/src/net.h b/src/net.h index fac5fe8f06..706626ce58 100644 --- a/src/net.h +++ b/src/net.h @@ -48,6 +48,7 @@ struct mg_connection { unsigned long id; // Auto-incrementing unique connection ID struct mg_iobuf recv; // Incoming data struct mg_iobuf send; // Outgoing data + struct mg_iobuf raw; // TLS only. Incoming encrypted data mg_event_handler_t fn; // User-specified event handler function void *fn_data; // User-specified function parameter mg_event_handler_t pfn; // Protocol-specific handler function diff --git a/src/net_builtin.c b/src/net_builtin.c index 47049f1429..ad01ffad67 100644 --- a/src/net_builtin.c +++ b/src/net_builtin.c @@ -572,20 +572,10 @@ long mg_io_send(struct mg_connection *c, const void *buf, size_t len) { return (long) len; } -long mg_io_recv(struct mg_connection *c, void *buf, size_t len) { - struct connstate *s = (struct connstate *) (c + 1); - if (s->raw.len == 0) return MG_IO_WAIT; - if (len > s->raw.len) len = s->raw.len; - memcpy(buf, s->raw.buf, len); - mg_iobuf_del(&s->raw, 0, len); - return (long) len; -} - static void read_conn(struct mg_connection *c, struct pkt *pkt) { struct connstate *s = (struct connstate *) (c + 1); - struct mg_iobuf *io = c->is_tls ? &s->raw : &c->recv; + struct mg_iobuf *io = c->is_tls ? &c->raw : &c->recv; uint32_t seq = mg_ntohl(pkt->tcp->seq); - s->raw.align = c->recv.align; uint32_t rem_ip; memcpy(&rem_ip, c->rem.ip, sizeof(uint32_t)); if (pkt->tcp->flags & TH_FIN) { @@ -625,9 +615,9 @@ static void read_conn(struct mg_connection *c, struct pkt *pkt) { } else { // Copy TCP payload into the IO buffer. If the connection is plain text, // we copy to c->recv. If the connection is TLS, this data is encrypted, - // therefore we copy that encrypted data to the s->raw iobuffer instead, + // therefore we copy that encrypted data to the c->raw iobuffer instead, // and then call mg_tls_recv() to decrypt it. NOTE: mg_tls_recv() will - // call back mg_io_recv() which grabs raw data from s->raw + // call back mg_io_recv() which grabs raw data from c->raw memcpy(&io->buf[io->len], pkt->pay.ptr, pkt->pay.len); io->len += pkt->pay.len; diff --git a/src/sock.c b/src/sock.c index 727c1c617c..11a565acb7 100644 --- a/src/sock.c +++ b/src/sock.c @@ -238,7 +238,7 @@ bool mg_open_listener(struct mg_connection *c, const char *url) { return success; } -long mg_io_recv(struct mg_connection *c, void *buf, size_t len) { +static long recv_raw(struct mg_connection *c, void *buf, size_t len) { long n = 0; if (c->is_udp) { union usa usa; @@ -254,19 +254,42 @@ long mg_io_recv(struct mg_connection *c, void *buf, size_t len) { return n; } +static bool ioalloc(struct mg_connection *c, struct mg_iobuf *io) { + bool res = false; + if (io->len >= MG_MAX_RECV_SIZE) { + mg_error(c, "MG_MAX_RECV_SIZE"); + } else if (io->size <= io->len && + !mg_iobuf_resize(io, io->size + MG_IO_SIZE)) { + mg_error(c, "OOM"); + } else { + res = true; + } + return res; +} + // NOTE(lsm): do only one iteration of reads, cause some systems // (e.g. FreeRTOS stack) return 0 instead of -1/EWOULDBLOCK when no data static void read_conn(struct mg_connection *c) { - long n = -1; - if (c->recv.len >= MG_MAX_RECV_SIZE) { - mg_error(c, "max_recv_buf_size reached"); - } else if (c->recv.size <= c->recv.len && - !mg_iobuf_resize(&c->recv, c->recv.size + MG_IO_SIZE)) { - mg_error(c, "oom"); - } else { + if (ioalloc(c, &c->recv)) { char *buf = (char *) &c->recv.buf[c->recv.len]; size_t len = c->recv.size - c->recv.len; - n = c->is_tls ? mg_tls_recv(c, buf, len) : mg_io_recv(c, buf, len); + long n = -1; + if (c->is_tls) { + if (!ioalloc(c, &c->raw)) return; + n = recv_raw(c, (char *) &c->raw.buf[c->raw.len], + c->raw.size - c->raw.len); + MG_INFO(("%lu %ld", c->id, n)); + if (n == MG_IO_ERR) { + c->is_closing = 1; + } else if (n > 0) { + c->raw.len += (size_t) n; + if (c->is_tls_hs) mg_tls_handshake(c); + if (c->is_tls_hs) return; + n = mg_tls_recv(c, buf, len); + } + } else { + n = recv_raw(c, buf, len); + } MG_DEBUG(("%lu %p snd %ld/%ld rcv %ld/%ld n=%ld err=%d", c->id, c->fd, (long) c->send.len, (long) c->send.size, (long) c->recv.len, (long) c->recv.size, n, MG_SOCK_ERR(n))); @@ -586,8 +609,8 @@ void mg_mgr_poll(struct mg_mgr *mgr, int ms) { if (c->is_readable) accept_conn(mgr, c); } else if (c->is_connecting) { if (c->is_readable || c->is_writable) connect_conn(c); - } else if (c->is_tls_hs) { - if ((c->is_readable || c->is_writable)) mg_tls_handshake(c); + //} else if (c->is_tls_hs) { + // if ((c->is_readable || c->is_writable)) mg_tls_handshake(c); } else { if (c->is_readable) read_conn(c); if (c->is_writable) write_conn(c); diff --git a/src/tls_builtin.c b/src/tls_builtin.c index eebe5a48cc..826ae078cb 100644 --- a/src/tls_builtin.c +++ b/src/tls_builtin.c @@ -2,8 +2,12 @@ #if MG_TLS == MG_TLS_BUILTIN struct tls_data { - struct mg_iobuf send; - struct mg_iobuf recv; + uint8_t client_random[32]; // From client hello + uint8_t client_pub[32]; // From client hello +}; +struct tls_ctx { + struct mg_iobuf server_cert; // Decoded server certificate + struct mg_iobuf server_key; // Decoded server private key }; #define MG_LOAD_BE16(p) ((uint16_t) ((MG_U8P(p)[0] << 8U) | MG_U8P(p)[1])) @@ -16,6 +20,10 @@ static inline bool mg_is_big_endian(void) { static inline uint16_t mg_swap16(uint16_t v) { return (uint16_t) ((v << 8U) | (v >> 8U)); } +static inline uint16_t mg_be16(uint16_t v) { + return mg_is_big_endian() ? mg_swap16(v) : v; +} +#if 0 static inline uint32_t mg_swap32(uint32_t v) { return (v >> 24) | (v >> 8 & 0xff00) | (v << 8 & 0xff0000) | (v << 24); } @@ -23,12 +31,10 @@ static inline uint64_t mg_swap64(uint64_t v) { return (((uint64_t) mg_swap32((uint32_t) v)) << 32) | mg_swap32((uint32_t) (v >> 32)); } -static inline uint16_t mg_be16(uint16_t v) { - return mg_is_big_endian() ? mg_swap16(v) : v; -} static inline uint32_t mg_be32(uint32_t v) { return mg_is_big_endian() ? mg_swap32(v) : v; } +#endif static inline void add8(struct mg_iobuf *io, uint8_t data) { mg_iobuf_add(io, io->len, &data, sizeof(data)); @@ -45,7 +51,7 @@ static inline void add32(struct mg_iobuf *io, uint32_t data) { void mg_tls_init(struct mg_connection *c, struct mg_str hostname) { struct tls_data *tls = (struct tls_data *) calloc(1, sizeof(struct tls_data)); if (tls != NULL) { - tls->send.align = tls->recv.align = MG_IO_SIZE; + // tls->send.align = tls->recv.align = MG_IO_SIZE; c->tls = tls; c->is_tls = c->is_tls_hs = 1; } else { @@ -56,8 +62,8 @@ void mg_tls_init(struct mg_connection *c, struct mg_str hostname) { void mg_tls_free(struct mg_connection *c) { struct tls_data *tls = c->tls; if (tls != NULL) { - mg_iobuf_free(&tls->send); - mg_iobuf_free(&tls->recv); + // mg_iobuf_free(&tls->send); + // mg_iobuf_free(&tls->recv); } free(c->tls); c->tls = NULL; @@ -86,36 +92,24 @@ size_t mg_tls_pending(struct mg_connection *c) { return 0; } void mg_tls_handshake(struct mg_connection *c) { - struct tls_data *tls = c->tls; - struct mg_iobuf *rio = &tls->recv; - struct mg_iobuf *wio = &tls->send; - // Pull data from TCP - for (;;) { - mg_iobuf_resize(rio, rio->len + 1); - long n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len); - if (n > 0) { - rio->len += (size_t) n; - } else if (n == MG_IO_WAIT) { - break; - } else { - mg_error(c, "IO err"); - return; - } - } + // struct tls_data *tls = c->tls; + struct mg_iobuf *rio = &c->raw; + struct mg_iobuf *wio = &c->send; + // Look if we've pulled everything if (rio->len < TLS_HDR_SIZE) return; uint8_t record_type = rio->buf[0]; uint16_t record_len = MG_LOAD_BE16(rio->buf + 3); uint16_t record_version = MG_LOAD_BE16(rio->buf + 1); if (record_type != 22) { - mg_error(c, "no 22"); + mg_error(c, "not a handshake"); return; } if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return; // Got full hello // struct tls_hello *hello = (struct tls_hello *) (hdr + 1); MG_INFO(("CT=%d V=%hx L=%hu", record_type, record_version, record_len)); - mg_hexdump(rio->buf, rio->len); + // mg_hexdump(rio->buf, rio->len); // Send response. Server Hello size_t ofs = wio->len; @@ -128,6 +122,8 @@ void mg_tls_handshake(struct mg_connection *c) { add8(wio, 0); // Compression method: 0 add16(wio, 46); // Extensions length add16(wio, 43), add16(wio, 2), add16(wio, 0x304); // extension: TLS 1.3 + + // Key share: use curve x25519 (id 29) add16(wio, 51), add16(wio, 36), add16(wio, 29), add16(wio, 32); // keyshare mg_iobuf_add(wio, wio->len, NULL, 32); // 32 random mg_random(wio->buf + wio->len - 32, 32); // bytes @@ -147,7 +143,7 @@ void mg_tls_handshake(struct mg_connection *c) { add8(wio, 0), add16(wio, 2), add16(wio, 0); // empty 2 bytes add8(wio, 11); // certificate message add8(wio, 0), add16(wio, 4), add32(wio, 0x1020304); // len - *(uint16_t *) &wio->buf[ofs + 3] = mg_be16((uint16_t)(wio->len - ofs - 5)); + *(uint16_t *) &wio->buf[ofs + 3] = mg_be16((uint16_t) (wio->len - ofs - 5)); mg_io_send(c, wio->buf, wio->len); wio->len = 0; @@ -157,9 +153,35 @@ void mg_tls_handshake(struct mg_connection *c) { mg_error(c, "doh"); } void mg_tls_ctx_free(struct mg_mgr *mgr) { + free(mgr->tls_ctx); mgr->tls_ctx = NULL; } +static void pem_load(struct mg_str pem, struct mg_iobuf *io) { + if (pem.len > 0) { + if (pem.ptr[0] == '-') { + char in[4], out[4]; + size_t i = 0, j = 0; + while (i < pem.len && pem.ptr[i] != '\n') i++; // Skip first line, ----- + for (; i < pem.len && pem.ptr[i] != '-'; i++) { // Until the end ----- + if (pem.ptr[i] == '\r' || pem.ptr[i] == '\n') continue; + in[j++] = pem.ptr[i]; + if (j >= sizeof(in)) { + int n = mg_base64_decode(in, 4, out); + mg_iobuf_add(io, io->len, out, n); + j = 0; + } + } + mg_hexdump(io->buf, io->len); + } else { + mg_iobuf_add(io, io->len, pem.ptr, pem.len); // Already in DER + } + } +} void mg_tls_ctx_init(struct mg_mgr *mgr, const struct mg_tls_opts *opts) { - (void) opts, (void) mgr; + struct tls_ctx *ctx = (struct tls_ctx *) calloc(1, sizeof(*ctx)); + ctx->server_key.align = ctx->server_cert.align = 128; + pem_load(opts->server_cert, &ctx->server_cert); + pem_load(opts->server_key, &ctx->server_key); + mgr->tls_ctx = ctx; } #endif diff --git a/test/fuzz.c b/test/fuzz.c index 1c5dec3271..89a100b162 100644 --- a/test/fuzz.c +++ b/test/fuzz.c @@ -23,7 +23,7 @@ static void fn(struct mg_connection *c, int ev, void *ev_data, void *fn_data) { } int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { - mg_log_set(MG_LL_NONE); + mg_log_set(MG_LL_INFO); struct mg_dns_message dm; mg_dns_parse(data, size, &dm);