diff --git a/crypto/ec_extra/ec_derive.c b/crypto/ec_extra/ec_derive.c index 6904d7bd09..7e666d5f69 100644 --- a/crypto/ec_extra/ec_derive.c +++ b/crypto/ec_extra/ec_derive.c @@ -55,7 +55,8 @@ EC_KEY *EC_KEY_derive_from_secret(const EC_GROUP *group, const uint8_t *secret, } uint8_t derived[EC_KEY_DERIVE_EXTRA_BYTES + EC_MAX_BYTES]; - size_t derived_len = BN_num_bytes(&group->order) + EC_KEY_DERIVE_EXTRA_BYTES; + size_t derived_len = + BN_num_bytes(EC_GROUP_get0_order(group)) + EC_KEY_DERIVE_EXTRA_BYTES; assert(derived_len <= sizeof(derived)); if (!HKDF(derived, derived_len, EVP_sha256(), secret, secret_len, /*salt=*/NULL, /*salt_len=*/0, (const uint8_t *)info, @@ -74,10 +75,10 @@ EC_KEY *EC_KEY_derive_from_secret(const EC_GROUP *group, const uint8_t *secret, // enough. 2^(num_bytes(order)) < 2^8 * order, so: // // priv < 2^8 * order * 2^128 < order * order < order * R - !BN_from_montgomery(priv, priv, group->order_mont, ctx) || + !BN_from_montgomery(priv, priv, group->order, ctx) || // Multiply by R^2 and do another Montgomery reduction to compute // priv * R^-1 * R^2 * R^-1 = priv mod order. - !BN_to_montgomery(priv, priv, group->order_mont, ctx) || + !BN_to_montgomery(priv, priv, group->order, ctx) || !EC_POINT_mul(group, pub, priv, NULL, NULL, ctx) || !EC_KEY_set_group(key, group) || !EC_KEY_set_public_key(key, pub) || !EC_KEY_set_private_key(key, priv)) { diff --git a/crypto/ec_extra/hash_to_curve.c b/crypto/ec_extra/hash_to_curve.c index 7ddd41ec42..0cc046edad 100644 --- a/crypto/ec_extra/hash_to_curve.c +++ b/crypto/ec_extra/hash_to_curve.c @@ -185,15 +185,16 @@ static int hash_to_field2(const EC_GROUP *group, const EVP_MD *md, static int hash_to_scalar(const EC_GROUP *group, const EVP_MD *md, EC_SCALAR *out, const uint8_t *dst, size_t dst_len, unsigned k, const uint8_t *msg, size_t msg_len) { + const BIGNUM *order = EC_GROUP_get0_order(group); size_t L; uint8_t buf[EC_MAX_BYTES * 2]; - if (!num_bytes_to_derive(&L, &group->order, k) || + if (!num_bytes_to_derive(&L, order, k) || !expand_message_xmd(md, buf, L, msg, msg_len, dst, dst_len)) { return 0; } BN_ULONG words[2 * EC_MAX_WORDS]; - size_t num_words = 2 * group->order.width; + size_t num_words = 2 * order->width; bn_big_endian_to_words(words, num_words, buf, L); ec_scalar_reduce(group, out, words, num_words); return 1; diff --git a/crypto/evp_extra/p_ec_asn1.c b/crypto/evp_extra/p_ec_asn1.c index f9d873f450..daaff9c334 100644 --- a/crypto/evp_extra/p_ec_asn1.c +++ b/crypto/evp_extra/p_ec_asn1.c @@ -204,7 +204,7 @@ static int ec_bits(const EVP_PKEY *pkey) { ERR_clear_error(); return 0; } - return BN_num_bits(EC_GROUP_get0_order(group)); + return EC_GROUP_order_bits(group); } static int ec_missing_parameters(const EVP_PKEY *pkey) { diff --git a/crypto/fipsmodule/ec/ec.c b/crypto/fipsmodule/ec/ec.c index d05e29e1d9..7dfd3c0c92 100644 --- a/crypto/fipsmodule/ec/ec.c +++ b/crypto/fipsmodule/ec/ec.c @@ -338,7 +338,6 @@ EC_GROUP *ec_group_new(const EC_METHOD *meth) { ret->references = 1; ret->meth = meth; - BN_init(&ret->order); if (!meth->group_init(ret)) { OPENSSL_free(ret); @@ -352,30 +351,13 @@ static int ec_group_set_generator(EC_GROUP *group, const EC_AFFINE *generator, const BIGNUM *order) { assert(group->generator == NULL); - if (!BN_copy(&group->order, order)) { - return 0; - } - // Store the order in minimal form, so it can be used with |BN_ULONG| arrays. - bn_set_minimal_width(&group->order); - - BN_MONT_CTX_free(group->order_mont); - group->order_mont = BN_MONT_CTX_new_for_modulus(&group->order, NULL); - if (group->order_mont == NULL) { + BN_MONT_CTX_free(group->order); + group->order = BN_MONT_CTX_new_for_modulus(order, NULL); + if (group->order == NULL) { return 0; } group->field_greater_than_order = BN_cmp(&group->field, order) > 0; - if (group->field_greater_than_order) { - BIGNUM tmp; - BN_init(&tmp); - int ok = - BN_sub(&tmp, &group->field, order) && - bn_copy_words(group->field_minus_order.words, group->field.width, &tmp); - BN_free(&tmp); - if (!ok) { - return 0; - } - } group->generator = EC_POINT_new(group); if (group->generator == NULL) { @@ -609,8 +591,7 @@ void EC_GROUP_free(EC_GROUP *group) { } ec_point_free(group->generator, 0 /* don't free group */); - BN_free(&group->order); - BN_MONT_CTX_free(group->order_mont); + BN_MONT_CTX_free(group->order); OPENSSL_free(group); } @@ -649,7 +630,7 @@ int EC_GROUP_cmp(const EC_GROUP *a, const EC_GROUP *b, BN_CTX *ignored) { return a->meth != b->meth || a->generator == NULL || b->generator == NULL || - BN_cmp(&a->order, &b->order) != 0 || + BN_cmp(&a->order->N, &b->order->N) != 0 || BN_cmp(&a->field, &b->field) != 0 || !ec_felem_equal(a, &a->a, &b->a) || !ec_felem_equal(a, &a->b, &b->b) || @@ -661,8 +642,8 @@ const EC_POINT *EC_GROUP_get0_generator(const EC_GROUP *group) { } const BIGNUM *EC_GROUP_get0_order(const EC_GROUP *group) { - assert(!BN_is_zero(&group->order)); - return &group->order; + assert(group->order != NULL); + return &group->order->N; } int EC_GROUP_get_order(const EC_GROUP *group, BIGNUM *order, BN_CTX *ctx) { @@ -673,7 +654,7 @@ int EC_GROUP_get_order(const EC_GROUP *group, BIGNUM *order, BN_CTX *ctx) { } int EC_GROUP_order_bits(const EC_GROUP *group) { - return BN_num_bits(&group->order); + return BN_num_bits(&group->order->N); } int EC_GROUP_get_cofactor(const EC_GROUP *group, BIGNUM *cofactor, @@ -979,11 +960,10 @@ static int arbitrary_bignum_to_scalar(const EC_GROUP *group, EC_SCALAR *out, ERR_clear_error(); // This is an unusual input, so we do not guarantee constant-time processing. - const BIGNUM *order = &group->order; BN_CTX_start(ctx); BIGNUM *tmp = BN_CTX_get(ctx); int ok = tmp != NULL && - BN_nnmod(tmp, in, order, ctx) && + BN_nnmod(tmp, in, EC_GROUP_get0_order(group), ctx) && ec_bignum_to_scalar(group, out, tmp); BN_CTX_end(ctx); return ok; diff --git a/crypto/fipsmodule/ec/ec_key.c b/crypto/fipsmodule/ec/ec_key.c index 3cdd39d37e..bd163badc2 100644 --- a/crypto/fipsmodule/ec/ec_key.c +++ b/crypto/fipsmodule/ec/ec_key.c @@ -93,8 +93,8 @@ static EC_WRAPPED_SCALAR *ec_wrapped_scalar_new(const EC_GROUP *group) { OPENSSL_memset(wrapped, 0, sizeof(EC_WRAPPED_SCALAR)); wrapped->bignum.d = wrapped->scalar.words; - wrapped->bignum.width = group->order.width; - wrapped->bignum.dmax = group->order.width; + wrapped->bignum.width = group->order->N.width; + wrapped->bignum.dmax = group->order->N.width; wrapped->bignum.flags = BN_FLG_STATIC_DATA; return wrapped; } @@ -486,7 +486,7 @@ int EC_KEY_generate_key(EC_KEY *key) { } // Check that the group order is FIPS compliant (FIPS 186-4 B.4.2). - if (BN_num_bits(EC_GROUP_get0_order(key->group)) < 160) { + if (EC_GROUP_order_bits(key->group) < 160) { OPENSSL_PUT_ERROR(EC, EC_R_INVALID_GROUP_ORDER); return 0; } diff --git a/crypto/fipsmodule/ec/ec_montgomery.c b/crypto/fipsmodule/ec/ec_montgomery.c index 78e0507699..8d99238ad7 100644 --- a/crypto/fipsmodule/ec/ec_montgomery.c +++ b/crypto/fipsmodule/ec/ec_montgomery.c @@ -457,7 +457,7 @@ static int ec_GFp_mont_cmp_x_coordinate(const EC_GROUP *group, const EC_JACOBIAN *p, const EC_SCALAR *r) { if (!group->field_greater_than_order || - group->field.width != group->order.width) { + group->field.width != group->order->N.width) { // Do not bother optimizing this case. p > order in all commonly-used // curves. return ec_GFp_simple_cmp_x_coordinate(group, p, r); @@ -485,10 +485,11 @@ static int ec_GFp_mont_cmp_x_coordinate(const EC_GROUP *group, // Therefore there is a small possibility, less than 1/2^128, that group_order // < p.x < P. in that case we need not only to compare against |r| but also to // compare against r+group_order. - if (bn_less_than_words(r->words, group->field_minus_order.words, - group->field.width)) { - // We can ignore the carry because: r + group_order < p < 2^256. - bn_add_words(r_Z2.words, r->words, group->order.d, group->field.width); + BN_ULONG carry = + bn_add_words(r_Z2.words, r->words, group->order->N.d, group->field.width); + if (carry == 0 && + bn_less_than_words(r_Z2.words, group->field.d, group->field.width)) { + // r + group_order < p, so compare (r + group_order) * Z^2 against X. ec_GFp_mont_felem_mul(group, &r_Z2, &r_Z2, &Z2_mont); if (ec_felem_equal(group, &r_Z2, &X)) { return 1; diff --git a/crypto/fipsmodule/ec/ec_test.cc b/crypto/fipsmodule/ec/ec_test.cc index d797c733c5..e2507cb4fc 100644 --- a/crypto/fipsmodule/ec/ec_test.cc +++ b/crypto/fipsmodule/ec/ec_test.cc @@ -1267,8 +1267,8 @@ TEST(ECTest, SmallGroupOrder) { bssl::UniquePtr key2(EC_KEY_new()); ASSERT_TRUE(key2); ASSERT_TRUE(EC_KEY_set_group(key2.get(), group.get())); - BN_clear(&key2.get()->group->order); - ASSERT_TRUE(BN_set_word(&key2.get()->group->order, 7)); + BN_clear(&key2.get()->group->order->N); + ASSERT_TRUE(BN_set_word(&key2.get()->group->order->N, 7)); ASSERT_FALSE(EC_KEY_generate_key_fips(key2.get())); } @@ -1325,8 +1325,8 @@ TEST(ECDeathTest, SmallGroupOrderAndDie) { bssl::UniquePtr key2(EC_KEY_new()); ASSERT_TRUE(key2); ASSERT_TRUE(EC_KEY_set_group(key2.get(), group.get())); - BN_clear(&key2.get()->group->order); - ASSERT_TRUE(BN_set_word(&key2.get()->group->order, 7)); + BN_clear(&key2.get()->group->order->N); + ASSERT_TRUE(BN_set_word(&key2.get()->group->order->N, 7)); ASSERT_DEATH_IF_SUPPORTED(EC_KEY_generate_key_fips(key2.get()), ""); } diff --git a/crypto/fipsmodule/ec/internal.h b/crypto/fipsmodule/ec/internal.h index 31c28349bd..5a2dee0485 100644 --- a/crypto/fipsmodule/ec/internal.h +++ b/crypto/fipsmodule/ec/internal.h @@ -441,7 +441,7 @@ void ec_precomp_select(const EC_GROUP *group, EC_PRECOMP *out, BN_ULONG mask, // ec_cmp_x_coordinate compares the x (affine) coordinate of |p|, mod the group // order, with |r|. It returns one if the values match and zero if |p| is the -// point at infinity of the values do not match. +// point at infinity of the values do not match. |p| is treated as public. int ec_cmp_x_coordinate(const EC_GROUP *group, const EC_JACOBIAN *p, const EC_SCALAR *r); @@ -619,11 +619,10 @@ struct ec_group_st { // to avoid a reference cycle. Additionally, Z is guaranteed to be one, so X // and Y are suitable for use as an |EC_AFFINE|. EC_POINT *generator; - BIGNUM order; int curve_name; // optional NID for named curve - BN_MONT_CTX *order_mont; // data for ECDSA inverse + BN_MONT_CTX *order; // The following members are handled by the method functions, // even if they appear generic @@ -640,13 +639,6 @@ struct ec_group_st { // otherwise. int field_greater_than_order; - // field_minus_order, if |field_greater_than_order| is true, is |field| minus - // |order| represented as an |EC_FELEM|. Otherwise, it is zero. - // - // Note: unlike |EC_FELEM|s used as intermediate values internal to the - // |EC_METHOD|, this value is not encoded in Montgomery form. - EC_FELEM field_minus_order; - CRYPTO_refcount_t references; BN_MONT_CTX *mont; // Montgomery structure. diff --git a/crypto/fipsmodule/ec/p256-nistz.c b/crypto/fipsmodule/ec/p256-nistz.c index 0229a1aaea..19d0306a43 100644 --- a/crypto/fipsmodule/ec/p256-nistz.c +++ b/crypto/fipsmodule/ec/p256-nistz.c @@ -583,8 +583,8 @@ static int ecp_nistz256_scalar_to_montgomery_inv_vartime(const EC_GROUP *group, } #endif - assert(group->order.width == P256_LIMBS); - if (!beeu_mod_inverse_vartime(out->words, in->words, group->order.d)) { + assert(group->order->N.width == P256_LIMBS); + if (!beeu_mod_inverse_vartime(out->words, in->words, group->order->N.d)) { return 0; } @@ -602,7 +602,7 @@ static int ecp_nistz256_cmp_x_coordinate(const EC_GROUP *group, return 0; } - assert(group->order.width == P256_LIMBS); + assert(group->order->N.width == P256_LIMBS); assert(group->field.width == P256_LIMBS); // We wish to compare X/Z^2 with r. This is equivalent to comparing X with @@ -621,10 +621,9 @@ static int ecp_nistz256_cmp_x_coordinate(const EC_GROUP *group, // Therefore there is a small possibility, less than 1/2^128, that group_order // < p.x < P. in that case we need not only to compare against |r| but also to // compare against r+group_order. - if (bn_less_than_words(r->words, group->field_minus_order.words, - P256_LIMBS)) { - // We can ignore the carry because: r + group_order < p < 2^256. - bn_add_words(r_Z2, r->words, group->order.d, P256_LIMBS); + BN_ULONG carry = bn_add_words(r_Z2, r->words, group->order->N.d, P256_LIMBS); + if (carry == 0 && bn_less_than_words(r_Z2, group->field.d, P256_LIMBS)) { + // r + group_order < p, so compare (r + group_order) * Z^2 against X. ecp_nistz256_mul_mont(r_Z2, r_Z2, Z2_mont); if (OPENSSL_memcmp(r_Z2, X, sizeof(r_Z2)) == 0) { return 1; diff --git a/crypto/fipsmodule/ec/p256.c b/crypto/fipsmodule/ec/p256.c index 8684d1e1b0..561416cdb6 100644 --- a/crypto/fipsmodule/ec/p256.c +++ b/crypto/fipsmodule/ec/p256.c @@ -710,12 +710,12 @@ static int ec_GFp_nistp256_cmp_x_coordinate(const EC_GROUP *group, // Therefore there is a small possibility, less than 1/2^128, that group_order // < p.x < P. in that case we need not only to compare against |r| but also to // compare against r+group_order. - assert(group->field.width == group->order.width); - if (bn_less_than_words(r->words, group->field_minus_order.words, - group->field.width)) { - // We can ignore the carry because: r + group_order < p < 2^256. - EC_FELEM tmp; - bn_add_words(tmp.words, r->words, group->order.d, group->order.width); + assert(group->field.width == group->order->N.width); + EC_FELEM tmp; + BN_ULONG carry = + bn_add_words(tmp.words, r->words, group->order->N.d, group->field.width); + if (carry == 0 && + bn_less_than_words(tmp.words, group->field.d, group->field.width)) { fiat_p256_from_generic(r_Z2, &tmp); fiat_p256_mul(r_Z2, r_Z2, Z2_mont); if (OPENSSL_memcmp(&r_Z2, &X, sizeof(r_Z2)) == 0) { diff --git a/crypto/fipsmodule/ec/p384.c b/crypto/fipsmodule/ec/p384.c index 966729f458..f8a64a91a4 100644 --- a/crypto/fipsmodule/ec/p384.c +++ b/crypto/fipsmodule/ec/p384.c @@ -568,7 +568,7 @@ static void ec_GFp_nistp384_mont_felem_to_bytes( p384_felem_from_mont(tmp, tmp); p384_to_generic(&felem_tmp, tmp); - bn_words_to_big_endian(out, len, felem_tmp.words, group->order.width); + bn_words_to_big_endian(out, len, felem_tmp.words, group->order->N.width); *out_len = len; } @@ -619,12 +619,12 @@ static int ec_GFp_nistp384_cmp_x_coordinate(const EC_GROUP *group, // that group_order < p.x < p. // In that case, we need not only to compare against |r| but also to // compare against r+group_order. - assert(group->field.width == group->order.width); - if (bn_less_than_words(r->words, group->field_minus_order.words, - group->field.width)) { - // We can ignore the carry because: r + group_order < p < 2^384. - EC_FELEM tmp; - bn_add_words(tmp.words, r->words, group->order.d, group->order.width); + assert(group->field.width == group->order->N.width); + EC_FELEM tmp; + BN_ULONG carry = + bn_add_words(tmp.words, r->words, group->order->N.d, group->field.width); + if (carry == 0 && + bn_less_than_words(tmp.words, group->field.d, group->field.width)) { p384_from_generic(r_Z2, &tmp); p384_felem_mul(r_Z2, r_Z2, Z2_mont); if (OPENSSL_memcmp(&r_Z2, &X, sizeof(r_Z2)) == 0) { diff --git a/crypto/fipsmodule/ec/scalar.c b/crypto/fipsmodule/ec/scalar.c index 036049e090..5c6f664b78 100644 --- a/crypto/fipsmodule/ec/scalar.c +++ b/crypto/fipsmodule/ec/scalar.c @@ -23,8 +23,9 @@ int ec_bignum_to_scalar(const EC_GROUP *group, EC_SCALAR *out, const BIGNUM *in) { - if (!bn_copy_words(out->words, group->order.width, in) || - !bn_less_than_words(out->words, group->order.d, group->order.width)) { + if (!bn_copy_words(out->words, group->order->N.width, in) || + !bn_less_than_words(out->words, group->order->N.d, + group->order->N.width)) { OPENSSL_PUT_ERROR(EC, EC_R_INVALID_SCALAR); return 0; } @@ -34,12 +35,12 @@ int ec_bignum_to_scalar(const EC_GROUP *group, EC_SCALAR *out, int ec_scalar_equal_vartime(const EC_GROUP *group, const EC_SCALAR *a, const EC_SCALAR *b) { return OPENSSL_memcmp(a->words, b->words, - group->order.width * sizeof(BN_ULONG)) == 0; + group->order->N.width * sizeof(BN_ULONG)) == 0; } int ec_scalar_is_zero(const EC_GROUP *group, const EC_SCALAR *a) { BN_ULONG mask = 0; - for (int i = 0; i < group->order.width; i++) { + for (int i = 0; i < group->order->N.width; i++) { mask |= a->words[i]; } return mask == 0; @@ -47,27 +48,28 @@ int ec_scalar_is_zero(const EC_GROUP *group, const EC_SCALAR *a) { int ec_random_nonzero_scalar(const EC_GROUP *group, EC_SCALAR *out, const uint8_t additional_data[32]) { - return bn_rand_range_words(out->words, 1, group->order.d, group->order.width, - additional_data); + return bn_rand_range_words(out->words, 1, group->order->N.d, + group->order->N.width, additional_data); } void ec_scalar_to_bytes(const EC_GROUP *group, uint8_t *out, size_t *out_len, const EC_SCALAR *in) { - size_t len = BN_num_bytes(&group->order); - bn_words_to_big_endian(out, len, in->words, group->order.width); + size_t len = BN_num_bytes(&group->order->N); + bn_words_to_big_endian(out, len, in->words, group->order->N.width); *out_len = len; } int ec_scalar_from_bytes(const EC_GROUP *group, EC_SCALAR *out, const uint8_t *in, size_t len) { - if (len != BN_num_bytes(&group->order)) { + if (len != BN_num_bytes(&group->order->N)) { OPENSSL_PUT_ERROR(EC, EC_R_INVALID_SCALAR); return 0; } - bn_big_endian_to_words(out->words, group->order.width, in, len); + bn_big_endian_to_words(out->words, group->order->N.width, in, len); - if (!bn_less_than_words(out->words, group->order.d, group->order.width)) { + if (!bn_less_than_words(out->words, group->order->N.d, + group->order->N.width)) { OPENSSL_PUT_ERROR(EC, EC_R_INVALID_SCALAR); return 0; } @@ -78,15 +80,15 @@ int ec_scalar_from_bytes(const EC_GROUP *group, EC_SCALAR *out, void ec_scalar_reduce(const EC_GROUP *group, EC_SCALAR *out, const BN_ULONG *words, size_t num) { // Convert "from" Montgomery form so the value is reduced modulo the order. - bn_from_montgomery_small(out->words, group->order.width, words, num, - group->order_mont); + bn_from_montgomery_small(out->words, group->order->N.width, words, num, + group->order); // Convert "to" Montgomery form to remove the R^-1 factor added. ec_scalar_to_montgomery(group, out, out); } void ec_scalar_add(const EC_GROUP *group, EC_SCALAR *r, const EC_SCALAR *a, const EC_SCALAR *b) { - const BIGNUM *order = &group->order; + const BIGNUM *order = &group->order->N; BN_ULONG tmp[EC_MAX_WORDS]; bn_mod_add_words(r->words, a->words, b->words, order->d, tmp, order->width); OPENSSL_cleanse(tmp, sizeof(tmp)); @@ -94,7 +96,7 @@ void ec_scalar_add(const EC_GROUP *group, EC_SCALAR *r, const EC_SCALAR *a, void ec_scalar_sub(const EC_GROUP *group, EC_SCALAR *r, const EC_SCALAR *a, const EC_SCALAR *b) { - const BIGNUM *order = &group->order; + const BIGNUM *order = &group->order->N; BN_ULONG tmp[EC_MAX_WORDS]; bn_mod_sub_words(r->words, a->words, b->words, order->d, tmp, order->width); OPENSSL_cleanse(tmp, sizeof(tmp)); @@ -108,35 +110,35 @@ void ec_scalar_neg(const EC_GROUP *group, EC_SCALAR *r, const EC_SCALAR *a) { void ec_scalar_select(const EC_GROUP *group, EC_SCALAR *out, BN_ULONG mask, const EC_SCALAR *a, const EC_SCALAR *b) { - const BIGNUM *order = &group->order; + const BIGNUM *order = &group->order->N; bn_select_words(out->words, mask, a->words, b->words, order->width); } void ec_scalar_to_montgomery(const EC_GROUP *group, EC_SCALAR *r, const EC_SCALAR *a) { - const BIGNUM *order = &group->order; - bn_to_montgomery_small(r->words, a->words, order->width, group->order_mont); + const BIGNUM *order = &group->order->N; + bn_to_montgomery_small(r->words, a->words, order->width, group->order); } void ec_scalar_from_montgomery(const EC_GROUP *group, EC_SCALAR *r, const EC_SCALAR *a) { - const BIGNUM *order = &group->order; + const BIGNUM *order = &group->order->N; bn_from_montgomery_small(r->words, order->width, a->words, order->width, - group->order_mont); + group->order); } void ec_scalar_mul_montgomery(const EC_GROUP *group, EC_SCALAR *r, const EC_SCALAR *a, const EC_SCALAR *b) { - const BIGNUM *order = &group->order; + const BIGNUM *order = &group->order->N; bn_mod_mul_montgomery_small(r->words, a->words, b->words, order->width, - group->order_mont); + group->order); } void ec_simple_scalar_inv0_montgomery(const EC_GROUP *group, EC_SCALAR *r, const EC_SCALAR *a) { - const BIGNUM *order = &group->order; + const BIGNUM *order = &group->order->N; bn_mod_inverse0_prime_mont_small(r->words, a->words, order->width, - group->order_mont); + group->order); } int ec_simple_scalar_to_montgomery_inv_vartime(const EC_GROUP *group, diff --git a/crypto/fipsmodule/ec/simple_mul.c b/crypto/fipsmodule/ec/simple_mul.c index b7ce6f1358..487c27c722 100644 --- a/crypto/fipsmodule/ec/simple_mul.c +++ b/crypto/fipsmodule/ec/simple_mul.c @@ -40,7 +40,7 @@ void ec_GFp_mont_mul(const EC_GROUP *group, EC_JACOBIAN *r, } // Divide bits in |scalar| into windows. - unsigned bits = BN_num_bits(&group->order); + unsigned bits = EC_GROUP_order_bits(group); int r_is_at_infinity = 1; for (unsigned i = bits - 1; i < bits; i--) { if (!r_is_at_infinity) { @@ -48,7 +48,7 @@ void ec_GFp_mont_mul(const EC_GROUP *group, EC_JACOBIAN *r, } if (i % 5 == 0) { // Compute the next window value. - const size_t width = group->order.width; + const size_t width = group->order->N.width; uint8_t window = bn_is_bit_set_words(scalar->words, width, i + 4) << 4; window |= bn_is_bit_set_words(scalar->words, width, i + 3) << 3; window |= bn_is_bit_set_words(scalar->words, width, i + 2) << 2; @@ -99,7 +99,7 @@ static void ec_GFp_mont_batch_get_window(const EC_GROUP *group, EC_JACOBIAN *out, const EC_JACOBIAN precomp[17], const EC_SCALAR *scalar, unsigned i) { - const size_t width = group->order.width; + const size_t width = group->order->N.width; uint8_t window = bn_is_bit_set_words(scalar->words, width, i + 4) << 5; window |= bn_is_bit_set_words(scalar->words, width, i + 3) << 4; window |= bn_is_bit_set_words(scalar->words, width, i + 2) << 3; @@ -141,7 +141,7 @@ void ec_GFp_mont_mul_batch(const EC_GROUP *group, EC_JACOBIAN *r, } // Divide bits in |scalar| into windows. - unsigned bits = BN_num_bits(&group->order); + unsigned bits = EC_GROUP_order_bits(group); int r_is_at_infinity = 1; for (unsigned i = bits; i <= bits; i--) { if (!r_is_at_infinity) { @@ -216,7 +216,7 @@ static void ec_GFp_mont_get_comb_window(const EC_GROUP *group, EC_JACOBIAN *out, const EC_PRECOMP *precomp, const EC_SCALAR *scalar, unsigned i) { - const size_t width = group->order.width; + const size_t width = group->order->N.width; unsigned stride = ec_GFp_mont_comb_stride(group); // Select the bits corresponding to the comb shifted up by |i|. unsigned window = 0; diff --git a/crypto/fipsmodule/ec/wnaf.c b/crypto/fipsmodule/ec/wnaf.c index beb9295434..c9935fcb40 100644 --- a/crypto/fipsmodule/ec/wnaf.c +++ b/crypto/fipsmodule/ec/wnaf.c @@ -138,8 +138,8 @@ void ec_compute_wNAF(const EC_GROUP *group, int8_t *out, // we shift and add at most one copy of |bit|, this will continue to hold // afterwards. window_val >>= 1; - window_val += - bit * bn_is_bit_set_words(scalar->words, group->order.width, j + w + 1); + window_val += bit * bn_is_bit_set_words(scalar->words, + group->order->N.width, j + w + 1); assert(window_val <= next_bit); } @@ -183,7 +183,7 @@ int ec_GFp_mont_mul_public_batch(const EC_GROUP *group, EC_JACOBIAN *r, const EC_SCALAR *g_scalar, const EC_JACOBIAN *points, const EC_SCALAR *scalars, size_t num) { - size_t bits = BN_num_bits(&group->order); + size_t bits = EC_GROUP_order_bits(group); size_t wNAF_len = bits + 1; int ret = 0; diff --git a/crypto/fipsmodule/ecdsa/ecdsa.c b/crypto/fipsmodule/ecdsa/ecdsa.c index 3d38ef3472..67d820b411 100644 --- a/crypto/fipsmodule/ecdsa/ecdsa.c +++ b/crypto/fipsmodule/ecdsa/ecdsa.c @@ -72,7 +72,7 @@ // ECDSA. static void digest_to_scalar(const EC_GROUP *group, EC_SCALAR *out, const uint8_t *digest, size_t digest_len) { - const BIGNUM *order = &group->order; + const BIGNUM *order = EC_GROUP_get0_order(group); size_t num_bits = BN_num_bits(order); // Need to truncate digest if it is too long: first truncate whole bytes. size_t num_bytes = (num_bits + 7) / 8; diff --git a/crypto/fipsmodule/rand/internal.h b/crypto/fipsmodule/rand/internal.h index 839b5718b0..eb0c7c192b 100644 --- a/crypto/fipsmodule/rand/internal.h +++ b/crypto/fipsmodule/rand/internal.h @@ -26,9 +26,16 @@ extern "C" { #endif -#if !defined(OPENSSL_WINDOWS) && !defined(OPENSSL_FUCHSIA) && \ - !defined(BORINGSSL_UNSAFE_DETERMINISTIC_MODE) && !defined(OPENSSL_TRUSTY) -#define OPENSSL_URANDOM +#if defined(BORINGSSL_UNSAFE_DETERMINISTIC_MODE) +#define OPENSSL_RAND_DETERMINISTIC +#elif defined(OPENSSL_FUCHSIA) +#define OPENSSL_RAND_FUCHSIA +#elif defined(OPENSSL_TRUSTY) +// Trusty's PRNG file is, for now, maintained outside the tree. +#elif defined(OPENSSL_WINDOWS) +#define OPENSSL_RAND_WINDOWS +#else +#define OPENSSL_RAND_URANDOM #endif // RAND_bytes_with_additional_data samples from the RNG after mixing 32 bytes @@ -45,15 +52,15 @@ void CRYPTO_sysrand(uint8_t *buf, size_t len); // depending on the vendor's configuration. void CRYPTO_sysrand_for_seed(uint8_t *buf, size_t len); -#if defined(OPENSSL_URANDOM) || defined(OPENSSL_WINDOWS) +#if defined(OPENSSL_RAND_URANDOM) || defined(OPENSSL_RAND_WINDOWS) // CRYPTO_init_sysrand initializes long-lived resources needed to draw entropy // from the operating system. void CRYPTO_init_sysrand(void); #else OPENSSL_INLINE void CRYPTO_init_sysrand(void) {} -#endif // defined(OPENSSL_URANDOM) || defined(OPENSSL_WINDOWS) +#endif // defined(OPENSSL_RAND_URANDOM) || defined(OPENSSL_RAND_WINDOWS) -#if defined(OPENSSL_URANDOM) +#if defined(OPENSSL_RAND_URANDOM) // CRYPTO_sysrand_if_available fills |len| bytes at |buf| with entropy from the // operating system, or early /dev/urandom data, and returns 1, _if_ the entropy // pool is initialized or if getrandom() is not available and not in FIPS mode. @@ -65,7 +72,7 @@ OPENSSL_INLINE int CRYPTO_sysrand_if_available(uint8_t *buf, size_t len) { CRYPTO_sysrand(buf, len); return 1; } -#endif // defined(OPENSSL_URANDOM) +#endif // defined(OPENSSL_RAND_URANDOM) // rand_fork_unsafe_buffering_enabled returns whether fork-unsafe buffering has // been enabled via |RAND_enable_fork_unsafe_buffering|. diff --git a/crypto/fipsmodule/rand/urandom.c b/crypto/fipsmodule/rand/urandom.c index 386cc742c2..fcc3dc09ef 100644 --- a/crypto/fipsmodule/rand/urandom.c +++ b/crypto/fipsmodule/rand/urandom.c @@ -20,7 +20,7 @@ #include "internal.h" -#if defined(OPENSSL_URANDOM) +#if defined(OPENSSL_RAND_URANDOM) #include #include @@ -480,4 +480,4 @@ int CRYPTO_sysrand_if_available(uint8_t *out, size_t requested) { } } -#endif // defined(OPENSSL_URANDOM) +#endif // OPENSSL_RAND_URANDOM diff --git a/crypto/rand_extra/deterministic.c b/crypto/rand_extra/deterministic.c index 435f063382..f4ede82093 100644 --- a/crypto/rand_extra/deterministic.c +++ b/crypto/rand_extra/deterministic.c @@ -14,14 +14,15 @@ #include -#if defined(BORINGSSL_UNSAFE_DETERMINISTIC_MODE) +#include "../fipsmodule/rand/internal.h" + +#if defined(OPENSSL_RAND_DETERMINISTIC) #include #include #include "../internal.h" -#include "../fipsmodule/rand/internal.h" // g_num_calls is the number of calls to |CRYPTO_sysrand| that have occurred. @@ -53,4 +54,4 @@ void CRYPTO_sysrand_for_seed(uint8_t *out, size_t requested) { CRYPTO_sysrand(out, requested); } -#endif // BORINGSSL_UNSAFE_DETERMINISTIC_MODE +#endif // OPENSSL_RAND_DETERMINISTIC diff --git a/crypto/rand_extra/fuchsia.c b/crypto/rand_extra/fuchsia.c index ee6cfdbac7..d4fb9797ad 100644 --- a/crypto/rand_extra/fuchsia.c +++ b/crypto/rand_extra/fuchsia.c @@ -14,15 +14,15 @@ #include -#if defined(OPENSSL_FUCHSIA) && !defined(BORINGSSL_UNSAFE_DETERMINISTIC_MODE) +#include "../fipsmodule/rand/internal.h" + +#if defined(OPENSSL_RAND_FUCHSIA) #include #include #include -#include "../fipsmodule/rand/internal.h" - void CRYPTO_sysrand(uint8_t *out, size_t requested) { zx_cprng_draw(out, requested); } @@ -31,4 +31,4 @@ void CRYPTO_sysrand_for_seed(uint8_t *out, size_t requested) { CRYPTO_sysrand(out, requested); } -#endif // OPENSSL_FUCHSIA && !BORINGSSL_UNSAFE_DETERMINISTIC_MODE +#endif // OPENSSL_RAND_FUCHSIA diff --git a/crypto/rand_extra/windows.c b/crypto/rand_extra/windows.c index 0dbc0e3601..6b407b7c0f 100644 --- a/crypto/rand_extra/windows.c +++ b/crypto/rand_extra/windows.c @@ -14,7 +14,9 @@ #include -#if defined(OPENSSL_WINDOWS) && !defined(BORINGSSL_UNSAFE_DETERMINISTIC_MODE) +#include "../fipsmodule/rand/internal.h" + +#if defined(OPENSSL_RAND_WINDOWS) #include #include @@ -31,12 +33,29 @@ OPENSSL_MSVC_PRAGMA(comment(lib, "bcrypt.lib")) OPENSSL_MSVC_PRAGMA(warning(pop)) -#include "../fipsmodule/rand/internal.h" - #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP) && \ !WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + void CRYPTO_init_sysrand(void) {} + +void CRYPTO_sysrand(uint8_t *out, size_t requested) { + while (requested > 0) { + ULONG output_bytes_this_pass = ULONG_MAX; + if (requested < output_bytes_this_pass) { + output_bytes_this_pass = (ULONG)requested; + } + if (!BCRYPT_SUCCESS(BCryptGenRandom( + /*hAlgorithm=*/NULL, out, output_bytes_this_pass, + BCRYPT_USE_SYSTEM_PREFERRED_RNG))) { + abort(); + } + requested -= output_bytes_this_pass; + out += output_bytes_this_pass; + } +} + #else + // See: https://learn.microsoft.com/en-us/windows/win32/seccng/processprng typedef BOOL (WINAPI *ProcessPrngFunction)(PBYTE pbData, SIZE_T cbData); static ProcessPrngFunction g_processprng_fn = NULL; @@ -56,26 +75,8 @@ void CRYPTO_init_sysrand(void) { static CRYPTO_once_t once = CRYPTO_ONCE_INIT; CRYPTO_once(&once, init_processprng); } -#endif // WINAPI_PARTITION_APP && !WINAPI_PARTITION_DESKTOP void CRYPTO_sysrand(uint8_t *out, size_t requested) { -#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP) && \ - !WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) - while (requested > 0) { - ULONG output_bytes_this_pass = ULONG_MAX; - if (requested < output_bytes_this_pass) { - output_bytes_this_pass = (ULONG)requested; - } - if (!BCRYPT_SUCCESS(BCryptGenRandom( - /*hAlgorithm=*/NULL, out, output_bytes_this_pass, - BCRYPT_USE_SYSTEM_PREFERRED_RNG))) { - abort(); - } - requested -= output_bytes_this_pass; - out += output_bytes_this_pass; - } - return; -#else CRYPTO_init_sysrand(); // On non-UWP configurations, use ProcessPrng instead of BCryptGenRandom // to avoid accessing resources that may be unavailable inside the @@ -83,11 +84,12 @@ void CRYPTO_sysrand(uint8_t *out, size_t requested) { if (!g_processprng_fn(out, requested)) { abort(); } -#endif // WINAPI_PARTITION_APP && !WINAPI_PARTITION_DESKTOP } +#endif // WINAPI_PARTITION_APP && !WINAPI_PARTITION_DESKTOP + void CRYPTO_sysrand_for_seed(uint8_t *out, size_t requested) { CRYPTO_sysrand(out, requested); } -#endif // OPENSSL_WINDOWS && !BORINGSSL_UNSAFE_DETERMINISTIC_MODE +#endif // OPENSSL_RAND_WINDOWS diff --git a/crypto/trust_token/pmbtoken.c b/crypto/trust_token/pmbtoken.c index d49a2b86bb..93a0191f3a 100644 --- a/crypto/trust_token/pmbtoken.c +++ b/crypto/trust_token/pmbtoken.c @@ -201,7 +201,7 @@ static int pmbtoken_compute_keys(const PMBTOKEN_METHOD *method, } const EC_SCALAR *scalars[] = {x0, y0, x1, y1, xs, ys}; - size_t scalar_len = BN_num_bytes(&group->order); + size_t scalar_len = BN_num_bytes(EC_GROUP_get0_order(group)); for (size_t i = 0; i < OPENSSL_ARRAY_SIZE(scalars); i++) { uint8_t *buf; if (!CBB_add_space(out_private, &buf, scalar_len)) { @@ -290,7 +290,7 @@ static int pmbtoken_issuer_key_from_bytes(const PMBTOKEN_METHOD *method, const EC_GROUP *group = method->group; CBS cbs, tmp; CBS_init(&cbs, in, len); - size_t scalar_len = BN_num_bytes(&group->order); + size_t scalar_len = BN_num_bytes(EC_GROUP_get0_order(group)); EC_SCALAR *scalars[] = {&key->x0, &key->y0, &key->x1, &key->y1, &key->xs, &key->ys}; for (size_t i = 0; i < OPENSSL_ARRAY_SIZE(scalars); i++) { @@ -390,7 +390,7 @@ static STACK_OF(TRUST_TOKEN_PRETOKEN) *pmbtoken_blind( static int scalar_to_cbb(CBB *out, const EC_GROUP *group, const EC_SCALAR *scalar) { uint8_t *buf; - size_t scalar_len = BN_num_bytes(&group->order); + size_t scalar_len = BN_num_bytes(EC_GROUP_get0_order(group)); if (!CBB_add_space(out, &buf, scalar_len)) { return 0; } @@ -399,7 +399,7 @@ static int scalar_to_cbb(CBB *out, const EC_GROUP *group, } static int scalar_from_cbs(CBS *cbs, const EC_GROUP *group, EC_SCALAR *out) { - size_t scalar_len = BN_num_bytes(&group->order); + size_t scalar_len = BN_num_bytes(EC_GROUP_get0_order(group)); CBS tmp; if (!CBS_get_bytes(cbs, &tmp, scalar_len)) { OPENSSL_PUT_ERROR(TRUST_TOKEN, TRUST_TOKEN_R_DECODE_FAILURE); diff --git a/crypto/trust_token/voprf.c b/crypto/trust_token/voprf.c index ea7c193182..e62aca59b5 100644 --- a/crypto/trust_token/voprf.c +++ b/crypto/trust_token/voprf.c @@ -95,7 +95,7 @@ static int cbs_get_point(CBS *cbs, const EC_GROUP *group, EC_AFFINE *out) { static int scalar_to_cbb(CBB *out, const EC_GROUP *group, const EC_SCALAR *scalar) { uint8_t *buf; - size_t scalar_len = BN_num_bytes(&group->order); + size_t scalar_len = BN_num_bytes(EC_GROUP_get0_order(group)); if (!CBB_add_space(out, &buf, scalar_len)) { return 0; } @@ -104,7 +104,7 @@ static int scalar_to_cbb(CBB *out, const EC_GROUP *group, } static int scalar_from_cbs(CBS *cbs, const EC_GROUP *group, EC_SCALAR *out) { - size_t scalar_len = BN_num_bytes(&group->order); + size_t scalar_len = BN_num_bytes(EC_GROUP_get0_order(group)); CBS tmp; if (!CBS_get_bytes(cbs, &tmp, scalar_len)) { OPENSSL_PUT_ERROR(TRUST_TOKEN, TRUST_TOKEN_R_DECODE_FAILURE); diff --git a/crypto/x509/by_dir.c b/crypto/x509/by_dir.c index a9236fe3c8..8e55edff47 100644 --- a/crypto/x509/by_dir.c +++ b/crypto/x509/by_dir.c @@ -81,7 +81,6 @@ typedef struct lookup_dir_entry_st { } BY_DIR_ENTRY; typedef struct lookup_dir_st { - BUF_MEM *buffer; STACK_OF(BY_DIR_ENTRY) *dirs; } BY_DIR; @@ -141,10 +140,6 @@ static int new_dir(X509_LOOKUP *lu) { if ((a = (BY_DIR *)OPENSSL_malloc(sizeof(BY_DIR))) == NULL) { return 0; } - if ((a->buffer = BUF_MEM_new()) == NULL) { - OPENSSL_free(a); - return 0; - } a->dirs = NULL; lu->method_data = a; return 1; @@ -175,7 +170,6 @@ static void free_dir(X509_LOOKUP *lu) { BY_DIR *a = lu->method_data; if (a != NULL) { sk_BY_DIR_ENTRY_pop_free(a->dirs, by_dir_entry_free); - BUF_MEM_free(a->buffer); OPENSSL_free(a); } } diff --git a/crypto/x509/internal.h b/crypto/x509/internal.h index e6eff5a676..420e61cd67 100644 --- a/crypto/x509/internal.h +++ b/crypto/x509/internal.h @@ -275,7 +275,6 @@ struct x509_lookup_method_st { // function is then called to actually check the cert chain. struct x509_store_st { // The following is a cache of trusted certs - int cache; // if true, stash any hits STACK_OF(X509_OBJECT) *objs; // Cache of all objects CRYPTO_MUTEX objs_lock; diff --git a/crypto/x509/x509_lu.c b/crypto/x509/x509_lu.c index 4d859807db..e32aab7838 100644 --- a/crypto/x509/x509_lu.c +++ b/crypto/x509/x509_lu.c @@ -173,7 +173,6 @@ X509_STORE *X509_STORE_new(void) { if (ret->objs == NULL) { goto err; } - ret->cache = 1; ret->get_cert_methods = sk_X509_LOOKUP_new_null(); if (ret->get_cert_methods == NULL) { goto err; diff --git a/go.mod b/go.mod index 1e0fedd7e5..48feb26be3 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,13 @@ module boringssl.googlesource.com/boringssl go 1.18 require ( - golang.org/x/crypto v0.6.0 - golang.org/x/net v0.7.0 + golang.org/x/crypto v0.10.0 + golang.org/x/net v0.11.0 ) require github.com/ethereum/go-ethereum v1.11.5 require ( - golang.org/x/sys v0.5.0 // indirect - golang.org/x/term v0.5.0 // indirect + golang.org/x/sys v0.9.0 // indirect + golang.org/x/term v0.9.0 // indirect ) diff --git a/go.sum b/go.sum index 8e1aae3731..05c20c32f2 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,8 @@ -github.com/ethereum/go-ethereum v1.11.5 h1:3M1uan+LAUvdn+7wCEFrcMM4LJTeuxDrPTg/f31a5QQ= -github.com/ethereum/go-ethereum v1.11.5/go.mod h1:it7x0DWnTDMfVFdXcU6Ti4KEFQynLHVRarcSlPr0HBo= -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= +golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= +golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28= +golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h index 1aa56de80a..4d21da5f0f 100644 --- a/include/openssl/ssl.h +++ b/include/openssl/ssl.h @@ -5601,16 +5601,7 @@ OPENSSL_EXPORT int SSL_CTX_set_tlsext_status_arg(SSL_CTX *ctx, void *arg); SSL_R_TLSV1_ALERT_BAD_CERTIFICATE_HASH_VALUE #define SSL_R_TLSV1_CERTIFICATE_REQUIRED SSL_R_TLSV1_ALERT_CERTIFICATE_REQUIRED -// The following symbols are compatibility aliases for equivalent functions that -// use the newer "group" terminology. New code should use the new functions for -// consistency, but we do not plan to remove these aliases. -#define SSL_CTX_set1_curves SSL_CTX_set1_groups -#define SSL_set1_curves SSL_set1_groups -#define SSL_CTX_set1_curves_list SSL_CTX_set1_groups_list -#define SSL_set1_curves_list SSL_set1_groups_list -#define SSL_get_curve_id SSL_get_group_id -#define SSL_get_curve_name SSL_get_group_name -#define SSL_get_all_curve_names SSL_get_all_group_names +// The following symbols are compatibility aliases for |SSL_GROUP_*|. #define SSL_CURVE_SECP224R1 SSL_GROUP_SECP224R1 #define SSL_CURVE_SECP256R1 SSL_GROUP_SECP256R1 #define SSL_CURVE_SECP384R1 SSL_GROUP_SECP384R1 @@ -5619,6 +5610,29 @@ OPENSSL_EXPORT int SSL_CTX_set_tlsext_status_arg(SSL_CTX *ctx, void *arg); #define SSL_CURVE_SECP256R1_KYBER768_DRAFT00 SSL_GROUP_SECP256R1_KYBER768_DRAFT00 #define SSL_CURVE_X25519_KYBER768_DRAFT00 SSL_GROUP_X25519_KYBER768_DRAFT00 +// SSL_get_curve_id calls |SSL_get_group_id|. +OPENSSL_EXPORT uint16_t SSL_get_curve_id(const SSL *ssl); + +// SSL_get_curve_name calls |SSL_get_group_name|. +OPENSSL_EXPORT const char *SSL_get_curve_name(uint16_t curve_id); + +// SSL_get_all_curve_names calls |SSL_get_all_group_names|. +OPENSSL_EXPORT size_t SSL_get_all_curve_names(const char **out, size_t max_out); + +// SSL_CTX_set1_curves calls |SSL_CTX_set1_groups|. +OPENSSL_EXPORT int SSL_CTX_set1_curves(SSL_CTX *ctx, const int *curves, + size_t num_curves); + +// SSL_set1_curves calls |SSL_set1_groups|. +OPENSSL_EXPORT int SSL_set1_curves(SSL *ssl, const int *curves, + size_t num_curves); + +// SSL_CTX_set1_curves_list calls |SSL_CTX_set1_groups_list|. +OPENSSL_EXPORT int SSL_CTX_set1_curves_list(SSL_CTX *ctx, const char *curves); + +// SSL_set1_curves_list calls |SSL_set1_groups_list|. +OPENSSL_EXPORT int SSL_set1_curves_list(SSL *ssl, const char *curves); + // Nodejs compatibility section (hidden). // @@ -5725,6 +5739,7 @@ OPENSSL_EXPORT int SSL_CTX_set_tlsext_status_arg(SSL_CTX *ctx, void *arg); #define SSL_CTX_sess_set_cache_size SSL_CTX_sess_set_cache_size #define SSL_CTX_set0_chain SSL_CTX_set0_chain #define SSL_CTX_set1_chain SSL_CTX_set1_chain +#define SSL_CTX_set1_curves SSL_CTX_set1_curves #define SSL_CTX_set1_groups SSL_CTX_set1_groups #define SSL_CTX_set_max_cert_list SSL_CTX_set_max_cert_list #define SSL_CTX_set_max_send_fragment SSL_CTX_set_max_send_fragment @@ -5760,6 +5775,7 @@ OPENSSL_EXPORT int SSL_CTX_set_tlsext_status_arg(SSL_CTX *ctx, void *arg); #define SSL_session_reused SSL_session_reused #define SSL_set0_chain SSL_set0_chain #define SSL_set1_chain SSL_set1_chain +#define SSL_set1_curves SSL_set1_curves #define SSL_set1_groups SSL_set1_groups #define SSL_set_max_cert_list SSL_set_max_cert_list #define SSL_set_max_send_fragment SSL_set_max_send_fragment diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc index d93b345c74..909d3a97ed 100644 --- a/ssl/ssl_lib.cc +++ b/ssl/ssl_lib.cc @@ -3267,3 +3267,29 @@ int SSL_CTX_set_tlsext_status_arg(SSL_CTX *ctx, void *arg) { ctx->legacy_ocsp_callback_arg = arg; return 1; } + +uint16_t SSL_get_curve_id(const SSL *ssl) { return SSL_get_group_id(ssl); } + +const char *SSL_get_curve_name(uint16_t curve_id) { + return SSL_get_group_name(curve_id); +} + +size_t SSL_get_all_curve_names(const char **out, size_t max_out) { + return SSL_get_all_group_names(out, max_out); +} + +int SSL_CTX_set1_curves(SSL_CTX *ctx, const int *curves, size_t num_curves) { + return SSL_CTX_set1_groups(ctx, curves, num_curves); +} + +int SSL_set1_curves(SSL *ssl, const int *curves, size_t num_curves) { + return SSL_set1_groups(ssl, curves, num_curves); +} + +int SSL_CTX_set1_curves_list(SSL_CTX *ctx, const char *curves) { + return SSL_CTX_set1_groups_list(ctx, curves); +} + +int SSL_set1_curves_list(SSL *ssl, const char *curves) { + return SSL_set1_groups_list(ssl, curves); +} diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go index a972d3ebea..bb59581e6f 100644 --- a/ssl/test/runner/handshake_client.go +++ b/ssl/test/runner/handshake_client.go @@ -21,6 +21,7 @@ import ( "time" "boringssl.googlesource.com/boringssl/ssl/test/runner/hpke" + "golang.org/x/crypto/cryptobyte" ) const echBadPayloadByte = 0xff @@ -71,9 +72,12 @@ func replaceClientHello(hello *clientHelloMsg, in []byte) (*clientHelloMsg, erro // Replace |newHellos|'s key shares with those of |hello|. For simplicity, // we require their lengths match, which is satisfied by matching the // DefaultCurves setting to the selection in the replacement ClientHello. - bb := newByteBuilder() + bb := cryptobyte.NewBuilder(nil) hello.marshalKeyShares(bb) - keyShares := bb.finish() + keyShares, err := bb.Bytes() + if err != nil { + return nil, err + } if len(keyShares) != len(newHello.keySharesRaw) { return nil, errors.New("tls: ClientHello key share length is inconsistent with DefaultCurves setting") } diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go index b253b0c849..6ea7faaa85 100644 --- a/ssl/test/runner/handshake_messages.go +++ b/ssl/test/runner/handshake_messages.go @@ -5,239 +5,49 @@ package runner import ( - "encoding/binary" "errors" "fmt" -) - -func writeLen(buf []byte, v, size int) { - for i := 0; i < size; i++ { - buf[size-i-1] = byte(v) - v >>= 8 - } - if v != 0 { - panic("length is too long") - } -} - -type byteBuilder struct { - buf *[]byte - start int - prefixLen int - child *byteBuilder -} - -func newByteBuilder() *byteBuilder { - buf := make([]byte, 0, 32) - return &byteBuilder{buf: &buf} -} - -func (bb *byteBuilder) len() int { - return len(*bb.buf) - bb.start - bb.prefixLen -} - -func (bb *byteBuilder) data() []byte { - bb.flush() - return (*bb.buf)[bb.start+bb.prefixLen:] -} - -func (bb *byteBuilder) flush() { - if bb.child == nil { - return - } - bb.child.flush() - writeLen((*bb.buf)[bb.child.start:], bb.child.len(), bb.child.prefixLen) - bb.child = nil - return -} - -func (bb *byteBuilder) finish() []byte { - bb.flush() - return *bb.buf -} - -func (bb *byteBuilder) addU8(u uint8) { - bb.flush() - *bb.buf = append(*bb.buf, u) -} - -func (bb *byteBuilder) addU16(u uint16) { - bb.flush() - *bb.buf = append(*bb.buf, byte(u>>8), byte(u)) -} - -func (bb *byteBuilder) addU24(u int) { - bb.flush() - *bb.buf = append(*bb.buf, byte(u>>16), byte(u>>8), byte(u)) -} - -func (bb *byteBuilder) addU32(u uint32) { - bb.flush() - *bb.buf = append(*bb.buf, byte(u>>24), byte(u>>16), byte(u>>8), byte(u)) -} - -func (bb *byteBuilder) addU64(u uint64) { - bb.flush() - var b [8]byte - binary.BigEndian.PutUint64(b[:], u) - *bb.buf = append(*bb.buf, b[:]...) -} - -func (bb *byteBuilder) addU8LengthPrefixed() *byteBuilder { - return bb.createChild(1) -} - -func (bb *byteBuilder) addU16LengthPrefixed() *byteBuilder { - return bb.createChild(2) -} - -func (bb *byteBuilder) addU24LengthPrefixed() *byteBuilder { - return bb.createChild(3) -} - -func (bb *byteBuilder) addU32LengthPrefixed() *byteBuilder { - return bb.createChild(4) -} - -func (bb *byteBuilder) addBytes(b []byte) { - bb.flush() - *bb.buf = append(*bb.buf, b...) -} - -func (bb *byteBuilder) createChild(lengthPrefixSize int) *byteBuilder { - bb.flush() - bb.child = &byteBuilder{ - buf: bb.buf, - start: len(*bb.buf), - prefixLen: lengthPrefixSize, - } - for i := 0; i < lengthPrefixSize; i++ { - *bb.buf = append(*bb.buf, 0) - } - return bb.child -} -func (bb *byteBuilder) discardChild() { - if bb.child == nil { - return - } - *bb.buf = (*bb.buf)[:bb.child.start] - bb.child = nil -} - -type byteReader []byte - -func (br *byteReader) readInternal(out *byteReader, n int) bool { - if len(*br) < n { - return false - } - *out = (*br)[:n] - *br = (*br)[n:] - return true -} - -func (br *byteReader) readBytes(out *[]byte, n int) bool { - var child byteReader - if !br.readInternal(&child, n) { - return false - } - *out = []byte(child) - return true -} - -func (br *byteReader) readUint(out *uint64, n int) bool { - var b []byte - if !br.readBytes(&b, n) { - return false - } - *out = 0 - for _, v := range b { - *out <<= 8 - *out |= uint64(v) - } - return true -} - -func (br *byteReader) readU8(out *uint8) bool { - var b []byte - if !br.readBytes(&b, 1) { - return false - } - *out = b[0] - return true -} + "golang.org/x/crypto/cryptobyte" +) -func (br *byteReader) readU16(out *uint16) bool { - var v uint64 - if !br.readUint(&v, 2) { +func readUint8LengthPrefixedBytes(s *cryptobyte.String, out *[]byte) bool { + var child cryptobyte.String + if !s.ReadUint8LengthPrefixed(&child) { return false } - *out = uint16(v) + *out = child return true } -func (br *byteReader) readU24(out *uint32) bool { - var v uint64 - if !br.readUint(&v, 3) { +func readUint16LengthPrefixedBytes(s *cryptobyte.String, out *[]byte) bool { + var child cryptobyte.String + if !s.ReadUint16LengthPrefixed(&child) { return false } - *out = uint32(v) + *out = child return true } -func (br *byteReader) readU32(out *uint32) bool { - var v uint64 - if !br.readUint(&v, 4) { +func readUint24LengthPrefixedBytes(s *cryptobyte.String, out *[]byte) bool { + var child cryptobyte.String + if !s.ReadUint24LengthPrefixed(&child) { return false } - *out = uint32(v) + *out = child return true } -func (br *byteReader) readU64(out *uint64) bool { - return br.readUint(out, 8) +func addUint8LengthPrefixedBytes(b *cryptobyte.Builder, v []byte) { + b.AddUint8LengthPrefixed(func(child *cryptobyte.Builder) { child.AddBytes(v) }) } -func (br *byteReader) readLengthPrefixed(out *byteReader, n int) bool { - var length uint64 - return br.readUint(&length, n) && - uint64(len(*br)) >= length && - br.readInternal(out, int(length)) +func addUint16LengthPrefixedBytes(b *cryptobyte.Builder, v []byte) { + b.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) { child.AddBytes(v) }) } -func (br *byteReader) readLengthPrefixedBytes(out *[]byte, n int) bool { - var length uint64 - return br.readUint(&length, n) && - uint64(len(*br)) >= length && - br.readBytes(out, int(length)) -} - -func (br *byteReader) readU8LengthPrefixed(out *byteReader) bool { - return br.readLengthPrefixed(out, 1) -} -func (br *byteReader) readU8LengthPrefixedBytes(out *[]byte) bool { - return br.readLengthPrefixedBytes(out, 1) -} - -func (br *byteReader) readU16LengthPrefixed(out *byteReader) bool { - return br.readLengthPrefixed(out, 2) -} -func (br *byteReader) readU16LengthPrefixedBytes(out *[]byte) bool { - return br.readLengthPrefixedBytes(out, 2) -} - -func (br *byteReader) readU24LengthPrefixed(out *byteReader) bool { - return br.readLengthPrefixed(out, 3) -} -func (br *byteReader) readU24LengthPrefixedBytes(out *[]byte) bool { - return br.readLengthPrefixedBytes(out, 3) -} - -func (br *byteReader) readU32LengthPrefixed(out *byteReader) bool { - return br.readLengthPrefixed(out, 4) -} -func (br *byteReader) readU32LengthPrefixedBytes(out *[]byte) bool { - return br.readLengthPrefixedBytes(out, 4) +func addUint24LengthPrefixedBytes(b *cryptobyte.Builder, v []byte) { + b.AddUint24LengthPrefixed(func(child *cryptobyte.Builder) { child.AddBytes(v) }) } type keyShareEntry struct { @@ -269,48 +79,52 @@ type ECHConfig struct { } func CreateECHConfig(template *ECHConfig) *ECHConfig { - bb := newByteBuilder() + bb := cryptobyte.NewBuilder(nil) // ECHConfig reuses the encrypted_client_hello extension codepoint as a // version identifier. - bb.addU16(extensionEncryptedClientHello) - contents := bb.addU16LengthPrefixed() - contents.addU8(template.ConfigID) - contents.addU16(template.KEM) - contents.addU16LengthPrefixed().addBytes(template.PublicKey) - cipherSuites := contents.addU16LengthPrefixed() - for _, suite := range template.CipherSuites { - cipherSuites.addU16(suite.KDF) - cipherSuites.addU16(suite.AEAD) - } - contents.addU8(template.MaxNameLen) - contents.addU8LengthPrefixed().addBytes([]byte(template.PublicName)) - extensions := contents.addU16LengthPrefixed() - // Mandatory extensions have the high bit set. - if template.UnsupportedExtension { - extensions.addU16(0x1111) - extensions.addU16LengthPrefixed().addBytes([]byte("test")) - } - if template.UnsupportedMandatoryExtension { - extensions.addU16(0xaaaa) - extensions.addU16LengthPrefixed().addBytes([]byte("test")) - } - - // This ought to be a call to a function like ParseECHConfig(bb.finish()), + bb.AddUint16(extensionEncryptedClientHello) + bb.AddUint16LengthPrefixed(func(contents *cryptobyte.Builder) { + contents.AddUint8(template.ConfigID) + contents.AddUint16(template.KEM) + addUint16LengthPrefixedBytes(contents, template.PublicKey) + contents.AddUint16LengthPrefixed(func(cipherSuites *cryptobyte.Builder) { + for _, suite := range template.CipherSuites { + cipherSuites.AddUint16(suite.KDF) + cipherSuites.AddUint16(suite.AEAD) + } + }) + contents.AddUint8(template.MaxNameLen) + addUint8LengthPrefixedBytes(contents, []byte(template.PublicName)) + contents.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + // Mandatory extensions have the high bit set. + if template.UnsupportedExtension { + extensions.AddUint16(0x1111) + addUint16LengthPrefixedBytes(extensions, []byte("test")) + } + if template.UnsupportedMandatoryExtension { + extensions.AddUint16(0xaaaa) + addUint16LengthPrefixedBytes(extensions, []byte("test")) + } + }) + }) + + // This ought to be a call to a function like ParseECHConfig(bb.BytesOrPanic()), // but this constrains us to constructing ECHConfigs we are willing to // support. We need to test the client's behavior in response to unparsable // or unsupported ECHConfigs, so populate fields from the template directly. ret := *template - ret.Raw = bb.finish() + ret.Raw = bb.BytesOrPanic() return &ret } func CreateECHConfigList(configs ...[]byte) []byte { - bb := newByteBuilder() - list := bb.addU16LengthPrefixed() - for _, config := range configs { - list.addBytes(config) - } - return bb.finish() + bb := cryptobyte.NewBuilder(nil) + bb.AddUint16LengthPrefixed(func(list *cryptobyte.Builder) { + for _, config := range configs { + list.AddBytes(config) + } + }) + return bb.BytesOrPanic() } type ServerECHConfig struct { @@ -392,16 +206,16 @@ type clientHelloMsg struct { rawExtensions []byte } -func (m *clientHelloMsg) marshalKeyShares(bb *byteBuilder) { - keyShares := bb.addU16LengthPrefixed() - for _, keyShare := range m.keyShares { - keyShares.addU16(uint16(keyShare.group)) - keyExchange := keyShares.addU16LengthPrefixed() - keyExchange.addBytes(keyShare.keyExchange) - } - if m.trailingKeyShareData { - keyShares.addU8(0) - } +func (m *clientHelloMsg) marshalKeyShares(bb *cryptobyte.Builder) { + bb.AddUint16LengthPrefixed(func(keyShares *cryptobyte.Builder) { + for _, keyShare := range m.keyShares { + keyShares.AddUint16(uint16(keyShare.group)) + addUint16LengthPrefixedBytes(keyShares, keyShare.keyExchange) + } + if m.trailingKeyShareData { + keyShares.AddUint8(0) + } + }) } type clientHelloType int @@ -411,23 +225,27 @@ const ( clientHelloEncodedInner ) -func (m *clientHelloMsg) marshalBody(hello *byteBuilder, typ clientHelloType) { - hello.addU16(m.vers) - hello.addBytes(m.random) - sessionID := hello.addU8LengthPrefixed() - if typ != clientHelloEncodedInner { - sessionID.addBytes(m.sessionID) - } +func (m *clientHelloMsg) marshalBody(hello *cryptobyte.Builder, typ clientHelloType) { + hello.AddUint16(m.vers) + hello.AddBytes(m.random) + hello.AddUint8LengthPrefixed(func(sessionID *cryptobyte.Builder) { + if typ != clientHelloEncodedInner { + sessionID.AddBytes(m.sessionID) + } + }) if m.isDTLS { - cookie := hello.addU8LengthPrefixed() - cookie.addBytes(m.cookie) - } - cipherSuites := hello.addU16LengthPrefixed() - for _, suite := range m.cipherSuites { - cipherSuites.addU16(suite) + hello.AddUint8LengthPrefixed(func(cookie *cryptobyte.Builder) { + cookie.AddBytes(m.cookie) + }) } - compressionMethods := hello.addU8LengthPrefixed() - compressionMethods.addBytes(m.compressionMethods) + hello.AddUint16LengthPrefixed(func(cipherSuites *cryptobyte.Builder) { + for _, suite := range m.cipherSuites { + cipherSuites.AddUint16(suite) + } + }) + hello.AddUint8LengthPrefixed(func(compressionMethods *cryptobyte.Builder) { + compressionMethods.AddBytes(m.compressionMethods) + }) type extension struct { id uint16 @@ -462,99 +280,99 @@ func (m *clientHelloMsg) marshalBody(hello *byteBuilder, typ clientHelloType) { // ServerName server_name_list<1..2^16-1> // } ServerNameList; - serverNameList := newByteBuilder() - serverName := serverNameList.addU16LengthPrefixed() - serverName.addU8(0) // NameType host_name(0) - hostName := serverName.addU16LengthPrefixed() - hostName.addBytes([]byte(m.serverName)) + serverNameList := cryptobyte.NewBuilder(nil) + serverNameList.AddUint16LengthPrefixed(func(serverName *cryptobyte.Builder) { + serverName.AddUint8(0) // NameType host_name(0) + addUint16LengthPrefixedBytes(serverName, []byte(m.serverName)) + }) extensions = append(extensions, extension{ id: extensionServerName, - body: serverNameList.finish(), + body: serverNameList.BytesOrPanic(), }) } if m.echOuter != nil { - body := newByteBuilder() - body.addU8(echClientTypeOuter) - body.addU16(m.echOuter.kdfID) - body.addU16(m.echOuter.aeadID) - body.addU8(m.echOuter.configID) - body.addU16LengthPrefixed().addBytes(m.echOuter.enc) - body.addU16LengthPrefixed().addBytes(m.echOuter.payload) + body := cryptobyte.NewBuilder(nil) + body.AddUint8(echClientTypeOuter) + body.AddUint16(m.echOuter.kdfID) + body.AddUint16(m.echOuter.aeadID) + body.AddUint8(m.echOuter.configID) + addUint16LengthPrefixedBytes(body, m.echOuter.enc) + addUint16LengthPrefixedBytes(body, m.echOuter.payload) extensions = append(extensions, extension{ id: extensionEncryptedClientHello, - body: body.finish(), + body: body.BytesOrPanic(), }) } if m.echInner { - body := newByteBuilder() - body.addU8(echClientTypeInner) + body := cryptobyte.NewBuilder(nil) + body.AddUint8(echClientTypeInner) // If unset, invalidECHInner is empty, which is the correct serialization. - body.addBytes(m.invalidECHInner) + body.AddBytes(m.invalidECHInner) extensions = append(extensions, extension{ id: extensionEncryptedClientHello, - body: body.finish(), + body: body.BytesOrPanic(), }) } if m.ocspStapling { - certificateStatusRequest := newByteBuilder() + certificateStatusRequest := cryptobyte.NewBuilder(nil) // RFC 4366, section 3.6 - certificateStatusRequest.addU8(1) // OCSP type + certificateStatusRequest.AddUint8(1) // OCSP type // Two zero valued uint16s for the two lengths. - certificateStatusRequest.addU16(0) // ResponderID length - certificateStatusRequest.addU16(0) // Extensions length + certificateStatusRequest.AddUint16(0) // ResponderID length + certificateStatusRequest.AddUint16(0) // Extensions length extensions = append(extensions, extension{ id: extensionStatusRequest, - body: certificateStatusRequest.finish(), + body: certificateStatusRequest.BytesOrPanic(), }) } if len(m.supportedCurves) > 0 { // http://tools.ietf.org/html/rfc4492#section-5.1.1 - supportedCurvesList := newByteBuilder() - supportedCurves := supportedCurvesList.addU16LengthPrefixed() - for _, curve := range m.supportedCurves { - supportedCurves.addU16(uint16(curve)) - } + supportedCurvesList := cryptobyte.NewBuilder(nil) + supportedCurvesList.AddUint16LengthPrefixed(func(supportedCurves *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + supportedCurves.AddUint16(uint16(curve)) + } + }) extensions = append(extensions, extension{ id: extensionSupportedCurves, - body: supportedCurvesList.finish(), + body: supportedCurvesList.BytesOrPanic(), }) } if len(m.supportedPoints) > 0 { // http://tools.ietf.org/html/rfc4492#section-5.1.2 - supportedPointsList := newByteBuilder() - supportedPoints := supportedPointsList.addU8LengthPrefixed() - supportedPoints.addBytes(m.supportedPoints) + supportedPointsList := cryptobyte.NewBuilder(nil) + addUint8LengthPrefixedBytes(supportedPointsList, m.supportedPoints) extensions = append(extensions, extension{ id: extensionSupportedPoints, - body: supportedPointsList.finish(), + body: supportedPointsList.BytesOrPanic(), }) } if m.hasKeyShares { - keyShareList := newByteBuilder() + keyShareList := cryptobyte.NewBuilder(nil) m.marshalKeyShares(keyShareList) extensions = append(extensions, extension{ id: extensionKeyShare, - body: keyShareList.finish(), + body: keyShareList.BytesOrPanic(), }) } if len(m.pskKEModes) > 0 { - pskModesExtension := newByteBuilder() - pskModesExtension.addU8LengthPrefixed().addBytes(m.pskKEModes) + pskModesExtension := cryptobyte.NewBuilder(nil) + addUint8LengthPrefixedBytes(pskModesExtension, m.pskKEModes) extensions = append(extensions, extension{ id: extensionPSKKeyExchangeModes, - body: pskModesExtension.finish(), + body: pskModesExtension.BytesOrPanic(), }) } if m.hasEarlyData { extensions = append(extensions, extension{id: extensionEarlyData}) } if len(m.tls13Cookie) > 0 { - body := newByteBuilder() - body.addU16LengthPrefixed().addBytes(m.tls13Cookie) + body := cryptobyte.NewBuilder(nil) + addUint16LengthPrefixedBytes(body, m.tls13Cookie) extensions = append(extensions, extension{ id: extensionCookie, - body: body.finish(), + body: body.BytesOrPanic(), }) } if m.ticketSupported { @@ -566,57 +384,60 @@ func (m *clientHelloMsg) marshalBody(hello *byteBuilder, typ clientHelloType) { } if len(m.signatureAlgorithms) > 0 { // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 - signatureAlgorithmsExtension := newByteBuilder() - signatureAlgorithms := signatureAlgorithmsExtension.addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithms { - signatureAlgorithms.addU16(uint16(sigAlg)) - } + signatureAlgorithmsExtension := cryptobyte.NewBuilder(nil) + signatureAlgorithmsExtension.AddUint16LengthPrefixed(func(signatureAlgorithms *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithms { + signatureAlgorithms.AddUint16(uint16(sigAlg)) + } + }) extensions = append(extensions, extension{ id: extensionSignatureAlgorithms, - body: signatureAlgorithmsExtension.finish(), + body: signatureAlgorithmsExtension.BytesOrPanic(), }) } if len(m.signatureAlgorithmsCert) > 0 { - signatureAlgorithmsCertExtension := newByteBuilder() - signatureAlgorithmsCert := signatureAlgorithmsCertExtension.addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithmsCert { - signatureAlgorithmsCert.addU16(uint16(sigAlg)) - } + signatureAlgorithmsCertExtension := cryptobyte.NewBuilder(nil) + signatureAlgorithmsCertExtension.AddUint16LengthPrefixed(func(signatureAlgorithmsCert *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithmsCert { + signatureAlgorithmsCert.AddUint16(uint16(sigAlg)) + } + }) extensions = append(extensions, extension{ id: extensionSignatureAlgorithmsCert, - body: signatureAlgorithmsCertExtension.finish(), + body: signatureAlgorithmsCertExtension.BytesOrPanic(), }) } if len(m.supportedVersions) > 0 { - supportedVersionsExtension := newByteBuilder() - supportedVersions := supportedVersionsExtension.addU8LengthPrefixed() - for _, version := range m.supportedVersions { - supportedVersions.addU16(uint16(version)) - } + supportedVersionsExtension := cryptobyte.NewBuilder(nil) + supportedVersionsExtension.AddUint8LengthPrefixed(func(supportedVersions *cryptobyte.Builder) { + for _, version := range m.supportedVersions { + supportedVersions.AddUint16(uint16(version)) + } + }) extensions = append(extensions, extension{ id: extensionSupportedVersions, - body: supportedVersionsExtension.finish(), + body: supportedVersionsExtension.BytesOrPanic(), }) } if m.secureRenegotiation != nil { - secureRenegoExt := newByteBuilder() - secureRenegoExt.addU8LengthPrefixed().addBytes(m.secureRenegotiation) + secureRenegoExt := cryptobyte.NewBuilder(nil) + addUint8LengthPrefixedBytes(secureRenegoExt, m.secureRenegotiation) extensions = append(extensions, extension{ id: extensionRenegotiationInfo, - body: secureRenegoExt.finish(), + body: secureRenegoExt.BytesOrPanic(), }) } if len(m.alpnProtocols) > 0 { // https://tools.ietf.org/html/rfc7301#section-3.1 - alpnExtension := newByteBuilder() - protocolNameList := alpnExtension.addU16LengthPrefixed() - for _, s := range m.alpnProtocols { - protocolName := protocolNameList.addU8LengthPrefixed() - protocolName.addBytes([]byte(s)) - } + alpnExtension := cryptobyte.NewBuilder(nil) + alpnExtension.AddUint16LengthPrefixed(func(protocolNameList *cryptobyte.Builder) { + for _, s := range m.alpnProtocols { + addUint8LengthPrefixedBytes(protocolNameList, []byte(s)) + } + }) extensions = append(extensions, extension{ id: extensionALPN, - body: alpnExtension.finish(), + body: alpnExtension.BytesOrPanic(), }) } if len(m.quicTransportParams) > 0 { @@ -644,18 +465,18 @@ func (m *clientHelloMsg) marshalBody(hello *byteBuilder, typ clientHelloType) { } if len(m.srtpProtectionProfiles) > 0 { // https://tools.ietf.org/html/rfc5764#section-4.1.1 - useSrtpExt := newByteBuilder() + useSrtpExt := cryptobyte.NewBuilder(nil) - srtpProtectionProfiles := useSrtpExt.addU16LengthPrefixed() - for _, p := range m.srtpProtectionProfiles { - srtpProtectionProfiles.addU16(p) - } - srtpMki := useSrtpExt.addU8LengthPrefixed() - srtpMki.addBytes([]byte(m.srtpMasterKeyIdentifier)) + useSrtpExt.AddUint16LengthPrefixed(func(srtpProtectionProfiles *cryptobyte.Builder) { + for _, p := range m.srtpProtectionProfiles { + srtpProtectionProfiles.AddUint16(p) + } + }) + addUint8LengthPrefixedBytes(useSrtpExt, []byte(m.srtpMasterKeyIdentifier)) extensions = append(extensions, extension{ id: extensionUseSRTP, - body: useSrtpExt.finish(), + body: useSrtpExt.BytesOrPanic(), }) } if m.sctListSupported { @@ -668,130 +489,138 @@ func (m *clientHelloMsg) marshalBody(hello *byteBuilder, typ clientHelloType) { }) } if len(m.compressedCertAlgs) > 0 { - body := newByteBuilder() - algIDs := body.addU8LengthPrefixed() - for _, v := range m.compressedCertAlgs { - algIDs.addU16(v) - } + body := cryptobyte.NewBuilder(nil) + body.AddUint8LengthPrefixed(func(algIDs *cryptobyte.Builder) { + for _, v := range m.compressedCertAlgs { + algIDs.AddUint16(v) + } + }) extensions = append(extensions, extension{ id: extensionCompressedCertAlgs, - body: body.finish(), + body: body.BytesOrPanic(), }) } if m.delegatedCredentials { - body := newByteBuilder() - signatureSchemeList := body.addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithms { - signatureSchemeList.addU16(uint16(sigAlg)) - } + body := cryptobyte.NewBuilder(nil) + body.AddUint16LengthPrefixed(func(signatureSchemeList *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithms { + signatureSchemeList.AddUint16(uint16(sigAlg)) + } + }) extensions = append(extensions, extension{ id: extensionDelegatedCredentials, - body: body.finish(), + body: body.BytesOrPanic(), }) } if len(m.alpsProtocols) > 0 { - body := newByteBuilder() - protocolNameList := body.addU16LengthPrefixed() - for _, s := range m.alpsProtocols { - protocolNameList.addU8LengthPrefixed().addBytes([]byte(s)) - } + body := cryptobyte.NewBuilder(nil) + body.AddUint16LengthPrefixed(func(protocolNameList *cryptobyte.Builder) { + for _, s := range m.alpsProtocols { + addUint8LengthPrefixedBytes(protocolNameList, []byte(s)) + } + }) extensions = append(extensions, extension{ id: extensionApplicationSettings, - body: body.finish(), + body: body.BytesOrPanic(), }) } // The PSK extension must be last. See https://tools.ietf.org/html/rfc8446#section-4.2.11 if len(m.pskIdentities) > 0 { - pskExtension := newByteBuilder() - pskIdentities := pskExtension.addU16LengthPrefixed() - for _, psk := range m.pskIdentities { - pskIdentities.addU16LengthPrefixed().addBytes(psk.ticket) - pskIdentities.addU32(psk.obfuscatedTicketAge) - } - pskBinders := pskExtension.addU16LengthPrefixed() - for _, binder := range m.pskBinders { - pskBinders.addU8LengthPrefixed().addBytes(binder) - } + pskExtension := cryptobyte.NewBuilder(nil) + pskExtension.AddUint16LengthPrefixed(func(pskIdentities *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + addUint16LengthPrefixedBytes(pskIdentities, psk.ticket) + pskIdentities.AddUint32(psk.obfuscatedTicketAge) + } + }) + pskExtension.AddUint16LengthPrefixed(func(pskBinders *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + addUint8LengthPrefixedBytes(pskBinders, binder) + } + }) extensions = append(extensions, extension{ id: extensionPreSharedKey, - body: pskExtension.finish(), + body: pskExtension.BytesOrPanic(), }) } - extensionsBB := hello.addU16LengthPrefixed() - extMap := make(map[uint16][]byte) - extsWritten := make(map[uint16]struct{}) - for _, ext := range extensions { - extMap[ext.id] = ext.body - } - // Write each of the prefix extensions, if we have it. - for _, extID := range m.prefixExtensions { - if body, ok := extMap[extID]; ok { - extensionsBB.addU16(extID) - extensionsBB.addU16LengthPrefixed().addBytes(body) - extsWritten[extID] = struct{}{} - } + if m.omitExtensions { + return } - // Write outer extensions, possibly in compressed form. - if m.outerExtensions != nil { - if typ == clientHelloEncodedInner && !m.reorderOuterExtensionsWithoutCompressing { - extensionsBB.addU16(extensionECHOuterExtensions) - list := extensionsBB.addU16LengthPrefixed().addU8LengthPrefixed() - for _, extID := range m.outerExtensions { - list.addU16(extID) + hello.AddUint16LengthPrefixed(func(extensionsBB *cryptobyte.Builder) { + if m.emptyExtensions { + return + } + extMap := make(map[uint16][]byte) + extsWritten := make(map[uint16]struct{}) + for _, ext := range extensions { + extMap[ext.id] = ext.body + } + // Write each of the prefix extensions, if we have it. + for _, extID := range m.prefixExtensions { + if body, ok := extMap[extID]; ok { + extensionsBB.AddUint16(extID) + addUint16LengthPrefixedBytes(extensionsBB, body) extsWritten[extID] = struct{}{} } - } else { - for _, extID := range m.outerExtensions { - // m.outerExtensions may intentionally contain duplicates to test the - // server's reaction. If m.reorderOuterExtensionsWithoutCompressing - // is set, we are targetting the second ClientHello and wish to send a - // valid first ClientHello. In that case, deduplicate so the error - // only appears later. - if _, written := extsWritten[extID]; m.reorderOuterExtensionsWithoutCompressing && written { - continue - } - if body, ok := extMap[extID]; ok { - extensionsBB.addU16(extID) - extensionsBB.addU16LengthPrefixed().addBytes(body) - extsWritten[extID] = struct{}{} + } + // Write outer extensions, possibly in compressed form. + if m.outerExtensions != nil { + if typ == clientHelloEncodedInner && !m.reorderOuterExtensionsWithoutCompressing { + extensionsBB.AddUint16(extensionECHOuterExtensions) + extensionsBB.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) { + child.AddUint8LengthPrefixed(func(list *cryptobyte.Builder) { + for _, extID := range m.outerExtensions { + list.AddUint16(extID) + extsWritten[extID] = struct{}{} + } + }) + }) + } else { + for _, extID := range m.outerExtensions { + // m.outerExtensions may intentionally contain duplicates to test the + // server's reaction. If m.reorderOuterExtensionsWithoutCompressing + // is set, we are targetting the second ClientHello and wish to send a + // valid first ClientHello. In that case, deduplicate so the error + // only appears later. + if _, written := extsWritten[extID]; m.reorderOuterExtensionsWithoutCompressing && written { + continue + } + if body, ok := extMap[extID]; ok { + extensionsBB.AddUint16(extID) + addUint16LengthPrefixedBytes(extensionsBB, body) + extsWritten[extID] = struct{}{} + } } } } - } - - // Write each of the remaining extensions in their original order. - for _, ext := range extensions { - if _, written := extsWritten[ext.id]; !written { - extensionsBB.addU16(ext.id) - extensionsBB.addU16LengthPrefixed().addBytes(ext.body) - } - } - if m.pad != 0 && hello.len()%m.pad != 0 { - extensionsBB.addU16(extensionPadding) - padding := extensionsBB.addU16LengthPrefixed() - // Note hello.len() has changed at this point from the length - // prefix. - if l := hello.len() % m.pad; l != 0 { - padding.addBytes(make([]byte, m.pad-l)) + // Write each of the remaining extensions in their original order. + for _, ext := range extensions { + if _, written := extsWritten[ext.id]; !written { + extensionsBB.AddUint16(ext.id) + addUint16LengthPrefixedBytes(extensionsBB, ext.body) + } } - } - if m.omitExtensions || m.emptyExtensions { - // Silently erase any extensions which were sent. - hello.discardChild() - if m.emptyExtensions { - hello.addU16(0) + if m.pad != 0 && len(hello.BytesOrPanic())%m.pad != 0 { + extensionsBB.AddUint16(extensionPadding) + extensionsBB.AddUint16LengthPrefixed(func(padding *cryptobyte.Builder) { + // Note hello.len() has changed at this point from the length + // prefix. + if l := len(hello.BytesOrPanic()) % m.pad; l != 0 { + padding.AddBytes(make([]byte, m.pad-l)) + } + }) } - } + }) } func (m *clientHelloMsg) marshalForEncodedInner() []byte { - hello := newByteBuilder() + hello := cryptobyte.NewBuilder(nil) m.marshalBody(hello, clientHelloEncodedInner) - return hello.finish() + return hello.BytesOrPanic() } func (m *clientHelloMsg) marshal() []byte { @@ -800,26 +629,27 @@ func (m *clientHelloMsg) marshal() []byte { } if m.isV2ClientHello { - v2Msg := newByteBuilder() - v2Msg.addU8(1) - v2Msg.addU16(m.vers) - v2Msg.addU16(uint16(len(m.cipherSuites) * 3)) - v2Msg.addU16(uint16(len(m.sessionID))) - v2Msg.addU16(uint16(len(m.v2Challenge))) + v2Msg := cryptobyte.NewBuilder(nil) + v2Msg.AddUint8(1) + v2Msg.AddUint16(m.vers) + v2Msg.AddUint16(uint16(len(m.cipherSuites) * 3)) + v2Msg.AddUint16(uint16(len(m.sessionID))) + v2Msg.AddUint16(uint16(len(m.v2Challenge))) for _, spec := range m.cipherSuites { - v2Msg.addU24(int(spec)) + v2Msg.AddUint24(uint32(spec)) } - v2Msg.addBytes(m.sessionID) - v2Msg.addBytes(m.v2Challenge) - m.raw = v2Msg.finish() + v2Msg.AddBytes(m.sessionID) + v2Msg.AddBytes(m.v2Challenge) + m.raw = v2Msg.BytesOrPanic() return m.raw } - handshakeMsg := newByteBuilder() - handshakeMsg.addU8(typeClientHello) - hello := handshakeMsg.addU24LengthPrefixed() - m.marshalBody(hello, clientHelloNormal) - m.raw = handshakeMsg.finish() + handshakeMsg := cryptobyte.NewBuilder(nil) + handshakeMsg.AddUint8(typeClientHello) + handshakeMsg.AddUint24LengthPrefixed(func(hello *cryptobyte.Builder) { + m.marshalBody(hello, clientHelloNormal) + }) + m.raw = handshakeMsg.BytesOrPanic() // Sanity-check padding. if m.pad != 0 && (len(m.raw)-4)%m.pad != 0 { panic(fmt.Sprintf("%d is not a multiple of %d", len(m.raw)-4, m.pad)) @@ -827,9 +657,9 @@ func (m *clientHelloMsg) marshal() []byte { return m.raw } -func parseSignatureAlgorithms(reader *byteReader, out *[]signatureAlgorithm, allowEmpty bool) bool { - var sigAlgs byteReader - if !reader.readU16LengthPrefixed(&sigAlgs) { +func parseSignatureAlgorithms(reader *cryptobyte.String, out *[]signatureAlgorithm, allowEmpty bool) bool { + var sigAlgs cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&sigAlgs) { return false } if !allowEmpty && len(sigAlgs) == 0 { @@ -838,7 +668,7 @@ func parseSignatureAlgorithms(reader *byteReader, out *[]signatureAlgorithm, all *out = make([]signatureAlgorithm, 0, len(sigAlgs)/2) for len(sigAlgs) > 0 { var v uint16 - if !sigAlgs.readU16(&v) { + if !sigAlgs.ReadUint16(&v) { return false } if signatureAlgorithm(v) == signatureRSAPKCS1WithMD5AndSHA1 { @@ -852,13 +682,13 @@ func parseSignatureAlgorithms(reader *byteReader, out *[]signatureAlgorithm, all return true } -func checkDuplicateExtensions(extensions byteReader) bool { +func checkDuplicateExtensions(extensions cryptobyte.String) bool { seen := make(map[uint16]struct{}) for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } if _, ok := seen[extension]; ok { @@ -871,26 +701,26 @@ func checkDuplicateExtensions(extensions byteReader) bool { func (m *clientHelloMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) - if !reader.readU16(&m.vers) || - !reader.readBytes(&m.random, 32) || - !reader.readU8LengthPrefixedBytes(&m.sessionID) || + reader := cryptobyte.String(data[4:]) + if !reader.ReadUint16(&m.vers) || + !reader.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixedBytes(&reader, &m.sessionID) || len(m.sessionID) > 32 { return false } - if m.isDTLS && !reader.readU8LengthPrefixedBytes(&m.cookie) { + if m.isDTLS && !readUint8LengthPrefixedBytes(&reader, &m.cookie) { return false } - var cipherSuites byteReader - if !reader.readU16LengthPrefixed(&cipherSuites) || - !reader.readU8LengthPrefixedBytes(&m.compressionMethods) { + var cipherSuites cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&cipherSuites) || + !readUint8LengthPrefixedBytes(&reader, &m.compressionMethods) { return false } m.cipherSuites = make([]uint16, 0, len(cipherSuites)/2) for len(cipherSuites) > 0 { var v uint16 - if !cipherSuites.readU16(&v) { + if !cipherSuites.ReadUint16(&v) { return false } m.cipherSuites = append(m.cipherSuites, v) @@ -921,29 +751,29 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return true } - var extensions byteReader - if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 || !checkDuplicateExtensions(extensions) { + var extensions cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 || !checkDuplicateExtensions(extensions) { return false } m.rawExtensions = extensions for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { case extensionServerName: - var names byteReader - if !body.readU16LengthPrefixed(&names) || len(body) != 0 { + var names cryptobyte.String + if !body.ReadUint16LengthPrefixed(&names) || len(body) != 0 { return false } for len(names) > 0 { var nameType byte var name []byte - if !names.readU8(&nameType) || - !names.readU16LengthPrefixedBytes(&name) { + if !names.ReadUint8(&nameType) || + !readUint16LengthPrefixedBytes(&names, &name) { return false } if nameType == 0 { @@ -952,17 +782,17 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } case extensionEncryptedClientHello: var typ byte - if !body.readU8(&typ) { + if !body.ReadUint8(&typ) { return false } switch typ { case echClientTypeOuter: var echOuter echClientOuter - if !body.readU16(&echOuter.kdfID) || - !body.readU16(&echOuter.aeadID) || - !body.readU8(&echOuter.configID) || - !body.readU16LengthPrefixedBytes(&echOuter.enc) || - !body.readU16LengthPrefixedBytes(&echOuter.payload) || + if !body.ReadUint16(&echOuter.kdfID) || + !body.ReadUint16(&echOuter.aeadID) || + !body.ReadUint8(&echOuter.configID) || + !readUint16LengthPrefixedBytes(&body, &echOuter.enc) || + !readUint16LengthPrefixedBytes(&body, &echOuter.payload) || len(echOuter.payload) == 0 || len(body) > 0 { return false @@ -989,11 +819,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { // extensibility, but we expect our client to only send empty // requests of type OCSP. var statusType uint8 - var responderIDList, innerExtensions byteReader - if !body.readU8(&statusType) || + var responderIDList, innerExtensions cryptobyte.String + if !body.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !body.readU16LengthPrefixed(&responderIDList) || - !body.readU16LengthPrefixed(&innerExtensions) || + !body.ReadUint16LengthPrefixed(&responderIDList) || + !body.ReadUint16LengthPrefixed(&innerExtensions) || len(responderIDList) != 0 || len(innerExtensions) != 0 || len(body) != 0 { @@ -1002,21 +832,21 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.ocspStapling = true case extensionSupportedCurves: // http://tools.ietf.org/html/rfc4492#section-5.5.1 - var curves byteReader - if !body.readU16LengthPrefixed(&curves) || len(body) != 0 { + var curves cryptobyte.String + if !body.ReadUint16LengthPrefixed(&curves) || len(body) != 0 { return false } m.supportedCurves = make([]CurveID, 0, len(curves)/2) for len(curves) > 0 { var v uint16 - if !curves.readU16(&v) { + if !curves.ReadUint16(&v) { return false } m.supportedCurves = append(m.supportedCurves, CurveID(v)) } case extensionSupportedPoints: // http://tools.ietf.org/html/rfc4492#section-5.1.2 - if !body.readU8LengthPrefixedBytes(&m.supportedPoints) || len(m.supportedPoints) == 0 || len(body) != 0 { + if !readUint8LengthPrefixedBytes(&body, &m.supportedPoints) || len(m.supportedPoints) == 0 || len(body) != 0 { return false } case extensionSessionTicket: @@ -1027,15 +857,15 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { // https://tools.ietf.org/html/rfc8446#section-4.2.8 m.hasKeyShares = true m.keySharesRaw = body - var keyShares byteReader - if !body.readU16LengthPrefixed(&keyShares) || len(body) != 0 { + var keyShares cryptobyte.String + if !body.ReadUint16LengthPrefixed(&keyShares) || len(body) != 0 { return false } for len(keyShares) > 0 { var entry keyShareEntry var group uint16 - if !keyShares.readU16(&group) || - !keyShares.readU16LengthPrefixedBytes(&entry.keyExchange) { + if !keyShares.ReadUint16(&group) || + !readUint16LengthPrefixedBytes(&keyShares, &entry.keyExchange) { return false } entry.group = CurveID(group) @@ -1043,23 +873,23 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } case extensionPreSharedKey: // https://tools.ietf.org/html/rfc8446#section-4.2.11 - var psks, binders byteReader - if !body.readU16LengthPrefixed(&psks) || - !body.readU16LengthPrefixed(&binders) || + var psks, binders cryptobyte.String + if !body.ReadUint16LengthPrefixed(&psks) || + !body.ReadUint16LengthPrefixed(&binders) || len(body) != 0 { return false } for len(psks) > 0 { var psk pskIdentity - if !psks.readU16LengthPrefixedBytes(&psk.ticket) || - !psks.readU32(&psk.obfuscatedTicketAge) { + if !readUint16LengthPrefixedBytes(&psks, &psk.ticket) || + !psks.ReadUint32(&psk.obfuscatedTicketAge) { return false } m.pskIdentities = append(m.pskIdentities, psk) } for len(binders) > 0 { var binder []byte - if !binders.readU8LengthPrefixedBytes(&binder) { + if !readUint8LengthPrefixedBytes(&binders, &binder) { return false } m.pskBinders = append(m.pskBinders, binder) @@ -1071,7 +901,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } case extensionPSKKeyExchangeModes: // https://tools.ietf.org/html/rfc8446#section-4.2.9 - if !body.readU8LengthPrefixedBytes(&m.pskKEModes) || len(body) != 0 { + if !readUint8LengthPrefixedBytes(&body, &m.pskKEModes) || len(body) != 0 { return false } case extensionEarlyData: @@ -1081,7 +911,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } m.hasEarlyData = true case extensionCookie: - if !body.readU16LengthPrefixedBytes(&m.tls13Cookie) || len(body) != 0 { + if !readUint16LengthPrefixedBytes(&body, &m.tls13Cookie) || len(body) != 0 { return false } case extensionSignatureAlgorithms: @@ -1094,30 +924,30 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } case extensionSupportedVersions: - var versions byteReader - if !body.readU8LengthPrefixed(&versions) || len(body) != 0 { + var versions cryptobyte.String + if !body.ReadUint8LengthPrefixed(&versions) || len(body) != 0 { return false } m.supportedVersions = make([]uint16, 0, len(versions)/2) for len(versions) > 0 { var v uint16 - if !versions.readU16(&v) { + if !versions.ReadUint16(&v) { return false } m.supportedVersions = append(m.supportedVersions, v) } case extensionRenegotiationInfo: - if !body.readU8LengthPrefixedBytes(&m.secureRenegotiation) || len(body) != 0 { + if !readUint8LengthPrefixedBytes(&body, &m.secureRenegotiation) || len(body) != 0 { return false } case extensionALPN: - var protocols byteReader - if !body.readU16LengthPrefixed(&protocols) || len(body) != 0 { + var protocols cryptobyte.String + if !body.ReadUint16LengthPrefixed(&protocols) || len(body) != 0 { return false } for len(protocols) > 0 { var protocol []byte - if !protocols.readU8LengthPrefixedBytes(&protocol) || len(protocol) == 0 { + if !readUint8LengthPrefixedBytes(&protocols, &protocol) || len(protocol) == 0 { return false } m.alpnProtocols = append(m.alpnProtocols, string(protocol)) @@ -1137,17 +967,17 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } m.extendedMasterSecret = true case extensionUseSRTP: - var profiles byteReader + var profiles cryptobyte.String var mki []byte - if !body.readU16LengthPrefixed(&profiles) || - !body.readU8LengthPrefixedBytes(&mki) || + if !body.ReadUint16LengthPrefixed(&profiles) || + !readUint8LengthPrefixedBytes(&body, &mki) || len(body) != 0 { return false } m.srtpProtectionProfiles = make([]uint16, 0, len(profiles)/2) for len(profiles) > 0 { var v uint16 - if !profiles.readU16(&v) { + if !profiles.ReadUint16(&v) { return false } m.srtpProtectionProfiles = append(m.srtpProtectionProfiles, v) @@ -1161,15 +991,15 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { case extensionCustom: m.customExtension = string(body) case extensionCompressedCertAlgs: - var algIDs byteReader - if !body.readU8LengthPrefixed(&algIDs) { + var algIDs cryptobyte.String + if !body.ReadUint8LengthPrefixed(&algIDs) { return false } seen := make(map[uint16]struct{}) for len(algIDs) > 0 { var algID uint16 - if !algIDs.readU16(&algID) { + if !algIDs.ReadUint16(&algID) { return false } if _, ok := seen[algID]; ok { @@ -1191,13 +1021,13 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } m.delegatedCredentials = true case extensionApplicationSettings: - var protocols byteReader - if !body.readU16LengthPrefixed(&protocols) || len(body) != 0 { + var protocols cryptobyte.String + if !body.ReadUint16LengthPrefixed(&protocols) || len(body) != 0 { return false } for len(protocols) > 0 { var protocol []byte - if !protocols.readU8LengthPrefixedBytes(&protocol) || len(protocol) == 0 { + if !readUint8LengthPrefixedBytes(&protocols, &protocol) || len(protocol) == 0 { return false } m.alpsProtocols = append(m.alpsProtocols, string(protocol)) @@ -1213,15 +1043,15 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } func decodeClientHelloInner(config *Config, encoded []byte, helloOuter *clientHelloMsg) (*clientHelloMsg, error) { - reader := byteReader(encoded) + reader := cryptobyte.String(encoded) var versAndRandom, sessionID, cipherSuites, compressionMethods []byte - var extensions byteReader - if !reader.readBytes(&versAndRandom, 2+32) || - !reader.readU8LengthPrefixedBytes(&sessionID) || + var extensions cryptobyte.String + if !reader.ReadBytes(&versAndRandom, 2+32) || + !readUint8LengthPrefixedBytes(&reader, &sessionID) || len(sessionID) != 0 || // Copied from |helloOuter| - !reader.readU16LengthPrefixedBytes(&cipherSuites) || - !reader.readU8LengthPrefixedBytes(&compressionMethods) || - !reader.readU16LengthPrefixed(&extensions) { + !readUint16LengthPrefixedBytes(&reader, &cipherSuites) || + !readUint8LengthPrefixedBytes(&reader, &compressionMethods) || + !reader.ReadUint16LengthPrefixed(&extensions) { return nil, errors.New("tls: error parsing EncodedClientHelloInner") } @@ -1232,64 +1062,77 @@ func decodeClientHelloInner(config *Config, encoded []byte, helloOuter *clientHe } } - builder := newByteBuilder() - builder.addU8(typeClientHello) - body := builder.addU24LengthPrefixed() - body.addBytes(versAndRandom) - body.addU8LengthPrefixed().addBytes(helloOuter.sessionID) - body.addU16LengthPrefixed().addBytes(cipherSuites) - body.addU8LengthPrefixed().addBytes(compressionMethods) - newExtensions := body.addU16LengthPrefixed() - - var seenOuterExtensions bool - outerExtensions := byteReader(helloOuter.rawExtensions) copied := make(map[uint16]struct{}) - for len(extensions) > 0 { - var extType uint16 - var extBody byteReader - if !extensions.readU16(&extType) || - !extensions.readU16LengthPrefixed(&extBody) { - return nil, errors.New("tls: error parsing EncodedClientHelloInner") - } - if extType != extensionECHOuterExtensions { - newExtensions.addU16(extType) - newExtensions.addU16LengthPrefixed().addBytes(extBody) - continue - } - if seenOuterExtensions { - return nil, errors.New("tls: duplicate ech_outer_extensions extension") - } - seenOuterExtensions = true - var extList byteReader - if !extBody.readU8LengthPrefixed(&extList) || len(extList) == 0 || len(extBody) != 0 { - return nil, errors.New("tls: error parsing ech_outer_extensions") - } - for len(extList) != 0 { - var newExtType uint16 - if !extList.readU16(&newExtType) { - return nil, errors.New("tls: error parsing ech_outer_extensions") - } - if newExtType == extensionEncryptedClientHello { - return nil, errors.New("tls: error parsing ech_outer_extensions") - } - for { - if len(outerExtensions) == 0 { - return nil, fmt.Errorf("tls: extension %d not found in ClientHelloOuter", newExtType) + builder := cryptobyte.NewBuilder(nil) + builder.AddUint8(typeClientHello) + builder.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + body.AddBytes(versAndRandom) + addUint8LengthPrefixedBytes(body, helloOuter.sessionID) + addUint16LengthPrefixedBytes(body, cipherSuites) + addUint8LengthPrefixedBytes(body, compressionMethods) + body.AddUint16LengthPrefixed(func(newExtensions *cryptobyte.Builder) { + var seenOuterExtensions bool + outerExtensions := cryptobyte.String(helloOuter.rawExtensions) + for len(extensions) > 0 { + var extType uint16 + var extBody cryptobyte.String + if !extensions.ReadUint16(&extType) || + !extensions.ReadUint16LengthPrefixed(&extBody) { + newExtensions.SetError(errors.New("tls: error parsing EncodedClientHelloInner")) + return + } + if extType != extensionECHOuterExtensions { + newExtensions.AddUint16(extType) + addUint16LengthPrefixedBytes(newExtensions, extBody) + continue } - var foundExt uint16 - var newExtBody []byte - if !outerExtensions.readU16(&foundExt) || - !outerExtensions.readU16LengthPrefixedBytes(&newExtBody) { - return nil, errors.New("tls: error parsing ClientHelloOuter") + if seenOuterExtensions { + newExtensions.SetError(errors.New("tls: duplicate ech_outer_extensions extension")) + return } - if foundExt == newExtType { - newExtensions.addU16(newExtType) - newExtensions.addU16LengthPrefixed().addBytes(newExtBody) - copied[newExtType] = struct{}{} - break + seenOuterExtensions = true + var extList cryptobyte.String + if !extBody.ReadUint8LengthPrefixed(&extList) || len(extList) == 0 || len(extBody) != 0 { + newExtensions.SetError(errors.New("tls: error parsing ech_outer_extensions")) + return + } + for len(extList) != 0 { + var newExtType uint16 + if !extList.ReadUint16(&newExtType) { + newExtensions.SetError(errors.New("tls: error parsing ech_outer_extensions")) + return + } + if newExtType == extensionEncryptedClientHello { + newExtensions.SetError(errors.New("tls: error parsing ech_outer_extensions")) + return + } + for { + if len(outerExtensions) == 0 { + newExtensions.SetError(fmt.Errorf("tls: extension %d not found in ClientHelloOuter", newExtType)) + return + } + var foundExt uint16 + var newExtBody []byte + if !outerExtensions.ReadUint16(&foundExt) || + !readUint16LengthPrefixedBytes(&outerExtensions, &newExtBody) { + newExtensions.SetError(errors.New("tls: error parsing ClientHelloOuter")) + return + } + if foundExt == newExtType { + newExtensions.AddUint16(newExtType) + addUint16LengthPrefixedBytes(newExtensions, newExtBody) + copied[newExtType] = struct{}{} + break + } + } } } - } + }) + }) + + bytes, err := builder.Bytes() + if err != nil { + return nil, err } for _, expected := range config.Bugs.ExpectECHOuterExtensions { @@ -1304,9 +1147,10 @@ func decodeClientHelloInner(config *Config, encoded []byte, helloOuter *clientHe } ret := new(clientHelloMsg) - if !ret.unmarshal(builder.finish()) { + if !ret.unmarshal(bytes) { return nil, errors.New("tls: error parsing reconstructed ClientHello") } + return ret, nil } @@ -1337,102 +1181,100 @@ func (m *serverHelloMsg) marshal() []byte { return m.raw } - handshakeMsg := newByteBuilder() - handshakeMsg.addU8(typeServerHello) - hello := handshakeMsg.addU24LengthPrefixed() - - // m.vers is used both to determine the format of the rest of the - // ServerHello and to override the value, so include a second version - // field. - vers, ok := wireToVersion(m.vers, m.isDTLS) - if !ok { - panic("unknown version") - } - if m.versOverride != 0 { - hello.addU16(m.versOverride) - } else if vers >= VersionTLS13 { - hello.addU16(VersionTLS12) - } else { - hello.addU16(m.vers) - } - - hello.addBytes(m.random) - sessionID := hello.addU8LengthPrefixed() - sessionID.addBytes(m.sessionID) - hello.addU16(m.cipherSuite) - hello.addU8(m.compressionMethod) - - extensions := hello.addU16LengthPrefixed() - - if vers >= VersionTLS13 { - if m.hasKeyShare { - extensions.addU16(extensionKeyShare) - keyShare := extensions.addU16LengthPrefixed() - keyShare.addU16(uint16(m.keyShare.group)) - keyExchange := keyShare.addU16LengthPrefixed() - keyExchange.addBytes(m.keyShare.keyExchange) - } - if m.hasPSKIdentity { - extensions.addU16(extensionPreSharedKey) - extensions.addU16(2) // Length - extensions.addU16(m.pskIdentity) - } - if !m.omitSupportedVers { - extensions.addU16(extensionSupportedVersions) - extensions.addU16(2) // Length - if m.supportedVersOverride != 0 { - extensions.addU16(m.supportedVersOverride) + handshakeMsg := cryptobyte.NewBuilder(nil) + handshakeMsg.AddUint8(typeServerHello) + handshakeMsg.AddUint24LengthPrefixed(func(hello *cryptobyte.Builder) { + // m.vers is used both to determine the format of the rest of the + // ServerHello and to override the value, so include a second version + // field. + vers, ok := wireToVersion(m.vers, m.isDTLS) + if !ok { + panic("unknown version") + } + if m.versOverride != 0 { + hello.AddUint16(m.versOverride) + } else if vers >= VersionTLS13 { + hello.AddUint16(VersionTLS12) + } else { + hello.AddUint16(m.vers) + } + + hello.AddBytes(m.random) + addUint8LengthPrefixedBytes(hello, m.sessionID) + hello.AddUint16(m.cipherSuite) + hello.AddUint8(m.compressionMethod) + + hello.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + if vers >= VersionTLS13 { + if m.hasKeyShare { + extensions.AddUint16(extensionKeyShare) + extensions.AddUint16LengthPrefixed(func(keyShare *cryptobyte.Builder) { + keyShare.AddUint16(uint16(m.keyShare.group)) + addUint16LengthPrefixedBytes(keyShare, m.keyShare.keyExchange) + }) + } + if m.hasPSKIdentity { + extensions.AddUint16(extensionPreSharedKey) + extensions.AddUint16(2) // Length + extensions.AddUint16(m.pskIdentity) + } + if !m.omitSupportedVers { + extensions.AddUint16(extensionSupportedVersions) + extensions.AddUint16(2) // Length + if m.supportedVersOverride != 0 { + extensions.AddUint16(m.supportedVersOverride) + } else { + extensions.AddUint16(m.vers) + } + } + if len(m.customExtension) > 0 { + extensions.AddUint16(extensionCustom) + addUint16LengthPrefixedBytes(extensions, []byte(m.customExtension)) + } + if len(m.unencryptedALPN) > 0 { + extensions.AddUint16(extensionALPN) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(protocolNameList *cryptobyte.Builder) { + addUint8LengthPrefixedBytes(protocolNameList, []byte(m.unencryptedALPN)) + }) + }) + } } else { - extensions.addU16(m.vers) - } - } - if len(m.customExtension) > 0 { - extensions.addU16(extensionCustom) - customExt := extensions.addU16LengthPrefixed() - customExt.addBytes([]byte(m.customExtension)) - } - if len(m.unencryptedALPN) > 0 { - extensions.addU16(extensionALPN) - extension := extensions.addU16LengthPrefixed() - - protocolNameList := extension.addU16LengthPrefixed() - protocolName := protocolNameList.addU8LengthPrefixed() - protocolName.addBytes([]byte(m.unencryptedALPN)) - } - } else { - m.extensions.marshal(extensions) - if m.omitExtensions || m.emptyExtensions { - // Silently erasing server extensions will break the handshake. Instead, - // assert that tests which use this field also disable all features which - // would write an extension. - if extensions.len() != 0 { - panic(fmt.Sprintf("ServerHello unexpectedly contained extensions: %x, %+v", extensions.data(), m)) - } - hello.discardChild() - if m.emptyExtensions { - hello.addU16(0) + m.extensions.marshal(extensions) + } + if m.omitExtensions || m.emptyExtensions { + // Silently erasing server extensions will break the handshake. Instead, + // assert that tests which use this field also disable all features which + // would write an extension. Note the length includes the length prefix. + if b := extensions.BytesOrPanic(); len(b) != 2 { + panic(fmt.Sprintf("ServerHello unexpectedly contained extensions: %x, %+v", b, m)) + } } + }) + // Remove the length prefix. + if m.omitExtensions { + hello.Unwrite(2) } - } + }) - m.raw = handshakeMsg.finish() + m.raw = handshakeMsg.BytesOrPanic() return m.raw } func (m *serverHelloMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) - if !reader.readU16(&m.vers) || - !reader.readBytes(&m.random, 32) { + reader := cryptobyte.String(data[4:]) + if !reader.ReadUint16(&m.vers) || + !reader.ReadBytes(&m.random, 32) { return false } vers, ok := wireToVersion(m.vers, m.isDTLS) if !ok { return false } - if !reader.readU8LengthPrefixedBytes(&m.sessionID) || - !reader.readU16(&m.cipherSuite) || - !reader.readU8(&m.compressionMethod) { + if !readUint8LengthPrefixedBytes(&reader, &m.sessionID) || + !reader.ReadUint16(&m.cipherSuite) || + !reader.ReadUint8(&m.compressionMethod) { return false } @@ -1443,8 +1285,8 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return true } - var extensions byteReader - if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 || !checkDuplicateExtensions(extensions) { + var extensions cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 || !checkDuplicateExtensions(extensions) { return false } @@ -1453,13 +1295,13 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { extensionsCopy := extensions for len(extensionsCopy) > 0 { var extension uint16 - var body byteReader - if !extensionsCopy.readU16(&extension) || - !extensionsCopy.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensionsCopy.ReadUint16(&extension) || + !extensionsCopy.ReadUint16LengthPrefixed(&body) { return false } if extension == extensionSupportedVersions { - if !body.readU16(&m.vers) || len(body) != 0 { + if !body.ReadUint16(&m.vers) || len(body) != 0 { return false } vers, ok = wireToVersion(m.vers, m.isDTLS) @@ -1473,23 +1315,23 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { if vers >= VersionTLS13 { for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { case extensionKeyShare: m.hasKeyShare = true var group uint16 - if !body.readU16(&group) || - !body.readU16LengthPrefixedBytes(&m.keyShare.keyExchange) || + if !body.ReadUint16(&group) || + !readUint16LengthPrefixedBytes(&body, &m.keyShare.keyExchange) || len(body) != 0 { return false } m.keyShare.group = CurveID(group) case extensionPreSharedKey: - if !body.readU16(&m.pskIdentity) || len(body) != 0 { + if !body.ReadUint16(&m.pskIdentity) || len(body) != 0 { return false } m.hasPSKIdentity = true @@ -1519,23 +1361,25 @@ func (m *encryptedExtensionsMsg) marshal() []byte { return m.raw } - encryptedExtensionsMsg := newByteBuilder() - encryptedExtensionsMsg.addU8(typeEncryptedExtensions) - encryptedExtensions := encryptedExtensionsMsg.addU24LengthPrefixed() - if !m.empty { - extensions := encryptedExtensions.addU16LengthPrefixed() - m.extensions.marshal(extensions) - } + encryptedExtensionsMsg := cryptobyte.NewBuilder(nil) + encryptedExtensionsMsg.AddUint8(typeEncryptedExtensions) + encryptedExtensionsMsg.AddUint24LengthPrefixed(func(encryptedExtensions *cryptobyte.Builder) { + if !m.empty { + encryptedExtensions.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + m.extensions.marshal(extensions) + }) + } + }) - m.raw = encryptedExtensionsMsg.finish() + m.raw = encryptedExtensionsMsg.BytesOrPanic() return m.raw } func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) - var extensions byteReader - if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 { + reader := cryptobyte.String(data[4:]) + var extensions cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 { return false } return m.extensions.unmarshal(extensions, VersionTLS13) @@ -1571,147 +1415,137 @@ type serverExtensions struct { echRetryConfigs []byte } -func (m *serverExtensions) marshal(extensions *byteBuilder) { +func (m *serverExtensions) marshal(extensions *cryptobyte.Builder) { if m.duplicateExtension { // Add a duplicate bogus extension at the beginning and end. - extensions.addU16(extensionDuplicate) - extensions.addU16(0) // length = 0 for empty extension + extensions.AddUint16(extensionDuplicate) + extensions.AddUint16(0) // length = 0 for empty extension } if m.nextProtoNeg && !m.npnAfterAlpn { - extensions.addU16(extensionNextProtoNeg) - extension := extensions.addU16LengthPrefixed() - - for _, v := range m.nextProtos { - if len(v) > 255 { - v = v[:255] + extensions.AddUint16(extensionNextProtoNeg) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + for _, v := range m.nextProtos { + addUint8LengthPrefixedBytes(extension, []byte(v)) } - npn := extension.addU8LengthPrefixed() - npn.addBytes([]byte(v)) - } + }) } if m.ocspStapling { - extensions.addU16(extensionStatusRequest) - extensions.addU16(0) + extensions.AddUint16(extensionStatusRequest) + extensions.AddUint16(0) } if m.ticketSupported { - extensions.addU16(extensionSessionTicket) - extensions.addU16(0) + extensions.AddUint16(extensionSessionTicket) + extensions.AddUint16(0) } if m.secureRenegotiation != nil { - extensions.addU16(extensionRenegotiationInfo) - extension := extensions.addU16LengthPrefixed() - secureRenego := extension.addU8LengthPrefixed() - secureRenego.addBytes(m.secureRenegotiation) + extensions.AddUint16(extensionRenegotiationInfo) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + addUint8LengthPrefixedBytes(extension, m.secureRenegotiation) + }) } if len(m.alpnProtocol) > 0 || m.alpnProtocolEmpty { - extensions.addU16(extensionALPN) - extension := extensions.addU16LengthPrefixed() - - protocolNameList := extension.addU16LengthPrefixed() - protocolName := protocolNameList.addU8LengthPrefixed() - protocolName.addBytes([]byte(m.alpnProtocol)) + extensions.AddUint16(extensionALPN) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(protocolNameList *cryptobyte.Builder) { + addUint8LengthPrefixedBytes(protocolNameList, []byte(m.alpnProtocol)) + }) + }) } if m.channelIDRequested { - extensions.addU16(extensionChannelID) - extensions.addU16(0) + extensions.AddUint16(extensionChannelID) + extensions.AddUint16(0) } if m.duplicateExtension { // Add a duplicate bogus extension at the beginning and end. - extensions.addU16(extensionDuplicate) - extensions.addU16(0) + extensions.AddUint16(extensionDuplicate) + extensions.AddUint16(0) } if m.extendedMasterSecret { - extensions.addU16(extensionExtendedMasterSecret) - extensions.addU16(0) + extensions.AddUint16(extensionExtendedMasterSecret) + extensions.AddUint16(0) } if m.srtpProtectionProfile != 0 { - extensions.addU16(extensionUseSRTP) - extension := extensions.addU16LengthPrefixed() - - srtpProtectionProfiles := extension.addU16LengthPrefixed() - srtpProtectionProfiles.addU16(m.srtpProtectionProfile) - srtpMki := extension.addU8LengthPrefixed() - srtpMki.addBytes([]byte(m.srtpMasterKeyIdentifier)) + extensions.AddUint16(extensionUseSRTP) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(srtpProtectionProfiles *cryptobyte.Builder) { + srtpProtectionProfiles.AddUint16(m.srtpProtectionProfile) + }) + addUint8LengthPrefixedBytes(extension, []byte(m.srtpMasterKeyIdentifier)) + }) } if m.sctList != nil { - extensions.addU16(extensionSignedCertificateTimestamp) - extension := extensions.addU16LengthPrefixed() - extension.addBytes(m.sctList) + extensions.AddUint16(extensionSignedCertificateTimestamp) + addUint16LengthPrefixedBytes(extensions, m.sctList) } if l := len(m.customExtension); l > 0 { - extensions.addU16(extensionCustom) - customExt := extensions.addU16LengthPrefixed() - customExt.addBytes([]byte(m.customExtension)) + extensions.AddUint16(extensionCustom) + addUint16LengthPrefixedBytes(extensions, []byte(m.customExtension)) } if m.nextProtoNeg && m.npnAfterAlpn { - extensions.addU16(extensionNextProtoNeg) - extension := extensions.addU16LengthPrefixed() - - for _, v := range m.nextProtos { - if len(v) > 255 { - v = v[0:255] + extensions.AddUint16(extensionNextProtoNeg) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + for _, v := range m.nextProtos { + addUint8LengthPrefixedBytes(extension, []byte(v)) } - npn := extension.addU8LengthPrefixed() - npn.addBytes([]byte(v)) - } + }) } if m.hasKeyShare { - extensions.addU16(extensionKeyShare) - keyShare := extensions.addU16LengthPrefixed() - keyShare.addU16(uint16(m.keyShare.group)) - keyExchange := keyShare.addU16LengthPrefixed() - keyExchange.addBytes(m.keyShare.keyExchange) + extensions.AddUint16(extensionKeyShare) + extensions.AddUint16LengthPrefixed(func(keyShare *cryptobyte.Builder) { + keyShare.AddUint16(uint16(m.keyShare.group)) + addUint16LengthPrefixedBytes(keyShare, m.keyShare.keyExchange) + }) } if m.supportedVersion != 0 { - extensions.addU16(extensionSupportedVersions) - extensions.addU16(2) // Length - extensions.addU16(m.supportedVersion) + extensions.AddUint16(extensionSupportedVersions) + extensions.AddUint16(2) // Length + extensions.AddUint16(m.supportedVersion) } if len(m.supportedPoints) > 0 { // http://tools.ietf.org/html/rfc4492#section-5.1.2 - extensions.addU16(extensionSupportedPoints) - supportedPointsList := extensions.addU16LengthPrefixed() - supportedPoints := supportedPointsList.addU8LengthPrefixed() - supportedPoints.addBytes(m.supportedPoints) + extensions.AddUint16(extensionSupportedPoints) + extensions.AddUint16LengthPrefixed(func(supportedPointsList *cryptobyte.Builder) { + addUint8LengthPrefixedBytes(supportedPointsList, m.supportedPoints) + }) } if len(m.supportedCurves) > 0 { // https://tools.ietf.org/html/rfc8446#section-4.2.7 - extensions.addU16(extensionSupportedCurves) - supportedCurvesList := extensions.addU16LengthPrefixed() - supportedCurves := supportedCurvesList.addU16LengthPrefixed() - for _, curve := range m.supportedCurves { - supportedCurves.addU16(uint16(curve)) - } + extensions.AddUint16(extensionSupportedCurves) + extensions.AddUint16LengthPrefixed(func(supportedCurvesList *cryptobyte.Builder) { + supportedCurvesList.AddUint16LengthPrefixed(func(supportedCurves *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + supportedCurves.AddUint16(uint16(curve)) + } + }) + }) } if len(m.quicTransportParams) > 0 { - extensions.addU16(extensionQUICTransportParams) - params := extensions.addU16LengthPrefixed() - params.addBytes(m.quicTransportParams) + extensions.AddUint16(extensionQUICTransportParams) + addUint16LengthPrefixedBytes(extensions, m.quicTransportParams) } if len(m.quicTransportParamsLegacy) > 0 { - extensions.addU16(extensionQUICTransportParamsLegacy) - params := extensions.addU16LengthPrefixed() - params.addBytes(m.quicTransportParamsLegacy) + extensions.AddUint16(extensionQUICTransportParamsLegacy) + addUint16LengthPrefixedBytes(extensions, m.quicTransportParamsLegacy) } if m.hasEarlyData { - extensions.addU16(extensionEarlyData) - extensions.addBytes([]byte{0, 0}) + extensions.AddUint16(extensionEarlyData) + extensions.AddBytes([]byte{0, 0}) } if m.serverNameAck { - extensions.addU16(extensionServerName) - extensions.addU16(0) // zero length + extensions.AddUint16(extensionServerName) + extensions.AddUint16(0) // zero length } if m.hasApplicationSettings { - extensions.addU16(extensionApplicationSettings) - extensions.addU16LengthPrefixed().addBytes(m.applicationSettings) + extensions.AddUint16(extensionApplicationSettings) + addUint16LengthPrefixedBytes(extensions, m.applicationSettings) } if len(m.echRetryConfigs) > 0 { - extensions.addU16(extensionEncryptedClientHello) - extensions.addU16LengthPrefixed().addBytes(m.echRetryConfigs) + extensions.AddUint16(extensionEncryptedClientHello) + addUint16LengthPrefixedBytes(extensions, m.echRetryConfigs) } } -func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { +func (m *serverExtensions) unmarshal(data cryptobyte.String, version uint16) bool { // Reset all fields. *m = serverExtensions{} @@ -1721,9 +1555,9 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { for len(data) > 0 { var extension uint16 - var body byteReader - if !data.readU16(&extension) || - !data.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !data.ReadUint16(&extension) || + !data.ReadUint16LengthPrefixed(&body) { return false } switch extension { @@ -1731,7 +1565,7 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { m.nextProtoNeg = true for len(body) > 0 { var protocol []byte - if !body.readU8LengthPrefixedBytes(&protocol) { + if !readUint8LengthPrefixedBytes(&body, &protocol) { return false } m.nextProtos = append(m.nextProtos, string(protocol)) @@ -1747,14 +1581,14 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { } m.ticketSupported = true case extensionRenegotiationInfo: - if !body.readU8LengthPrefixedBytes(&m.secureRenegotiation) || len(body) != 0 { + if !readUint8LengthPrefixedBytes(&body, &m.secureRenegotiation) || len(body) != 0 { return false } case extensionALPN: - var protocols, protocol byteReader - if !body.readU16LengthPrefixed(&protocols) || + var protocols, protocol cryptobyte.String + if !body.ReadUint16LengthPrefixed(&protocols) || len(body) != 0 || - !protocols.readU8LengthPrefixed(&protocol) || + !protocols.ReadUint8LengthPrefixed(&protocol) || len(protocols) != 0 { return false } @@ -1771,11 +1605,11 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { } m.extendedMasterSecret = true case extensionUseSRTP: - var profiles, mki byteReader - if !body.readU16LengthPrefixed(&profiles) || - !profiles.readU16(&m.srtpProtectionProfile) || + var profiles, mki cryptobyte.String + if !body.ReadUint16LengthPrefixed(&profiles) || + !profiles.ReadUint16(&m.srtpProtectionProfile) || len(profiles) != 0 || - !body.readU8LengthPrefixed(&mki) || + !body.ReadUint8LengthPrefixed(&mki) || len(body) != 0 { return false } @@ -1795,7 +1629,7 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { return false } // http://tools.ietf.org/html/rfc4492#section-5.5.2 - if !body.readU8LengthPrefixedBytes(&m.supportedPoints) || len(body) != 0 { + if !readUint8LengthPrefixedBytes(&body, &m.supportedPoints) || len(body) != 0 { return false } case extensionSupportedCurves: @@ -1822,15 +1656,15 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { m.echRetryConfigs = body // Validate the ECHConfig with a top-level parse. - var echConfigs byteReader - if !body.readU16LengthPrefixed(&echConfigs) { + var echConfigs cryptobyte.String + if !body.ReadUint16LengthPrefixed(&echConfigs) { return false } for len(echConfigs) > 0 { var version uint16 - var contents byteReader - if !echConfigs.readU16(&version) || - !echConfigs.readU16LengthPrefixed(&contents) { + var contents cryptobyte.String + if !echConfigs.ReadUint16(&version) || + !echConfigs.ReadUint16LengthPrefixed(&contents) { return false } } @@ -1858,29 +1692,31 @@ func (m *clientEncryptedExtensionsMsg) marshal() (x []byte) { return m.raw } - builder := newByteBuilder() - builder.addU8(typeEncryptedExtensions) - body := builder.addU24LengthPrefixed() - extensions := body.addU16LengthPrefixed() - if m.hasApplicationSettings { - extensions.addU16(extensionApplicationSettings) - extensions.addU16LengthPrefixed().addBytes(m.applicationSettings) - } - if len(m.customExtension) > 0 { - extensions.addU16(extensionCustom) - extensions.addU16LengthPrefixed().addBytes(m.customExtension) - } + builder := cryptobyte.NewBuilder(nil) + builder.AddUint8(typeEncryptedExtensions) + builder.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + body.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + if m.hasApplicationSettings { + extensions.AddUint16(extensionApplicationSettings) + addUint16LengthPrefixedBytes(extensions, m.applicationSettings) + } + if len(m.customExtension) > 0 { + extensions.AddUint16(extensionCustom) + addUint16LengthPrefixedBytes(extensions, m.customExtension) + } + }) + }) - m.raw = builder.finish() + m.raw = builder.BytesOrPanic() return m.raw } func (m *clientEncryptedExtensionsMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) - var extensions byteReader - if !reader.readU16LengthPrefixed(&extensions) || + var extensions cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 { return false } @@ -1891,9 +1727,9 @@ func (m *clientEncryptedExtensionsMsg) unmarshal(data []byte) bool { for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { @@ -1928,92 +1764,93 @@ func (m *helloRetryRequestMsg) marshal() []byte { return m.raw } - retryRequestMsg := newByteBuilder() - retryRequestMsg.addU8(typeServerHello) - retryRequest := retryRequestMsg.addU24LengthPrefixed() - retryRequest.addU16(VersionTLS12) - retryRequest.addBytes(tls13HelloRetryRequest) - sessionID := retryRequest.addU8LengthPrefixed() - sessionID.addBytes(m.sessionID) - retryRequest.addU16(m.cipherSuite) - retryRequest.addU8(m.compressionMethod) - - extensions := retryRequest.addU16LengthPrefixed() + retryRequestMsg := cryptobyte.NewBuilder(nil) + retryRequestMsg.AddUint8(typeServerHello) + retryRequestMsg.AddUint24LengthPrefixed(func(retryRequest *cryptobyte.Builder) { + retryRequest.AddUint16(VersionTLS12) + retryRequest.AddBytes(tls13HelloRetryRequest) + addUint8LengthPrefixedBytes(retryRequest, m.sessionID) + retryRequest.AddUint16(m.cipherSuite) + retryRequest.AddUint8(m.compressionMethod) - count := 1 - if m.duplicateExtensions { - count = 2 - } + retryRequest.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + count := 1 + if m.duplicateExtensions { + count = 2 + } - for i := 0; i < count; i++ { - extensions.addU16(extensionSupportedVersions) - extensions.addU16(2) // Length - extensions.addU16(m.vers) - if m.hasSelectedGroup { - extensions.addU16(extensionKeyShare) - extensions.addU16(2) // length - extensions.addU16(uint16(m.selectedGroup)) - } - // m.cookie may be a non-nil empty slice for empty cookie tests. - if m.cookie != nil { - extensions.addU16(extensionCookie) - body := extensions.addU16LengthPrefixed() - body.addU16LengthPrefixed().addBytes(m.cookie) - } - if len(m.customExtension) > 0 { - extensions.addU16(extensionCustom) - extensions.addU16LengthPrefixed().addBytes([]byte(m.customExtension)) - } - if len(m.echConfirmation) > 0 { - extensions.addU16(extensionEncryptedClientHello) - extensions.addU16LengthPrefixed().addBytes(m.echConfirmation) - } - } + for i := 0; i < count; i++ { + extensions.AddUint16(extensionSupportedVersions) + extensions.AddUint16(2) // Length + extensions.AddUint16(m.vers) + if m.hasSelectedGroup { + extensions.AddUint16(extensionKeyShare) + extensions.AddUint16(2) // length + extensions.AddUint16(uint16(m.selectedGroup)) + } + // m.cookie may be a non-nil empty slice for empty cookie tests. + if m.cookie != nil { + extensions.AddUint16(extensionCookie) + extensions.AddUint16LengthPrefixed(func(body *cryptobyte.Builder) { + addUint16LengthPrefixedBytes(body, m.cookie) + }) + } + if len(m.customExtension) > 0 { + extensions.AddUint16(extensionCustom) + addUint16LengthPrefixedBytes(extensions, []byte(m.customExtension)) + } + if len(m.echConfirmation) > 0 { + extensions.AddUint16(extensionEncryptedClientHello) + addUint16LengthPrefixedBytes(extensions, m.echConfirmation) + } + } + }) + }) - m.raw = retryRequestMsg.finish() + m.raw = retryRequestMsg.BytesOrPanic() return m.raw } func (m *helloRetryRequestMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) var legacyVers uint16 var random []byte var compressionMethod byte - var extensions byteReader - if !reader.readU16(&legacyVers) || + var extensions cryptobyte.String + if !reader.ReadUint16(&legacyVers) || legacyVers != VersionTLS12 || - !reader.readBytes(&random, 32) || - !reader.readU8LengthPrefixedBytes(&m.sessionID) || - !reader.readU16(&m.cipherSuite) || - !reader.readU8(&compressionMethod) || + !reader.ReadBytes(&random, 32) || + !readUint8LengthPrefixedBytes(&reader, &m.sessionID) || + !reader.ReadUint16(&m.cipherSuite) || + !reader.ReadUint8(&compressionMethod) || compressionMethod != 0 || - !reader.readU16LengthPrefixed(&extensions) || + !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 { return false } for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { case extensionSupportedVersions: - if !body.readU16(&m.vers) || + if !body.ReadUint16(&m.vers) || len(body) != 0 { return false } case extensionKeyShare: var v uint16 - if !body.readU16(&v) || len(body) != 0 { + if !body.ReadUint16(&v) || len(body) != 0 { return false } m.hasSelectedGroup = true m.selectedGroup = CurveID(v) case extensionCookie: - if !body.readU16LengthPrefixedBytes(&m.cookie) || + if !readUint16LengthPrefixedBytes(&body, &m.cookie) || len(m.cookie) == 0 || len(body) != 0 { return false @@ -2063,85 +1900,86 @@ func (m *certificateMsg) marshal() (x []byte) { return m.raw } - certMsg := newByteBuilder() - certMsg.addU8(typeCertificate) - certificate := certMsg.addU24LengthPrefixed() - if m.hasRequestContext { - context := certificate.addU8LengthPrefixed() - context.addBytes(m.requestContext) - } - certificateList := certificate.addU24LengthPrefixed() - for _, cert := range m.certificates { - certEntry := certificateList.addU24LengthPrefixed() - certEntry.addBytes(cert.data) + certMsg := cryptobyte.NewBuilder(nil) + certMsg.AddUint8(typeCertificate) + certMsg.AddUint24LengthPrefixed(func(certificate *cryptobyte.Builder) { if m.hasRequestContext { - extensions := certificateList.addU16LengthPrefixed() - count := 1 - if cert.duplicateExtensions { - count = 2 - } - - for i := 0; i < count; i++ { - if cert.ocspResponse != nil { - extensions.addU16(extensionStatusRequest) - body := extensions.addU16LengthPrefixed() - body.addU8(statusTypeOCSP) - response := body.addU24LengthPrefixed() - response.addBytes(cert.ocspResponse) - } - - if cert.sctList != nil { - extensions.addU16(extensionSignedCertificateTimestamp) - extension := extensions.addU16LengthPrefixed() - extension.addBytes(cert.sctList) + addUint8LengthPrefixedBytes(certificate, m.requestContext) + } + certificate.AddUint24LengthPrefixed(func(certificateList *cryptobyte.Builder) { + for _, cert := range m.certificates { + addUint24LengthPrefixedBytes(certificateList, cert.data) + if m.hasRequestContext { + certificateList.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + count := 1 + if cert.duplicateExtensions { + count = 2 + } + + for i := 0; i < count; i++ { + if cert.ocspResponse != nil { + extensions.AddUint16(extensionStatusRequest) + extensions.AddUint16LengthPrefixed(func(body *cryptobyte.Builder) { + body.AddUint8(statusTypeOCSP) + addUint24LengthPrefixedBytes(body, cert.ocspResponse) + }) + } + + if cert.sctList != nil { + extensions.AddUint16(extensionSignedCertificateTimestamp) + addUint16LengthPrefixedBytes(extensions, cert.sctList) + } + } + if cert.extraExtension != nil { + extensions.AddBytes(cert.extraExtension) + } + }) } } - if cert.extraExtension != nil { - extensions.addBytes(cert.extraExtension) - } - } - } + }) - m.raw = certMsg.finish() + }) + + m.raw = certMsg.BytesOrPanic() return m.raw } func (m *certificateMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) - if m.hasRequestContext && !reader.readU8LengthPrefixedBytes(&m.requestContext) { + if m.hasRequestContext && !readUint8LengthPrefixedBytes(&reader, &m.requestContext) { return false } - var certs byteReader - if !reader.readU24LengthPrefixed(&certs) || len(reader) != 0 { + var certs cryptobyte.String + if !reader.ReadUint24LengthPrefixed(&certs) || len(reader) != 0 { return false } m.certificates = nil for len(certs) > 0 { var cert certificateEntry - if !certs.readU24LengthPrefixedBytes(&cert.data) { + if !readUint24LengthPrefixedBytes(&certs, &cert.data) { return false } if m.hasRequestContext { - var extensions byteReader - if !certs.readU16LengthPrefixed(&extensions) || !checkDuplicateExtensions(extensions) { + var extensions cryptobyte.String + if !certs.ReadUint16LengthPrefixed(&extensions) || !checkDuplicateExtensions(extensions) { return false } for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { case extensionStatusRequest: var statusType byte - if !body.readU8(&statusType) || + if !body.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !body.readU24LengthPrefixedBytes(&cert.ocspResponse) || + !readUint24LengthPrefixedBytes(&body, &cert.ocspResponse) || len(body) != 0 { return false } @@ -2157,11 +1995,11 @@ func (m *certificateMsg) unmarshal(data []byte) bool { origBody := body var expectedCertVerifyAlgo, algorithm uint16 - if !body.readU32(&dc.lifetimeSecs) || - !body.readU16(&expectedCertVerifyAlgo) || - !body.readU24LengthPrefixedBytes(&dc.pkixPublicKey) || - !body.readU16(&algorithm) || - !body.readU16LengthPrefixedBytes(&dc.signature) || + if !body.ReadUint32(&dc.lifetimeSecs) || + !body.ReadUint16(&expectedCertVerifyAlgo) || + !readUint24LengthPrefixedBytes(&body, &dc.pkixPublicKey) || + !body.ReadUint16(&algorithm) || + !readUint16LengthPrefixedBytes(&body, &dc.signature) || len(body) != 0 { return false } @@ -2193,25 +2031,25 @@ func (m *compressedCertificateMsg) marshal() (x []byte) { return m.raw } - certMsg := newByteBuilder() - certMsg.addU8(typeCompressedCertificate) - certificate := certMsg.addU24LengthPrefixed() - certificate.addU16(m.algID) - certificate.addU24(int(m.uncompressedLength)) - compressed := certificate.addU24LengthPrefixed() - compressed.addBytes(m.compressed) + certMsg := cryptobyte.NewBuilder(nil) + certMsg.AddUint8(typeCompressedCertificate) + certMsg.AddUint24LengthPrefixed(func(certificate *cryptobyte.Builder) { + certificate.AddUint16(m.algID) + certificate.AddUint24(m.uncompressedLength) + addUint24LengthPrefixedBytes(certificate, m.compressed) + }) - m.raw = certMsg.finish() + m.raw = certMsg.BytesOrPanic() return m.raw } func (m *compressedCertificateMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) - if !reader.readU16(&m.algID) || - !reader.readU24(&m.uncompressedLength) || - !reader.readU24LengthPrefixedBytes(&m.compressed) || + if !reader.ReadUint16(&m.algID) || + !reader.ReadUint24(&m.uncompressedLength) || + !readUint24LengthPrefixedBytes(&reader, &m.compressed) || len(reader) != 0 { return false } @@ -2232,10 +2070,10 @@ func (m *serverKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw } - msg := newByteBuilder() - msg.addU8(typeServerKeyExchange) - msg.addU24LengthPrefixed().addBytes(m.key) - m.raw = msg.finish() + msg := cryptobyte.NewBuilder(nil) + msg.AddUint8(typeServerKeyExchange) + addUint24LengthPrefixedBytes(msg, m.key) + m.raw = msg.BytesOrPanic() return m.raw } @@ -2261,12 +2099,13 @@ func (m *certificateStatusMsg) marshal() []byte { var x []byte if m.statusType == statusTypeOCSP { - msg := newByteBuilder() - msg.addU8(typeCertificateStatus) - body := msg.addU24LengthPrefixed() - body.addU8(statusTypeOCSP) - body.addU24LengthPrefixed().addBytes(m.response) - x = msg.finish() + msg := cryptobyte.NewBuilder(nil) + msg.AddUint8(typeCertificateStatus) + msg.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + body.AddUint8(statusTypeOCSP) + addUint24LengthPrefixedBytes(body, m.response) + }) + x = msg.BytesOrPanic() } else { x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType} } @@ -2277,10 +2116,10 @@ func (m *certificateStatusMsg) marshal() []byte { func (m *certificateStatusMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) - if !reader.readU8(&m.statusType) || + reader := cryptobyte.String(data[4:]) + if !reader.ReadUint8(&m.statusType) || m.statusType != statusTypeOCSP || - !reader.readU24LengthPrefixedBytes(&m.response) || + !readUint24LengthPrefixedBytes(&reader, &m.response) || len(reader) != 0 { return false } @@ -2308,10 +2147,10 @@ func (m *clientKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw } - msg := newByteBuilder() - msg.addU8(typeClientKeyExchange) - msg.addU24LengthPrefixed().addBytes(m.ciphertext) - m.raw = msg.finish() + msg := cryptobyte.NewBuilder(nil) + msg.AddUint8(typeClientKeyExchange) + addUint24LengthPrefixedBytes(msg, m.ciphertext) + m.raw = msg.BytesOrPanic() return m.raw } @@ -2338,10 +2177,10 @@ func (m *finishedMsg) marshal() []byte { return m.raw } - msg := newByteBuilder() - msg.addU8(typeFinished) - msg.addU24LengthPrefixed().addBytes(m.verifyData) - m.raw = msg.finish() + msg := cryptobyte.NewBuilder(nil) + msg.AddUint8(typeFinished) + addUint24LengthPrefixedBytes(msg, m.verifyData) + m.raw = msg.BytesOrPanic() return m.raw } @@ -2366,21 +2205,22 @@ func (m *nextProtoMsg) marshal() []byte { padding := 32 - (len(m.proto)+2)%32 - msg := newByteBuilder() - msg.addU8(typeNextProtocol) - body := msg.addU24LengthPrefixed() - body.addU8LengthPrefixed().addBytes([]byte(m.proto)) - body.addU8LengthPrefixed().addBytes(make([]byte, padding)) - m.raw = msg.finish() + msg := cryptobyte.NewBuilder(nil) + msg.AddUint8(typeNextProtocol) + msg.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + addUint8LengthPrefixedBytes(body, []byte(m.proto)) + addUint8LengthPrefixedBytes(body, make([]byte, padding)) + }) + m.raw = msg.BytesOrPanic() return m.raw } func (m *nextProtoMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) var proto, padding []byte - if !reader.readU8LengthPrefixedBytes(&proto) || - !reader.readU8LengthPrefixedBytes(&padding) || + if !readUint8LengthPrefixedBytes(&reader, &proto) || + !readUint8LengthPrefixedBytes(&reader, &padding) || len(reader) != 0 { return false } @@ -2427,72 +2267,79 @@ func (m *certificateRequestMsg) marshal() []byte { } // See http://tools.ietf.org/html/rfc4346#section-7.4.4 - builder := newByteBuilder() - builder.addU8(typeCertificateRequest) - body := builder.addU24LengthPrefixed() - - if m.hasRequestContext { - requestContext := body.addU8LengthPrefixed() - requestContext.addBytes(m.requestContext) - extensions := newByteBuilder() - extensions = body.addU16LengthPrefixed() - if m.hasSignatureAlgorithm { - extensions.addU16(extensionSignatureAlgorithms) - signatureAlgorithms := extensions.addU16LengthPrefixed().addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithms { - signatureAlgorithms.addU16(uint16(sigAlg)) - } - } - if len(m.signatureAlgorithmsCert) > 0 { - extensions.addU16(extensionSignatureAlgorithmsCert) - signatureAlgorithmsCert := extensions.addU16LengthPrefixed().addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithmsCert { - signatureAlgorithmsCert.addU16(uint16(sigAlg)) - } - } - if len(m.certificateAuthorities) > 0 { - extensions.addU16(extensionCertificateAuthorities) - certificateAuthorities := extensions.addU16LengthPrefixed().addU16LengthPrefixed() - for _, ca := range m.certificateAuthorities { - caEntry := certificateAuthorities.addU16LengthPrefixed() - caEntry.addBytes(ca) - } - } + builder := cryptobyte.NewBuilder(nil) + builder.AddUint8(typeCertificateRequest) + builder.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + if m.hasRequestContext { + addUint8LengthPrefixedBytes(body, m.requestContext) + body.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + if m.hasSignatureAlgorithm { + extensions.AddUint16(extensionSignatureAlgorithms) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(signatureAlgorithms *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithms { + signatureAlgorithms.AddUint16(uint16(sigAlg)) + } + }) + }) + } + if len(m.signatureAlgorithmsCert) > 0 { + extensions.AddUint16(extensionSignatureAlgorithmsCert) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(signatureAlgorithmsCert *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithmsCert { + signatureAlgorithmsCert.AddUint16(uint16(sigAlg)) + } + }) + }) + } + if len(m.certificateAuthorities) > 0 { + extensions.AddUint16(extensionCertificateAuthorities) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(certificateAuthorities *cryptobyte.Builder) { + for _, ca := range m.certificateAuthorities { + addUint16LengthPrefixedBytes(certificateAuthorities, ca) + } + }) + }) + } - if m.customExtension > 0 { - extensions.addU16(m.customExtension) - extensions.addU16LengthPrefixed() - } - } else { - certificateTypes := body.addU8LengthPrefixed() - certificateTypes.addBytes(m.certificateTypes) + if m.customExtension > 0 { + extensions.AddUint16(m.customExtension) + extensions.AddUint16(0) // Empty extension + } + }) + } else { + addUint8LengthPrefixedBytes(body, m.certificateTypes) - if m.hasSignatureAlgorithm { - signatureAlgorithms := body.addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithms { - signatureAlgorithms.addU16(uint16(sigAlg)) + if m.hasSignatureAlgorithm { + body.AddUint16LengthPrefixed(func(signatureAlgorithms *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithms { + signatureAlgorithms.AddUint16(uint16(sigAlg)) + } + }) } - } - certificateAuthorities := body.addU16LengthPrefixed() - for _, ca := range m.certificateAuthorities { - caEntry := certificateAuthorities.addU16LengthPrefixed() - caEntry.addBytes(ca) + body.AddUint16LengthPrefixed(func(certificateAuthorities *cryptobyte.Builder) { + for _, ca := range m.certificateAuthorities { + addUint16LengthPrefixedBytes(certificateAuthorities, ca) + } + }) } - } + }) - m.raw = builder.finish() + m.raw = builder.BytesOrPanic() return m.raw } -func parseCAs(reader *byteReader, out *[][]byte) bool { - var cas byteReader - if !reader.readU16LengthPrefixed(&cas) { +func parseCAs(reader *cryptobyte.String, out *[][]byte) bool { + var cas cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&cas) { return false } for len(cas) > 0 { var ca []byte - if !cas.readU16LengthPrefixedBytes(&ca) { + if !readUint16LengthPrefixedBytes(&cas, &ca) { return false } *out = append(*out, ca) @@ -2502,21 +2349,21 @@ func parseCAs(reader *byteReader, out *[][]byte) bool { func (m *certificateRequestMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) if m.hasRequestContext { - var extensions byteReader - if !reader.readU8LengthPrefixedBytes(&m.requestContext) || - !reader.readU16LengthPrefixed(&extensions) || + var extensions cryptobyte.String + if !readUint8LengthPrefixedBytes(&reader, &m.requestContext) || + !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 || !checkDuplicateExtensions(extensions) { return false } for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { @@ -2536,7 +2383,7 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { } } } else { - if !reader.readU8LengthPrefixedBytes(&m.certificateTypes) { + if !readUint8LengthPrefixedBytes(&reader, &m.certificateTypes) { return false } // In TLS 1.2, the supported_signature_algorithms field in @@ -2648,35 +2495,40 @@ func (m *newSessionTicketMsg) marshal() []byte { } // See http://tools.ietf.org/html/rfc5077#section-3.3 - ticketMsg := newByteBuilder() - ticketMsg.addU8(typeNewSessionTicket) - body := ticketMsg.addU24LengthPrefixed() - body.addU32(m.ticketLifetime) - if version >= VersionTLS13 { - body.addU32(m.ticketAgeAdd) - body.addU8LengthPrefixed().addBytes(m.ticketNonce) - } - - ticket := body.addU16LengthPrefixed() - ticket.addBytes(m.ticket) - - if version >= VersionTLS13 { - extensions := body.addU16LengthPrefixed() - if m.maxEarlyDataSize > 0 { - extensions.addU16(extensionEarlyData) - extensions.addU16LengthPrefixed().addU32(m.maxEarlyDataSize) - if m.duplicateEarlyDataExtension { - extensions.addU16(extensionEarlyData) - extensions.addU16LengthPrefixed().addU32(m.maxEarlyDataSize) - } - } - if len(m.customExtension) > 0 { - extensions.addU16(extensionCustom) - extensions.addU16LengthPrefixed().addBytes([]byte(m.customExtension)) + ticketMsg := cryptobyte.NewBuilder(nil) + ticketMsg.AddUint8(typeNewSessionTicket) + ticketMsg.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + body.AddUint32(m.ticketLifetime) + if version >= VersionTLS13 { + body.AddUint32(m.ticketAgeAdd) + addUint8LengthPrefixedBytes(body, m.ticketNonce) + } + + addUint16LengthPrefixedBytes(body, m.ticket) + + if version >= VersionTLS13 { + body.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + if m.maxEarlyDataSize > 0 { + extensions.AddUint16(extensionEarlyData) + extensions.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) { + child.AddUint32(m.maxEarlyDataSize) + }) + if m.duplicateEarlyDataExtension { + extensions.AddUint16(extensionEarlyData) + extensions.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) { + child.AddUint32(m.maxEarlyDataSize) + }) + } + } + if len(m.customExtension) > 0 { + extensions.AddUint16(extensionCustom) + addUint16LengthPrefixedBytes(extensions, []byte(m.customExtension)) + } + }) } - } + }) - m.raw = ticketMsg.finish() + m.raw = ticketMsg.BytesOrPanic() return m.raw } diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go index d3ecf3bbd7..da39432868 100644 --- a/ssl/test/runner/handshake_server.go +++ b/ssl/test/runner/handshake_server.go @@ -20,6 +20,7 @@ import ( "time" "boringssl.googlesource.com/boringssl/ssl/test/runner/hpke" + "golang.org/x/crypto/cryptobyte" ) // serverHandshakeState contains details of a server handshake in progress. @@ -2435,18 +2436,18 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) } // Skip the handshake message header. - aReader := byteReader(a[4:]) - bReader := byteReader(b[4:]) + aReader := cryptobyte.String(a[4:]) + bReader := cryptobyte.String(b[4:]) var aVers, bVers uint16 var aRandom, bRandom []byte var aSessionID, bSessionID []byte - if !aReader.readU16(&aVers) || - !bReader.readU16(&bVers) || - !aReader.readBytes(&aRandom, 32) || - !bReader.readBytes(&bRandom, 32) || - !aReader.readU8LengthPrefixedBytes(&aSessionID) || - !bReader.readU8LengthPrefixedBytes(&bSessionID) { + if !aReader.ReadUint16(&aVers) || + !bReader.ReadUint16(&bVers) || + !aReader.ReadBytes(&aRandom, 32) || + !bReader.ReadBytes(&bRandom, 32) || + !readUint8LengthPrefixedBytes(&aReader, &aSessionID) || + !readUint8LengthPrefixedBytes(&bReader, &bSessionID) { return errors.New("tls: could not parse ClientHello") } @@ -2466,17 +2467,17 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) // cookie altogether. If we implement DTLS 1.3, we'll need to ensure // that parsing logic above this function rejects this cookie. var aCookie, bCookie []byte - if !aReader.readU8LengthPrefixedBytes(&aCookie) || - !bReader.readU8LengthPrefixedBytes(&bCookie) { + if !readUint8LengthPrefixedBytes(&aReader, &aCookie) || + !readUint8LengthPrefixedBytes(&bReader, &bCookie) { return errors.New("tls: could not parse ClientHello") } } var aCipherSuites, bCipherSuites, aCompressionMethods, bCompressionMethods []byte - if !aReader.readU16LengthPrefixedBytes(&aCipherSuites) || - !bReader.readU16LengthPrefixedBytes(&bCipherSuites) || - !aReader.readU8LengthPrefixedBytes(&aCompressionMethods) || - !bReader.readU8LengthPrefixedBytes(&bCompressionMethods) { + if !readUint16LengthPrefixedBytes(&aReader, &aCipherSuites) || + !readUint16LengthPrefixedBytes(&bReader, &bCipherSuites) || + !readUint8LengthPrefixedBytes(&aReader, &aCompressionMethods) || + !readUint8LengthPrefixedBytes(&bReader, &bCompressionMethods) { return errors.New("tls: could not parse ClientHello") } if !bytes.Equal(aCipherSuites, bCipherSuites) { @@ -2491,9 +2492,9 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) return nil } - var aExtensions, bExtensions byteReader - if !aReader.readU16LengthPrefixed(&aExtensions) || - !bReader.readU16LengthPrefixed(&bExtensions) || + var aExtensions, bExtensions cryptobyte.String + if !aReader.ReadUint16LengthPrefixed(&aExtensions) || + !bReader.ReadUint16LengthPrefixed(&bExtensions) || len(aReader) != 0 || len(bReader) != 0 { return errors.New("tls: could not parse ClientHello") @@ -2502,8 +2503,8 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) for len(aExtensions) != 0 { var aID uint16 var aBody []byte - if !aExtensions.readU16(&aID) || - !aExtensions.readU16LengthPrefixedBytes(&aBody) { + if !aExtensions.ReadUint16(&aID) || + !readUint16LengthPrefixedBytes(&aExtensions, &aBody) { return errors.New("tls: could not parse ClientHello") } if _, ok := ignoreExtensionsSet[aID]; ok { @@ -2516,8 +2517,8 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) } var bID uint16 var bBody []byte - if !bExtensions.readU16(&bID) || - !bExtensions.readU16LengthPrefixedBytes(&bBody) { + if !bExtensions.ReadUint16(&bID) || + !readUint16LengthPrefixedBytes(&bExtensions, &bBody) { return errors.New("tls: could not parse ClientHello") } if _, ok := ignoreExtensionsSet[bID]; ok { @@ -2538,8 +2539,8 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) for len(bExtensions) != 0 { var id uint16 var body []byte - if !bExtensions.readU16(&id) || - !bExtensions.readU16LengthPrefixedBytes(&body) { + if !bExtensions.ReadUint16(&id) || + !readUint16LengthPrefixedBytes(&bExtensions, &body) { return errors.New("tls: could not parse ClientHello") } if _, ok := ignoreExtensionsSet[id]; !ok { diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go index fc67d7503b..4cdc7c8c0b 100644 --- a/ssl/test/runner/prf.go +++ b/ssl/test/runner/prf.go @@ -13,6 +13,7 @@ import ( "encoding" "hash" + "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/hkdf" ) @@ -228,15 +229,15 @@ type finishedHash struct { } func (h *finishedHash) UpdateForHelloRetryRequest() { - data := newByteBuilder() - data.addU8(typeMessageHash) - data.addU24(h.hash.Size()) - data.addBytes(h.Sum()) + data := cryptobyte.NewBuilder(nil) + data.AddUint8(typeMessageHash) + data.AddUint24(uint32(h.hash.Size())) + data.AddBytes(h.Sum()) h.hash = h.suite.hash().New() if h.buffer != nil { h.buffer = []byte{} } - h.Write(data.finish()) + h.Write(data.BytesOrPanic()) } func (h *finishedHash) Write(msg []byte) (n int, err error) { diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index 0110580e44..9c42c16a65 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go @@ -48,6 +48,7 @@ import ( "boringssl.googlesource.com/boringssl/ssl/test/runner/hpke" "boringssl.googlesource.com/boringssl/ssl/test/runner/ssl_transfer" "boringssl.googlesource.com/boringssl/util/testresult" + "golang.org/x/crypto/cryptobyte" ) var ( @@ -860,10 +861,10 @@ func doExchange(test *testCase, config *Config, conn net.Conn, isResume bool, tr if err := os.MkdirAll(dir, 0755); err != nil { return err } - bb := newByteBuilder() - bb.addU24LengthPrefixed().addBytes(encodedInner) - bb.addBytes(outer) - return os.WriteFile(filepath.Join(dir, name), bb.finish(), 0644) + bb := cryptobyte.NewBuilder(nil) + addUint24LengthPrefixedBytes(bb, encodedInner) + bb.AddBytes(outer) + return os.WriteFile(filepath.Join(dir, name), bb.BytesOrPanic(), 0644) } } diff --git a/ssl/test/runner/ticket.go b/ssl/test/runner/ticket.go index 46a6b3579b..f0a8bf18ad 100644 --- a/ssl/test/runner/ticket.go +++ b/ssl/test/runner/ticket.go @@ -13,6 +13,8 @@ import ( "errors" "io" "time" + + "golang.org/x/crypto/cryptobyte" ) // sessionState contains the information that is serialized into a session @@ -35,49 +37,45 @@ type sessionState struct { } func (s *sessionState) marshal() []byte { - msg := newByteBuilder() - msg.addU16(s.vers) - msg.addU16(s.cipherSuite) - secret := msg.addU16LengthPrefixed() - secret.addBytes(s.secret) - handshakeHash := msg.addU16LengthPrefixed() - handshakeHash.addBytes(s.handshakeHash) - msg.addU16(uint16(len(s.certificates))) + msg := cryptobyte.NewBuilder(nil) + msg.AddUint16(s.vers) + msg.AddUint16(s.cipherSuite) + addUint16LengthPrefixedBytes(msg, s.secret) + addUint16LengthPrefixedBytes(msg, s.handshakeHash) + msg.AddUint16(uint16(len(s.certificates))) for _, cert := range s.certificates { - certMsg := msg.addU32LengthPrefixed() - certMsg.addBytes(cert) + addUint24LengthPrefixedBytes(msg, cert) } if s.extendedMasterSecret { - msg.addU8(1) + msg.AddUint8(1) } else { - msg.addU8(0) + msg.AddUint8(0) } if s.vers >= VersionTLS13 { - msg.addU64(uint64(s.ticketCreationTime.UnixNano())) - msg.addU64(uint64(s.ticketExpiration.UnixNano())) - msg.addU32(s.ticketFlags) - msg.addU32(s.ticketAgeAdd) + msg.AddUint64(uint64(s.ticketCreationTime.UnixNano())) + msg.AddUint64(uint64(s.ticketExpiration.UnixNano())) + msg.AddUint32(s.ticketFlags) + msg.AddUint32(s.ticketAgeAdd) } - earlyALPN := msg.addU16LengthPrefixed() - earlyALPN.addBytes(s.earlyALPN) + addUint16LengthPrefixedBytes(msg, s.earlyALPN) if s.hasApplicationSettings { - msg.addU8(1) - msg.addU16LengthPrefixed().addBytes(s.localApplicationSettings) - msg.addU16LengthPrefixed().addBytes(s.peerApplicationSettings) + msg.AddUint8(1) + addUint16LengthPrefixedBytes(msg, s.localApplicationSettings) + addUint16LengthPrefixedBytes(msg, s.peerApplicationSettings) } else { - msg.addU8(0) + msg.AddUint8(0) } - return msg.finish() + return msg.BytesOrPanic() } -func readBool(reader *byteReader, out *bool) bool { +func readBool(reader *cryptobyte.String, out *bool) bool { var value uint8 - if !reader.readU8(&value) { + if !reader.ReadUint8(&value) { return false } if value == 0 { @@ -92,19 +90,19 @@ func readBool(reader *byteReader, out *bool) bool { } func (s *sessionState) unmarshal(data []byte) bool { - reader := byteReader(data) + reader := cryptobyte.String(data) var numCerts uint16 - if !reader.readU16(&s.vers) || - !reader.readU16(&s.cipherSuite) || - !reader.readU16LengthPrefixedBytes(&s.secret) || - !reader.readU16LengthPrefixedBytes(&s.handshakeHash) || - !reader.readU16(&numCerts) { + if !reader.ReadUint16(&s.vers) || + !reader.ReadUint16(&s.cipherSuite) || + !readUint16LengthPrefixedBytes(&reader, &s.secret) || + !readUint16LengthPrefixedBytes(&reader, &s.handshakeHash) || + !reader.ReadUint16(&numCerts) { return false } s.certificates = make([][]byte, int(numCerts)) for i := range s.certificates { - if !reader.readU32LengthPrefixedBytes(&s.certificates[i]) { + if !readUint24LengthPrefixedBytes(&reader, &s.certificates[i]) { return false } } @@ -115,24 +113,24 @@ func (s *sessionState) unmarshal(data []byte) bool { if s.vers >= VersionTLS13 { var ticketCreationTime, ticketExpiration uint64 - if !reader.readU64(&ticketCreationTime) || - !reader.readU64(&ticketExpiration) || - !reader.readU32(&s.ticketFlags) || - !reader.readU32(&s.ticketAgeAdd) { + if !reader.ReadUint64(&ticketCreationTime) || + !reader.ReadUint64(&ticketExpiration) || + !reader.ReadUint32(&s.ticketFlags) || + !reader.ReadUint32(&s.ticketAgeAdd) { return false } s.ticketCreationTime = time.Unix(0, int64(ticketCreationTime)) s.ticketExpiration = time.Unix(0, int64(ticketExpiration)) } - if !reader.readU16LengthPrefixedBytes(&s.earlyALPN) || + if !readUint16LengthPrefixedBytes(&reader, &s.earlyALPN) || !readBool(&reader, &s.hasApplicationSettings) { return false } if s.hasApplicationSettings { - if !reader.readU16LengthPrefixedBytes(&s.localApplicationSettings) || - !reader.readU16LengthPrefixedBytes(&s.peerApplicationSettings) { + if !readUint16LengthPrefixedBytes(&reader, &s.localApplicationSettings) || + !readUint16LengthPrefixedBytes(&reader, &s.peerApplicationSettings) { return false } } diff --git a/util/all_tests.go b/util/all_tests.go index 5136222d64..a931c755d5 100644 --- a/util/all_tests.go +++ b/util/all_tests.go @@ -340,12 +340,7 @@ func (t test) envMsg() string { } func (t test) getGTestShards() ([]test, error) { - if *numWorkers == 1 || len(t.Cmd) != 1 { - return []test{t}, nil - } - - // Only shard the three GTest-based tests. - if t.Cmd[0] != "crypto/crypto_test" && t.Cmd[0] != "ssl/ssl_test" { + if *numWorkers == 1 || !t.Shard { return []test{t}, nil } diff --git a/util/all_tests.json b/util/all_tests.json index 76434d240f..c2baaa4551 100644 --- a/util/all_tests.json +++ b/util/all_tests.json @@ -1,7 +1,8 @@ [ { "cmd": ["crypto/crypto_test"], - "valgrind_supp": ["valgrind_suppressions_crypto_test.supp"] + "valgrind_supp": ["valgrind_suppressions_crypto_test.supp"], + "shard": true }, { "cmd": ["crypto/crypto_test", "--gtest_also_run_disabled_tests", "--gtest_filter=BNTest.DISABLED_WycheproofPrimality"], @@ -16,19 +17,22 @@ "cmd": ["crypto/crypto_test"], "env": ["OPENSSL_armcap=0x0"], "target_arch": "arm", - "skip_valgrind": true + "skip_valgrind": true, + "shard": true }, { "cmd": ["crypto/crypto_test"], "env": ["OPENSSL_armcap=0x1"], "target_arch": "arm", - "skip_valgrind": true + "skip_valgrind": true, + "shard": true }, { "cmd": ["crypto/crypto_test"], "env": ["OPENSSL_armcap=0x3D"], "target_arch": "arm", - "skip_valgrind": true + "skip_valgrind": true, + "shard": true }, { "comment": "Test OPENSSL_ia32cap on crypto_test for x86, as urandom_test is disabled for shared builds on x86", @@ -93,7 +97,8 @@ "skip_valgrind": true }, { - "cmd": ["ssl/ssl_test"] + "cmd": ["ssl/ssl_test"], + "shard": true }, { "cmd": ["ssl/integration_test"] diff --git a/util/testconfig/testconfig.go b/util/testconfig/testconfig.go index e41c7afc65..999c70abba 100644 --- a/util/testconfig/testconfig.go +++ b/util/testconfig/testconfig.go @@ -25,7 +25,8 @@ type Test struct { SkipSDE bool `json:"skip_sde"` SkipValgrind bool `json:"skip_valgrind"` ValgrindSupp []string `json:"valgrind_supp"` - TargetArch string `json:"target_arch"` + TargetArch string `json:"target_arch"` + Shard bool `json:"shard"` } func ParseTestConfig(filename string) ([]Test, error) {