From 1d1decc3f5fa93c6738addb0616edc39210d7676 Mon Sep 17 00:00:00 2001 From: Davis Goodin Date: Thu, 21 Dec 2023 17:32:37 -0600 Subject: [PATCH] Add NewGCMTLS13 --- aes.go | 12 +++++- aes_test.go | 115 ++++++++++++++++++++++++++++++++-------------------- cipher.go | 53 ++++++++++++++++++++---- 3 files changed, 127 insertions(+), 53 deletions(-) diff --git a/aes.go b/aes.go index 1fc11f00..231b75e2 100644 --- a/aes.go +++ b/aes.go @@ -47,6 +47,12 @@ func NewGCMTLS(c cipher.Block) (cipher.AEAD, error) { return c.(*aesCipher).NewGCMTLS() } +// NewGCMTLS13 returns a GCM cipher specific to TLS 1.3 and should not be used +// for non-TLS purposes. +func NewGCMTLS13(c cipher.Block) (cipher.AEAD, error) { + return c.(*aesCipher).NewGCMTLS13() +} + type aesCipher struct { *evpCipher } @@ -86,5 +92,9 @@ func (c *aesCipher) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) { } func (c *aesCipher) NewGCMTLS() (cipher.AEAD, error) { - return c.newGCM(true) + return c.newGCM(cipherGCMTLS12) +} + +func (c *aesCipher) NewGCMTLS13() (cipher.AEAD, error) { + return c.newGCM(cipherGCMTLS13) } diff --git a/aes_test.go b/aes_test.go index 3125d61f..142ac878 100644 --- a/aes_test.go +++ b/aes_test.go @@ -153,51 +153,76 @@ func TestSealAndOpen_Empty(t *testing.T) { func TestSealAndOpenTLS(t *testing.T) { key := []byte("D249BF6DEC97B1EBD69BC4D6B3A3C49D") - ci, err := openssl.NewAESCipher(key) - if err != nil { - t.Fatal(err) - } - gcm, err := openssl.NewGCMTLS(ci) - if err != nil { - t.Fatal(err) - } - nonce := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - nonce1 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} - nonce9 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9} - nonce10 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10} - nonceMax := [12]byte{0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255} - plainText := []byte{0x01, 0x02, 0x03} - additionalData := make([]byte, 13) - additionalData[11] = byte(len(plainText) >> 8) - additionalData[12] = byte(len(plainText)) - sealed := gcm.Seal(nil, nonce[:], plainText, additionalData) - assertPanic(t, func() { - gcm.Seal(nil, nonce[:], plainText, additionalData) - }) - sealed1 := gcm.Seal(nil, nonce1[:], plainText, additionalData) - gcm.Seal(nil, nonce10[:], plainText, additionalData) - assertPanic(t, func() { - gcm.Seal(nil, nonce9[:], plainText, additionalData) - }) - assertPanic(t, func() { - gcm.Seal(nil, nonceMax[:], plainText, additionalData) - }) - if bytes.Equal(sealed, sealed1) { - t.Errorf("different nonces should produce different outputs\ngot: %#v\nexp: %#v", sealed, sealed1) - } - decrypted, err := gcm.Open(nil, nonce[:], sealed, additionalData) - if err != nil { - t.Error(err) - } - decrypted1, err := gcm.Open(nil, nonce1[:], sealed1, additionalData) - if err != nil { - t.Error(err) - } - if !bytes.Equal(decrypted, plainText) { - t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, plainText) - } - if !bytes.Equal(decrypted, decrypted1) { - t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, decrypted1) + tests := []struct { + name string + new func(c cipher.Block) (cipher.AEAD, error) + mask func(n *[12]byte) + }{ + {"1.2", openssl.NewGCMTLS, nil}, + {"1.3", openssl.NewGCMTLS13, nil}, + {"1.3_masked", openssl.NewGCMTLS13, func(n *[12]byte) { + // Arbitrary mask in the high bits. + n[9] ^= 0x42 + // Mask the very first bit. This makes sure that if Seal doesn't + // handle the mask, the counter appears to go backwards and panics + // when it shouldn't. + n[11] ^= 0x1 + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ci, err := openssl.NewAESCipher(key) + if err != nil { + t.Fatal(err) + } + gcm, err := tt.new(ci) + if err != nil { + t.Fatal(err) + } + nonce := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + nonce1 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + nonce9 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9} + nonce10 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10} + nonceMax := [12]byte{0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255} + if tt.mask != nil { + for _, m := range []*[12]byte{&nonce, &nonce1, &nonce9, &nonce10, &nonceMax} { + tt.mask(m) + } + } + plainText := []byte{0x01, 0x02, 0x03} + additionalData := make([]byte, 13) + additionalData[11] = byte(len(plainText) >> 8) + additionalData[12] = byte(len(plainText)) + sealed := gcm.Seal(nil, nonce[:], plainText, additionalData) + assertPanic(t, func() { + gcm.Seal(nil, nonce[:], plainText, additionalData) + }) + sealed1 := gcm.Seal(nil, nonce1[:], plainText, additionalData) + gcm.Seal(nil, nonce10[:], plainText, additionalData) + assertPanic(t, func() { + gcm.Seal(nil, nonce9[:], plainText, additionalData) + }) + assertPanic(t, func() { + gcm.Seal(nil, nonceMax[:], plainText, additionalData) + }) + if bytes.Equal(sealed, sealed1) { + t.Errorf("different nonces should produce different outputs\ngot: %#v\nexp: %#v", sealed, sealed1) + } + decrypted, err := gcm.Open(nil, nonce[:], sealed, additionalData) + if err != nil { + t.Error(err) + } + decrypted1, err := gcm.Open(nil, nonce1[:], sealed1, additionalData) + if err != nil { + t.Error(err) + } + if !bytes.Equal(decrypted, plainText) { + t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, plainText) + } + if !bytes.Equal(decrypted, decrypted1) { + t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, decrypted1) + } + }) } } diff --git a/cipher.go b/cipher.go index b56de6a7..1113d067 100644 --- a/cipher.go +++ b/cipher.go @@ -312,11 +312,26 @@ func (c *cipherCTR) finalize() { C.go_openssl_EVP_CIPHER_CTX_free(c.ctx) } +type cipherGCMTLS uint8 + +const ( + cipherGCMTLSNone cipherGCMTLS = iota + cipherGCMTLS12 + cipherGCMTLS13 +) + type cipherGCM struct { - ctx C.GO_EVP_CIPHER_CTX_PTR - tls bool + ctx C.GO_EVP_CIPHER_CTX_PTR + tls cipherGCMTLS + // minNextNonce is the minimum value that the next nonce can be, enforced by + // all TLS modes. minNextNonce uint64 - blockSize int + // mask is the nonce mask used in TLS 1.3 mode. + mask uint64 + // maskInitialized is true if mask has been initialized. This happens during + // the first Seal. The initialized mask may be 0. Used by TLS 1.3 mode. + maskInitialized bool + blockSize int } const ( @@ -353,10 +368,10 @@ func (c *evpCipher) newGCMChecked(nonceSize, tagSize int) (cipher.AEAD, error) { if tagSize != gcmTagSize { return cipher.NewGCMWithTagSize(&noGCM{c}, tagSize) } - return c.newGCM(false) + return c.newGCM(cipherGCMTLSNone) } -func (c *evpCipher) newGCM(tls bool) (cipher.AEAD, error) { +func (c *evpCipher) newGCM(tls cipherGCMTLS) (cipher.AEAD, error) { ctx, err := newCipherCtx(c.kind, cipherModeGCM, cipherOpNone, c.key, nil) if err != nil { return nil, err @@ -388,15 +403,39 @@ func (g *cipherGCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte { if len(dst)+len(plaintext)+gcmTagSize < len(dst) { panic("cipher: message too large for buffer") } - if g.tls { + if g.tls != cipherGCMTLSNone { if len(additionalData) != gcmTlsAddSize { panic("cipher: incorrect additional data length given to GCM TLS") } + counter := binary.BigEndian.Uint64(nonce[gcmTlsFixedNonceSize:]) + if g.tls == cipherGCMTLS13 { + // In TLS 1.3, the counter in the nonce has a mask and requires + // further decoding. + if !g.maskInitialized { + // According to TLS 1.3 nonce construction details at + // https://tools.ietf.org/html/rfc8446#section-5.3: + // + // the first record transmitted under a particular traffic + // key MUST use sequence number 0. + // + // The padded sequence number is XORed with [a mask]. + // + // The resulting quantity (of length iv_length) is used as + // the per-record nonce. + // + // We need to convert from the given nonce to sequence numbers + // to keep track of minNextNonce and enforce the counter + // maximum. On the first call, we know counter^mask is 0^mask, + // so we can simply store it as the mask. + g.mask = counter + g.maskInitialized = true + } + counter ^= g.mask + } // BoringCrypto enforces strictly monotonically increasing explicit nonces // and to fail after 2^64 - 1 keys as per FIPS 140-2 IG A.5, // but OpenSSL does not perform this check, so it is implemented here. const maxUint64 = 1<<64 - 1 - counter := binary.BigEndian.Uint64(nonce[gcmTlsFixedNonceSize:]) if counter == maxUint64 { panic("cipher: nonce counter must be less than 2^64 - 1") }