diff --git a/export_test.go b/export_test.go index 9ed2fa05..d3d10ec7 100644 --- a/export_test.go +++ b/export_test.go @@ -1,3 +1,5 @@ package openssl var ErrOpen = errOpen + +var TestNotMarshalable = &testNotMarshalable diff --git a/hash.go b/hash.go index 646b4ce2..e0ea56d8 100644 --- a/hash.go +++ b/hash.go @@ -10,6 +10,7 @@ import ( "hash" "runtime" "strconv" + "sync" "unsafe" ) @@ -110,6 +111,37 @@ func SHA3_512(p []byte) (sum [64]byte) { return } +var isMarshallableCache sync.Map + +// isHashMarshallable returns true if the memory layout of cb +// is known by this library and can therefore be marshalled. +func isHashMarshallable(ch crypto.Hash) bool { + if vMajor == 1 { + return true + } + if v, ok := isMarshallableCache.Load(ch); ok { + return v.(bool) + } + md := cryptoHashToMD(ch) + if md == nil { + return false + } + prov := C.go_openssl_EVP_MD_get0_provider(md) + if prov == nil { + return false + } + cname := C.go_openssl_OSSL_PROVIDER_get0_name(prov) + if cname == nil { + return false + } + name := C.GoString(cname) + // We only know the memory layout of the built-in providers. + // See evpHash.hashState for more details. + marshallable := name == "default" || name == "fips" + isMarshallableCache.Store(ch, marshallable) + return marshallable +} + // evpHash implements generic hash methods. type evpHash struct { ctx C.GO_EVP_MD_CTX_PTR @@ -119,6 +151,8 @@ type evpHash struct { ctx2 C.GO_EVP_MD_CTX_PTR size int blockSize int + + marshallable bool } func newEvpHash(ch crypto.Hash, size, blockSize int) *evpHash { @@ -137,6 +171,8 @@ func newEvpHash(ch crypto.Hash, size, blockSize int) *evpHash { ctx2: ctx2, size: size, blockSize: blockSize, + + marshallable: isHashMarshallable(ch), } runtime.SetFinalizer(h, (*evpHash).finalize) return h @@ -195,11 +231,16 @@ func (h *evpHash) sum(out []byte) { runtime.KeepAlive(h) } +var testNotMarshalable bool // Used in tests. + // hashState returns a pointer to the internal hash structure. // // The EVP_MD_CTX memory layout has changed in OpenSSL 3 // and the property holding the internal structure is no longer md_data but algctx. func (h *evpHash) hashState() unsafe.Pointer { + if !h.marshallable || testNotMarshalable { + return nil + } switch vMajor { case 1: // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12. diff --git a/hash_test.go b/hash_test.go index 7244038a..dcf01c01 100644 --- a/hash_test.go +++ b/hash_test.go @@ -39,6 +39,29 @@ func cryptoToHash(h crypto.Hash) func() hash.Hash { return nil } +func TestHashNotMarshalable(t *testing.T) { + h := openssl.NewSHA256() + state, err := h.(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + // In the go1.23 support we only test using the built-in providers, + // which are all marshalable, so this should never happen. + t.Fatal(err) + } + *openssl.TestNotMarshalable = true + defer func() { + *openssl.TestNotMarshalable = false + }() + + _, err = h.(encoding.BinaryMarshaler).MarshalBinary() + if err == nil { + t.Error("expected error") + } + err = h.(encoding.BinaryUnmarshaler).UnmarshalBinary(state) + if err == nil { + t.Error("expected error") + } +} + func TestHash(t *testing.T) { msg := []byte("testing") var tests = []struct { diff --git a/shims.h b/shims.h index 99656f0c..6ea5bc6d 100644 --- a/shims.h +++ b/shims.h @@ -190,9 +190,11 @@ DEFINEFUNC_3_0(int, EVP_default_properties_is_fips_enabled, (GO_OSSL_LIB_CTX_PTR DEFINEFUNC_3_0(int, EVP_default_properties_enable_fips, (GO_OSSL_LIB_CTX_PTR libctx, int enable), (libctx, enable)) \ DEFINEFUNC_3_0(int, OSSL_PROVIDER_available, (GO_OSSL_LIB_CTX_PTR libctx, const char *name), (libctx, name)) \ DEFINEFUNC_3_0(GO_OSSL_PROVIDER_PTR, OSSL_PROVIDER_load, (GO_OSSL_LIB_CTX_PTR libctx, const char *name), (libctx, name)) \ +DEFINEFUNC_3_0(const char *, OSSL_PROVIDER_get0_name, (const GO_OSSL_PROVIDER_PTR prov), (prov)) \ DEFINEFUNC_3_0(GO_EVP_MD_PTR, EVP_MD_fetch, (GO_OSSL_LIB_CTX_PTR ctx, const char *algorithm, const char *properties), (ctx, algorithm, properties)) \ DEFINEFUNC_3_0(void, EVP_MD_free, (GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC_3_0(const char *, EVP_MD_get0_name, (const GO_EVP_MD_PTR md), (md)) \ +DEFINEFUNC_3_0(const GO_OSSL_PROVIDER_PTR, EVP_MD_get0_provider, (const GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC(int, RAND_bytes, (unsigned char *arg0, int arg1), (arg0, arg1)) \ DEFINEFUNC_RENAMED_1_1(GO_EVP_MD_CTX_PTR, EVP_MD_CTX_new, EVP_MD_CTX_create, (void), ()) \ DEFINEFUNC_RENAMED_1_1(void, EVP_MD_CTX_free, EVP_MD_CTX_destroy, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \