Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AES-GCM decryption on OpenSSL 1.0.2-fips #111

Merged
merged 6 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 46 additions & 19 deletions aes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,52 @@ func TestNewGCMNonce(t *testing.T) {
}

func TestSealAndOpen(t *testing.T) {
key := []byte("D249BF6DEC97B1EBD69BC4D6B3A3C49D")
ci, err := openssl.NewAESCipher(key)
if err != nil {
t.Fatal(err)
}
gcm, err := cipher.NewGCM(ci)
if err != nil {
t.Fatal(err)
}
nonce := []byte{0x91, 0xc7, 0xa7, 0x54, 0x52, 0xef, 0x10, 0xdb, 0x91, 0xa8, 0x6c, 0xf9}
plainText := []byte{0x01, 0x02, 0x03}
additionalData := []byte{0x05, 0x05, 0x07}
sealed := gcm.Seal(nil, nonce, plainText, additionalData)
decrypted, err := gcm.Open(nil, nonce, sealed, additionalData)
if err != nil {
t.Error(err)
}
if !bytes.Equal(decrypted, plainText) {
t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, plainText)
for _, tt := range aesGCMTests {
t.Run(tt.description, func(t *testing.T) {
ci, err := openssl.NewAESCipher(tt.key)
if err != nil {
t.Fatalf("NewAESCipher() err = %v", err)
}
gcm, err := cipher.NewGCM(ci)
if err != nil {
t.Fatalf("cipher.NewGCM() err = %v", err)
}

sealed := gcm.Seal(nil, tt.nonce, tt.plaintext, tt.aad)
if !bytes.Equal(sealed, tt.ciphertext) {
t.Errorf("unexpected sealed result\ngot: %#v\nexp: %#v", sealed, tt.ciphertext)
}

decrypted, err := gcm.Open(nil, tt.nonce, tt.ciphertext, tt.aad)
if err != nil {
t.Errorf("gcm.Open() err = %v", err)
}
if !bytes.Equal(decrypted, tt.plaintext) {
t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, tt.plaintext)
}

// Test that open fails if the ciphertext is modified.
tt.ciphertext[0] ^= 0x80
_, err = gcm.Open(nil, tt.nonce, tt.ciphertext, tt.aad)
if err != openssl.ErrOpen {
t.Errorf("expected authentication error for tampered message\ngot: %#v", err)
}
tt.ciphertext[0] ^= 0x80

// Test that the ciphertext can be opened using a fresh context
// which was not previously used to seal the same message.
gcm, err = cipher.NewGCM(ci)
if err != nil {
t.Fatalf("cipher.NewGCM() err = %v", err)
}
decrypted, err = gcm.Open(nil, tt.nonce, tt.ciphertext, tt.aad)
if err != nil {
t.Errorf("fresh GCM instance: gcm.Open() err = %v", err)
}
if !bytes.Equal(decrypted, tt.plaintext) {
t.Errorf("fresh GCM instance: unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, tt.plaintext)
}
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ func newCipherCtx(kind cipherKind, mode cipherMode, encrypt cipherOp, key, iv []
cipher = nil
}
if C.go_openssl_EVP_CipherInit_ex(ctx, cipher, nil, base(key), base(iv), C.int(encrypt)) != 1 {
return nil, fail("unable to initialize EVP cipher ctx")
return nil, newOpenSSLError("unable to initialize EVP cipher ctx")
}
return ctx, nil
}
Expand Down
116 changes: 116 additions & 0 deletions cmd/gentestvectors/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// gentestvectors emits cryptographic test vectors using the Go standard library
// cryptographic routines to test the OpenSSL bindings.
package main

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"flag"
"fmt"
"go/format"
"io"
"log"
"math/rand"
"os"
"path/filepath"
)

var outputPath = flag.String("out", "", "output path (default stdout)")

func init() {
log.SetFlags(log.Llongfile)
log.SetOutput(os.Stderr)
}

func main() {
flag.Parse()

var b bytes.Buffer
fmt.Fprint(&b, "// Code generated by cmd/gentestvectors. DO NOT EDIT.\n\n")
if *outputPath != "" {
fmt.Fprintf(&b, "//go"+":generate go run github.com/golang-fips/openssl/v2/cmd/gentestvectors -out %s\n\n", filepath.Base(*outputPath))
}

pkg := "openssl_test"
if gopackage := os.Getenv("GOPACKAGE"); gopackage != "" {
pkg = gopackage + "_test"
}
fmt.Fprintf(&b, "package %s\n\n", pkg)

aesGCM(&b)

generated, err := format.Source(b.Bytes())
if err != nil {
log.Fatalf("failed to format generated code: %v", err)
}

if *outputPath != "" {
err := os.WriteFile(*outputPath, generated, 0o644)
if err != nil {
log.Fatalf("failed to write output file: %v\n", err)
}
} else {
_, _ = os.Stdout.Write(generated)
}
}

func aesGCM(w io.Writer) {
r := rand.New(rand.NewSource(0))

fmt.Fprintln(w, `var aesGCMTests = []struct {
description string
key, nonce, plaintext, aad, ciphertext []byte
}{`)

for _, keyLen := range []int{16, 24, 32} {
for _, aadLen := range []int{0, 1, 3, 13, 30} {
for _, plaintextLen := range []int{0, 1, 3, 13, 16, 51} {
if aadLen == 0 && plaintextLen == 0 {
continue
}

key := randbytes(r, keyLen)
nonce := randbytes(r, 12)
plaintext := randbytes(r, plaintextLen)
aad := randbytes(r, aadLen)

c, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(c)
if err != nil {
panic(err)
}
ciphertext := aead.Seal(nil, nonce, plaintext, aad)

fmt.Fprint(w, "\t{\n")
fmt.Fprintf(w, "\t\tdescription: \"AES-%d/AAD=%d/Plaintext=%d\",\n", keyLen*8, aadLen, plaintextLen)
printBytesField(w, "key", key)
printBytesField(w, "nonce", nonce)
printBytesField(w, "plaintext", plaintext)
printBytesField(w, "aad", aad)
printBytesField(w, "ciphertext", ciphertext)
fmt.Fprint(w, "\t},\n")
}
}
}
fmt.Fprintln(w, "}")
}

func randbytes(r *rand.Rand, n int) []byte {
if n == 0 {
return nil
}
b := make([]byte, n)
r.Read(b)
return b
}

func printBytesField(w io.Writer, name string, b []byte) {
if len(b) == 0 {
return
}
fmt.Fprintf(w, "\t\t%s: %#v,\n", name, b)
}
4 changes: 2 additions & 2 deletions des.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import (
// If CBC is also supported, then the returned cipher.Block
// will also implement NewCBCEncrypter and NewCBCDecrypter.
func SupportsDESCipher() bool {
// True for stock OpenSSL 1.
// True for stock OpenSSL 1 w/o FIPS.
// False for stock OpenSSL 3 unless the legacy provider is available.
return loadCipher(cipherDES, cipherModeECB) != nil
return (versionAtOrAbove(1, 1, 0) || !FIPS()) && loadCipher(cipherDES, cipherModeECB) != nil
}

// SupportsTripleDESCipher returns true if NewTripleDESCipher is supported,
Expand Down
2 changes: 1 addition & 1 deletion ed25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func SupportsEd25519() bool {
onceSupportsEd25519.Do(func() {
switch vMajor {
case 1:
supportsEd25519 = version1_1_1_or_above()
supportsEd25519 = versionAtOrAbove(1, 1, 1)
case 3:
name := C.CString("ED25519")
defer C.free(unsafe.Pointer(name))
Expand Down
16 changes: 10 additions & 6 deletions evp.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,13 @@ func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) {
}
switch ch {
case crypto.MD4:
return C.go_openssl_EVP_md4()
if versionAtOrAbove(1, 1, 0) || !FIPS() {
return C.go_openssl_EVP_md4()
}
case crypto.MD5:
return C.go_openssl_EVP_md5()
if versionAtOrAbove(1, 1, 0) || !FIPS() {
return C.go_openssl_EVP_md5()
}
case crypto.SHA1:
return C.go_openssl_EVP_sha1()
case crypto.SHA224:
Expand All @@ -86,19 +90,19 @@ func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) {
case crypto.SHA512:
return C.go_openssl_EVP_sha512()
case crypto.SHA3_224:
if version1_1_1_or_above() {
if versionAtOrAbove(1, 1, 1) {
return C.go_openssl_EVP_sha3_224()
}
case crypto.SHA3_256:
if version1_1_1_or_above() {
if versionAtOrAbove(1, 1, 1) {
return C.go_openssl_EVP_sha3_256()
}
case crypto.SHA3_384:
if version1_1_1_or_above() {
if versionAtOrAbove(1, 1, 1) {
return C.go_openssl_EVP_sha3_384()
}
case crypto.SHA3_512:
if version1_1_1_or_above() {
if versionAtOrAbove(1, 1, 1) {
return C.go_openssl_EVP_sha3_512()
}
}
Expand Down
4 changes: 2 additions & 2 deletions goopenssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ go_openssl_fips_enabled(void* handle)
// and assign them to their corresponding function pointer
// defined in goopenssl.h.
void
go_openssl_load_functions(void* handle, int major, int minor, int patch)
go_openssl_load_functions(void* handle, unsigned int major, unsigned int minor, unsigned int patch)
{
#define DEFINEFUNC_INTERNAL(name, func) \
_g_##name = dlsym(handle, func); \
if (_g_##name == NULL) { \
fprintf(stderr, "Cannot get required symbol " #func " from libcrypto version %d.%d\n", major, minor); \
fprintf(stderr, "Cannot get required symbol " #func " from libcrypto version %u.%u\n", major, minor); \
abort(); \
}
#define DEFINEFUNC(ret, func, args, argscall) \
Expand Down
28 changes: 23 additions & 5 deletions goopenssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int go_openssl_version_major(void* handle);
int go_openssl_version_minor(void* handle);
int go_openssl_version_patch(void* handle);
int go_openssl_thread_setup(void);
void go_openssl_load_functions(void* handle, int major, int minor, int patch);
void go_openssl_load_functions(void* handle, unsigned int major, unsigned int minor, unsigned int patch);
const GO_EVP_MD_PTR go_openssl_EVP_md5_sha1_backport(void);

// Define pointers to all the used OpenSSL functions.
Expand Down Expand Up @@ -144,22 +144,40 @@ go_openssl_EVP_CIPHER_CTX_open_wrapper(const GO_EVP_CIPHER_CTX_PTR ctx,
const unsigned char *aad, int aad_len,
const unsigned char *tag)
{
if (in_len == 0) in = (const unsigned char *)"";
if (in_len == 0) {
in = (const unsigned char *)"";
// OpenSSL 1.0.2 in FIPS mode contains a bug: it will fail to verify
// unless EVP_DecryptUpdate is called at least once with a non-NULL
// output buffer. OpenSSL will not dereference the output buffer when
// the input length is zero, so set it to an arbitrary non-NULL pointer
// to satisfy OpenSSL when the caller only has authenticated additional
// data (AAD) to verify. While a stack-allocated buffer could be used,
// that would risk a stack-corrupting buffer overflow if OpenSSL
// unexpectedly dereferenced it. Instead pass a value which would
// segfault if dereferenced on any modern platform where a NULL-pointer
// dereference would also segfault.
if (out == NULL) out = (unsigned char *)1;
}
if (aad_len == 0) aad = (const unsigned char *)"";

if (go_openssl_EVP_DecryptInit_ex(ctx, NULL, NULL, NULL, nonce) != 1)
return 0;

// OpenSSL 1.0.x FIPS Object Module 2.0 versions below 2.0.5 require that
// the tag be set before the ciphertext, otherwise EVP_DecryptUpdate returns
// an error. At least one extant commercially-supported, FIPS validated
// build of OpenSSL 1.0.2 uses FIPS module version 2.0.1. Set the tag first
// to maximize compatibility with all OpenSSL version combinations.
if (go_openssl_EVP_CIPHER_CTX_ctrl(ctx, GO_EVP_CTRL_GCM_SET_TAG, 16, (unsigned char *)(tag)) != 1)
return 0;

int discard_len, out_len;
if (go_openssl_EVP_DecryptUpdate(ctx, NULL, &discard_len, aad, aad_len) != 1
|| go_openssl_EVP_DecryptUpdate(ctx, out, &out_len, in, in_len) != 1)
{
return 0;
}

if (go_openssl_EVP_CIPHER_CTX_ctrl(ctx, GO_EVP_CTRL_GCM_SET_TAG, 16, (unsigned char *)(tag)) != 1)
return 0;

if (go_openssl_EVP_DecryptFinal_ex(ctx, out + out_len, &discard_len) != 1)
return 0;

Expand Down
2 changes: 1 addition & 1 deletion hkdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func SupportsHKDF() bool {
return version1_1_1_or_above()
return versionAtOrAbove(1, 1, 1)
}

func newHKDF(h func() hash.Hash, mode C.int) (*hkdf, error) {
Expand Down
13 changes: 7 additions & 6 deletions init.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
// as reported by the OpenSSL API.
//
// See Init() for details about file.
func opensslInit(file string) (major, minor, patch int, err error) {
func opensslInit(file string) (major, minor, patch uint, err error) {
// Load the OpenSSL shared library using dlopen.
handle, err := dlopen(file)
if err != nil {
Expand All @@ -24,12 +24,13 @@ func opensslInit(file string) (major, minor, patch int, err error) {
// Notice that major and minor could not match with the version parameter
// in case the name of the shared library file differs from the OpenSSL
// version it contains.
major = int(C.go_openssl_version_major(handle))
minor = int(C.go_openssl_version_minor(handle))
patch = int(C.go_openssl_version_patch(handle))
if major == -1 || minor == -1 || patch == -1 {
imajor := int(C.go_openssl_version_major(handle))
iminor := int(C.go_openssl_version_minor(handle))
ipatch := int(C.go_openssl_version_patch(handle))
if imajor < 0 || iminor < 0 || ipatch < 0 {
return 0, 0, 0, errors.New("openssl: can't retrieve OpenSSL version")
}
major, minor, patch = uint(imajor), uint(iminor), uint(ipatch)
var supported bool
if major == 1 {
supported = minor == 0 || minor == 1
Expand All @@ -43,7 +44,7 @@ func opensslInit(file string) (major, minor, patch int, err error) {

// Load the OpenSSL functions.
// See shims.go for the complete list of supported functions.
C.go_openssl_load_functions(handle, C.int(major), C.int(minor), C.int(patch))
C.go_openssl_load_functions(handle, C.uint(major), C.uint(minor), C.uint(patch))

// Initialize OpenSSL.
C.go_openssl_OPENSSL_init()
Expand Down
Loading