diff --git a/evp.go b/evp.go index f6a2d3e..b595ce9 100644 --- a/evp.go +++ b/evp.go @@ -48,29 +48,8 @@ func hashFuncHash(fn func() hash.Hash) (h hash.Hash, err error) { // hashToMD converts a hash.Hash implementation from this package to a GO_EVP_MD_PTR. func hashToMD(h hash.Hash) C.GO_EVP_MD_PTR { - var ch crypto.Hash - switch h.(type) { - case *sha1Hash, *sha1Marshal: - ch = crypto.SHA1 - case *sha224Hash, *sha224Marshal: - ch = crypto.SHA224 - case *sha256Hash, *sha256Marshal: - ch = crypto.SHA256 - case *sha384Hash, *sha384Marshal: - ch = crypto.SHA384 - case *sha512Hash, *sha512Marshal: - ch = crypto.SHA512 - case *sha3_224Hash: - ch = crypto.SHA3_224 - case *sha3_256Hash: - ch = crypto.SHA3_256 - case *sha3_384Hash: - ch = crypto.SHA3_384 - case *sha3_512Hash: - ch = crypto.SHA3_512 - } - if ch != 0 { - return cryptoHashToMD(ch) + if h, ok := h.(*evpHash); ok { + return h.alg.md } return nil } @@ -89,78 +68,109 @@ func hashFuncToMD(fn func() hash.Hash) (C.GO_EVP_MD_PTR, error) { return md, nil } -// cryptoHashToMD converts a crypto.Hash to a EVP_MD. -func cryptoHashToMD(ch crypto.Hash) C.GO_EVP_MD_PTR { +type hashAlgorithm struct { + md C.GO_EVP_MD_PTR + ch crypto.Hash + size int + blockSize int + marshallable bool + magic string + marshalledSize int +} + +// loadHash converts a crypto.Hash to a EVP_MD. +func loadHash(ch crypto.Hash) *hashAlgorithm { if v, ok := cacheMD.Load(ch); ok { - return v.(C.GO_EVP_MD_PTR) + return v.(*hashAlgorithm) } - var md C.GO_EVP_MD_PTR + + var hash hashAlgorithm switch ch { case crypto.RIPEMD160: - md = C.go_openssl_EVP_ripemd160() + hash.md = C.go_openssl_EVP_ripemd160() case crypto.MD4: - md = C.go_openssl_EVP_md4() + hash.md = C.go_openssl_EVP_md4() case crypto.MD5: - md = C.go_openssl_EVP_md5() + hash.md = C.go_openssl_EVP_md5() + hash.magic = md5Magic + hash.marshalledSize = md5MarshaledSize case crypto.MD5SHA1: if vMajor == 1 && vMinor == 0 { - md = C.go_openssl_EVP_md5_sha1_backport() + hash.md = C.go_openssl_EVP_md5_sha1_backport() } else { - md = C.go_openssl_EVP_md5_sha1() + hash.md = C.go_openssl_EVP_md5_sha1() } case crypto.SHA1: - md = C.go_openssl_EVP_sha1() + hash.md = C.go_openssl_EVP_sha1() + hash.magic = sha1Magic + hash.marshalledSize = sha1MarshaledSize case crypto.SHA224: - md = C.go_openssl_EVP_sha224() + hash.md = C.go_openssl_EVP_sha224() + hash.magic = magic224 + hash.marshalledSize = marshaledSize256 case crypto.SHA256: - md = C.go_openssl_EVP_sha256() + hash.md = C.go_openssl_EVP_sha256() + hash.magic = magic256 + hash.marshalledSize = marshaledSize256 case crypto.SHA384: - md = C.go_openssl_EVP_sha384() + hash.md = C.go_openssl_EVP_sha384() + hash.magic = magic384 + hash.marshalledSize = marshaledSize512 case crypto.SHA512: - md = C.go_openssl_EVP_sha512() + hash.md = C.go_openssl_EVP_sha512() + hash.magic = magic512 + hash.marshalledSize = marshaledSize512 case crypto.SHA512_224: if versionAtOrAbove(1, 1, 1) { - md = C.go_openssl_EVP_sha512_224() + hash.md = C.go_openssl_EVP_sha512_224() + hash.magic = magic512_224 + hash.marshalledSize = marshaledSize512 } case crypto.SHA512_256: if versionAtOrAbove(1, 1, 1) { - md = C.go_openssl_EVP_sha512_256() + hash.md = C.go_openssl_EVP_sha512_256() + hash.magic = magic512_256 + hash.marshalledSize = marshaledSize512 } case crypto.SHA3_224: if versionAtOrAbove(1, 1, 1) { - md = C.go_openssl_EVP_sha3_224() + hash.md = C.go_openssl_EVP_sha3_224() } case crypto.SHA3_256: if versionAtOrAbove(1, 1, 1) { - md = C.go_openssl_EVP_sha3_256() + hash.md = C.go_openssl_EVP_sha3_256() } case crypto.SHA3_384: if versionAtOrAbove(1, 1, 1) { - md = C.go_openssl_EVP_sha3_384() + hash.md = C.go_openssl_EVP_sha3_384() } case crypto.SHA3_512: if versionAtOrAbove(1, 1, 1) { - md = C.go_openssl_EVP_sha3_512() + hash.md = C.go_openssl_EVP_sha3_512() } } - if md == nil { - cacheMD.Store(ch, nil) + if hash.md == nil { + cacheMD.Store(ch, (*hashAlgorithm)(nil)) return nil } + hash.ch = ch + hash.size = int(C.go_openssl_EVP_MD_get_size(hash.md)) + hash.blockSize = int(C.go_openssl_EVP_MD_get_block_size(hash.md)) if vMajor == 3 { // On OpenSSL 3, directly operating on a EVP_MD object // not created by EVP_MD_fetch has negative performance // implications, as digest operations will have // to fetch it on every call. Better to just fetch it once here. - md1 := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(md), nil) + md := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(hash.md), nil) // Don't overwrite md in case it can't be fetched, as the md may still be used // outside of EVP_MD_CTX, for example to sign and verify RSA signatures. - if md1 != nil { - md = md1 + if md != nil { + hash.md = md } } - cacheMD.Store(ch, md) - return md + hash.marshallable = hash.magic != "" && isHashMarshallable(hash.md) + cacheMD.Store(ch, &hash) + return &hash } // generateEVPPKey generates a new EVP_PKEY with the given id and properties. @@ -302,11 +312,11 @@ func setupEVP(withKey withKeyFunc, padding C.int, } } case C.GO_RSA_PKCS1_PSS_PADDING: - md := cryptoHashToMD(ch) - if md == nil { + alg := loadHash(ch) + if alg == nil { return nil, errors.New("crypto/rsa: unsupported hash function") } - if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, C.GO_EVP_PKEY_RSA, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(md)) != 1 { + if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, C.GO_EVP_PKEY_RSA, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(alg.md)) != 1 { return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed") } // setPadding must happen after setting EVP_PKEY_CTRL_MD. @@ -322,11 +332,11 @@ func setupEVP(withKey withKeyFunc, padding C.int, case C.GO_RSA_PKCS1_PADDING: if ch != 0 { // We support unhashed messages. - md := cryptoHashToMD(ch) - if md == nil { + alg := loadHash(ch) + if alg == nil { return nil, errors.New("crypto/rsa: unsupported hash function") } - if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(md)) != 1 { + if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(alg.md)) != 1 { return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed") } if err := setPadding(); err != nil { @@ -441,8 +451,8 @@ func evpVerify(withKey withKeyFunc, padding C.int, saltLen C.int, h crypto.Hash, } func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error) { - md := cryptoHashToMD(h) - if md == nil { + alg := loadHash(h) + if alg == nil { return nil, errors.New("unsupported hash function: " + strconv.Itoa(int(h))) } var out []byte @@ -453,7 +463,7 @@ func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error) } defer C.go_openssl_EVP_MD_CTX_free(ctx) if withKey(func(key C.GO_EVP_PKEY_PTR) C.int { - return C.go_openssl_EVP_DigestSignInit(ctx, nil, md, nil, key) + return C.go_openssl_EVP_DigestSignInit(ctx, nil, alg.md, nil, key) }) != 1 { return nil, newOpenSSLError("EVP_DigestSignInit failed") } @@ -473,8 +483,8 @@ func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error) } func evpHashVerify(withKey withKeyFunc, h crypto.Hash, msg, sig []byte) error { - md := cryptoHashToMD(h) - if md == nil { + alg := loadHash(h) + if alg == nil { return errors.New("unsupported hash function: " + strconv.Itoa(int(h))) } ctx := C.go_openssl_EVP_MD_CTX_new() @@ -483,7 +493,7 @@ func evpHashVerify(withKey withKeyFunc, h crypto.Hash, msg, sig []byte) error { } defer C.go_openssl_EVP_MD_CTX_free(ctx) if withKey(func(key C.GO_EVP_PKEY_PTR) C.int { - return C.go_openssl_EVP_DigestVerifyInit(ctx, nil, md, nil, key) + return C.go_openssl_EVP_DigestVerifyInit(ctx, nil, alg.md, nil, key) }) != 1 { return newOpenSSLError("EVP_DigestVerifyInit failed") } diff --git a/hash.go b/hash.go index 379dd1b..d01dd68 100644 --- a/hash.go +++ b/hash.go @@ -14,6 +14,9 @@ import ( "unsafe" ) +// maxHashSize is the size of SHA52 and SHA3_512, the largest hashes we support. +const maxHashSize = 64 + // NOTE: Implementation ported from https://go-review.googlesource.com/c/go/+/404295. // The cgo calls in this file are arranged to avoid marking the parameters as escaping. // To do that, we call noescape (including via addr). @@ -26,7 +29,7 @@ import ( // This is all to preserve compatibility with the allocation behavior of the non-openssl implementations. func hashOneShot(ch crypto.Hash, p []byte, sum []byte) bool { - return C.go_openssl_EVP_Digest(unsafe.Pointer(&*addr(p)), C.size_t(len(p)), (*C.uchar)(unsafe.Pointer(&*addr(sum))), nil, cryptoHashToMD(ch), nil) != 0 + return C.go_openssl_EVP_Digest(unsafe.Pointer(&*addr(p)), C.size_t(len(p)), (*C.uchar)(unsafe.Pointer(&*addr(sum))), nil, loadHash(ch).md, nil) != 0 } func MD4(p []byte) (sum [16]byte) { @@ -86,8 +89,8 @@ func SupportsHash(h crypto.Hash) bool { if v, ok := cacheHashSupported.Load(h); ok { return v.(bool) } - md := cryptoHashToMD(h) - if md == nil { + alg := loadHash(h) + if alg == nil { cacheHashSupported.Store(h, false) return false } @@ -96,7 +99,7 @@ func SupportsHash(h crypto.Hash) bool { // if they can be used by passing them to a EVP_MD_CTX. var supported bool if ctx := C.go_openssl_EVP_MD_CTX_new(); ctx != nil { - supported = C.go_openssl_EVP_DigestInit_ex(ctx, md, nil) == 1 + supported = C.go_openssl_EVP_DigestInit_ex(ctx, alg.md, nil) == 1 C.go_openssl_EVP_MD_CTX_free(ctx) } cacheHashSupported.Store(h, supported) @@ -131,21 +134,69 @@ func SHA3_512(p []byte) (sum [64]byte) { return } -var isMarshallableCache sync.Map +// NewMD4 returns a new MD4 hash. +// The returned hash doesn't implement encoding.BinaryMarshaler and +// encoding.BinaryUnmarshaler. +func NewMD4() hash.Hash { + return newEvpHash(crypto.MD4) +} + +// NewMD5 returns a new MD5 hash. +func NewMD5() hash.Hash { + return newEvpHash(crypto.MD5) +} + +// NewSHA1 returns a new SHA1 hash. +func NewSHA1() hash.Hash { + return newEvpHash(crypto.SHA1) +} + +// NewSHA224 returns a new SHA224 hash. +func NewSHA224() hash.Hash { + return newEvpHash(crypto.SHA224) +} + +// NewSHA256 returns a new SHA256 hash. +func NewSHA256() hash.Hash { + return newEvpHash(crypto.SHA256) +} + +// NewSHA384 returns a new SHA384 hash. +func NewSHA384() hash.Hash { + return newEvpHash(crypto.SHA384) +} + +// NewSHA512 returns a new SHA512 hash. +func NewSHA512() hash.Hash { + return newEvpHash(crypto.SHA512) +} + +// NewSHA3_224 returns a new SHA3-224 hash. +func NewSHA3_224() hash.Hash { + return newEvpHash(crypto.SHA3_224) +} + +// NewSHA3_256 returns a new SHA3-256 hash. +func NewSHA3_256() hash.Hash { + return newEvpHash(crypto.SHA3_256) +} -// isHashMarshallable returns true if the memory layout of cb +// NewSHA3_384 returns a new SHA3-384 hash. +func NewSHA3_384() hash.Hash { + return newEvpHash(crypto.SHA3_384) +} + +// NewSHA3_512 returns a new SHA3-512 hash. +func NewSHA3_512() hash.Hash { + return newEvpHash(crypto.SHA3_512) +} + +// isHashMarshallable returns true if the memory layout of md // is known by this library and can therefore be marshalled. -func isHashMarshallable(ch crypto.Hash) bool { +func isHashMarshallable(md C.GO_EVP_MD_PTR) bool { if vMajor == 1 { return true } - if v, ok := isMarshallableCache.Load(ch); ok { - return v.(bool) - } - md := cryptoHashToMD(ch) - if md == nil { - return false - } prov := C.go_openssl_EVP_MD_get0_provider(md) if prov == nil { return false @@ -158,40 +209,34 @@ func isHashMarshallable(ch crypto.Hash) bool { // We only know the memory layout of the built-in providers. // See evpHash.hashState for more details. marshallable := name == "default" || name == "fips" - isMarshallableCache.Store(ch, marshallable) return marshallable } // evpHash implements generic hash methods. type evpHash struct { + alg *hashAlgorithm ctx C.GO_EVP_MD_CTX_PTR // ctx2 is used in evpHash.sum to avoid changing // the state of ctx. Having it here allows reusing the // same allocated object multiple times. - ctx2 C.GO_EVP_MD_CTX_PTR - size int - blockSize int - marshallable bool + ctx2 C.GO_EVP_MD_CTX_PTR } func newEvpHash(ch crypto.Hash) *evpHash { - md := cryptoHashToMD(ch) - if md == nil { + alg := loadHash(ch) + if alg == nil { panic("openssl: unsupported hash function: " + strconv.Itoa(int(ch))) } ctx := C.go_openssl_EVP_MD_CTX_new() - if C.go_openssl_EVP_DigestInit_ex(ctx, md, nil) != 1 { + if C.go_openssl_EVP_DigestInit_ex(ctx, alg.md, nil) != 1 { C.go_openssl_EVP_MD_CTX_free(ctx) panic(newOpenSSLError("EVP_DigestInit_ex")) } ctx2 := C.go_openssl_EVP_MD_CTX_new() - blockSize := int(C.go_openssl_EVP_MD_get_block_size(md)) h := &evpHash{ - ctx: ctx, - ctx2: ctx2, - size: ch.Size(), - blockSize: blockSize, - marshallable: isHashMarshallable(ch), + alg: alg, + ctx: ctx, + ctx2: ctx2, } runtime.SetFinalizer(h, (*evpHash).finalize) return h @@ -236,24 +281,26 @@ func (h *evpHash) WriteByte(c byte) error { } func (h *evpHash) Size() int { - return h.size + return h.alg.size } func (h *evpHash) BlockSize() int { - return h.blockSize + return h.alg.blockSize } -func (h *evpHash) sum(out []byte) { +func (h *evpHash) Sum(in []byte) []byte { + defer runtime.KeepAlive(h) + out := make([]byte, h.Size(), maxHashSize) // explicit cap to allow stack allocation if C.go_hash_sum(h.ctx, h.ctx2, base(out)) != 1 { panic(newOpenSSLError("go_hash_sum")) } - runtime.KeepAlive(h) + return append(in, out...) } -// clone returns a new evpHash object that is a deep clone of itself. +// Clone returns a new evpHash object that is a deep clone of itself. // The duplicate object contains all state and data contained in the // original object at the point of duplication. -func (h *evpHash) clone() (*evpHash, error) { +func (h *evpHash) Clone() (hash.Hash, error) { ctx := C.go_openssl_EVP_MD_CTX_new() if ctx == nil { return nil, newOpenSSLError("EVP_MD_CTX_new") @@ -268,11 +315,9 @@ func (h *evpHash) clone() (*evpHash, error) { return nil, newOpenSSLError("EVP_MD_CTX_new") } cloned := &evpHash{ - ctx: ctx, - ctx2: ctx2, - size: h.size, - blockSize: h.blockSize, - marshallable: h.marshallable, + alg: h.alg, + ctx: ctx, + ctx2: ctx2, } runtime.SetFinalizer(cloned, (*evpHash).finalize) return cloned, nil @@ -282,10 +327,7 @@ func (h *evpHash) clone() (*evpHash, error) { // // The EVP_MD_CTX memory layout has changed in OpenSSL 3 // and the property holding the internal structure is no longer md_data but algctx. -func (h *evpHash) hashState() unsafe.Pointer { - if !h.marshallable { - panic("openssl: hash state is not marshallable") - } +func hashState(ctx C.GO_EVP_MD_CTX_PTR) unsafe.Pointer { switch vMajor { case 1: // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12. @@ -294,7 +336,7 @@ func (h *evpHash) hashState() unsafe.Pointer { _ C.ulong md_data unsafe.Pointer } - return (*mdCtx)(unsafe.Pointer(h.ctx)).md_data + return (*mdCtx)(unsafe.Pointer(ctx)).md_data case 3: // https://github.com/openssl/openssl/blob/5675a5aaf6a2e489022bcfc18330dae9263e598e/crypto/evp/evp_local.h#L16. type mdCtx struct { @@ -303,49 +345,86 @@ func (h *evpHash) hashState() unsafe.Pointer { _ [3]unsafe.Pointer algctx unsafe.Pointer } - return (*mdCtx)(unsafe.Pointer(h.ctx)).algctx + return (*mdCtx)(unsafe.Pointer(ctx)).algctx default: panic(errUnsupportedVersion()) } } -// NewMD4 returns a new MD4 hash. -// The returned hash doesn't implement encoding.BinaryMarshaler and -// encoding.BinaryUnmarshaler. -func NewMD4() hash.Hash { - return &md4Hash{ - evpHash: newEvpHash(crypto.MD4), +func (d *evpHash) MarshalBinary() ([]byte, error) { + if !d.alg.marshallable { + return nil, errors.New("openssl: hash state is not marshallable") } + buf := make([]byte, 0, d.alg.marshalledSize) + return d.AppendBinary(buf) } -type md4Hash struct { - *evpHash - out [16]byte -} - -func (h *md4Hash) Sum(in []byte) []byte { - h.sum(h.out[:]) - return append(in, h.out[:]...) -} - -// Clone returns a new [hash.Hash] object that is a deep clone of itself. -// The duplicate object contains all state and data contained in the -// original object at the point of duplication. -func (h *md4Hash) Clone() (hash.Hash, error) { - c, err := h.clone() - if err != nil { - return nil, err +func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) { + if !d.alg.marshallable { + return nil, errors.New("openssl: hash state is not marshallable") } - return &md4Hash{evpHash: c}, nil -} - -// NewMD5 returns a new MD5 hash. -func NewMD5() hash.Hash { - h := md5Hash{evpHash: newEvpHash(crypto.MD5)} - if h.marshallable { - return &md5Marshal{h} + state := hashState(d.ctx) + if state == nil { + return nil, errors.New("openssl: can't retrieve hash state") + } + var appender interface { + AppendBinary([]byte) ([]byte, error) + } + switch d.alg.ch { + case crypto.MD5: + appender = (*md5State)(state) + case crypto.SHA1: + appender = (*sha1State)(state) + case crypto.SHA224: + appender = (*sha256State)(state) + case crypto.SHA256: + appender = (*sha256State)(state) + case crypto.SHA384: + appender = (*sha512State)(state) + case crypto.SHA512: + appender = (*sha512State)(state) + default: + panic("openssl: unsupported hash function: " + strconv.Itoa(int(d.alg.ch))) + } + buf = append(buf, d.alg.magic[:]...) + return appender.AppendBinary(buf) +} + +func (d *evpHash) UnmarshalBinary(b []byte) error { + if !d.alg.marshallable { + return errors.New("openssl: hash state is not marshallable") + } + if len(b) < len(d.alg.magic) || string(b[:len(d.alg.magic)]) != string(d.alg.magic[:]) { + return errors.New("openssl: invalid hash state identifier") + } + if len(b) != d.alg.marshalledSize { + return errors.New("openssl: invalid hash state size") + } + state := hashState(d.ctx) + if state == nil { + return errors.New("openssl: can't retrieve hash state") + } + b = b[len(d.alg.magic):] + var unmarshaler interface { + UnmarshalBinary([]byte) error + } + switch d.alg.ch { + case crypto.MD5: + unmarshaler = (*md5State)(state) + case crypto.SHA1: + unmarshaler = (*sha1State)(state) + case crypto.SHA224: + unmarshaler = (*sha256State)(state) + case crypto.SHA256: + unmarshaler = (*sha256State)(state) + case crypto.SHA384: + unmarshaler = (*sha512State)(state) + case crypto.SHA512: + unmarshaler = (*sha512State)(state) + default: + panic("openssl: unsupported hash function: " + strconv.Itoa(int(d.alg.ch))) } - return &h + return unmarshaler.UnmarshalBinary(b) } // md5State layout is taken from @@ -357,53 +436,12 @@ type md5State struct { nx uint32 } -type md5Hash struct { - *evpHash - out [16]byte -} - -func (h *md5Hash) Sum(in []byte) []byte { - h.sum(h.out[:]) - return append(in, h.out[:]...) -} - -// Clone returns a new [hash.Hash] object that is a deep clone of itself. -// The duplicate object contains all state and data contained in the -// original object at the point of duplication. -func (h *md5Hash) Clone() (hash.Hash, error) { - c, err := h.clone() - if err != nil { - return nil, err - } - return &md5Hash{evpHash: c}, nil -} - const ( md5Magic = "md5\x01" md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8 ) -type md5Marshal struct { - md5Hash -} - -func (h *md5Marshal) MarshalBinary() ([]byte, error) { - buf := make([]byte, 0, md5MarshaledSize) - return h.AppendBinary(buf) -} - -func (h *md5Marshal) UnmarshalBinary(b []byte) error { - if len(b) < len(md5Magic) || string(b[:len(md5Magic)]) != md5Magic { - return errors.New("crypto/md5: invalid hash state identifier") - } - if len(b) != md5MarshaledSize { - return errors.New("crypto/md5: invalid hash state size") - } - d := (*md5State)(h.hashState()) - if d == nil { - return errors.New("crypto/md5: can't retrieve hash state") - } - b = b[len(md5Magic):] +func (d *md5State) UnmarshalBinary(b []byte) error { b, d.h[0] = consumeUint32(b) b, d.h[1] = consumeUint32(b) b, d.h[2] = consumeUint32(b) @@ -416,13 +454,7 @@ func (h *md5Marshal) UnmarshalBinary(b []byte) error { return nil } -func (h *md5Marshal) AppendBinary(buf []byte) ([]byte, error) { - d := (*md5State)(h.hashState()) - if d == nil { - return nil, errors.New("crypto/md5: can't retrieve hash state") - } - - buf = append(buf, md5Magic...) +func (d *md5State) AppendBinary(buf []byte) ([]byte, error) { buf = appendUint32(buf, d.h[0]) buf = appendUint32(buf, d.h[1]) buf = appendUint32(buf, d.h[2]) @@ -433,36 +465,6 @@ func (h *md5Marshal) AppendBinary(buf []byte) ([]byte, error) { return buf, nil } -// NewSHA1 returns a new SHA1 hash. -func NewSHA1() hash.Hash { - h := sha1Hash{evpHash: newEvpHash(crypto.SHA1)} - if h.marshallable { - return &sha1Marshal{h} - } - return &h -} - -type sha1Hash struct { - *evpHash - out [20]byte -} - -func (h *sha1Hash) Sum(in []byte) []byte { - h.sum(h.out[:]) - return append(in, h.out[:]...) -} - -// Clone returns a new [hash.Hash] object that is a deep clone of itself. -// The duplicate object contains all state and data contained in the -// original object at the point of duplication. -func (h *sha1Hash) Clone() (hash.Hash, error) { - c, err := h.clone() - if err != nil { - return nil, err - } - return &sha1Hash{evpHash: c}, nil -} - // sha1State layout is taken from // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L34. type sha1State struct { @@ -477,27 +479,7 @@ const ( sha1MarshaledSize = len(sha1Magic) + 5*4 + 64 + 8 ) -type sha1Marshal struct { - sha1Hash -} - -func (h *sha1Marshal) MarshalBinary() ([]byte, error) { - buf := make([]byte, 0, sha1MarshaledSize) - return h.AppendBinary(buf) -} - -func (h *sha1Marshal) UnmarshalBinary(b []byte) error { - if len(b) < len(sha1Magic) || string(b[:len(sha1Magic)]) != sha1Magic { - return errors.New("crypto/sha1: invalid hash state identifier") - } - if len(b) != sha1MarshaledSize { - return errors.New("crypto/sha1: invalid hash state size") - } - d := (*sha1State)(h.hashState()) - if d == nil { - return errors.New("crypto/sha1: can't retrieve hash state") - } - b = b[len(sha1Magic):] +func (d *sha1State) UnmarshalBinary(b []byte) error { b, d.h[0] = consumeUint32(b) b, d.h[1] = consumeUint32(b) b, d.h[2] = consumeUint32(b) @@ -511,12 +493,7 @@ func (h *sha1Marshal) UnmarshalBinary(b []byte) error { return nil } -func (h *sha1Marshal) AppendBinary(buf []byte) ([]byte, error) { - d := (*sha1State)(h.hashState()) - if d == nil { - return nil, errors.New("crypto/sha1: can't retrieve hash state") - } - buf = append(buf, sha1Magic...) +func (d *sha1State) AppendBinary(buf []byte) ([]byte, error) { buf = appendUint32(buf, d.h[0]) buf = appendUint32(buf, d.h[1]) buf = appendUint32(buf, d.h[2]) @@ -528,66 +505,6 @@ func (h *sha1Marshal) AppendBinary(buf []byte) ([]byte, error) { return buf, nil } -// NewSHA224 returns a new SHA224 hash. -func NewSHA224() hash.Hash { - h := sha224Hash{evpHash: newEvpHash(crypto.SHA224)} - if h.marshallable { - return &sha224Marshal{h} - } - return &h -} - -type sha224Hash struct { - *evpHash - out [224 / 8]byte -} - -func (h *sha224Hash) Sum(in []byte) []byte { - h.sum(h.out[:]) - return append(in, h.out[:]...) -} - -// Clone returns a new [hash.Hash] object that is a deep clone of itself. -// The duplicate object contains all state and data contained in the -// original object at the point of duplication. -func (h *sha224Hash) Clone() (hash.Hash, error) { - c, err := h.clone() - if err != nil { - return nil, err - } - return &sha224Hash{evpHash: c}, nil -} - -// NewSHA256 returns a new SHA256 hash. -func NewSHA256() hash.Hash { - h := sha256Hash{evpHash: newEvpHash(crypto.SHA256)} - if h.marshallable { - return &sha256Marshal{h} - } - return &h -} - -type sha256Hash struct { - *evpHash - out [256 / 8]byte -} - -func (h *sha256Hash) Sum(in []byte) []byte { - h.sum(h.out[:]) - return append(in, h.out[:]...) -} - -// Clone returns a new [hash.Hash] object that is a deep clone of itself. -// The duplicate object contains all state and data contained in the -// original object at the point of duplication. -func (h *sha256Hash) Clone() (hash.Hash, error) { - c, err := h.clone() - if err != nil { - return nil, err - } - return &sha256Hash{evpHash: c}, nil -} - const ( magic224 = "sha\x02" magic256 = "sha\x03" @@ -603,36 +520,7 @@ type sha256State struct { nx uint32 } -type sha224Marshal struct { - sha224Hash -} - -type sha256Marshal struct { - sha256Hash -} - -func (h *sha224Marshal) MarshalBinary() ([]byte, error) { - buf := make([]byte, 0, marshaledSize256) - return h.AppendBinary(buf) -} - -func (h *sha256Marshal) MarshalBinary() ([]byte, error) { - buf := make([]byte, 0, marshaledSize256) - return h.AppendBinary(buf) -} - -func (h *sha224Marshal) UnmarshalBinary(b []byte) error { - if len(b) < len(magic224) || string(b[:len(magic224)]) != magic224 { - return errors.New("crypto/sha256: invalid hash state identifier") - } - if len(b) != marshaledSize256 { - return errors.New("crypto/sha256: invalid hash state size") - } - d := (*sha256State)(h.hashState()) - if d == nil { - return errors.New("crypto/sha256: can't retrieve hash state") - } - b = b[len(magic224):] +func (d *sha256State) UnmarshalBinary(b []byte) error { b, d.h[0] = consumeUint32(b) b, d.h[1] = consumeUint32(b) b, d.h[2] = consumeUint32(b) @@ -649,60 +537,7 @@ func (h *sha224Marshal) UnmarshalBinary(b []byte) error { return nil } -func (h *sha256Marshal) UnmarshalBinary(b []byte) error { - if len(b) < len(magic256) || string(b[:len(magic256)]) != magic256 { - return errors.New("crypto/sha256: invalid hash state identifier") - } - if len(b) != marshaledSize256 { - return errors.New("crypto/sha256: invalid hash state size") - } - d := (*sha256State)(h.hashState()) - if d == nil { - return errors.New("crypto/sha256: can't retrieve hash state") - } - b = b[len(magic256):] - b, d.h[0] = consumeUint32(b) - b, d.h[1] = consumeUint32(b) - b, d.h[2] = consumeUint32(b) - b, d.h[3] = consumeUint32(b) - b, d.h[4] = consumeUint32(b) - b, d.h[5] = consumeUint32(b) - b, d.h[6] = consumeUint32(b) - b, d.h[7] = consumeUint32(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = uint32(n << 3) - d.nh = uint32(n >> 29) - d.nx = uint32(n) % 64 - return nil -} - -func (h *sha224Marshal) AppendBinary(buf []byte) ([]byte, error) { - d := (*sha256State)(h.hashState()) - if d == nil { - return nil, errors.New("crypto/sha256: can't retrieve hash state") - } - buf = append(buf, magic224...) - buf = appendUint32(buf, d.h[0]) - buf = appendUint32(buf, d.h[1]) - buf = appendUint32(buf, d.h[2]) - buf = appendUint32(buf, d.h[3]) - buf = appendUint32(buf, d.h[4]) - buf = appendUint32(buf, d.h[5]) - buf = appendUint32(buf, d.h[6]) - buf = appendUint32(buf, d.h[7]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, uint64(d.nl)>>3|uint64(d.nh)<<29) - return buf, nil -} - -func (h *sha256Marshal) AppendBinary(buf []byte) ([]byte, error) { - d := (*sha256State)(h.hashState()) - if d == nil { - return nil, errors.New("crypto/sha256: can't retrieve hash state") - } - buf = append(buf, magic256...) +func (d *sha256State) AppendBinary(buf []byte) ([]byte, error) { buf = appendUint32(buf, d.h[0]) buf = appendUint32(buf, d.h[1]) buf = appendUint32(buf, d.h[2]) @@ -717,66 +552,6 @@ func (h *sha256Marshal) AppendBinary(buf []byte) ([]byte, error) { return buf, nil } -// NewSHA384 returns a new SHA384 hash. -func NewSHA384() hash.Hash { - h := sha384Hash{evpHash: newEvpHash(crypto.SHA384)} - if h.marshallable { - return &sha384Marshal{h} - } - return &h -} - -type sha384Hash struct { - *evpHash - out [384 / 8]byte -} - -func (h *sha384Hash) Sum(in []byte) []byte { - h.sum(h.out[:]) - return append(in, h.out[:]...) -} - -// Clone returns a new [hash.Hash] object that is a deep clone of itself. -// The duplicate object contains all state and data contained in the -// original object at the point of duplication. -func (h *sha384Hash) Clone() (hash.Hash, error) { - c, err := h.clone() - if err != nil { - return nil, err - } - return &sha384Hash{evpHash: c}, nil -} - -// NewSHA512 returns a new SHA512 hash. -func NewSHA512() hash.Hash { - h := sha512Hash{evpHash: newEvpHash(crypto.SHA512)} - if h.marshallable { - return &sha512Marshal{h} - } - return &h -} - -type sha512Hash struct { - *evpHash - out [512 / 8]byte -} - -func (h *sha512Hash) Sum(in []byte) []byte { - h.sum(h.out[:]) - return append(in, h.out[:]...) -} - -// Clone returns a new [hash.Hash] object that is a deep clone of itself. -// The duplicate object contains all state and data contained in the -// original object at the point of duplication. -func (h *sha512Hash) Clone() (hash.Hash, error) { - c, err := h.clone() - if err != nil { - return nil, err - } - return &sha512Hash{evpHash: c}, nil -} - // sha512State layout is taken from // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L95. type sha512State struct { @@ -794,39 +569,12 @@ const ( marshaledSize512 = len(magic512) + 8*8 + 128 + 8 ) -type sha384Marshal struct { - sha384Hash -} - -type sha512Marshal struct { - sha512Hash -} - -func (h *sha384Marshal) MarshalBinary() ([]byte, error) { +func (d *sha512State) MarshalBinary() ([]byte, error) { buf := make([]byte, 0, marshaledSize512) - return h.AppendBinary(buf) + return d.AppendBinary(buf) } -func (h *sha512Marshal) MarshalBinary() ([]byte, error) { - buf := make([]byte, 0, marshaledSize512) - return h.AppendBinary(buf) -} - -func (h *sha384Marshal) UnmarshalBinary(b []byte) error { - if len(b) < len(magic512) { - return errors.New("crypto/sha512: invalid hash state identifier") - } - if string(b[:len(magic384)]) != magic384 { - return errors.New("crypto/sha512: invalid hash state identifier") - } - if len(b) != marshaledSize512 { - return errors.New("crypto/sha512: invalid hash state size") - } - d := (*sha512State)(h.hashState()) - if d == nil { - return errors.New("crypto/sha512: can't retrieve hash state") - } - b = b[len(magic512):] +func (d *sha512State) UnmarshalBinary(b []byte) error { b, d.h[0] = consumeUint64(b) b, d.h[1] = consumeUint64(b) b, d.h[2] = consumeUint64(b) @@ -843,43 +591,7 @@ func (h *sha384Marshal) UnmarshalBinary(b []byte) error { return nil } -func (h *sha512Marshal) UnmarshalBinary(b []byte) error { - if len(b) < len(magic512) { - return errors.New("crypto/sha512: invalid hash state identifier") - } - if string(b[:len(magic512)]) != magic512 { - return errors.New("crypto/sha512: invalid hash state identifier") - } - if len(b) != marshaledSize512 { - return errors.New("crypto/sha512: invalid hash state size") - } - d := (*sha512State)(h.hashState()) - if d == nil { - return errors.New("crypto/sha512: can't retrieve hash state") - } - b = b[len(magic512):] - b, d.h[0] = consumeUint64(b) - b, d.h[1] = consumeUint64(b) - b, d.h[2] = consumeUint64(b) - b, d.h[3] = consumeUint64(b) - b, d.h[4] = consumeUint64(b) - b, d.h[5] = consumeUint64(b) - b, d.h[6] = consumeUint64(b) - b, d.h[7] = consumeUint64(b) - b = b[copy(d.x[:], b):] - _, n := consumeUint64(b) - d.nl = n << 3 - d.nh = n >> 61 - d.nx = uint32(n) % 128 - return nil -} - -func (h *sha384Marshal) AppendBinary(buf []byte) ([]byte, error) { - d := (*sha512State)(h.hashState()) - if d == nil { - return nil, errors.New("crypto/sha512: can't retrieve hash state") - } - buf = append(buf, magic384...) +func (d *sha512State) AppendBinary(buf []byte) ([]byte, error) { buf = appendUint64(buf, d.h[0]) buf = appendUint64(buf, d.h[1]) buf = appendUint64(buf, d.h[2]) @@ -894,138 +606,6 @@ func (h *sha384Marshal) AppendBinary(buf []byte) ([]byte, error) { return buf, nil } -func (h *sha512Marshal) AppendBinary(buf []byte) ([]byte, error) { - d := (*sha512State)(h.hashState()) - if d == nil { - return nil, errors.New("crypto/sha512: can't retrieve hash state") - } - buf = append(buf, magic512...) - buf = appendUint64(buf, d.h[0]) - buf = appendUint64(buf, d.h[1]) - buf = appendUint64(buf, d.h[2]) - buf = appendUint64(buf, d.h[3]) - buf = appendUint64(buf, d.h[4]) - buf = appendUint64(buf, d.h[5]) - buf = appendUint64(buf, d.h[6]) - buf = appendUint64(buf, d.h[7]) - buf = append(buf, d.x[:d.nx]...) - buf = append(buf, make([]byte, len(d.x)-int(d.nx))...) - buf = appendUint64(buf, d.nl>>3|d.nh<<61) - return buf, nil -} - -// NewSHA3_224 returns a new SHA3-224 hash. -func NewSHA3_224() hash.Hash { - return &sha3_224Hash{ - evpHash: newEvpHash(crypto.SHA3_224), - } -} - -type sha3_224Hash struct { - *evpHash - out [224 / 8]byte -} - -func (h *sha3_224Hash) Sum(in []byte) []byte { - h.sum(h.out[:]) - return append(in, h.out[:]...) -} - -// Clone returns a new [hash.Hash] object that is a deep clone of itself. -// The duplicate object contains all state and data contained in the -// original object at the point of duplication. -func (h *sha3_224Hash) Clone() (hash.Hash, error) { - c, err := h.clone() - if err != nil { - return nil, err - } - return &sha3_224Hash{evpHash: c}, nil -} - -// NewSHA3_256 returns a new SHA3-256 hash. -func NewSHA3_256() hash.Hash { - return &sha3_256Hash{ - evpHash: newEvpHash(crypto.SHA3_256), - } -} - -type sha3_256Hash struct { - *evpHash - out [256 / 8]byte -} - -func (h *sha3_256Hash) Sum(in []byte) []byte { - h.sum(h.out[:]) - return append(in, h.out[:]...) -} - -// Clone returns a new [hash.Hash] object that is a deep clone of itself. -// The duplicate object contains all state and data contained in the -// original object at the point of duplication. -func (h *sha3_256Hash) Clone() (hash.Hash, error) { - c, err := h.clone() - if err != nil { - return nil, err - } - return &sha3_256Hash{evpHash: c}, nil -} - -// NewSHA3_384 returns a new SHA3-384 hash. -func NewSHA3_384() hash.Hash { - return &sha3_384Hash{ - evpHash: newEvpHash(crypto.SHA3_384), - } -} - -type sha3_384Hash struct { - *evpHash - out [384 / 8]byte -} - -func (h *sha3_384Hash) Sum(in []byte) []byte { - h.sum(h.out[:]) - return append(in, h.out[:]...) -} - -// Clone returns a new [hash.Hash] object that is a deep clone of itself. -// The duplicate object contains all state and data contained in the -// original object at the point of duplication. -func (h *sha3_384Hash) Clone() (hash.Hash, error) { - c, err := h.clone() - if err != nil { - return nil, err - } - return &sha3_384Hash{evpHash: c}, nil -} - -// NewSHA3_512 returns a new SHA3-512 hash. -func NewSHA3_512() hash.Hash { - return &sha3_512Hash{ - evpHash: newEvpHash(crypto.SHA3_512), - } -} - -type sha3_512Hash struct { - *evpHash - out [512 / 8]byte -} - -func (h *sha3_512Hash) Sum(in []byte) []byte { - h.sum(h.out[:]) - return append(in, h.out[:]...) -} - -// Clone returns a new [hash.Hash] object that is a deep clone of itself. -// The duplicate object contains all state and data contained in the -// original object at the point of duplication. -func (h *sha3_512Hash) Clone() (hash.Hash, error) { - c, err := h.clone() - if err != nil { - return nil, err - } - return &sha3_512Hash{evpHash: c}, nil -} - // appendUint64 appends x into b as a big endian byte sequence. func appendUint64(b []byte, x uint64) []byte { return append(b, diff --git a/hash_test.go b/hash_test.go index 18d4707..948be94 100644 --- a/hash_test.go +++ b/hash_test.go @@ -6,6 +6,7 @@ import ( "encoding" "hash" "io" + "strings" "testing" "github.com/golang-fips/openssl/v2" @@ -100,7 +101,7 @@ func TestHash_BinaryMarshaler(t *testing.T) { encoding.BinaryMarshaler }) if !ok { - t.Skip("BinaryMarshaler not supported") + t.Fatal("BinaryMarshaler not supported") } if _, err := hashMarshaler.Write(msg); err != nil { @@ -109,6 +110,9 @@ func TestHash_BinaryMarshaler(t *testing.T) { state, err := hashMarshaler.MarshalBinary() if err != nil { + if strings.Contains(err.Error(), "hash state is not marshallable") { + t.Skip("BinaryMarshaler not supported") + } t.Fatalf("MarshalBinary failed: %v", err) } @@ -140,7 +144,7 @@ func TestHash_BinaryAppender(t *testing.T) { AppendBinary(b []byte) ([]byte, error) }) if !ok { - t.Skip("not supported") + t.Fatal("AppendBinary not supported") } // Create a slice with 10 elements @@ -156,6 +160,9 @@ func TestHash_BinaryAppender(t *testing.T) { // Append binary data to the prebuilt slice state, err := hashWithBinaryAppender.AppendBinary(prebuiltSlice) if err != nil { + if strings.Contains(err.Error(), "hash state is not marshallable") { + t.Skip("AppendBinary not supported") + } t.Errorf("could not append binary: %v", err) } diff --git a/tls1prf.go b/tls1prf.go index f342f22..3313454 100644 --- a/tls1prf.go +++ b/tls1prf.go @@ -35,7 +35,7 @@ func TLS1PRF(result, secret, label, seed []byte, fh func() hash.Hash) error { // that the caller wants to use TLS 1.0/1.1 PRF. // OpenSSL detects this case by checking if the hash // function is MD5SHA1. - md = cryptoHashToMD(crypto.MD5SHA1) + md = loadHash(crypto.MD5SHA1).md } else { h, err := hashFuncHash(fh) if err != nil {