Skip to content

Commit

Permalink
crypto/goolm: use golang.org/x/crypto/hkdf and remove crypto.HKDFSHA256
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
  • Loading branch information
sumnerevans committed Oct 25, 2024
1 parent 165c1ed commit d5e0ed4
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 31 deletions.
7 changes: 5 additions & 2 deletions crypto/goolm/cipher/aes_sha256.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package cipher

import (
"bytes"
"crypto/sha256"
"io"

"golang.org/x/crypto/hkdf"

"maunium.net/go/mautrix/crypto/aescbc"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
Expand All @@ -17,9 +20,9 @@ type derivedAESKeys struct {

// deriveAESKeys derives three keys for the AESSHA256 cipher
func deriveAESKeys(kdfInfo []byte, key []byte) (derivedAESKeys, error) {
hkdf := crypto.HKDFSHA256(key, nil, kdfInfo)
kdf := hkdf.New(sha256.New, key, nil, kdfInfo)
keymatter := make([]byte, 80)
_, err := io.ReadFull(hkdf, keymatter)
_, err := io.ReadFull(kdf, keymatter)
return derivedAESKeys{
key: keymatter[:32],
hmacKey: keymatter[32:64],
Expand Down
9 changes: 0 additions & 9 deletions crypto/goolm/crypto/hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ package crypto
import (
"crypto/hmac"
"crypto/sha256"
"io"

"golang.org/x/crypto/hkdf"
)

// HMACSHA256 returns the hash message authentication code with SHA-256 of the input with the key.
Expand All @@ -21,9 +18,3 @@ func SHA256(value []byte) []byte {
hash.Write(value)
return hash.Sum(nil)
}

// HKDFSHA256 is the key deivation function based on HMAC and returns a reader based on input. salt and info can both be nil.
// The reader can be used to read an arbitary length of bytes which are based on all parameters.
func HKDFSHA256(input, salt, info []byte) io.Reader {
return hkdf.New(sha256.New, input, salt, info)
}
20 changes: 3 additions & 17 deletions crypto/goolm/crypto/hmac_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package crypto_test

import (
"crypto/sha256"
"encoding/base64"
"io"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/crypto/hkdf"

"maunium.net/go/mautrix/crypto/goolm/crypto"
)
Expand All @@ -22,22 +24,6 @@ func TestHMACSHA256(t *testing.T) {
assert.Equal(t, result, hash)
}

func TestHKDFSHA256(t *testing.T) {
message := []byte("test content")

hkdf := crypto.HKDFSHA256(message, nil, nil)
result := make([]byte, 32)
_, err := io.ReadFull(hkdf, result)
assert.NoError(t, err)

hkdf2 := crypto.HKDFSHA256(message, nil, nil)
result2 := make([]byte, 32)
_, err = io.ReadFull(hkdf2, result2)
assert.NoError(t, err)

assert.Equal(t, result, result2)
}

func TestSHA256Case1(t *testing.T) {
input := make([]byte, 0)
expected := []byte{
Expand Down Expand Up @@ -93,7 +79,7 @@ func TestHDKFCase1(t *testing.T) {
0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18,
0x58, 0x65,
}
resultReader := crypto.HKDFSHA256(input, salt, info)
resultReader := hkdf.New(sha256.New, input, salt, info)
result = make([]byte, len(expectedHDKF))
_, err := io.ReadFull(resultReader, result)
assert.NoError(t, err)
Expand Down
9 changes: 6 additions & 3 deletions crypto/goolm/ratchet/olm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
package ratchet

import (
"crypto/sha256"
"fmt"
"io"

"golang.org/x/crypto/hkdf"

"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
Expand Down Expand Up @@ -70,7 +73,7 @@ func New() *Ratchet {

// InitializeAsBob initializes this ratchet from a receiving point of view (only first message).
func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error {
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root)
derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return err
Expand All @@ -83,7 +86,7 @@ func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Cu

// InitializeAsAlice initializes this ratchet from a sending point of view (only first message).
func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error {
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root)
derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return err
Expand Down Expand Up @@ -192,7 +195,7 @@ func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatc
if err != nil {
return nil, err
}
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, r.RootKey, KdfInfo.Ratchet)
derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, r.RootKey, KdfInfo.Ratchet)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return nil, err
Expand Down

0 comments on commit d5e0ed4

Please sign in to comment.