diff --git a/ecdh.go b/ecdh.go index 666e6c7..5d79db6 100644 --- a/ecdh.go +++ b/ecdh.go @@ -20,9 +20,8 @@ func (k *PublicKeyECDH) finalize() { } type PrivateKeyECDH struct { - _pkey C.GO_EVP_PKEY_PTR - curve string - hasPublicKey bool + _pkey C.GO_EVP_PKEY_PTR + curve string } func (k *PrivateKeyECDH) finalize() { @@ -52,20 +51,13 @@ func NewPrivateKeyECDH(curve string, bytes []byte) (*PrivateKeyECDH, error) { if err != nil { return nil, err } - k := &PrivateKeyECDH{pkey, curve, false} + k := &PrivateKeyECDH{pkey, curve} runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize) return k, nil } func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) { defer runtime.KeepAlive(k) - if !k.hasPublicKey { - err := deriveEcdhPublicKey(k._pkey, k.curve) - if err != nil { - return nil, err - } - k.hasPublicKey = true - } var pkey C.GO_EVP_PKEY_PTR defer func() { C.go_openssl_EVP_PKEY_free(pkey) @@ -141,6 +133,7 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P C.go_openssl_EC_KEY_free(key) } }() + group := C.go_openssl_EC_KEY_get0_group(key) if isPrivate { priv := C.go_openssl_BN_bin2bn(base(bytes), C.int(len(bytes)), nil) if priv == nil { @@ -150,8 +143,15 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P if C.go_openssl_EC_KEY_set_private_key(key, priv) != 1 { return nil, newOpenSSLError("EC_KEY_set_private_key") } + pub, err := pointMult(group, priv) + if err != nil { + return nil, err + } + defer C.go_openssl_EC_POINT_free(pub) + if C.go_openssl_EC_KEY_set_public_key(key, pub) != 1 { + return nil, newOpenSSLError("EC_KEY_set_public_key") + } } else { - group := C.go_openssl_EC_KEY_get0_group(key) pub := C.go_openssl_EC_POINT_new(group) if pub == nil { return nil, newOpenSSLError("EC_POINT_new") @@ -196,39 +196,25 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e return newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params) } +func pointMult(group C.GO_EC_GROUP_PTR, priv C.GO_BIGNUM_PTR) (C.GO_EC_POINT_PTR, error) { + // OpenSSL does not expose any method to generate the public + // key from the private key [1], so we have to calculate it here. + // [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206 + pt := C.go_openssl_EC_POINT_new(group) + if pt == nil { + return nil, newOpenSSLError("EC_POINT_new") + } + if C.go_openssl_EC_POINT_mul(group, pt, priv, nil, nil, nil) == 0 { + C.go_openssl_EC_POINT_free(pt) + return nil, newOpenSSLError("EC_POINT_mul") + } + return pt, nil +} + // deriveEcdhPublicKey sets the raw public key of pkey by deriving it from // the raw private key. func deriveEcdhPublicKey(pkey C.GO_EVP_PKEY_PTR, curve string) error { - derive := func(group C.GO_EC_GROUP_PTR, priv C.GO_BIGNUM_PTR) (C.GO_EC_POINT_PTR, error) { - // OpenSSL does not expose any method to generate the public - // key from the private key [1], so we have to calculate it here. - // [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206 - pt := C.go_openssl_EC_POINT_new(group) - if pt == nil { - return nil, newOpenSSLError("EC_POINT_new") - } - if C.go_openssl_EC_POINT_mul(group, pt, priv, nil, nil, nil) == 0 { - C.go_openssl_EC_POINT_free(pt) - return nil, newOpenSSLError("EC_POINT_mul") - } - return pt, nil - } switch vMajor { - case 1: - key := getECKey(pkey) - priv := C.go_openssl_EC_KEY_get0_private_key(key) - if priv == nil { - return newOpenSSLError("EC_KEY_get0_private_key") - } - group := C.go_openssl_EC_KEY_get0_group(key) - pub, err := derive(group, priv) - if err != nil { - return err - } - defer C.go_openssl_EC_POINT_free(pub) - if C.go_openssl_EC_KEY_set_public_key(key, pub) != 1 { - return newOpenSSLError("EC_KEY_set_public_key") - } case 3: var priv C.GO_BIGNUM_PTR if C.go_openssl_EVP_PKEY_get_bn_param(pkey, _OSSL_PKEY_PARAM_PRIV_KEY, &priv) != 1 { @@ -237,7 +223,7 @@ func deriveEcdhPublicKey(pkey C.GO_EVP_PKEY_PTR, curve string) error { defer C.go_openssl_BN_clear_free(priv) nid, _ := curveNID(curve) pubBytes, err := generateAndEncodeEcPublicKey(nid, func(group C.GO_EC_GROUP_PTR) (C.GO_EC_POINT_PTR, error) { - return derive(group, priv) + return pointMult(group, priv) }) if err != nil { return err @@ -313,7 +299,7 @@ func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) { if err := bnToBinPad(priv, bytes); err != nil { return nil, nil, err } - k = &PrivateKeyECDH{pkey, curve, true} + k = &PrivateKeyECDH{pkey, curve} runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize) return k, bytes, nil }