diff --git a/ec.go b/ec.go index 03c51e5..734c14b 100644 --- a/ec.go +++ b/ec.go @@ -5,18 +5,35 @@ package openssl // #include "goopenssl.h" import "C" -func curveNID(curve string) (C.int, error) { +func curveNID(curve string) C.int { switch curve { case "P-224": - return C.GO_NID_secp224r1, nil + return C.GO_NID_secp224r1 case "P-256": - return C.GO_NID_X9_62_prime256v1, nil + return C.GO_NID_X9_62_prime256v1 case "P-384": - return C.GO_NID_secp384r1, nil + return C.GO_NID_secp384r1 case "P-521": - return C.GO_NID_secp521r1, nil + return C.GO_NID_secp521r1 + default: + panic("openssl: unknown curve " + curve) + } +} + +// curveSize returns the size of the curve in bytes. +func curveSize(curve string) int { + switch curve { + case "P-224": + return 224 / 8 + case "P-256": + return 256 / 8 + case "P-384": + return 384 / 8 + case "P-521": + return (521 + 7) / 8 + default: + panic("openssl: unknown curve " + curve) } - return 0, errUnknownCurve } // encodeEcPoint encodes pt. diff --git a/ecdh.go b/ecdh.go index 5b14674..ad392dc 100644 --- a/ecdh.go +++ b/ecdh.go @@ -7,6 +7,7 @@ import "C" import ( "errors" "runtime" + "slices" "unsafe" ) @@ -20,9 +21,8 @@ func (k *PublicKeyECDH) finalize() { } type PrivateKeyECDH struct { - _pkey C.GO_EVP_PKEY_PTR - curve string - hasPublicKey bool + _pkey C.GO_EVP_PKEY_PTR + curve string } func (k *PrivateKeyECDH) finalize() { @@ -30,14 +30,14 @@ func (k *PrivateKeyECDH) finalize() { } func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) { - if len(bytes) < 1 { - return nil, errors.New("NewPublicKeyECDH: missing key") + if len(bytes) != 1+2*curveSize(curve) { + return nil, errors.New("NewPublicKeyECDH: wrong key length") } pkey, err := newECDHPkey(curve, bytes, false) if err != nil { return nil, err } - k := &PublicKeyECDH{pkey, append([]byte(nil), bytes...)} + k := &PublicKeyECDH{pkey, slices.Clone(bytes)} runtime.SetFinalizer(k, (*PublicKeyECDH).finalize) return k, nil } @@ -45,24 +45,20 @@ func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) { func (k *PublicKeyECDH) Bytes() []byte { return k.bytes } func NewPrivateKeyECDH(curve string, bytes []byte) (*PrivateKeyECDH, error) { + if len(bytes) != curveSize(curve) { + return nil, errors.New("NewPrivateKeyECDH: wrong key length") + } pkey, err := newECDHPkey(curve, bytes, true) if err != nil { return nil, err } - k := &PrivateKeyECDH{pkey, curve, false} + k := &PrivateKeyECDH{pkey, curve} runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize) return k, nil } func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) { defer runtime.KeepAlive(k) - if !k.hasPublicKey { - err := deriveEcdhPublicKey(k._pkey, k.curve) - if err != nil { - return nil, err - } - k.hasPublicKey = true - } var pkey C.GO_EVP_PKEY_PTR defer func() { C.go_openssl_EVP_PKEY_free(pkey) @@ -112,10 +108,7 @@ func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) { } func newECDHPkey(curve string, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, error) { - nid, err := curveNID(curve) - if err != nil { - return nil, err - } + nid := curveNID(curve) switch vMajor { case 1: return newECDHPkey1(nid, bytes, isPrivate) @@ -138,6 +131,7 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P C.go_openssl_EC_KEY_free(key) } }() + group := C.go_openssl_EC_KEY_get0_group(key) if isPrivate { priv := C.go_openssl_BN_bin2bn(base(bytes), C.int(len(bytes)), nil) if priv == nil { @@ -147,8 +141,15 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P if C.go_openssl_EC_KEY_set_private_key(key, priv) != 1 { return nil, newOpenSSLError("EC_KEY_set_private_key") } + pub, err := pointMult(group, priv) + if err != nil { + return nil, err + } + defer C.go_openssl_EC_POINT_free(pub) + if C.go_openssl_EC_KEY_set_public_key(key, pub) != 1 { + return nil, newOpenSSLError("EC_KEY_set_public_key") + } } else { - group := C.go_openssl_EC_KEY_get0_group(key) pub := C.go_openssl_EC_POINT_new(group) if pub == nil { return nil, newOpenSSLError("EC_POINT_new") @@ -161,6 +162,14 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P return nil, newOpenSSLError("EC_KEY_set_public_key") } } + if C.go_openssl_EC_KEY_check_key(key) != 1 { + // Match upstream error message. + if isPrivate { + return nil, errors.New("crypto/ecdh: invalid private key") + } else { + return nil, errors.New("crypto/ecdh: invalid public key") + } + } return newEVPPKEY(key) } @@ -175,7 +184,19 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e bld.addUTF8String(_OSSL_PKEY_PARAM_GROUP_NAME, C.go_openssl_OBJ_nid2sn(nid), 0) var selection C.int if isPrivate { - bld.addBin(_OSSL_PKEY_PARAM_PRIV_KEY, bytes, true) + priv := C.go_openssl_BN_bin2bn(base(bytes), C.int(len(bytes)), nil) + if priv == nil { + return nil, newOpenSSLError("BN_bin2bn") + } + defer C.go_openssl_BN_clear_free(priv) + pubBytes, err := generateAndEncodeEcPublicKey(nid, func(group C.GO_EC_GROUP_PTR) (C.GO_EC_POINT_PTR, error) { + return pointMult(group, priv) + }) + if err != nil { + return nil, err + } + bld.addOctetString(_OSSL_PKEY_PARAM_PUB_KEY, pubBytes) + bld.addBN(_OSSL_PKEY_PARAM_PRIV_KEY, priv) selection = C.GO_EVP_PKEY_KEYPAIR } else { bld.addOctetString(_OSSL_PKEY_PARAM_PUB_KEY, bytes) @@ -187,62 +208,31 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e return nil, err } defer C.go_openssl_OSSL_PARAM_free(params) - return newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params) + pkey, err := newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params) + if err != nil { + return nil, err + } + + if err := checkPkey(pkey, isPrivate); err != nil { + C.go_openssl_EVP_PKEY_free(pkey) + return nil, errors.New("crypto/ecdh: " + err.Error()) + } + return pkey, nil } -// deriveEcdhPublicKey sets the raw public key of pkey by deriving it from -// the raw private key. -func deriveEcdhPublicKey(pkey C.GO_EVP_PKEY_PTR, curve string) error { - derive := func(group C.GO_EC_GROUP_PTR, priv C.GO_BIGNUM_PTR) (C.GO_EC_POINT_PTR, error) { - // OpenSSL does not expose any method to generate the public - // key from the private key [1], so we have to calculate it here. - // [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206 - pt := C.go_openssl_EC_POINT_new(group) - if pt == nil { - return nil, newOpenSSLError("EC_POINT_new") - } - if C.go_openssl_EC_POINT_mul(group, pt, priv, nil, nil, nil) == 0 { - C.go_openssl_EC_POINT_free(pt) - return nil, newOpenSSLError("EC_POINT_mul") - } - return pt, nil +func pointMult(group C.GO_EC_GROUP_PTR, priv C.GO_BIGNUM_PTR) (C.GO_EC_POINT_PTR, error) { + // OpenSSL does not expose any method to generate the public + // key from the private key [1], so we have to calculate it here. + // [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206 + pt := C.go_openssl_EC_POINT_new(group) + if pt == nil { + return nil, newOpenSSLError("EC_POINT_new") } - switch vMajor { - case 1: - key := getECKey(pkey) - priv := C.go_openssl_EC_KEY_get0_private_key(key) - if priv == nil { - return newOpenSSLError("EC_KEY_get0_private_key") - } - group := C.go_openssl_EC_KEY_get0_group(key) - pub, err := derive(group, priv) - if err != nil { - return err - } - defer C.go_openssl_EC_POINT_free(pub) - if C.go_openssl_EC_KEY_set_public_key(key, pub) != 1 { - return newOpenSSLError("EC_KEY_set_public_key") - } - case 3: - var priv C.GO_BIGNUM_PTR - if C.go_openssl_EVP_PKEY_get_bn_param(pkey, _OSSL_PKEY_PARAM_PRIV_KEY, &priv) != 1 { - return newOpenSSLError("EVP_PKEY_get_bn_param") - } - defer C.go_openssl_BN_clear_free(priv) - nid, _ := curveNID(curve) - pubBytes, err := generateAndEncodeEcPublicKey(nid, func(group C.GO_EC_GROUP_PTR) (C.GO_EC_POINT_PTR, error) { - return derive(group, priv) - }) - if err != nil { - return err - } - if C.go_openssl_EVP_PKEY_set1_encoded_public_key(pkey, base(pubBytes), C.size_t(len(pubBytes))) != 1 { - return newOpenSSLError("EVP_PKEY_set1_encoded_public_key") - } - default: - panic(errUnsupportedVersion()) + if C.go_openssl_EC_POINT_mul(group, pt, priv, nil, nil, nil) == 0 { + C.go_openssl_EC_POINT_free(pt) + return nil, newOpenSSLError("EC_POINT_mul") } - return nil + return pt, nil } func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) { @@ -307,7 +297,7 @@ func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) { if err := bnToBinPad(priv, bytes); err != nil { return nil, nil, err } - k = &PrivateKeyECDH{pkey, curve, true} + k = &PrivateKeyECDH{pkey, curve} runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize) return k, bytes, nil } diff --git a/ecdh_test.go b/ecdh_test.go index c83a347..da1e37f 100644 --- a/ecdh_test.go +++ b/ecdh_test.go @@ -3,6 +3,7 @@ package openssl_test import ( "bytes" "encoding/hex" + "strings" "testing" "github.com/golang-fips/openssl/v2" @@ -171,3 +172,124 @@ func BenchmarkECDH(b *testing.B) { } } } + +var invalidECDHPrivateKeys = map[string][]string{ + "P-256": { + // Bad lengths. + "", + "01", + "01010101010101010101010101010101010101010101010101010101010101", + "000101010101010101010101010101010101010101010101010101010101010101", + strings.Repeat("01", 200), + // Zero. + "0000000000000000000000000000000000000000000000000000000000000000", + // Order of the curve and above. + "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551", + "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632552", + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + }, + "P-384": { + // Bad lengths. + "", + "01", + "0101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101", + "00010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101", + strings.Repeat("01", 200), + // Zero. + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + // Order of the curve and above. + "ffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52973", + "ffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52974", + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + }, + "P-521": { + // Bad lengths. + "", + "01", + "0101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101", + "00010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101", + strings.Repeat("01", 200), + // Zero. + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + // Order of the curve and above. + "01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409", + "01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e9138640a", + "11fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409", + "03fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff4a30d0f077e5f2cd6ff980291ee134ba0776b937113388f5d76df6e3d2270c812", + }, +} + +var invalidECDHPublicKeys = map[string][]string{ + "P-256": { + // Bad lengths. + "", + "04", + strings.Repeat("04", 200), + // Infinity. + "00", + // Compressed encodings. + "036b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", + "02e2534a3532d08fbba02dde659ee62bd0031fe2db785596ef509302446b030852", + // Points not on the curve. + "046b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c2964fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f6", + "0400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + }, + "P-384": { + // Bad lengths. + "", + "04", + strings.Repeat("04", 200), + // Infinity. + "00", + // Compressed encodings. + "03aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab7", + "0208d999057ba3d2d969260045c55b97f089025959a6f434d651d207d19fb96e9e4fe0e86ebe0e64f85b96a9c75295df61", + // Points not on the curve. + "04aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab73617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e60", + "04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + }, + "P-521": { + // Bad lengths. + "", + "04", + strings.Repeat("04", 200), + // Infinity. + "00", + // Compressed encodings. + "030035b5df64ae2ac204c354b483487c9070cdc61c891c5ff39afc06c5d55541d3ceac8659e24afe3d0750e8b88e9f078af066a1d5025b08e5a5e2fbc87412871902f3", + "0200c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66", + // Points not on the curve. + "0400c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66011839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088be94769fd16651", + "04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + }, +} + +func TestECDHNewPrivateKeyECDH_Invalid(t *testing.T) { + for _, curve := range []string{"P-256", "P-384", "P-521"} { + t.Run(curve, func(t *testing.T) { + for _, input := range invalidECDHPrivateKeys[curve] { + k, err := openssl.NewPrivateKeyECDH(curve, hexDecode(t, input)) + if err == nil { + t.Errorf("unexpectedly accepted %q", input) + } else if k != nil { + t.Error("PrivateKey was not nil on error") + } + } + }) + } +} + +func TestECDHNewPublicKeyECDH_Invalid(t *testing.T) { + for _, curve := range []string{"P-256", "P-384", "P-521"} { + t.Run(curve, func(t *testing.T) { + for _, input := range invalidECDHPublicKeys[curve] { + k, err := openssl.NewPublicKeyECDH(curve, hexDecode(t, input)) + if err == nil { + t.Errorf("unexpectedly accepted %q", input) + } else if k != nil { + t.Error("PublicKey was not nil on error") + } + } + }) + } +} diff --git a/ecdsa.go b/ecdsa.go index f85782a..bc5f111 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -122,10 +122,7 @@ func HashVerifyECDSA(pub *PublicKeyECDSA, h crypto.Hash, msg, sig []byte) bool { } func newECDSAKey(curve string, x, y, d BigInt) (C.GO_EVP_PKEY_PTR, error) { - nid, err := curveNID(curve) - if err != nil { - return nil, err - } + nid := curveNID(curve) var bx, by, bd C.GO_BIGNUM_PTR defer func() { C.go_openssl_BN_free(bx) diff --git a/evp.go b/evp.go index 91296a9..85c84b8 100644 --- a/evp.go +++ b/evp.go @@ -175,10 +175,7 @@ func generateEVPPKey(id C.int, bits int, curve string) (C.GO_EVP_PKEY_PTR, error } } if curve != "" { - nid, err := curveNID(curve) - if err != nil { - return nil, err - } + nid := curveNID(curve) if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, id, -1, C.GO_EVP_PKEY_CTRL_EC_PARAMGEN_CURVE_NID, nid, nil) != 1 { return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed") } @@ -513,7 +510,34 @@ func newEvpFromParams(id C.int, selection C.int, params C.GO_OSSL_PARAM_PTR) (C. } var pkey C.GO_EVP_PKEY_PTR if C.go_openssl_EVP_PKEY_fromdata(ctx, &pkey, selection, params) != 1 { + if vMajor == 3 && vMinor <= 2 { + // OpenSSL 3.0.1 and 3.0.2 have a bug where EVP_PKEY_fromdata + // does not free the internally allocated EVP_PKEY on error. + // See https://github.com/openssl/openssl/issues/17407. + C.go_openssl_EVP_PKEY_free(pkey) + } return nil, newOpenSSLError("EVP_PKEY_fromdata") } return pkey, nil } + +func checkPkey(pkey C.GO_EVP_PKEY_PTR, isPrivate bool) error { + ctx := C.go_openssl_EVP_PKEY_CTX_new(pkey, nil) + if ctx == nil { + return newOpenSSLError("EVP_PKEY_CTX_new") + } + defer C.go_openssl_EVP_PKEY_CTX_free(ctx) + if isPrivate { + if C.go_openssl_EVP_PKEY_private_check(ctx) != 1 { + // Match upstream error message. + return errors.New("invalid private key") + } + } else { + // Upstream Go does a partial check here, so do we. + if C.go_openssl_EVP_PKEY_public_check_quick(ctx) != 1 { + // Match upstream error message. + return errors.New("invalid public key") + } + } + return nil +} diff --git a/shims.h b/shims.h index 156d8e8..d16759d 100644 --- a/shims.h +++ b/shims.h @@ -310,6 +310,8 @@ DEFINEFUNC(int, EVP_PKEY_sign, (GO_EVP_PKEY_CTX_PTR arg0, unsigned char *arg1, s DEFINEFUNC(int, EVP_PKEY_derive_init, (GO_EVP_PKEY_CTX_PTR ctx), (ctx)) \ DEFINEFUNC(int, EVP_PKEY_derive_set_peer, (GO_EVP_PKEY_CTX_PTR ctx, GO_EVP_PKEY_PTR peer), (ctx, peer)) \ DEFINEFUNC(int, EVP_PKEY_derive, (GO_EVP_PKEY_CTX_PTR ctx, unsigned char *key, size_t *keylen), (ctx, key, keylen)) \ +DEFINEFUNC_3_0(int, EVP_PKEY_public_check_quick, (GO_EVP_PKEY_CTX_PTR ctx), (ctx)) \ +DEFINEFUNC_3_0(int, EVP_PKEY_private_check, (GO_EVP_PKEY_CTX_PTR ctx), (ctx)) \ DEFINEFUNC_LEGACY_1_0(void*, EVP_PKEY_get0, (GO_EVP_PKEY_PTR pkey), (pkey)) \ DEFINEFUNC_LEGACY_1_1(GO_EC_KEY_PTR, EVP_PKEY_get0_EC_KEY, (GO_EVP_PKEY_PTR pkey), (pkey)) \ DEFINEFUNC_LEGACY_1_1(GO_DSA_PTR, EVP_PKEY_get0_DSA, (GO_EVP_PKEY_PTR pkey), (pkey)) \ @@ -345,6 +347,7 @@ DEFINEFUNC_LEGACY_1(const GO_BIGNUM_PTR, EC_KEY_get0_private_key, (const GO_EC_K DEFINEFUNC_LEGACY_1(const GO_EC_POINT_PTR, EC_KEY_get0_public_key, (const GO_EC_KEY_PTR arg0), (arg0)) \ DEFINEFUNC_LEGACY_1(GO_EC_KEY_PTR, EC_KEY_new_by_curve_name, (int arg0), (arg0)) \ DEFINEFUNC_LEGACY_1(int, EC_KEY_set_private_key, (GO_EC_KEY_PTR arg0, const GO_BIGNUM_PTR arg1), (arg0, arg1)) \ +DEFINEFUNC_LEGACY_1(int, EC_KEY_check_key, (const GO_EC_KEY_PTR key), (key)) \ DEFINEFUNC(GO_EC_POINT_PTR, EC_POINT_new, (const GO_EC_GROUP_PTR arg0), (arg0)) \ DEFINEFUNC(void, EC_POINT_free, (GO_EC_POINT_PTR arg0), (arg0)) \ DEFINEFUNC(int, EC_POINT_mul, (const GO_EC_GROUP_PTR group, GO_EC_POINT_PTR r, const GO_BIGNUM_PTR n, const GO_EC_POINT_PTR q, const GO_BIGNUM_PTR m, GO_BN_CTX_PTR ctx), (group, r, n, q, m, ctx)) \