Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TLS layer: c->rtls to optimise recvd TLS data #2523

Merged
merged 1 commit into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 121 additions & 59 deletions mongoose.c
Original file line number Diff line number Diff line change
Expand Up @@ -4292,7 +4292,8 @@ static bool mg_aton6(struct mg_str str, struct mg_addr *addr) {
} else if (str.ptr[i] == '%') { // Scope ID
for (i = i + 1; i < str.len; i++) {
if (str.ptr[i] < '0' || str.ptr[i] > '9') return false;
addr->scope_id *= 10, addr->scope_id += (uint8_t) (str.ptr[i] - '0');
addr->scope_id = (uint8_t) (addr->scope_id * 10);
addr->scope_id = (uint8_t) (addr->scope_id + (str.ptr[i] - '0'));
}
} else {
return false;
Expand All @@ -4319,7 +4320,7 @@ struct mg_connection *mg_alloc_conn(struct mg_mgr *mgr) {
(struct mg_connection *) calloc(1, sizeof(*c) + mgr->extraconnsize);
if (c != NULL) {
c->mgr = mgr;
c->send.align = c->recv.align = MG_IO_SIZE;
c->send.align = c->recv.align = c->rtls.align = MG_IO_SIZE;
c->id = ++mgr->nextid;
MG_PROF_INIT(c);
}
Expand All @@ -4341,6 +4342,7 @@ void mg_close_conn(struct mg_connection *c) {
mg_tls_free(c);
mg_iobuf_free(&c->recv);
mg_iobuf_free(&c->send);
mg_iobuf_free(&c->rtls);
mg_bzero((unsigned char *) c, sizeof(*c));
free(c);
}
Expand Down Expand Up @@ -4413,6 +4415,14 @@ struct mg_timer *mg_timer_add(struct mg_mgr *mgr, uint64_t milliseconds,
return t;
}

long mg_io_recv(struct mg_connection *c, void *buf, size_t len) {
if (c->rtls.len == 0) return MG_IO_WAIT;
if (len > c->rtls.len) len = c->rtls.len;
memcpy(buf, c->rtls.buf, len);
mg_iobuf_del(&c->rtls, 0, len);
return (long) len;
}

void mg_mgr_free(struct mg_mgr *mgr) {
struct mg_connection *c;
struct mg_timer *tmp, *t = mgr->timers;
Expand Down Expand Up @@ -5054,20 +5064,10 @@ long mg_io_send(struct mg_connection *c, const void *buf, size_t len) {
return (long) len;
}

long mg_io_recv(struct mg_connection *c, void *buf, size_t len) {
struct connstate *s = (struct connstate *) (c + 1);
if (s->raw.len == 0) return MG_IO_WAIT;
if (len > s->raw.len) len = s->raw.len;
memcpy(buf, s->raw.buf, len);
mg_iobuf_del(&s->raw, 0, len);
return (long) len;
}

static void read_conn(struct mg_connection *c, struct pkt *pkt) {
struct connstate *s = (struct connstate *) (c + 1);
struct mg_iobuf *io = c->is_tls ? &s->raw : &c->recv;
struct mg_iobuf *io = c->is_tls ? &c->rtls : &c->recv;
uint32_t seq = mg_ntohl(pkt->tcp->seq);
s->raw.align = c->recv.align;
uint32_t rem_ip;
memcpy(&rem_ip, c->rem.ip, sizeof(uint32_t));
if (pkt->tcp->flags & TH_FIN) {
Expand Down Expand Up @@ -5107,9 +5107,9 @@ static void read_conn(struct mg_connection *c, struct pkt *pkt) {
} else {
// Copy TCP payload into the IO buffer. If the connection is plain text,
// we copy to c->recv. If the connection is TLS, this data is encrypted,
// therefore we copy that encrypted data to the s->raw iobuffer instead,
// therefore we copy that encrypted data to the c->rtls iobuffer instead,
// and then call mg_tls_recv() to decrypt it. NOTE: mg_tls_recv() will
// call back mg_io_recv() which grabs raw data from s->raw
// call back mg_io_recv() which grabs raw data from c->rtls
memcpy(&io->buf[io->len], pkt->pay.ptr, pkt->pay.len);
io->len += pkt->pay.len;

Expand Down Expand Up @@ -6735,7 +6735,7 @@ bool mg_open_listener(struct mg_connection *c, const char *url) {
return success;
}

long mg_io_recv(struct mg_connection *c, void *buf, size_t len) {
static long recv_raw(struct mg_connection *c, void *buf, size_t len) {
long n = 0;
if (c->is_udp) {
union usa usa;
Expand All @@ -6751,20 +6751,43 @@ long mg_io_recv(struct mg_connection *c, void *buf, size_t len) {
return n;
}

static bool ioalloc(struct mg_connection *c, struct mg_iobuf *io) {
bool res = false;
if (io->len >= MG_MAX_RECV_SIZE) {
mg_error(c, "MG_MAX_RECV_SIZE");
} else if (io->size <= io->len &&
!mg_iobuf_resize(io, io->size + MG_IO_SIZE)) {
mg_error(c, "OOM");
} else {
res = true;
}
return res;
}

// NOTE(lsm): do only one iteration of reads, cause some systems
// (e.g. FreeRTOS stack) return 0 instead of -1/EWOULDBLOCK when no data
static void read_conn(struct mg_connection *c) {
long n = -1;
if (c->recv.len >= MG_MAX_RECV_SIZE) {
mg_error(c, "max_recv_buf_size reached");
} else if (c->recv.size <= c->recv.len &&
!mg_iobuf_resize(&c->recv, c->recv.size + MG_IO_SIZE)) {
mg_error(c, "oom");
} else {
if (ioalloc(c, &c->recv)) {
char *buf = (char *) &c->recv.buf[c->recv.len];
size_t len = c->recv.size - c->recv.len;
n = c->is_tls ? mg_tls_recv(c, buf, len) : mg_io_recv(c, buf, len);
MG_DEBUG(("%lu %ld snd %ld/%ld rcv %ld/%ld n=%ld err=%d", c->id, c->fd,
long n = -1;
if (c->is_tls) {
if (!ioalloc(c, &c->rtls)) return;
n = recv_raw(c, (char *) &c->rtls.buf[c->rtls.len],
c->rtls.size - c->rtls.len);
// MG_DEBUG(("%lu %ld", c->id, n));
if (n == MG_IO_ERR) {
c->is_closing = 1;
} else if (n > 0) {
c->rtls.len += (size_t) n;
if (c->is_tls_hs) mg_tls_handshake(c);
if (c->is_tls_hs) return;
n = mg_tls_recv(c, buf, len);
}
} else {
n = recv_raw(c, buf, len);
}
MG_DEBUG(("%lu %p snd %ld/%ld rcv %ld/%ld n=%ld err=%d", c->id, c->fd,
(long) c->send.len, (long) c->send.size, (long) c->recv.len,
(long) c->recv.size, n, MG_SOCK_ERR(n)));
iolog(c, buf, n, true);
Expand Down Expand Up @@ -7086,8 +7109,8 @@ void mg_mgr_poll(struct mg_mgr *mgr, int ms) {
if (c->is_readable) accept_conn(mgr, c);
} else if (c->is_connecting) {
if (c->is_readable || c->is_writable) connect_conn(c);
} else if (c->is_tls_hs) {
if ((c->is_readable || c->is_writable)) mg_tls_handshake(c);
//} else if (c->is_tls_hs) {
// if ((c->is_readable || c->is_writable)) mg_tls_handshake(c);
} else {
if (c->is_readable) read_conn(c);
if (c->is_writable) write_conn(c);
Expand Down Expand Up @@ -7454,8 +7477,12 @@ void mg_timer_poll(struct mg_timer **head, uint64_t now_ms) {

#if MG_TLS == MG_TLS_BUILTIN
struct tls_data {
struct mg_iobuf send;
struct mg_iobuf recv;
uint8_t client_random[32]; // From client hello
uint8_t client_pub[32]; // From client hello
};
struct tls_ctx {
struct mg_iobuf server_cert; // Decoded server certificate
struct mg_iobuf server_key; // Decoded server private key
};

#define MG_LOAD_BE16(p) ((uint16_t) ((MG_U8P(p)[0] << 8U) | MG_U8P(p)[1]))
Expand All @@ -7468,19 +7495,21 @@ static inline bool mg_is_big_endian(void) {
static inline uint16_t mg_swap16(uint16_t v) {
return (uint16_t) ((v << 8U) | (v >> 8U));
}
static inline uint16_t mg_be16(uint16_t v) {
return mg_is_big_endian() ? mg_swap16(v) : v;
}
#if 0
static inline uint32_t mg_swap32(uint32_t v) {
return (v >> 24) | (v >> 8 & 0xff00) | (v << 8 & 0xff0000) | (v << 24);
}
static inline uint64_t mg_swap64(uint64_t v) {
return (((uint64_t) mg_swap32((uint32_t) v)) << 32) |
mg_swap32((uint32_t) (v >> 32));
}
static inline uint16_t mg_be16(uint16_t v) {
return mg_is_big_endian() ? mg_swap16(v) : v;
}
static inline uint32_t mg_be32(uint32_t v) {
return mg_is_big_endian() ? mg_swap32(v) : v;
}
#endif

static inline void add8(struct mg_iobuf *io, uint8_t data) {
mg_iobuf_add(io, io->len, &data, sizeof(data));
Expand All @@ -7497,7 +7526,7 @@ static inline void add32(struct mg_iobuf *io, uint32_t data) {
void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
struct tls_data *tls = (struct tls_data *) calloc(1, sizeof(struct tls_data));
if (tls != NULL) {
tls->send.align = tls->recv.align = MG_IO_SIZE;
// tls->send.align = tls->recv.align = MG_IO_SIZE;
c->tls = tls;
c->is_tls = c->is_tls_hs = 1;
} else {
Expand All @@ -7508,8 +7537,8 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
void mg_tls_free(struct mg_connection *c) {
struct tls_data *tls = c->tls;
if (tls != NULL) {
mg_iobuf_free(&tls->send);
mg_iobuf_free(&tls->recv);
// mg_iobuf_free(&tls->send);
// mg_iobuf_free(&tls->recv);
}
free(c->tls);
c->tls = NULL;
Expand Down Expand Up @@ -7538,36 +7567,24 @@ size_t mg_tls_pending(struct mg_connection *c) {
return 0;
}
void mg_tls_handshake(struct mg_connection *c) {
struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv;
struct mg_iobuf *wio = &tls->send;
// Pull data from TCP
for (;;) {
mg_iobuf_resize(rio, rio->len + 1);
long n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len);
if (n > 0) {
rio->len += (size_t) n;
} else if (n == MG_IO_WAIT) {
break;
} else {
mg_error(c, "IO err");
return;
}
}
// struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &c->raw;
struct mg_iobuf *wio = &c->send;

// Look if we've pulled everything
if (rio->len < TLS_HDR_SIZE) return;
uint8_t record_type = rio->buf[0];
uint16_t record_len = MG_LOAD_BE16(rio->buf + 3);
uint16_t record_version = MG_LOAD_BE16(rio->buf + 1);
if (record_type != 22) {
mg_error(c, "no 22");
mg_error(c, "not a handshake");
return;
}
if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return;
// Got full hello
// struct tls_hello *hello = (struct tls_hello *) (hdr + 1);
MG_INFO(("CT=%d V=%hx L=%hu", record_type, record_version, record_len));
mg_hexdump(rio->buf, rio->len);
// mg_hexdump(rio->buf, rio->len);

// Send response. Server Hello
size_t ofs = wio->len;
Expand All @@ -7580,6 +7597,8 @@ void mg_tls_handshake(struct mg_connection *c) {
add8(wio, 0); // Compression method: 0
add16(wio, 46); // Extensions length
add16(wio, 43), add16(wio, 2), add16(wio, 0x304); // extension: TLS 1.3

// Key share: use curve x25519 (id 29)
add16(wio, 51), add16(wio, 36), add16(wio, 29), add16(wio, 32); // keyshare
mg_iobuf_add(wio, wio->len, NULL, 32); // 32 random
mg_random(wio->buf + wio->len - 32, 32); // bytes
Expand All @@ -7599,7 +7618,7 @@ void mg_tls_handshake(struct mg_connection *c) {
add8(wio, 0), add16(wio, 2), add16(wio, 0); // empty 2 bytes
add8(wio, 11); // certificate message
add8(wio, 0), add16(wio, 4), add32(wio, 0x1020304); // len
*(uint16_t *) &wio->buf[ofs + 3] = mg_be16((uint16_t)(wio->len - ofs - 5));
*(uint16_t *) &wio->buf[ofs + 3] = mg_be16((uint16_t) (wio->len - ofs - 5));

mg_io_send(c, wio->buf, wio->len);
wio->len = 0;
Expand All @@ -7609,6 +7628,7 @@ void mg_tls_handshake(struct mg_connection *c) {
mg_error(c, "doh");
}
void mg_tls_ctx_free(struct mg_mgr *mgr) {
free(mgr->tls_ctx);
mgr->tls_ctx = NULL;
}
void mg_tls_ctx_init(struct mg_mgr *mgr) {
Expand Down Expand Up @@ -7759,6 +7779,7 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
}
if (c->is_listening) goto fail;
MG_DEBUG(("%lu Setting TLS", c->id));
MG_PROF_ADD(c, "mbedtls_init_start");
mbedtls_ssl_init(&tls->ssl);
mbedtls_ssl_config_init(&tls->conf);
mbedtls_x509_crt_init(&tls->ca);
Expand Down Expand Up @@ -7811,6 +7832,7 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
c->is_tls = 1;
c->is_tls_hs = 1;
mbedtls_ssl_set_bio(&tls->ssl, c, mg_net_send, mg_net_recv, 0);
MG_PROF_ADD(c, "mbedtls_init_end");
if (c->is_client && c->is_resolving == 0 && c->is_connecting == 0) {
mg_tls_handshake(c);
}
Expand Down Expand Up @@ -7882,7 +7904,7 @@ void mg_tls_ctx_free(struct mg_mgr *mgr) {
static int tls_err_cb(const char *s, size_t len, void *c) {
int n = (int) len - 1;
MG_ERROR(("%lu %.*s", ((struct mg_connection *) c)->id, n, s));
return 0; // undocumented
return 0; // undocumented
}

static int mg_tls_err(struct mg_connection *c, struct mg_tls *tls, int res) {
Expand Down Expand Up @@ -7938,10 +7960,41 @@ static X509 *load_cert(struct mg_str s) {
return cert;
}


static long mg_bio_ctrl(BIO *b, int cmd, long larg, void *pargs) {
long ret = 0;
if (cmd == BIO_CTRL_PUSH) ret = 1;
if (cmd == BIO_CTRL_POP) ret = 1;
if (cmd == BIO_CTRL_FLUSH) ret = 1;
if (cmd == BIO_C_SET_NBIO) ret = 1;
// MG_DEBUG(("%d -> %ld", cmd, ret));
(void) b, (void) cmd, (void) larg, (void) pargs;
return ret;
}

static int mg_bio_read(BIO *bio, char *buf, int len) {
struct mg_connection *c = (struct mg_connection *) BIO_get_data(bio);
long res = mg_io_recv(c, buf, (size_t) len);
// MG_DEBUG(("%p %d %ld", buf, len, res));
len = res > 0 ? (int) res : -1;
if (res == MG_IO_WAIT) BIO_set_retry_read(bio);
return len;
}

static int mg_bio_write(BIO *bio, const char *buf, int len) {
struct mg_connection *c = (struct mg_connection *) BIO_get_data(bio);
long res = mg_io_send(c, buf, (size_t) len);
// MG_DEBUG(("%p %d %ld", buf, len, res));
len = res > 0 ? (int) res : -1;
if (res == MG_IO_WAIT) BIO_set_retry_write(bio);
return len;
}

void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
struct mg_tls *tls = (struct mg_tls *) calloc(1, sizeof(*tls));
const char *id = "mongoose";
static unsigned char s_initialised = 0;
BIO *bio = NULL;
int rc;

if (tls == NULL) {
Expand Down Expand Up @@ -8006,7 +8059,7 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {

SSL_set_mode(tls->ssl, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
#if OPENSSL_VERSION_NUMBER > 0x10002000L
SSL_set_ecdh_auto(tls->ssl, 1);
(void) SSL_set_ecdh_auto(tls->ssl, 1);
#endif
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
if (opts->name.len > 0) {
Expand All @@ -8016,6 +8069,16 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
free(s);
}
#endif

tls->bm = BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "bio_mg");
BIO_meth_set_write(tls->bm, mg_bio_write);
BIO_meth_set_read(tls->bm, mg_bio_read);
BIO_meth_set_ctrl(tls->bm, mg_bio_ctrl);

bio = BIO_new(tls->bm);
BIO_set_data(bio, c);
SSL_set_bio(tls->ssl, bio, bio);

c->tls = tls;
c->is_tls = 1;
c->is_tls_hs = 1;
Expand All @@ -8030,9 +8093,7 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {

void mg_tls_handshake(struct mg_connection *c) {
struct mg_tls *tls = (struct mg_tls *) c->tls;
int rc;
SSL_set_fd(tls->ssl, (int) (size_t) c->fd);
rc = c->is_client ? SSL_connect(tls->ssl) : SSL_accept(tls->ssl);
int rc = c->is_client ? SSL_connect(tls->ssl) : SSL_accept(tls->ssl);
if (rc == 1) {
MG_DEBUG(("%lu success", c->id));
c->is_tls_hs = 0;
Expand All @@ -8048,6 +8109,7 @@ void mg_tls_free(struct mg_connection *c) {
if (tls == NULL) return;
SSL_free(tls->ssl);
SSL_CTX_free(tls->ctx);
BIO_meth_free(tls->bm);
free(tls);
c->tls = NULL;
}
Expand Down
Loading