Skip to content

Commit

Permalink
TLS layer: use c->raw, enhance tls_builtin
Browse files Browse the repository at this point in the history
  • Loading branch information
cpq committed Aug 29, 2023
1 parent d5b5cec commit 13b4b8a
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 109 deletions.
152 changes: 98 additions & 54 deletions mongoose.c
Original file line number Diff line number Diff line change
Expand Up @@ -3692,7 +3692,7 @@ struct mg_connection *mg_alloc_conn(struct mg_mgr *mgr) {
(struct mg_connection *) calloc(1, sizeof(*c) + mgr->extraconnsize);
if (c != NULL) {
c->mgr = mgr;
c->send.align = c->recv.align = MG_IO_SIZE;
c->send.align = c->recv.align = c->raw.align = MG_IO_SIZE;
c->id = ++mgr->nextid;
}
return c;
Expand All @@ -3711,6 +3711,7 @@ void mg_close_conn(struct mg_connection *c) {
mg_tls_free(c);
mg_iobuf_free(&c->recv);
mg_iobuf_free(&c->send);
mg_iobuf_free(&c->raw);
memset(c, 0, sizeof(*c));
free(c);
}
Expand Down Expand Up @@ -3756,7 +3757,7 @@ struct mg_connection *mg_listen(struct mg_mgr *mgr, const char *url,
c->fn = fn;
c->fn_data = fn_data;
mg_call(c, MG_EV_OPEN, NULL);
if (mg_url_is_ssl(url)) c->is_tls = 1; // Accepted connection must
if (mg_url_is_ssl(url)) c->is_tls = 1; // Accepted connection must
MG_DEBUG(("%lu %p %s", c->id, c->fd, url));
}
return c;
Expand Down Expand Up @@ -3786,6 +3787,14 @@ struct mg_timer *mg_timer_add(struct mg_mgr *mgr, uint64_t milliseconds,
return t;
}

long mg_io_recv(struct mg_connection *c, void *buf, size_t len) {
if (c->raw.len == 0) return MG_IO_WAIT;
if (len > c->raw.len) len = c->raw.len;
memcpy(buf, c->raw.buf, len);
mg_iobuf_del(&c->raw, 0, len);
return (long) len;
}

void mg_mgr_free(struct mg_mgr *mgr) {
struct mg_connection *c;
struct mg_timer *tmp, *t = mgr->timers;
Expand Down Expand Up @@ -4403,20 +4412,10 @@ long mg_io_send(struct mg_connection *c, const void *buf, size_t len) {
return (long) len;
}

long mg_io_recv(struct mg_connection *c, void *buf, size_t len) {
struct connstate *s = (struct connstate *) (c + 1);
if (s->raw.len == 0) return MG_IO_WAIT;
if (len > s->raw.len) len = s->raw.len;
memcpy(buf, s->raw.buf, len);
mg_iobuf_del(&s->raw, 0, len);
return (long) len;
}

static void read_conn(struct mg_connection *c, struct pkt *pkt) {
struct connstate *s = (struct connstate *) (c + 1);
struct mg_iobuf *io = c->is_tls ? &s->raw : &c->recv;
struct mg_iobuf *io = c->is_tls ? &c->raw : &c->recv;
uint32_t seq = mg_ntohl(pkt->tcp->seq);
s->raw.align = c->recv.align;
uint32_t rem_ip;
memcpy(&rem_ip, c->rem.ip, sizeof(uint32_t));
if (pkt->tcp->flags & TH_FIN) {
Expand Down Expand Up @@ -4456,9 +4455,9 @@ static void read_conn(struct mg_connection *c, struct pkt *pkt) {
} else {
// Copy TCP payload into the IO buffer. If the connection is plain text,
// we copy to c->recv. If the connection is TLS, this data is encrypted,
// therefore we copy that encrypted data to the s->raw iobuffer instead,
// therefore we copy that encrypted data to the c->raw iobuffer instead,
// and then call mg_tls_recv() to decrypt it. NOTE: mg_tls_recv() will
// call back mg_io_recv() which grabs raw data from s->raw
// call back mg_io_recv() which grabs raw data from c->raw
memcpy(&io->buf[io->len], pkt->pay.ptr, pkt->pay.len);
io->len += pkt->pay.len;

Expand Down Expand Up @@ -5831,7 +5830,7 @@ bool mg_open_listener(struct mg_connection *c, const char *url) {
return success;
}

long mg_io_recv(struct mg_connection *c, void *buf, size_t len) {
static long recv_raw(struct mg_connection *c, void *buf, size_t len) {
long n = 0;
if (c->is_udp) {
union usa usa;
Expand All @@ -5847,19 +5846,42 @@ long mg_io_recv(struct mg_connection *c, void *buf, size_t len) {
return n;
}

static bool ioalloc(struct mg_connection *c, struct mg_iobuf *io) {
bool res = false;
if (io->len >= MG_MAX_RECV_SIZE) {
mg_error(c, "MG_MAX_RECV_SIZE");
} else if (io->size <= io->len &&
!mg_iobuf_resize(io, io->size + MG_IO_SIZE)) {
mg_error(c, "OOM");
} else {
res = true;
}
return res;
}

// NOTE(lsm): do only one iteration of reads, cause some systems
// (e.g. FreeRTOS stack) return 0 instead of -1/EWOULDBLOCK when no data
static void read_conn(struct mg_connection *c) {
long n = -1;
if (c->recv.len >= MG_MAX_RECV_SIZE) {
mg_error(c, "max_recv_buf_size reached");
} else if (c->recv.size <= c->recv.len &&
!mg_iobuf_resize(&c->recv, c->recv.size + MG_IO_SIZE)) {
mg_error(c, "oom");
} else {
if (ioalloc(c, &c->recv)) {
char *buf = (char *) &c->recv.buf[c->recv.len];
size_t len = c->recv.size - c->recv.len;
n = c->is_tls ? mg_tls_recv(c, buf, len) : mg_io_recv(c, buf, len);
long n = -1;
if (c->is_tls) {
if (!ioalloc(c, &c->raw)) return;
n = recv_raw(c, (char *) &c->raw.buf[c->raw.len],
c->raw.size - c->raw.len);
MG_INFO(("%lu %ld", c->id, n));
if (n == MG_IO_ERR) {
c->is_closing = 1;
} else if (n > 0) {
c->raw.len += (size_t) n;
if (c->is_tls_hs) mg_tls_handshake(c);
if (c->is_tls_hs) return;
n = mg_tls_recv(c, buf, len);
}
} else {
n = recv_raw(c, buf, len);
}
MG_DEBUG(("%lu %p snd %ld/%ld rcv %ld/%ld n=%ld err=%d", c->id, c->fd,
(long) c->send.len, (long) c->send.size, (long) c->recv.len,
(long) c->recv.size, n, MG_SOCK_ERR(n)));
Expand Down Expand Up @@ -6179,8 +6201,8 @@ void mg_mgr_poll(struct mg_mgr *mgr, int ms) {
if (c->is_readable) accept_conn(mgr, c);
} else if (c->is_connecting) {
if (c->is_readable || c->is_writable) connect_conn(c);
} else if (c->is_tls_hs) {
if ((c->is_readable || c->is_writable)) mg_tls_handshake(c);
//} else if (c->is_tls_hs) {
// if ((c->is_readable || c->is_writable)) mg_tls_handshake(c);
} else {
if (c->is_readable) read_conn(c);
if (c->is_writable) write_conn(c);
Expand Down Expand Up @@ -6547,8 +6569,12 @@ void mg_timer_poll(struct mg_timer **head, uint64_t now_ms) {

#if MG_TLS == MG_TLS_BUILTIN
struct tls_data {
struct mg_iobuf send;
struct mg_iobuf recv;
uint8_t client_random[32]; // From client hello
uint8_t client_pub[32]; // From client hello
};
struct tls_ctx {
struct mg_iobuf server_cert; // Decoded server certificate
struct mg_iobuf server_key; // Decoded server private key
};

#define MG_LOAD_BE16(p) ((uint16_t) ((MG_U8P(p)[0] << 8U) | MG_U8P(p)[1]))
Expand All @@ -6561,19 +6587,21 @@ static inline bool mg_is_big_endian(void) {
static inline uint16_t mg_swap16(uint16_t v) {
return (uint16_t) ((v << 8U) | (v >> 8U));
}
static inline uint16_t mg_be16(uint16_t v) {
return mg_is_big_endian() ? mg_swap16(v) : v;
}
#if 0
static inline uint32_t mg_swap32(uint32_t v) {
return (v >> 24) | (v >> 8 & 0xff00) | (v << 8 & 0xff0000) | (v << 24);
}
static inline uint64_t mg_swap64(uint64_t v) {
return (((uint64_t) mg_swap32((uint32_t) v)) << 32) |
mg_swap32((uint32_t) (v >> 32));
}
static inline uint16_t mg_be16(uint16_t v) {
return mg_is_big_endian() ? mg_swap16(v) : v;
}
static inline uint32_t mg_be32(uint32_t v) {
return mg_is_big_endian() ? mg_swap32(v) : v;
}
#endif

static inline void add8(struct mg_iobuf *io, uint8_t data) {
mg_iobuf_add(io, io->len, &data, sizeof(data));
Expand All @@ -6590,7 +6618,7 @@ static inline void add32(struct mg_iobuf *io, uint32_t data) {
void mg_tls_init(struct mg_connection *c, struct mg_str hostname) {
struct tls_data *tls = (struct tls_data *) calloc(1, sizeof(struct tls_data));
if (tls != NULL) {
tls->send.align = tls->recv.align = MG_IO_SIZE;
// tls->send.align = tls->recv.align = MG_IO_SIZE;
c->tls = tls;
c->is_tls = c->is_tls_hs = 1;
} else {
Expand All @@ -6601,8 +6629,8 @@ void mg_tls_init(struct mg_connection *c, struct mg_str hostname) {
void mg_tls_free(struct mg_connection *c) {
struct tls_data *tls = c->tls;
if (tls != NULL) {
mg_iobuf_free(&tls->send);
mg_iobuf_free(&tls->recv);
// mg_iobuf_free(&tls->send);
// mg_iobuf_free(&tls->recv);
}
free(c->tls);
c->tls = NULL;
Expand Down Expand Up @@ -6631,36 +6659,24 @@ size_t mg_tls_pending(struct mg_connection *c) {
return 0;
}
void mg_tls_handshake(struct mg_connection *c) {
struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv;
struct mg_iobuf *wio = &tls->send;
// Pull data from TCP
for (;;) {
mg_iobuf_resize(rio, rio->len + 1);
long n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len);
if (n > 0) {
rio->len += (size_t) n;
} else if (n == MG_IO_WAIT) {
break;
} else {
mg_error(c, "IO err");
return;
}
}
// struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &c->raw;
struct mg_iobuf *wio = &c->send;

// Look if we've pulled everything
if (rio->len < TLS_HDR_SIZE) return;
uint8_t record_type = rio->buf[0];
uint16_t record_len = MG_LOAD_BE16(rio->buf + 3);
uint16_t record_version = MG_LOAD_BE16(rio->buf + 1);
if (record_type != 22) {
mg_error(c, "no 22");
mg_error(c, "not a handshake");
return;
}
if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return;
// Got full hello
// struct tls_hello *hello = (struct tls_hello *) (hdr + 1);
MG_INFO(("CT=%d V=%hx L=%hu", record_type, record_version, record_len));
mg_hexdump(rio->buf, rio->len);
// mg_hexdump(rio->buf, rio->len);

// Send response. Server Hello
size_t ofs = wio->len;
Expand All @@ -6673,6 +6689,8 @@ void mg_tls_handshake(struct mg_connection *c) {
add8(wio, 0); // Compression method: 0
add16(wio, 46); // Extensions length
add16(wio, 43), add16(wio, 2), add16(wio, 0x304); // extension: TLS 1.3

// Key share: use curve x25519 (id 29)
add16(wio, 51), add16(wio, 36), add16(wio, 29), add16(wio, 32); // keyshare
mg_iobuf_add(wio, wio->len, NULL, 32); // 32 random
mg_random(wio->buf + wio->len - 32, 32); // bytes
Expand All @@ -6692,7 +6710,7 @@ void mg_tls_handshake(struct mg_connection *c) {
add8(wio, 0), add16(wio, 2), add16(wio, 0); // empty 2 bytes
add8(wio, 11); // certificate message
add8(wio, 0), add16(wio, 4), add32(wio, 0x1020304); // len
*(uint16_t *) &wio->buf[ofs + 3] = mg_be16((uint16_t)(wio->len - ofs - 5));
*(uint16_t *) &wio->buf[ofs + 3] = mg_be16((uint16_t) (wio->len - ofs - 5));

mg_io_send(c, wio->buf, wio->len);
wio->len = 0;
Expand All @@ -6702,10 +6720,36 @@ void mg_tls_handshake(struct mg_connection *c) {
mg_error(c, "doh");
}
void mg_tls_ctx_free(struct mg_mgr *mgr) {
free(mgr->tls_ctx);
mgr->tls_ctx = NULL;
}
static void pem_load(struct mg_str pem, struct mg_iobuf *io) {
if (pem.len > 0) {
if (pem.ptr[0] == '-') {
char in[4], out[4];
size_t i = 0, j = 0;
while (i < pem.len && pem.ptr[i] != '\n') i++; // Skip first line, -----
for (; i < pem.len && pem.ptr[i] != '-'; i++) { // Until the end -----
if (pem.ptr[i] == '\r' || pem.ptr[i] == '\n') continue;
in[j++] = pem.ptr[i];
if (j >= sizeof(in)) {
int n = mg_base64_decode(in, 4, out);
mg_iobuf_add(io, io->len, out, n);
j = 0;
}
}
mg_hexdump(io->buf, io->len);
} else {
mg_iobuf_add(io, io->len, pem.ptr, pem.len); // Already in DER
}
}
}
void mg_tls_ctx_init(struct mg_mgr *mgr, const struct mg_tls_opts *opts) {
(void) opts, (void) mgr;
struct tls_ctx *ctx = (struct tls_ctx *) calloc(1, sizeof(*ctx));
ctx->server_key.align = ctx->server_cert.align = 128;
pem_load(opts->server_cert, &ctx->server_cert);
pem_load(opts->server_key, &ctx->server_key);
mgr->tls_ctx = ctx;
}
#endif

Expand Down
1 change: 1 addition & 0 deletions mongoose.h
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,7 @@ struct mg_connection {
unsigned long id; // Auto-incrementing unique connection ID
struct mg_iobuf recv; // Incoming data
struct mg_iobuf send; // Outgoing data
struct mg_iobuf raw; // TLS only. Incoming encrypted data
mg_event_handler_t fn; // User-specified event handler function
void *fn_data; // User-specified function parameter
mg_event_handler_t pfn; // Protocol-specific handler function
Expand Down
13 changes: 11 additions & 2 deletions src/net.c
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ struct mg_connection *mg_alloc_conn(struct mg_mgr *mgr) {
(struct mg_connection *) calloc(1, sizeof(*c) + mgr->extraconnsize);
if (c != NULL) {
c->mgr = mgr;
c->send.align = c->recv.align = MG_IO_SIZE;
c->send.align = c->recv.align = c->raw.align = MG_IO_SIZE;
c->id = ++mgr->nextid;
}
return c;
Expand All @@ -143,6 +143,7 @@ void mg_close_conn(struct mg_connection *c) {
mg_tls_free(c);
mg_iobuf_free(&c->recv);
mg_iobuf_free(&c->send);
mg_iobuf_free(&c->raw);
memset(c, 0, sizeof(*c));
free(c);
}
Expand Down Expand Up @@ -188,7 +189,7 @@ struct mg_connection *mg_listen(struct mg_mgr *mgr, const char *url,
c->fn = fn;
c->fn_data = fn_data;
mg_call(c, MG_EV_OPEN, NULL);
if (mg_url_is_ssl(url)) c->is_tls = 1; // Accepted connection must
if (mg_url_is_ssl(url)) c->is_tls = 1; // Accepted connection must
MG_DEBUG(("%lu %p %s", c->id, c->fd, url));
}
return c;
Expand Down Expand Up @@ -218,6 +219,14 @@ struct mg_timer *mg_timer_add(struct mg_mgr *mgr, uint64_t milliseconds,
return t;
}

long mg_io_recv(struct mg_connection *c, void *buf, size_t len) {
if (c->raw.len == 0) return MG_IO_WAIT;
if (len > c->raw.len) len = c->raw.len;
memcpy(buf, c->raw.buf, len);
mg_iobuf_del(&c->raw, 0, len);
return (long) len;
}

void mg_mgr_free(struct mg_mgr *mgr) {
struct mg_connection *c;
struct mg_timer *tmp, *t = mgr->timers;
Expand Down
1 change: 1 addition & 0 deletions src/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct mg_connection {
unsigned long id; // Auto-incrementing unique connection ID
struct mg_iobuf recv; // Incoming data
struct mg_iobuf send; // Outgoing data
struct mg_iobuf raw; // TLS only. Incoming encrypted data
mg_event_handler_t fn; // User-specified event handler function
void *fn_data; // User-specified function parameter
mg_event_handler_t pfn; // Protocol-specific handler function
Expand Down
16 changes: 3 additions & 13 deletions src/net_builtin.c
Original file line number Diff line number Diff line change
Expand Up @@ -572,20 +572,10 @@ long mg_io_send(struct mg_connection *c, const void *buf, size_t len) {
return (long) len;
}

long mg_io_recv(struct mg_connection *c, void *buf, size_t len) {
struct connstate *s = (struct connstate *) (c + 1);
if (s->raw.len == 0) return MG_IO_WAIT;
if (len > s->raw.len) len = s->raw.len;
memcpy(buf, s->raw.buf, len);
mg_iobuf_del(&s->raw, 0, len);
return (long) len;
}

static void read_conn(struct mg_connection *c, struct pkt *pkt) {
struct connstate *s = (struct connstate *) (c + 1);
struct mg_iobuf *io = c->is_tls ? &s->raw : &c->recv;
struct mg_iobuf *io = c->is_tls ? &c->raw : &c->recv;
uint32_t seq = mg_ntohl(pkt->tcp->seq);
s->raw.align = c->recv.align;
uint32_t rem_ip;
memcpy(&rem_ip, c->rem.ip, sizeof(uint32_t));
if (pkt->tcp->flags & TH_FIN) {
Expand Down Expand Up @@ -625,9 +615,9 @@ static void read_conn(struct mg_connection *c, struct pkt *pkt) {
} else {
// Copy TCP payload into the IO buffer. If the connection is plain text,
// we copy to c->recv. If the connection is TLS, this data is encrypted,
// therefore we copy that encrypted data to the s->raw iobuffer instead,
// therefore we copy that encrypted data to the c->raw iobuffer instead,
// and then call mg_tls_recv() to decrypt it. NOTE: mg_tls_recv() will
// call back mg_io_recv() which grabs raw data from s->raw
// call back mg_io_recv() which grabs raw data from c->raw
memcpy(&io->buf[io->len], pkt->pay.ptr, pkt->pay.len);
io->len += pkt->pay.len;

Expand Down
Loading

0 comments on commit 13b4b8a

Please sign in to comment.