Skip to content

Commit

Permalink
Validate ECDH keys (#226)
Browse files Browse the repository at this point in the history
* validate ECDH keys

* remove unused functions

* fix TestECDHVectors for OpenSSL 1

* fix newECDHPkey3

* add ECDH key validations for OpenSSL 3

* move key validation into newEvpFromParams

* fix openssl 3

* fix memory leak

* fix memory leak

* fix memory leak

* remove unnecessary test check

* Update evp.go

Co-authored-by: Davis Goodin <dagood@users.noreply.github.com>

---------

Co-authored-by: Davis Goodin <dagood@users.noreply.github.com>
  • Loading branch information
qmuntal and dagood authored Dec 10, 2024
1 parent d5122d3 commit bdc3592
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 87 deletions.
29 changes: 23 additions & 6 deletions ec.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,35 @@ package openssl
// #include "goopenssl.h"
import "C"

func curveNID(curve string) (C.int, error) {
func curveNID(curve string) C.int {
switch curve {
case "P-224":
return C.GO_NID_secp224r1, nil
return C.GO_NID_secp224r1
case "P-256":
return C.GO_NID_X9_62_prime256v1, nil
return C.GO_NID_X9_62_prime256v1
case "P-384":
return C.GO_NID_secp384r1, nil
return C.GO_NID_secp384r1
case "P-521":
return C.GO_NID_secp521r1, nil
return C.GO_NID_secp521r1
default:
panic("openssl: unknown curve " + curve)
}
}

// curveSize returns the size of the curve in bytes.
func curveSize(curve string) int {
switch curve {
case "P-224":
return 224 / 8
case "P-256":
return 256 / 8
case "P-384":
return 384 / 8
case "P-521":
return (521 + 7) / 8
default:
panic("openssl: unknown curve " + curve)
}
return 0, errUnknownCurve
}

// encodeEcPoint encodes pt.
Expand Down
136 changes: 63 additions & 73 deletions ecdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import "C"
import (
"errors"
"runtime"
"slices"
"unsafe"
)

Expand All @@ -20,49 +21,44 @@ func (k *PublicKeyECDH) finalize() {
}

type PrivateKeyECDH struct {
_pkey C.GO_EVP_PKEY_PTR
curve string
hasPublicKey bool
_pkey C.GO_EVP_PKEY_PTR
curve string
}

func (k *PrivateKeyECDH) finalize() {
C.go_openssl_EVP_PKEY_free(k._pkey)
}

func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) {
if len(bytes) < 1 {
return nil, errors.New("NewPublicKeyECDH: missing key")
if len(bytes) != 1+2*curveSize(curve) {
return nil, errors.New("NewPublicKeyECDH: wrong key length")
}
pkey, err := newECDHPkey(curve, bytes, false)
if err != nil {
return nil, err
}
k := &PublicKeyECDH{pkey, append([]byte(nil), bytes...)}
k := &PublicKeyECDH{pkey, slices.Clone(bytes)}
runtime.SetFinalizer(k, (*PublicKeyECDH).finalize)
return k, nil
}

func (k *PublicKeyECDH) Bytes() []byte { return k.bytes }

func NewPrivateKeyECDH(curve string, bytes []byte) (*PrivateKeyECDH, error) {
if len(bytes) != curveSize(curve) {
return nil, errors.New("NewPrivateKeyECDH: wrong key length")
}
pkey, err := newECDHPkey(curve, bytes, true)
if err != nil {
return nil, err
}
k := &PrivateKeyECDH{pkey, curve, false}
k := &PrivateKeyECDH{pkey, curve}
runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
return k, nil
}

func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) {
defer runtime.KeepAlive(k)
if !k.hasPublicKey {
err := deriveEcdhPublicKey(k._pkey, k.curve)
if err != nil {
return nil, err
}
k.hasPublicKey = true
}
var pkey C.GO_EVP_PKEY_PTR
defer func() {
C.go_openssl_EVP_PKEY_free(pkey)
Expand Down Expand Up @@ -112,10 +108,7 @@ func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) {
}

func newECDHPkey(curve string, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, error) {
nid, err := curveNID(curve)
if err != nil {
return nil, err
}
nid := curveNID(curve)
switch vMajor {
case 1:
return newECDHPkey1(nid, bytes, isPrivate)
Expand All @@ -138,6 +131,7 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
C.go_openssl_EC_KEY_free(key)
}
}()
group := C.go_openssl_EC_KEY_get0_group(key)
if isPrivate {
priv := C.go_openssl_BN_bin2bn(base(bytes), C.int(len(bytes)), nil)
if priv == nil {
Expand All @@ -147,8 +141,15 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
if C.go_openssl_EC_KEY_set_private_key(key, priv) != 1 {
return nil, newOpenSSLError("EC_KEY_set_private_key")
}
pub, err := pointMult(group, priv)
if err != nil {
return nil, err
}
defer C.go_openssl_EC_POINT_free(pub)
if C.go_openssl_EC_KEY_set_public_key(key, pub) != 1 {
return nil, newOpenSSLError("EC_KEY_set_public_key")
}
} else {
group := C.go_openssl_EC_KEY_get0_group(key)
pub := C.go_openssl_EC_POINT_new(group)
if pub == nil {
return nil, newOpenSSLError("EC_POINT_new")
Expand All @@ -161,6 +162,14 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
return nil, newOpenSSLError("EC_KEY_set_public_key")
}
}
if C.go_openssl_EC_KEY_check_key(key) != 1 {
// Match upstream error message.
if isPrivate {
return nil, errors.New("crypto/ecdh: invalid private key")
} else {
return nil, errors.New("crypto/ecdh: invalid public key")
}
}
return newEVPPKEY(key)
}

Expand All @@ -175,7 +184,19 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e
bld.addUTF8String(_OSSL_PKEY_PARAM_GROUP_NAME, C.go_openssl_OBJ_nid2sn(nid), 0)
var selection C.int
if isPrivate {
bld.addBin(_OSSL_PKEY_PARAM_PRIV_KEY, bytes, true)
priv := C.go_openssl_BN_bin2bn(base(bytes), C.int(len(bytes)), nil)
if priv == nil {
return nil, newOpenSSLError("BN_bin2bn")
}
defer C.go_openssl_BN_clear_free(priv)
pubBytes, err := generateAndEncodeEcPublicKey(nid, func(group C.GO_EC_GROUP_PTR) (C.GO_EC_POINT_PTR, error) {
return pointMult(group, priv)
})
if err != nil {
return nil, err
}
bld.addOctetString(_OSSL_PKEY_PARAM_PUB_KEY, pubBytes)
bld.addBN(_OSSL_PKEY_PARAM_PRIV_KEY, priv)
selection = C.GO_EVP_PKEY_KEYPAIR
} else {
bld.addOctetString(_OSSL_PKEY_PARAM_PUB_KEY, bytes)
Expand All @@ -187,62 +208,31 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e
return nil, err
}
defer C.go_openssl_OSSL_PARAM_free(params)
return newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params)
pkey, err := newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params)
if err != nil {
return nil, err
}

if err := checkPkey(pkey, isPrivate); err != nil {
C.go_openssl_EVP_PKEY_free(pkey)
return nil, errors.New("crypto/ecdh: " + err.Error())
}
return pkey, nil
}

// deriveEcdhPublicKey sets the raw public key of pkey by deriving it from
// the raw private key.
func deriveEcdhPublicKey(pkey C.GO_EVP_PKEY_PTR, curve string) error {
derive := func(group C.GO_EC_GROUP_PTR, priv C.GO_BIGNUM_PTR) (C.GO_EC_POINT_PTR, error) {
// OpenSSL does not expose any method to generate the public
// key from the private key [1], so we have to calculate it here.
// [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206
pt := C.go_openssl_EC_POINT_new(group)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
if C.go_openssl_EC_POINT_mul(group, pt, priv, nil, nil, nil) == 0 {
C.go_openssl_EC_POINT_free(pt)
return nil, newOpenSSLError("EC_POINT_mul")
}
return pt, nil
func pointMult(group C.GO_EC_GROUP_PTR, priv C.GO_BIGNUM_PTR) (C.GO_EC_POINT_PTR, error) {
// OpenSSL does not expose any method to generate the public
// key from the private key [1], so we have to calculate it here.
// [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206
pt := C.go_openssl_EC_POINT_new(group)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
switch vMajor {
case 1:
key := getECKey(pkey)
priv := C.go_openssl_EC_KEY_get0_private_key(key)
if priv == nil {
return newOpenSSLError("EC_KEY_get0_private_key")
}
group := C.go_openssl_EC_KEY_get0_group(key)
pub, err := derive(group, priv)
if err != nil {
return err
}
defer C.go_openssl_EC_POINT_free(pub)
if C.go_openssl_EC_KEY_set_public_key(key, pub) != 1 {
return newOpenSSLError("EC_KEY_set_public_key")
}
case 3:
var priv C.GO_BIGNUM_PTR
if C.go_openssl_EVP_PKEY_get_bn_param(pkey, _OSSL_PKEY_PARAM_PRIV_KEY, &priv) != 1 {
return newOpenSSLError("EVP_PKEY_get_bn_param")
}
defer C.go_openssl_BN_clear_free(priv)
nid, _ := curveNID(curve)
pubBytes, err := generateAndEncodeEcPublicKey(nid, func(group C.GO_EC_GROUP_PTR) (C.GO_EC_POINT_PTR, error) {
return derive(group, priv)
})
if err != nil {
return err
}
if C.go_openssl_EVP_PKEY_set1_encoded_public_key(pkey, base(pubBytes), C.size_t(len(pubBytes))) != 1 {
return newOpenSSLError("EVP_PKEY_set1_encoded_public_key")
}
default:
panic(errUnsupportedVersion())
if C.go_openssl_EC_POINT_mul(group, pt, priv, nil, nil, nil) == 0 {
C.go_openssl_EC_POINT_free(pt)
return nil, newOpenSSLError("EC_POINT_mul")
}
return nil
return pt, nil
}

func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) {
Expand Down Expand Up @@ -307,7 +297,7 @@ func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
if err := bnToBinPad(priv, bytes); err != nil {
return nil, nil, err
}
k = &PrivateKeyECDH{pkey, curve, true}
k = &PrivateKeyECDH{pkey, curve}
runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
return k, bytes, nil
}
Loading

0 comments on commit bdc3592

Please sign in to comment.