diff --git a/evp.go b/evp.go index ef68bbfb..f6a2d3ed 100644 --- a/evp.go +++ b/evp.go @@ -89,78 +89,78 @@ func hashFuncToMD(fn func() hash.Hash) (C.GO_EVP_MD_PTR, error) { return md, nil } -// cryptoHashToMD converts a crypto.Hash to a GO_EVP_MD_PTR. -func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) { +// cryptoHashToMD converts a crypto.Hash to a EVP_MD. +func cryptoHashToMD(ch crypto.Hash) C.GO_EVP_MD_PTR { if v, ok := cacheMD.Load(ch); ok { return v.(C.GO_EVP_MD_PTR) } - defer func() { - if md != nil { - switch vMajor { - case 1: - // On OpenSSL 1 EVP_MD objects can be not-nil even - // when they are not supported. We need to pass the md - // to a EVP_MD_CTX to really know if they can be used. - ctx := C.go_openssl_EVP_MD_CTX_new() - if C.go_openssl_EVP_DigestInit_ex(ctx, md, nil) != 1 { - md = nil - } - C.go_openssl_EVP_MD_CTX_free(ctx) - case 3: - // On OpenSSL 3, directly operating on a EVP_MD object - // not created by EVP_MD_fetch has negative performance - // implications, as digest operations will have - // to fetch it on every call. Better to just fetch it once here. - md = C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(md), nil) - default: - panic(errUnsupportedVersion()) - } - } - cacheMD.Store(ch, md) - }() - // SupportsHash returns false for MD5SHA1 because we don't - // provide a hash.Hash implementation for it. Yet, it can - // still be used when signing/verifying with an RSA key. - if ch == crypto.MD5SHA1 { - if vMajor == 1 && vMinor == 0 { - return C.go_openssl_EVP_md5_sha1_backport() - } else { - return C.go_openssl_EVP_md5_sha1() - } - } + var md C.GO_EVP_MD_PTR switch ch { + case crypto.RIPEMD160: + md = C.go_openssl_EVP_ripemd160() case crypto.MD4: - return C.go_openssl_EVP_md4() + md = C.go_openssl_EVP_md4() case crypto.MD5: - return C.go_openssl_EVP_md5() + md = C.go_openssl_EVP_md5() + case crypto.MD5SHA1: + if vMajor == 1 && vMinor == 0 { + md = C.go_openssl_EVP_md5_sha1_backport() + } else { + md = C.go_openssl_EVP_md5_sha1() + } case crypto.SHA1: - return C.go_openssl_EVP_sha1() + md = C.go_openssl_EVP_sha1() case crypto.SHA224: - return C.go_openssl_EVP_sha224() + md = C.go_openssl_EVP_sha224() case crypto.SHA256: - return C.go_openssl_EVP_sha256() + md = C.go_openssl_EVP_sha256() case crypto.SHA384: - return C.go_openssl_EVP_sha384() + md = C.go_openssl_EVP_sha384() case crypto.SHA512: - return C.go_openssl_EVP_sha512() + md = C.go_openssl_EVP_sha512() + case crypto.SHA512_224: + if versionAtOrAbove(1, 1, 1) { + md = C.go_openssl_EVP_sha512_224() + } + case crypto.SHA512_256: + if versionAtOrAbove(1, 1, 1) { + md = C.go_openssl_EVP_sha512_256() + } case crypto.SHA3_224: if versionAtOrAbove(1, 1, 1) { - return C.go_openssl_EVP_sha3_224() + md = C.go_openssl_EVP_sha3_224() } case crypto.SHA3_256: if versionAtOrAbove(1, 1, 1) { - return C.go_openssl_EVP_sha3_256() + md = C.go_openssl_EVP_sha3_256() } case crypto.SHA3_384: if versionAtOrAbove(1, 1, 1) { - return C.go_openssl_EVP_sha3_384() + md = C.go_openssl_EVP_sha3_384() } case crypto.SHA3_512: if versionAtOrAbove(1, 1, 1) { - return C.go_openssl_EVP_sha3_512() + md = C.go_openssl_EVP_sha3_512() } } - return nil + if md == nil { + cacheMD.Store(ch, nil) + return nil + } + if vMajor == 3 { + // On OpenSSL 3, directly operating on a EVP_MD object + // not created by EVP_MD_fetch has negative performance + // implications, as digest operations will have + // to fetch it on every call. Better to just fetch it once here. + md1 := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(md), nil) + // Don't overwrite md in case it can't be fetched, as the md may still be used + // outside of EVP_MD_CTX, for example to sign and verify RSA signatures. + if md1 != nil { + md = md1 + } + } + cacheMD.Store(ch, md) + return md } // generateEVPPKey generates a new EVP_PKEY with the given id and properties. diff --git a/hash.go b/hash.go index 6fd3a518..379dd1b8 100644 --- a/hash.go +++ b/hash.go @@ -78,9 +78,29 @@ func SHA512(p []byte) (sum [64]byte) { return } -// SupportsHash returns true if a hash.Hash implementation is supported for h. +// cacheHashSupported is a cache of crypto.Hash support. +var cacheHashSupported sync.Map + +// SupportsHash reports whether the current OpenSSL version supports the given hash. func SupportsHash(h crypto.Hash) bool { - return cryptoHashToMD(h) != nil + if v, ok := cacheHashSupported.Load(h); ok { + return v.(bool) + } + md := cryptoHashToMD(h) + if md == nil { + cacheHashSupported.Store(h, false) + return false + } + // EVP_MD objects can be non-nil even when they can't be used + // in a EVP_MD_CTX, e.g. MD5 in FIPS mode. We need to prove + // if they can be used by passing them to a EVP_MD_CTX. + var supported bool + if ctx := C.go_openssl_EVP_MD_CTX_new(); ctx != nil { + supported = C.go_openssl_EVP_DigestInit_ex(ctx, md, nil) == 1 + C.go_openssl_EVP_MD_CTX_free(ctx) + } + cacheHashSupported.Store(h, supported) + return supported } func SHA3_224(p []byte) (sum [28]byte) { diff --git a/rsa_test.go b/rsa_test.go index 5b92025e..4b383384 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -7,6 +7,7 @@ import ( "fmt" "math/big" "strconv" + "strings" "testing" "github.com/golang-fips/openssl/v2" @@ -193,7 +194,61 @@ func TestRSAEncryptDecryptOAEP_WrongLabel(t *testing.T) { } } +// These are all the hashes supported by Go's crypto/rsa package +// as of Go 1.24. +var stdHashes = [...]crypto.Hash{ + crypto.MD5SHA1, + crypto.MD5, + crypto.SHA1, + crypto.SHA224, + crypto.SHA256, + crypto.SHA512, + crypto.SHA512_224, + crypto.SHA512_256, + crypto.SHA3_224, + crypto.SHA3_256, + crypto.SHA3_512, + crypto.RIPEMD160, +} + func TestRSASignVerifyPKCS1v15(t *testing.T) { + priv, pub := newRSAKey(t, 2048) + for _, hash := range append([]crypto.Hash{0}, stdHashes[:]...) { + var name string + if hash == 0 { + name = "unhashed" + } else { + name = hash.String() + } + t.Run(name, func(t *testing.T) { + if hash != 0 && !openssl.SupportsHash(hash) { + t.Skip("skipping test because hash is not supported") + } + // Construct a fake hashed data. + size := 1 + if hash != 0 { + size = hash.Size() + } + hashed := make([]byte, size) + hashed[0] = 0x30 + signed, err := openssl.SignRSAPKCS1v15(priv, hash, hashed) + if err != nil { + if strings.Contains(err.Error(), "invalid digest") || strings.Contains(err.Error(), "digest not allowed") { + // Can happen if the hash is supported by EVP_MD_CTX but not by EVP_PKEY_CTX. + // There is nothing we can do about it. + t.Skip("skipping test because hash is not supported") + } + t.Fatal(err) + } + err = openssl.VerifyRSAPKCS1v15(pub, hash, hashed, signed) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func TestRSAHashSignVerifyPKCS1v15(t *testing.T) { sha256 := openssl.NewSHA256() priv, pub := newRSAKey(t, 2048) msg := []byte("hi!") @@ -220,23 +275,6 @@ func TestRSASignVerifyPKCS1v15(t *testing.T) { } } -func TestRSASignVerifyPKCS1v15_Unhashed(t *testing.T) { - if openssl.SymCryptProviderAvailable() { - t.Skip("SymCrypt provider does not support unhashed PKCS1v15") - } - - msg := []byte("hi!") - priv, pub := newRSAKey(t, 2048) - signed, err := openssl.SignRSAPKCS1v15(priv, 0, msg) - if err != nil { - t.Fatal(err) - } - err = openssl.VerifyRSAPKCS1v15(pub, 0, msg, signed) - if err != nil { - t.Fatal(err) - } -} - func TestRSASignVerifyPKCS1v15_Invalid(t *testing.T) { sha256 := openssl.NewSHA256() msg := []byte("hi!") @@ -254,6 +292,37 @@ func TestRSASignVerifyPKCS1v15_Invalid(t *testing.T) { } func TestRSASignVerifyRSAPSS(t *testing.T) { + priv, pub := newRSAKey(t, 2048) + for _, hash := range stdHashes { + t.Run(hash.String(), func(t *testing.T) { + if !openssl.SupportsHash(hash) { + t.Skip("skipping test because hash is not supported") + } + // Construct a fake hashed data. + size := 1 + if hash != 0 { + size = hash.Size() + } + hashed := make([]byte, size) + hashed[0] = 0x30 + signed, err := openssl.SignRSAPSS(priv, hash, hashed, rsa.PSSSaltLengthEqualsHash) + if err != nil { + if strings.Contains(err.Error(), "invalid digest") || strings.Contains(err.Error(), "digest not allowed") { + // Can happen if the hash is supported by EVP_MD_CTX but not by EVP_PKEY_CTX. + // There is nothing we can do about it. + t.Skip("skipping test because hash is not supported") + } + t.Fatal(err) + } + err = openssl.VerifyRSAPSS(pub, hash, hashed, signed, rsa.PSSSaltLengthEqualsHash) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func TestRSASignVerifyRSAPSS_SaltLength(t *testing.T) { // Test cases taken from // https://github.com/golang/go/blob/54182ff54a687272dd7632c3a963e036ce03cb7c/src/crypto/rsa/pss_test.go#L200. const keyBits = 2048 diff --git a/shims.h b/shims.h index c8f599f7..cfb65ff5 100644 --- a/shims.h +++ b/shims.h @@ -237,6 +237,7 @@ DEFINEFUNC_LEGACY_1_0(int, SHA1_Init, (GO_SHA_CTX_PTR c), (c)) \ DEFINEFUNC_LEGACY_1_0(int, SHA1_Update, (GO_SHA_CTX_PTR c, const void *data, size_t len), (c, data, len)) \ DEFINEFUNC_LEGACY_1_0(int, SHA1_Final, (unsigned char *md, GO_SHA_CTX_PTR c), (md, c)) \ DEFINEFUNC_1_1(const GO_EVP_MD_PTR, EVP_md5_sha1, (void), ()) \ +DEFINEFUNC(const GO_EVP_MD_PTR, EVP_ripemd160, (void), ()) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_md4, (void), ()) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_md5, (void), ()) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_sha1, (void), ()) \ @@ -244,6 +245,8 @@ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_sha224, (void), ()) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_sha256, (void), ()) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_sha384, (void), ()) \ DEFINEFUNC(const GO_EVP_MD_PTR, EVP_sha512, (void), ()) \ +DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha512_224, (void), ()) \ +DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha512_256, (void), ()) \ DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_224, (void), ()) \ DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_256, (void), ()) \ DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_384, (void), ()) \