Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate ECDH keys #226

Merged
merged 12 commits into from
Dec 10, 2024
29 changes: 23 additions & 6 deletions ec.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,35 @@ package openssl
// #include "goopenssl.h"
import "C"

func curveNID(curve string) (C.int, error) {
func curveNID(curve string) C.int {
switch curve {
case "P-224":
return C.GO_NID_secp224r1, nil
return C.GO_NID_secp224r1
case "P-256":
return C.GO_NID_X9_62_prime256v1, nil
return C.GO_NID_X9_62_prime256v1
case "P-384":
return C.GO_NID_secp384r1, nil
return C.GO_NID_secp384r1
case "P-521":
return C.GO_NID_secp521r1, nil
return C.GO_NID_secp521r1
default:
panic("openssl: unknown curve " + curve)
}
}

// curveSize returns the size of the curve in bytes.
func curveSize(curve string) int {
switch curve {
case "P-224":
return 224 / 8
case "P-256":
return 256 / 8
case "P-384":
return 384 / 8
case "P-521":
return (521 + 7) / 8
default:
panic("openssl: unknown curve " + curve)
}
return 0, errUnknownCurve
}

// encodeEcPoint encodes pt.
Expand Down
136 changes: 63 additions & 73 deletions ecdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import "C"
import (
"errors"
"runtime"
"slices"
"unsafe"
)

Expand All @@ -20,49 +21,44 @@ func (k *PublicKeyECDH) finalize() {
}

type PrivateKeyECDH struct {
_pkey C.GO_EVP_PKEY_PTR
curve string
hasPublicKey bool
_pkey C.GO_EVP_PKEY_PTR
curve string
}

func (k *PrivateKeyECDH) finalize() {
C.go_openssl_EVP_PKEY_free(k._pkey)
}

func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) {
if len(bytes) < 1 {
return nil, errors.New("NewPublicKeyECDH: missing key")
if len(bytes) != 1+2*curveSize(curve) {
return nil, errors.New("NewPublicKeyECDH: wrong key length")
}
pkey, err := newECDHPkey(curve, bytes, false)
if err != nil {
return nil, err
}
k := &PublicKeyECDH{pkey, append([]byte(nil), bytes...)}
k := &PublicKeyECDH{pkey, slices.Clone(bytes)}
qmuntal marked this conversation as resolved.
Show resolved Hide resolved
runtime.SetFinalizer(k, (*PublicKeyECDH).finalize)
return k, nil
}

func (k *PublicKeyECDH) Bytes() []byte { return k.bytes }

func NewPrivateKeyECDH(curve string, bytes []byte) (*PrivateKeyECDH, error) {
if len(bytes) != curveSize(curve) {
return nil, errors.New("NewPrivateKeyECDH: wrong key length")
}
pkey, err := newECDHPkey(curve, bytes, true)
if err != nil {
return nil, err
}
k := &PrivateKeyECDH{pkey, curve, false}
k := &PrivateKeyECDH{pkey, curve}
runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
return k, nil
}

func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) {
defer runtime.KeepAlive(k)
if !k.hasPublicKey {
err := deriveEcdhPublicKey(k._pkey, k.curve)
if err != nil {
return nil, err
}
k.hasPublicKey = true
}
var pkey C.GO_EVP_PKEY_PTR
defer func() {
C.go_openssl_EVP_PKEY_free(pkey)
Expand Down Expand Up @@ -112,10 +108,7 @@ func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) {
}

func newECDHPkey(curve string, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, error) {
nid, err := curveNID(curve)
if err != nil {
return nil, err
}
nid := curveNID(curve)
switch vMajor {
case 1:
return newECDHPkey1(nid, bytes, isPrivate)
Expand All @@ -138,6 +131,7 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
C.go_openssl_EC_KEY_free(key)
}
}()
group := C.go_openssl_EC_KEY_get0_group(key)
if isPrivate {
priv := C.go_openssl_BN_bin2bn(base(bytes), C.int(len(bytes)), nil)
if priv == nil {
Expand All @@ -147,8 +141,15 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
if C.go_openssl_EC_KEY_set_private_key(key, priv) != 1 {
return nil, newOpenSSLError("EC_KEY_set_private_key")
}
pub, err := pointMult(group, priv)
if err != nil {
return nil, err
}
defer C.go_openssl_EC_POINT_free(pub)
if C.go_openssl_EC_KEY_set_public_key(key, pub) != 1 {
return nil, newOpenSSLError("EC_KEY_set_public_key")
}
} else {
group := C.go_openssl_EC_KEY_get0_group(key)
pub := C.go_openssl_EC_POINT_new(group)
if pub == nil {
return nil, newOpenSSLError("EC_POINT_new")
Expand All @@ -161,6 +162,14 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
return nil, newOpenSSLError("EC_KEY_set_public_key")
}
}
if C.go_openssl_EC_KEY_check_key(key) != 1 {
// Match upstream error message.
if isPrivate {
return nil, errors.New("crypto/ecdh: invalid private key")
} else {
return nil, errors.New("crypto/ecdh: invalid public key")
}
}
return newEVPPKEY(key)
}

Expand All @@ -175,7 +184,19 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e
bld.addUTF8String(_OSSL_PKEY_PARAM_GROUP_NAME, C.go_openssl_OBJ_nid2sn(nid), 0)
var selection C.int
if isPrivate {
bld.addBin(_OSSL_PKEY_PARAM_PRIV_KEY, bytes, true)
priv := C.go_openssl_BN_bin2bn(base(bytes), C.int(len(bytes)), nil)
if priv == nil {
return nil, newOpenSSLError("BN_bin2bn")
}
defer C.go_openssl_BN_clear_free(priv)
pubBytes, err := generateAndEncodeEcPublicKey(nid, func(group C.GO_EC_GROUP_PTR) (C.GO_EC_POINT_PTR, error) {
return pointMult(group, priv)
})
if err != nil {
return nil, err
}
bld.addOctetString(_OSSL_PKEY_PARAM_PUB_KEY, pubBytes)
bld.addBN(_OSSL_PKEY_PARAM_PRIV_KEY, priv)
selection = C.GO_EVP_PKEY_KEYPAIR
} else {
bld.addOctetString(_OSSL_PKEY_PARAM_PUB_KEY, bytes)
Expand All @@ -187,62 +208,31 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e
return nil, err
}
defer C.go_openssl_OSSL_PARAM_free(params)
return newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params)
pkey, err := newEvpFromParams(C.GO_EVP_PKEY_EC, selection, params)
if err != nil {
return nil, err
}

if err := checkPkey(pkey, isPrivate); err != nil {
C.go_openssl_EVP_PKEY_free(pkey)
return nil, errors.New("crypto/ecdh: " + err.Error())
}
return pkey, nil
}

// deriveEcdhPublicKey sets the raw public key of pkey by deriving it from
// the raw private key.
func deriveEcdhPublicKey(pkey C.GO_EVP_PKEY_PTR, curve string) error {
derive := func(group C.GO_EC_GROUP_PTR, priv C.GO_BIGNUM_PTR) (C.GO_EC_POINT_PTR, error) {
// OpenSSL does not expose any method to generate the public
// key from the private key [1], so we have to calculate it here.
// [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206
pt := C.go_openssl_EC_POINT_new(group)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
if C.go_openssl_EC_POINT_mul(group, pt, priv, nil, nil, nil) == 0 {
C.go_openssl_EC_POINT_free(pt)
return nil, newOpenSSLError("EC_POINT_mul")
}
return pt, nil
func pointMult(group C.GO_EC_GROUP_PTR, priv C.GO_BIGNUM_PTR) (C.GO_EC_POINT_PTR, error) {
// OpenSSL does not expose any method to generate the public
// key from the private key [1], so we have to calculate it here.
// [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206
pt := C.go_openssl_EC_POINT_new(group)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
switch vMajor {
case 1:
key := getECKey(pkey)
priv := C.go_openssl_EC_KEY_get0_private_key(key)
if priv == nil {
return newOpenSSLError("EC_KEY_get0_private_key")
}
group := C.go_openssl_EC_KEY_get0_group(key)
pub, err := derive(group, priv)
if err != nil {
return err
}
defer C.go_openssl_EC_POINT_free(pub)
if C.go_openssl_EC_KEY_set_public_key(key, pub) != 1 {
return newOpenSSLError("EC_KEY_set_public_key")
}
case 3:
var priv C.GO_BIGNUM_PTR
if C.go_openssl_EVP_PKEY_get_bn_param(pkey, _OSSL_PKEY_PARAM_PRIV_KEY, &priv) != 1 {
return newOpenSSLError("EVP_PKEY_get_bn_param")
}
defer C.go_openssl_BN_clear_free(priv)
nid, _ := curveNID(curve)
pubBytes, err := generateAndEncodeEcPublicKey(nid, func(group C.GO_EC_GROUP_PTR) (C.GO_EC_POINT_PTR, error) {
return derive(group, priv)
})
if err != nil {
return err
}
if C.go_openssl_EVP_PKEY_set1_encoded_public_key(pkey, base(pubBytes), C.size_t(len(pubBytes))) != 1 {
return newOpenSSLError("EVP_PKEY_set1_encoded_public_key")
}
default:
panic(errUnsupportedVersion())
if C.go_openssl_EC_POINT_mul(group, pt, priv, nil, nil, nil) == 0 {
C.go_openssl_EC_POINT_free(pt)
return nil, newOpenSSLError("EC_POINT_mul")
}
return nil
return pt, nil
}

func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) {
Expand Down Expand Up @@ -307,7 +297,7 @@ func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
if err := bnToBinPad(priv, bytes); err != nil {
return nil, nil, err
}
k = &PrivateKeyECDH{pkey, curve, true}
k = &PrivateKeyECDH{pkey, curve}
runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
return k, bytes, nil
}
Loading
Loading