Skip to content

Commit

Permalink
fix TestECDHVectors for OpenSSL 1
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Nov 25, 2024
1 parent 97ac15f commit 888eba4
Showing 1 changed file with 29 additions and 43 deletions.
72 changes: 29 additions & 43 deletions ecdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ func (k *PublicKeyECDH) finalize() {
}

type PrivateKeyECDH struct {
_pkey C.GO_EVP_PKEY_PTR
curve string
hasPublicKey bool
_pkey C.GO_EVP_PKEY_PTR
curve string
}

func (k *PrivateKeyECDH) finalize() {
Expand Down Expand Up @@ -52,20 +51,13 @@ func NewPrivateKeyECDH(curve string, bytes []byte) (*PrivateKeyECDH, error) {
if err != nil {
return nil, err
}
k := &PrivateKeyECDH{pkey, curve, false}
k := &PrivateKeyECDH{pkey, curve}
runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
return k, nil
}

func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) {
defer runtime.KeepAlive(k)
if !k.hasPublicKey {
err := deriveEcdhPublicKey(k._pkey, k.curve)
if err != nil {
return nil, err
}
k.hasPublicKey = true
}
var pkey C.GO_EVP_PKEY_PTR
defer func() {
C.go_openssl_EVP_PKEY_free(pkey)
Expand Down Expand Up @@ -141,6 +133,7 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
C.go_openssl_EC_KEY_free(key)
}
}()
group := C.go_openssl_EC_KEY_get0_group(key)
if isPrivate {
priv := C.go_openssl_BN_bin2bn(base(bytes), C.int(len(bytes)), nil)
if priv == nil {
Expand All @@ -150,8 +143,15 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
if C.go_openssl_EC_KEY_set_private_key(key, priv) != 1 {
return nil, newOpenSSLError("EC_KEY_set_private_key")
}
pub, err := pointMult(group, priv)
if err != nil {
return nil, err
}
defer C.go_openssl_EC_POINT_free(pub)
if C.go_openssl_EC_KEY_set_public_key(key, pub) != 1 {
return nil, newOpenSSLError("EC_KEY_set_public_key")
}
} else {
group := C.go_openssl_EC_KEY_get0_group(key)
pub := C.go_openssl_EC_POINT_new(group)
if pub == nil {
return nil, newOpenSSLError("EC_POINT_new")
Expand Down Expand Up @@ -196,39 +196,25 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e
return newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params)
}

func pointMult(group C.GO_EC_GROUP_PTR, priv C.GO_BIGNUM_PTR) (C.GO_EC_POINT_PTR, error) {
// OpenSSL does not expose any method to generate the public
// key from the private key [1], so we have to calculate it here.
// [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206
pt := C.go_openssl_EC_POINT_new(group)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
if C.go_openssl_EC_POINT_mul(group, pt, priv, nil, nil, nil) == 0 {
C.go_openssl_EC_POINT_free(pt)
return nil, newOpenSSLError("EC_POINT_mul")
}
return pt, nil
}

// deriveEcdhPublicKey sets the raw public key of pkey by deriving it from
// the raw private key.
func deriveEcdhPublicKey(pkey C.GO_EVP_PKEY_PTR, curve string) error {
derive := func(group C.GO_EC_GROUP_PTR, priv C.GO_BIGNUM_PTR) (C.GO_EC_POINT_PTR, error) {
// OpenSSL does not expose any method to generate the public
// key from the private key [1], so we have to calculate it here.
// [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206
pt := C.go_openssl_EC_POINT_new(group)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
if C.go_openssl_EC_POINT_mul(group, pt, priv, nil, nil, nil) == 0 {
C.go_openssl_EC_POINT_free(pt)
return nil, newOpenSSLError("EC_POINT_mul")
}
return pt, nil
}
switch vMajor {
case 1:
key := getECKey(pkey)
priv := C.go_openssl_EC_KEY_get0_private_key(key)
if priv == nil {
return newOpenSSLError("EC_KEY_get0_private_key")
}
group := C.go_openssl_EC_KEY_get0_group(key)
pub, err := derive(group, priv)
if err != nil {
return err
}
defer C.go_openssl_EC_POINT_free(pub)
if C.go_openssl_EC_KEY_set_public_key(key, pub) != 1 {
return newOpenSSLError("EC_KEY_set_public_key")
}
case 3:
var priv C.GO_BIGNUM_PTR
if C.go_openssl_EVP_PKEY_get_bn_param(pkey, _OSSL_PKEY_PARAM_PRIV_KEY, &priv) != 1 {
Expand All @@ -237,7 +223,7 @@ func deriveEcdhPublicKey(pkey C.GO_EVP_PKEY_PTR, curve string) error {
defer C.go_openssl_BN_clear_free(priv)
nid, _ := curveNID(curve)
pubBytes, err := generateAndEncodeEcPublicKey(nid, func(group C.GO_EC_GROUP_PTR) (C.GO_EC_POINT_PTR, error) {
return derive(group, priv)
return pointMult(group, priv)
})
if err != nil {
return err
Expand Down Expand Up @@ -313,7 +299,7 @@ func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
if err := bnToBinPad(priv, bytes); err != nil {
return nil, nil, err
}
k = &PrivateKeyECDH{pkey, curve, true}
k = &PrivateKeyECDH{pkey, curve}
runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
return k, bytes, nil
}

0 comments on commit 888eba4

Please sign in to comment.