Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mbedtls] Support mbedtls 3.x and fix bugs and compiler warnings. #11646

Merged
merged 7 commits into from
Apr 26, 2024
4 changes: 2 additions & 2 deletions libs/mbedtls/mbedtls.ml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ external mbedtls_ssl_setup : mbedtls_ssl_context -> mbedtls_ssl_config -> mbedtl
external mbedtls_ssl_write : mbedtls_ssl_context -> bytes -> int -> int -> mbedtls_result = "ml_mbedtls_ssl_write"

external mbedtls_pk_init : unit -> mbedtls_pk_context = "ml_mbedtls_pk_init"
external mbedtls_pk_parse_key : mbedtls_pk_context -> bytes -> string option -> mbedtls_result = "ml_mbedtls_pk_parse_key"
external mbedtls_pk_parse_keyfile : mbedtls_pk_context -> string -> string option -> mbedtls_result = "ml_mbedtls_pk_parse_keyfile"
external mbedtls_pk_parse_key : mbedtls_pk_context -> bytes -> string option -> mbedtls_ctr_drbg_context -> mbedtls_result = "ml_mbedtls_pk_parse_key"
external mbedtls_pk_parse_keyfile : mbedtls_pk_context -> string -> string option -> mbedtls_ctr_drbg_context -> mbedtls_result = "ml_mbedtls_pk_parse_keyfile"
external mbedtls_pk_parse_public_keyfile : mbedtls_pk_context -> string -> mbedtls_result = "ml_mbedtls_pk_parse_public_keyfile"
external mbedtls_pk_parse_public_key : mbedtls_pk_context -> bytes -> mbedtls_result = "ml_mbedtls_pk_parse_public_key"

Expand Down
103 changes: 51 additions & 52 deletions libs/mbedtls/mbedtls_stubs.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <ctype.h>
#include <string.h>
#include <stdio.h>

Expand All @@ -18,13 +17,10 @@
#include <caml/callback.h>
#include <caml/custom.h>

#include "mbedtls/debug.h"
#include "mbedtls/error.h"
#include "mbedtls/config.h"
#include "mbedtls/ssl.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/certs.h"
#include "mbedtls/oid.h"

#define PVoid_val(v) (*((void**) Data_custom_val(v)))
Expand Down Expand Up @@ -84,7 +80,7 @@ CAMLprim value ml_mbedtls_ctr_drbg_init(void) {

CAMLprim value ml_mbedtls_ctr_drbg_random(value p_rng, value output, value output_len) {
CAMLparam3(p_rng, output, output_len);
CAMLreturn(Val_int(mbedtls_ctr_drbg_random(CtrDrbg_val(p_rng), String_val(output), Int_val(output_len))));
CAMLreturn(Val_int(mbedtls_ctr_drbg_random(CtrDrbg_val(p_rng), Bytes_val(output), Int_val(output_len))));
}

CAMLprim value ml_mbedtls_ctr_drbg_seed(value ctx, value p_entropy, value custom) {
Expand Down Expand Up @@ -124,7 +120,7 @@ CAMLprim value ml_mbedtls_entropy_init(void) {

CAMLprim value ml_mbedtls_entropy_func(value data, value output, value len) {
CAMLparam3(data, output, len);
CAMLreturn(Val_int(mbedtls_entropy_func(PVoid_val(data), String_val(output), Int_val(len))));
CAMLreturn(Val_int(mbedtls_entropy_func(PVoid_val(data), Bytes_val(output), Int_val(len))));
}

// Certificate
Expand Down Expand Up @@ -171,7 +167,7 @@ CAMLprim value ml_mbedtls_x509_next(value chain) {

CAMLprim value ml_mbedtls_x509_crt_parse(value chain, value bytes) {
CAMLparam2(chain, bytes);
const char* buf = String_val(bytes);
const unsigned char* buf = Bytes_val(bytes);
int len = caml_string_length(bytes);
CAMLreturn(Val_int(mbedtls_x509_crt_parse(X509Crt_val(chain), buf, len + 1)));
}
Expand All @@ -191,16 +187,19 @@ CAMLprim value ml_mbedtls_x509_crt_parse_path(value chain, value path) {
value caml_string_of_asn1_buf(mbedtls_asn1_buf* dat) {
CAMLparam0();
CAMLlocal1(s);
s = caml_alloc_string(dat->len);
memcpy(String_val(s), dat->p, dat->len);
s = caml_alloc_initialized_string(dat->len, (const char *)dat->p);
CAMLreturn(s);
}

CAMLprim value hx_cert_get_alt_names(value chain) {
CAMLparam1(chain);
CAMLlocal1(obj);
mbedtls_x509_crt* cert = X509Crt_val(chain);
if (cert->ext_types & MBEDTLS_X509_EXT_SUBJECT_ALT_NAME == 0 || &cert->subject_alt_names == NULL) {
#if MBEDTLS_VERSION_MAJOR >= 3
if (!mbedtls_x509_crt_has_ext_type(cert, MBEDTLS_X509_EXT_SUBJECT_ALT_NAME)) {
#else
if ((cert->ext_types & MBEDTLS_X509_EXT_SUBJECT_ALT_NAME) == 0) {
#endif
obj = Atom(0);
} else {
mbedtls_asn1_sequence* cur = &cert->subject_alt_names;
Expand Down Expand Up @@ -366,29 +365,39 @@ CAMLprim value ml_mbedtls_pk_init(void) {
CAMLreturn(obj);
}

CAMLprim value ml_mbedtls_pk_parse_key(value ctx, value key, value password) {
CAMLparam3(ctx, key, password);
const char* pwd = NULL;
CAMLprim value ml_mbedtls_pk_parse_key(value ctx, value key, value password, value rng) {
CAMLparam4(ctx, key, password, rng);
const unsigned char* pwd = NULL;
size_t pwdlen = 0;
if (password != Val_none) {
pwd = String_val(Field(password, 0));
pwd = Bytes_val(Field(password, 0));
pwdlen = caml_string_length(Field(password, 0));
}
CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), String_val(key), caml_string_length(key) + 1, pwd, pwdlen));
#if MBEDTLS_VERSION_MAJOR >= 3
mbedtls_ctr_drbg_context *ctr_drbg = CtrDrbg_val(rng);
CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1, pwd, pwdlen, mbedtls_ctr_drbg_random, NULL));
#else
CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1, pwd, pwdlen));
#endif
}

CAMLprim value ml_mbedtls_pk_parse_keyfile(value ctx, value path, value password) {
CAMLparam3(ctx, path, password);
CAMLprim value ml_mbedtls_pk_parse_keyfile(value ctx, value path, value password, value rng) {
CAMLparam4(ctx, path, password, rng);
const char* pwd = NULL;
if (password != Val_none) {
pwd = String_val(Field(password, 0));
}
#if MBEDTLS_VERSION_MAJOR >= 3
mbedtls_ctr_drbg_context *ctr_drbg = CtrDrbg_val(rng);
CAMLreturn(mbedtls_pk_parse_keyfile(PkContext_val(ctx), String_val(path), pwd, mbedtls_ctr_drbg_random, ctr_drbg));
#else
CAMLreturn(mbedtls_pk_parse_keyfile(PkContext_val(ctx), String_val(path), pwd));
#endif
}

CAMLprim value ml_mbedtls_pk_parse_public_key(value ctx, value key) {
CAMLparam2(ctx, key);
CAMLreturn(mbedtls_pk_parse_public_key(PkContext_val(ctx), String_val(key), caml_string_length(key) + 1));
CAMLreturn(mbedtls_pk_parse_public_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1));
}

CAMLprim value ml_mbedtls_pk_parse_public_keyfile(value ctx, value path) {
Expand Down Expand Up @@ -446,23 +455,22 @@ CAMLprim value ml_mbedtls_ssl_handshake(value ssl) {

CAMLprim value ml_mbedtls_ssl_read(value ssl, value buf, value pos, value len) {
CAMLparam4(ssl, buf, pos, len);
CAMLreturn(Val_int(mbedtls_ssl_read(SslContext_val(ssl), String_val(buf) + Int_val(pos), Int_val(len))));
CAMLreturn(Val_int(mbedtls_ssl_read(SslContext_val(ssl), Bytes_val(buf) + Int_val(pos), Int_val(len))));
}

static int bio_write_cb(void* ctx, const unsigned char* buf, size_t len) {
CAMLparam0();
CAMLlocal3(r, s, vctx);
vctx = (value)ctx;
s = caml_alloc_string(len);
memcpy(String_val(s), buf, len);
vctx = *(value*)ctx;
s = caml_alloc_initialized_string(len, (const char*)buf);
r = caml_callback2(Field(vctx, 1), Field(vctx, 0), s);
CAMLreturn(Int_val(r));
}

static int bio_read_cb(void* ctx, unsigned char* buf, size_t len) {
CAMLparam0();
CAMLlocal3(r, s, vctx);
vctx = (value)ctx;
vctx = *(value*)ctx;
s = caml_alloc_string(len);
r = caml_callback2(Field(vctx, 2), Field(vctx, 0), s);
memcpy(buf, String_val(s), len);
Expand All @@ -476,7 +484,11 @@ CAMLprim value ml_mbedtls_ssl_set_bio(value ssl, value p_bio, value f_send, valu
Store_field(ctx, 0, p_bio);
Store_field(ctx, 1, f_send);
Store_field(ctx, 2, f_recv);
mbedtls_ssl_set_bio(SslContext_val(ssl), (void*)ctx, bio_write_cb, bio_read_cb, NULL);
// TODO: this allocation is leaked
value *location = malloc(sizeof(value));
*location = ctx;
caml_register_generational_global_root(location);
mbedtls_ssl_set_bio(SslContext_val(ssl), (void*)location, bio_write_cb, bio_read_cb, NULL);
CAMLreturn(Val_unit);
}

Expand All @@ -492,7 +504,7 @@ CAMLprim value ml_mbedtls_ssl_setup(value ssl, value conf) {

CAMLprim value ml_mbedtls_ssl_write(value ssl, value buf, value pos, value len) {
CAMLparam4(ssl, buf, pos, len);
CAMLreturn(Val_int(mbedtls_ssl_write(SslContext_val(ssl), String_val(buf) + Int_val(pos), Int_val(len))));
CAMLreturn(Val_int(mbedtls_ssl_write(SslContext_val(ssl), Bytes_val(buf) + Int_val(pos), Int_val(len))));
}

// glue
Expand Down Expand Up @@ -520,36 +532,23 @@ CAMLprim value hx_cert_load_defaults(value certificate) {
#endif

#ifdef __APPLE__
CFMutableDictionaryRef search;
CFArrayRef result;
SecKeychainRef keychain;
SecCertificateRef item;
CFDataRef dat;
// Load keychain
if (SecKeychainOpen("/System/Library/Keychains/SystemRootCertificates.keychain", &keychain) == errSecSuccess) {
// Search for certificates
search = CFDictionaryCreateMutable(NULL, 0, NULL, NULL);
CFDictionarySetValue(search, kSecClass, kSecClassCertificate);
CFDictionarySetValue(search, kSecMatchLimit, kSecMatchLimitAll);
CFDictionarySetValue(search, kSecReturnRef, kCFBooleanTrue);
CFDictionarySetValue(search, kSecMatchSearchList, CFArrayCreate(NULL, (const void **)&keychain, 1, NULL));
if (SecItemCopyMatching(search, (CFTypeRef *)&result) == errSecSuccess) {
CFIndex n = CFArrayGetCount(result);
for (CFIndex i = 0; i < n; i++) {
item = (SecCertificateRef)CFArrayGetValueAtIndex(result, i);

// Get certificate in DER format
dat = SecCertificateCopyData(item);
if (dat) {
r = mbedtls_x509_crt_parse_der(chain, (unsigned char *)CFDataGetBytePtr(dat), CFDataGetLength(dat));
CFRelease(dat);
if (r != 0) {
CAMLreturn(Val_int(r));
}
CFArrayRef certs;
if (SecTrustCopyAnchorCertificates(&certs) == errSecSuccess) {
CFIndex count = CFArrayGetCount(certs);
for(CFIndex i = 0; i < count; i++) {
SecCertificateRef item = (SecCertificateRef)CFArrayGetValueAtIndex(certs, i);

// Get certificate in DER format
CFDataRef data = SecCertificateCopyData(item);
if(data) {
r = mbedtls_x509_crt_parse_der(chain, (unsigned char *)CFDataGetBytePtr(data), CFDataGetLength(data));
CFRelease(data);
if (r != 0) {
CAMLreturn(Val_int(r));
}
}
}
CFRelease(keychain);
CFRelease(certs);
}
#endif

Expand Down
8 changes: 4 additions & 4 deletions src/macro/eval/evalSsl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ let init_fields init_fields builtins =
"strerror",vfun1 (fun code -> encode_string (mbedtls_strerror (decode_int code)));
] [];
init_fields builtins (["mbedtls"],"PkContext") [] [
"parse_key",vifun2 (fun this key password ->
vint (mbedtls_pk_parse_key (as_pk_context this) (decode_bytes key) (match password with VNull -> None | _ -> Some (decode_string password)));
"parse_key",vifun3 (fun this key password rng ->
vint (mbedtls_pk_parse_key (as_pk_context this) (decode_bytes key) (match password with VNull -> None | _ -> Some (decode_string password)) (as_ctr_drbg rng));
);
"parse_keyfile",vifun2 (fun this path password ->
vint (mbedtls_pk_parse_keyfile (as_pk_context this) (decode_string path) (match password with VNull -> None | _ -> Some (decode_string password)));
"parse_keyfile",vifun3 (fun this path password rng ->
vint (mbedtls_pk_parse_keyfile (as_pk_context this) (decode_string path) (match password with VNull -> None | _ -> Some (decode_string password)) (as_ctr_drbg rng));
);
"parse_public_key",vifun1 (fun this key ->
vint (mbedtls_pk_parse_public_key (as_pk_context this) (decode_bytes key));
Expand Down
4 changes: 2 additions & 2 deletions std/eval/_std/mbedtls/PkContext.hx
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import haxe.io.Bytes;
extern class PkContext {
function new():Void;

function parse_key(key:Bytes, ?pwd:String):Int;
function parse_keyfile(path:String, ?password:String):Int;
function parse_key(key:Bytes, ?pwd:String, ctr_dbg: CtrDrbg):Int;
function parse_keyfile(path:String, ?password:String, ctr_dbg: CtrDrbg):Int;
function parse_public_key(key:Bytes):Int;
function parse_public_keyfile(path:String):Int;
}
4 changes: 2 additions & 2 deletions std/eval/_std/sys/ssl/Key.hx
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Key {
var code = if (isPublic) {
key.native.parse_public_keyfile(file);
} else {
key.native.parse_keyfile(file, pass);
key.native.parse_keyfile(file, pass, Mbedtls.getDefaultCtrDrbg());
}
if (code != 0) {
throw(mbedtls.Error.strerror(code));
Expand All @@ -51,7 +51,7 @@ class Key {
var code = if (isPublic) {
key.native.parse_public_key(data);
} else {
key.native.parse_key(data);
key.native.parse_key(data, null, Mbedtls.getDefaultCtrDrbg());
}
if (code != 0) {
throw(mbedtls.Error.strerror(code));
Expand Down