Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify hash implementation #237

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 72 additions & 62 deletions evp.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,8 @@ func hashFuncHash(fn func() hash.Hash) (h hash.Hash, err error) {

// hashToMD converts a hash.Hash implementation from this package to a GO_EVP_MD_PTR.
func hashToMD(h hash.Hash) C.GO_EVP_MD_PTR {
var ch crypto.Hash
switch h.(type) {
case *sha1Hash, *sha1Marshal:
ch = crypto.SHA1
case *sha224Hash, *sha224Marshal:
ch = crypto.SHA224
case *sha256Hash, *sha256Marshal:
ch = crypto.SHA256
case *sha384Hash, *sha384Marshal:
ch = crypto.SHA384
case *sha512Hash, *sha512Marshal:
ch = crypto.SHA512
case *sha3_224Hash:
ch = crypto.SHA3_224
case *sha3_256Hash:
ch = crypto.SHA3_256
case *sha3_384Hash:
ch = crypto.SHA3_384
case *sha3_512Hash:
ch = crypto.SHA3_512
}
if ch != 0 {
return cryptoHashToMD(ch)
if h, ok := h.(*evpHash); ok {
return h.alg.md
}
return nil
}
Expand All @@ -89,78 +68,109 @@ func hashFuncToMD(fn func() hash.Hash) (C.GO_EVP_MD_PTR, error) {
return md, nil
}

// cryptoHashToMD converts a crypto.Hash to a EVP_MD.
func cryptoHashToMD(ch crypto.Hash) C.GO_EVP_MD_PTR {
type hashAlgorithm struct {
md C.GO_EVP_MD_PTR
ch crypto.Hash
size int
blockSize int
marshallable bool
magic string
marshalledSize int
}

// loadHash converts a crypto.Hash to a EVP_MD.
func loadHash(ch crypto.Hash) *hashAlgorithm {
if v, ok := cacheMD.Load(ch); ok {
return v.(C.GO_EVP_MD_PTR)
return v.(*hashAlgorithm)
}
var md C.GO_EVP_MD_PTR

var hash hashAlgorithm
switch ch {
case crypto.RIPEMD160:
md = C.go_openssl_EVP_ripemd160()
hash.md = C.go_openssl_EVP_ripemd160()
case crypto.MD4:
md = C.go_openssl_EVP_md4()
hash.md = C.go_openssl_EVP_md4()
case crypto.MD5:
md = C.go_openssl_EVP_md5()
hash.md = C.go_openssl_EVP_md5()
hash.magic = md5Magic
hash.marshalledSize = md5MarshaledSize
case crypto.MD5SHA1:
if vMajor == 1 && vMinor == 0 {
md = C.go_openssl_EVP_md5_sha1_backport()
hash.md = C.go_openssl_EVP_md5_sha1_backport()
} else {
md = C.go_openssl_EVP_md5_sha1()
hash.md = C.go_openssl_EVP_md5_sha1()
}
case crypto.SHA1:
md = C.go_openssl_EVP_sha1()
hash.md = C.go_openssl_EVP_sha1()
hash.magic = sha1Magic
hash.marshalledSize = sha1MarshaledSize
case crypto.SHA224:
md = C.go_openssl_EVP_sha224()
hash.md = C.go_openssl_EVP_sha224()
hash.magic = magic224
hash.marshalledSize = marshaledSize256
case crypto.SHA256:
md = C.go_openssl_EVP_sha256()
hash.md = C.go_openssl_EVP_sha256()
hash.magic = magic256
hash.marshalledSize = marshaledSize256
case crypto.SHA384:
md = C.go_openssl_EVP_sha384()
hash.md = C.go_openssl_EVP_sha384()
hash.magic = magic384
hash.marshalledSize = marshaledSize512
case crypto.SHA512:
md = C.go_openssl_EVP_sha512()
hash.md = C.go_openssl_EVP_sha512()
hash.magic = magic512
hash.marshalledSize = marshaledSize512
case crypto.SHA512_224:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha512_224()
hash.md = C.go_openssl_EVP_sha512_224()
hash.magic = magic512_224
hash.marshalledSize = marshaledSize512
}
case crypto.SHA512_256:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha512_256()
hash.md = C.go_openssl_EVP_sha512_256()
hash.magic = magic512_256
hash.marshalledSize = marshaledSize512
}
case crypto.SHA3_224:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha3_224()
hash.md = C.go_openssl_EVP_sha3_224()
}
case crypto.SHA3_256:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha3_256()
hash.md = C.go_openssl_EVP_sha3_256()
}
case crypto.SHA3_384:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha3_384()
hash.md = C.go_openssl_EVP_sha3_384()
}
case crypto.SHA3_512:
if versionAtOrAbove(1, 1, 1) {
md = C.go_openssl_EVP_sha3_512()
hash.md = C.go_openssl_EVP_sha3_512()
}
}
if md == nil {
cacheMD.Store(ch, nil)
if hash.md == nil {
cacheMD.Store(ch, (*hashAlgorithm)(nil))
return nil
}
hash.ch = ch
hash.size = int(C.go_openssl_EVP_MD_get_size(hash.md))
hash.blockSize = int(C.go_openssl_EVP_MD_get_block_size(hash.md))
if vMajor == 3 {
// On OpenSSL 3, directly operating on a EVP_MD object
// not created by EVP_MD_fetch has negative performance
// implications, as digest operations will have
// to fetch it on every call. Better to just fetch it once here.
md1 := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(md), nil)
md := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(hash.md), nil)
// Don't overwrite md in case it can't be fetched, as the md may still be used
// outside of EVP_MD_CTX, for example to sign and verify RSA signatures.
if md1 != nil {
md = md1
if md != nil {
hash.md = md
}
}
cacheMD.Store(ch, md)
return md
hash.marshallable = hash.magic != "" && isHashMarshallable(hash.md)
cacheMD.Store(ch, &hash)
return &hash
}

// generateEVPPKey generates a new EVP_PKEY with the given id and properties.
Expand Down Expand Up @@ -302,11 +312,11 @@ func setupEVP(withKey withKeyFunc, padding C.int,
}
}
case C.GO_RSA_PKCS1_PSS_PADDING:
md := cryptoHashToMD(ch)
if md == nil {
alg := loadHash(ch)
if alg == nil {
return nil, errors.New("crypto/rsa: unsupported hash function")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, C.GO_EVP_PKEY_RSA, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(md)) != 1 {
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, C.GO_EVP_PKEY_RSA, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(alg.md)) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed")
}
// setPadding must happen after setting EVP_PKEY_CTRL_MD.
Expand All @@ -322,11 +332,11 @@ func setupEVP(withKey withKeyFunc, padding C.int,
case C.GO_RSA_PKCS1_PADDING:
if ch != 0 {
// We support unhashed messages.
md := cryptoHashToMD(ch)
if md == nil {
alg := loadHash(ch)
if alg == nil {
return nil, errors.New("crypto/rsa: unsupported hash function")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(md)) != 1 {
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(alg.md)) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed")
}
if err := setPadding(); err != nil {
Expand Down Expand Up @@ -441,8 +451,8 @@ func evpVerify(withKey withKeyFunc, padding C.int, saltLen C.int, h crypto.Hash,
}

func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error) {
md := cryptoHashToMD(h)
if md == nil {
alg := loadHash(h)
if alg == nil {
return nil, errors.New("unsupported hash function: " + strconv.Itoa(int(h)))
}
var out []byte
Expand All @@ -453,7 +463,7 @@ func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error)
}
defer C.go_openssl_EVP_MD_CTX_free(ctx)
if withKey(func(key C.GO_EVP_PKEY_PTR) C.int {
return C.go_openssl_EVP_DigestSignInit(ctx, nil, md, nil, key)
return C.go_openssl_EVP_DigestSignInit(ctx, nil, alg.md, nil, key)
}) != 1 {
return nil, newOpenSSLError("EVP_DigestSignInit failed")
}
Expand All @@ -473,8 +483,8 @@ func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error)
}

func evpHashVerify(withKey withKeyFunc, h crypto.Hash, msg, sig []byte) error {
md := cryptoHashToMD(h)
if md == nil {
alg := loadHash(h)
if alg == nil {
return errors.New("unsupported hash function: " + strconv.Itoa(int(h)))
}
ctx := C.go_openssl_EVP_MD_CTX_new()
Expand All @@ -483,7 +493,7 @@ func evpHashVerify(withKey withKeyFunc, h crypto.Hash, msg, sig []byte) error {
}
defer C.go_openssl_EVP_MD_CTX_free(ctx)
if withKey(func(key C.GO_EVP_PKEY_PTR) C.int {
return C.go_openssl_EVP_DigestVerifyInit(ctx, nil, md, nil, key)
return C.go_openssl_EVP_DigestVerifyInit(ctx, nil, alg.md, nil, key)
}) != 1 {
return newOpenSSLError("EVP_DigestVerifyInit failed")
}
Expand Down
Loading
Loading