diff --git a/dsa.go b/dsa.go index c56071f5..72cbb7a4 100644 --- a/dsa.go +++ b/dsa.go @@ -280,7 +280,7 @@ func newDSA3(params DSAParameters, x, y BigInt) (C.GO_EVP_PKEY_PTR, error) { return nil, err } defer C.go_openssl_OSSL_PARAM_free(bldparams) - pkey, err := newEvpFromParams(C.GO_EVP_PKEY_DSA, selection, bldparams) + pkey, err := newEvpFromParams(C.GO_EVP_PKEY_DSA, selection, bldparams, false) if err != nil { return nil, err } diff --git a/ecdh.go b/ecdh.go index 1cf02318..495529c7 100644 --- a/ecdh.go +++ b/ecdh.go @@ -165,7 +165,12 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P } } if C.go_openssl_EC_KEY_check_key(key) != 1 { - return nil, newOpenSSLError("EC_KEY_check_key") + // Match upstream error message. + if isPrivate { + return nil, errors.New("crypto/ecdh: invalid private key") + } else { + return nil, errors.New("crypto/ecdh: invalid public key") + } } return newEVPPKEY(key) } @@ -205,20 +210,7 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e return nil, err } defer C.go_openssl_OSSL_PARAM_free(params) - pkey, err := newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params) - if err != nil { - return nil, err - } - if isPrivate { - if C.go_openssl_EVP_PKEY_private_check(pkey) != 1 { - return nil, errors.New("crypto/ecdh: invalid private key") - } - } else { - if C.go_openssl_EVP_PKEY_public_check_quick(pkey) != 1 { - return nil, errors.New("crypto/ecdh: invalid public key") - } - } - return pkey, nil + return newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params, true) } func pointMult(group C.GO_EC_GROUP_PTR, priv C.GO_BIGNUM_PTR) (C.GO_EC_POINT_PTR, error) { diff --git a/ecdsa.go b/ecdsa.go index f85782a6..16b360a1 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -207,5 +207,5 @@ func newECDSAKey3(nid C.int, bx, by, bd C.GO_BIGNUM_PTR) (C.GO_EVP_PKEY_PTR, err return nil, err } defer C.go_openssl_OSSL_PARAM_free(params) - return newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params) + return newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params, false) } diff --git a/evp.go b/evp.go index 91296a93..90900d2d 100644 --- a/evp.go +++ b/evp.go @@ -502,7 +502,7 @@ func getECKey(pkey C.GO_EVP_PKEY_PTR) (key C.GO_EC_KEY_PTR) { return key } -func newEvpFromParams(id C.int, selection C.int, params C.GO_OSSL_PARAM_PTR) (C.GO_EVP_PKEY_PTR, error) { +func newEvpFromParams(id C.int, selection C.int, params C.GO_OSSL_PARAM_PTR, validate bool) (C.GO_EVP_PKEY_PTR, error) { ctx := C.go_openssl_EVP_PKEY_CTX_new_id(id, nil) if ctx == nil { return nil, newOpenSSLError("EVP_PKEY_CTX_new_id") @@ -515,5 +515,18 @@ func newEvpFromParams(id C.int, selection C.int, params C.GO_OSSL_PARAM_PTR) (C. if C.go_openssl_EVP_PKEY_fromdata(ctx, &pkey, selection, params) != 1 { return nil, newOpenSSLError("EVP_PKEY_fromdata") } + if validate { + if selection == C.GO_EVP_PKEY_KEYPAIR { // Private key + if C.go_openssl_EVP_PKEY_private_check(ctx) != 1 { + // Match upstream error message. + return nil, errors.New("crypto/ecdh: invalid private key") + } + } else { // Public key + if C.go_openssl_EVP_PKEY_public_check_quick(ctx) != 1 { + // Match upstream error message. + return nil, errors.New("crypto/ecdh: invalid public key") + } + } + } return pkey, nil } diff --git a/rsa.go b/rsa.go index cd5b3b8e..2694e2a2 100644 --- a/rsa.go +++ b/rsa.go @@ -404,5 +404,5 @@ func newRSAKey3(isPriv bool, n, e, d, p, q, dp, dq, qinv BigInt) (C.GO_EVP_PKEY_ if isPriv { selection = C.GO_EVP_PKEY_KEYPAIR } - return newEvpFromParams(C.GO_EVP_PKEY_RSA, C.int(selection), params) + return newEvpFromParams(C.GO_EVP_PKEY_RSA, C.int(selection), params, false) } diff --git a/shims.h b/shims.h index e4381754..d16759d6 100644 --- a/shims.h +++ b/shims.h @@ -310,8 +310,8 @@ DEFINEFUNC(int, EVP_PKEY_sign, (GO_EVP_PKEY_CTX_PTR arg0, unsigned char *arg1, s DEFINEFUNC(int, EVP_PKEY_derive_init, (GO_EVP_PKEY_CTX_PTR ctx), (ctx)) \ DEFINEFUNC(int, EVP_PKEY_derive_set_peer, (GO_EVP_PKEY_CTX_PTR ctx, GO_EVP_PKEY_PTR peer), (ctx, peer)) \ DEFINEFUNC(int, EVP_PKEY_derive, (GO_EVP_PKEY_CTX_PTR ctx, unsigned char *key, size_t *keylen), (ctx, key, keylen)) \ -DEFINEFUNC(int, EVP_PKEY_public_check_quick, (GO_EVP_PKEY_CTX_PTR ctx), (ctx)) \ -DEFINEFUNC(int, EVP_PKEY_private_check, (GO_EVP_PKEY_CTX_PTR ctx), (ctx)) \ +DEFINEFUNC_3_0(int, EVP_PKEY_public_check_quick, (GO_EVP_PKEY_CTX_PTR ctx), (ctx)) \ +DEFINEFUNC_3_0(int, EVP_PKEY_private_check, (GO_EVP_PKEY_CTX_PTR ctx), (ctx)) \ DEFINEFUNC_LEGACY_1_0(void*, EVP_PKEY_get0, (GO_EVP_PKEY_PTR pkey), (pkey)) \ DEFINEFUNC_LEGACY_1_1(GO_EC_KEY_PTR, EVP_PKEY_get0_EC_KEY, (GO_EVP_PKEY_PTR pkey), (pkey)) \ DEFINEFUNC_LEGACY_1_1(GO_DSA_PTR, EVP_PKEY_get0_DSA, (GO_EVP_PKEY_PTR pkey), (pkey)) \