Skip to content

Commit

Permalink
crypto/goolm/message: improve encode/decode to use buffers
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
  • Loading branch information
sumnerevans committed Oct 27, 2024
1 parent 3e4b1a3 commit 0b814b1
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 215 deletions.
79 changes: 29 additions & 50 deletions crypto/goolm/message/decoder.go
Original file line number Diff line number Diff line change
@@ -1,70 +1,49 @@
package message

import (
"bytes"
"encoding/binary"

"maunium.net/go/mautrix/crypto/olm"
)

// checkDecodeErr checks if there was an error during decode.
func checkDecodeErr(readBytes int) error {
if readBytes == 0 {
//end reached
return olm.ErrInputToSmall
}
if readBytes < 0 {
return olm.ErrOverflow
}
return nil
type Decoder struct {
*bytes.Buffer
}

// decodeVarInt decodes a single big-endian encoded varint.
func decodeVarInt(input []byte) (uint32, int) {
value, readBytes := binary.Uvarint(input)
return uint32(value), readBytes
func NewDecoder(buf []byte) *Decoder {
return &Decoder{bytes.NewBuffer(buf)}
}

// decodeVarString decodes the length of the string (varint) and returns the actual string
func decodeVarString(input []byte) ([]byte, int) {
stringLen, readBytes := decodeVarInt(input)
if readBytes <= 0 {
return nil, readBytes
}
input = input[readBytes:]
value := input[:stringLen]
readBytes += int(stringLen)
return value, readBytes
func (d *Decoder) ReadVarInt() (uint64, error) {
return binary.ReadUvarint(d)
}

// encodeVarIntByteLength returns the number of bytes needed to encode the uint32.
func encodeVarIntByteLength(input uint32) int {
result := 1
for input >= 128 {
result++
input >>= 7
func (d *Decoder) ReadVarBytes() ([]byte, error) {
if n, err := d.ReadVarInt(); err != nil {
return nil, err
} else {
out := make([]byte, n)
_, err = d.Read(out)
return out, err
}
return result
}

// encodeVarStringByteLength returns the number of bytes needed to encode the input.
func encodeVarStringByteLength(input []byte) int {
result := encodeVarIntByteLength(uint32(len(input)))
result += len(input)
return result
type Encoder struct {
buf []byte
}

func (e *Encoder) Bytes() []byte {
return e.buf
}

func (e *Encoder) PutByte(val byte) {
e.buf = append(e.buf, val)
}

// encodeVarInt encodes a single uint32
func encodeVarInt(input uint32) []byte {
out := make([]byte, encodeVarIntByteLength(input))
binary.PutUvarint(out, uint64(input))
return out
func (e *Encoder) PutVarInt(val uint64) {
e.buf = binary.AppendUvarint(e.buf, val)
}

// encodeVarString encodes the length of the input (varint) and appends the actual input
func encodeVarString(input []byte) []byte {
out := make([]byte, encodeVarStringByteLength(input))
length := encodeVarInt(uint32(len(input)))
copy(out, length)
copy(out[len(length):], input)
return out
func (e *Encoder) PutVarBytes(data []byte) {
e.PutVarInt(uint64(len(data)))
e.buf = append(e.buf, data...)
}
34 changes: 9 additions & 25 deletions crypto/goolm/message/decoder_test.go
Original file line number Diff line number Diff line change
@@ -1,32 +1,12 @@
package message
package message_test

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestEncodeLengthInt(t *testing.T) {
numbers := []uint32{127, 128, 16383, 16384, 32767}
expected := []int{1, 2, 2, 3, 3}
for curIndex := range numbers {
assert.Equal(t, expected[curIndex], encodeVarIntByteLength(numbers[curIndex]))
}
}

func TestEncodeLengthString(t *testing.T) {
var strings [][]byte
var expected []int
strings = append(strings, []byte("test"))
expected = append(expected, 1+4)
strings = append(strings, []byte("this is a long message with a length of 127 so that the varint of the length is just one byte. just needs some padding---------"))
expected = append(expected, 1+127)
strings = append(strings, []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------"))
expected = append(expected, 2+155)
for curIndex := range strings {
assert.Equal(t, expected[curIndex], encodeVarStringByteLength(strings[curIndex]))
}
}
"maunium.net/go/mautrix/crypto/goolm/message"
)

func TestEncodeInt(t *testing.T) {
var ints []uint32
Expand All @@ -40,7 +20,9 @@ func TestEncodeInt(t *testing.T) {
ints = append(ints, 16383)
expected = append(expected, []byte{0b11111111, 0b01111111})
for curIndex := range ints {
assert.Equal(t, expected[curIndex], encodeVarInt(ints[curIndex]))
var encoder message.Encoder
encoder.PutVarInt(uint64(ints[curIndex]))
assert.Equal(t, expected[curIndex], encoder.Bytes())
}
}

Expand Down Expand Up @@ -70,6 +52,8 @@ func TestEncodeString(t *testing.T) {
res = append(res, curTest...) //Add string itself
expected = append(expected, res)
for curIndex := range strings {
assert.Equal(t, expected[curIndex], encodeVarString(strings[curIndex]))
var encoder message.Encoder
encoder.PutVarBytes(strings[curIndex])
assert.Equal(t, expected[curIndex], encoder.Bytes())
}
}
68 changes: 32 additions & 36 deletions crypto/goolm/message/group_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package message

import (
"bytes"
"io"

"maunium.net/go/mautrix/crypto/aessha2"
"maunium.net/go/mautrix/crypto/goolm/crypto"
Expand All @@ -22,66 +23,61 @@ type GroupMessage struct {
}

// Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present.
func (r *GroupMessage) Decode(input []byte) error {
func (r *GroupMessage) Decode(input []byte) (err error) {
r.Version = 0
r.MessageIndex = 0
r.Ciphertext = nil
if len(input) == 0 {
return nil
}
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {

decoder := NewDecoder(input[:len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize])
r.Version, err = decoder.ReadByte() // First byte is the version
if err != nil {
return
}

for {
// Read Key
if curKey, err := decoder.ReadVarInt(); err != nil {
if err == io.EOF {
// No more keys to read
return nil
}
return err
}
curPos += readBytes
if (curKey & 0b111) == 0 {
//The value is of type varint
value, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
} else if (curKey & 0b111) == 0 {
// The value is of type varint
if value, err := decoder.ReadVarInt(); err != nil {
return err
}
curPos += readBytes
switch curKey {
case messageIndexTag:
r.MessageIndex = value
} else if curKey == messageIndexTag {
r.MessageIndex = uint32(value)
r.HasMessageIndex = true
}
} else if (curKey & 0b111) == 2 {
//The value is of type string
value, readBytes := decodeVarString(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
// The value is of type string
if value, err := decoder.ReadVarBytes(); err != nil {
return err
}
curPos += readBytes
switch curKey {
case cipherTextTag:
} else if curKey == cipherTextTag {
r.Ciphertext = value
}
}
}

return nil
}

// EncodeAndMACAndSign encodes the message, creates the mac with the key and the cipher and signs the message.
// If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended.
func (r *GroupMessage) EncodeAndMACAndSign(cipher aessha2.AESSHA2, signKey crypto.Ed25519KeyPair) ([]byte, error) {
var buf bytes.Buffer
buf.WriteByte(r.Version)
buf.Write(encodeVarInt(messageIndexTag))
buf.Write(encodeVarInt(r.MessageIndex))
buf.Write(encodeVarInt(cipherTextTag))
buf.Write(encodeVarString(r.Ciphertext))
mac, err := r.MAC(cipher, buf.Bytes())
var encoder Encoder
encoder.PutByte(r.Version)
encoder.PutVarInt(messageIndexTag)
encoder.PutVarInt(uint64(r.MessageIndex))
encoder.PutVarInt(cipherTextTag)
encoder.PutVarBytes(r.Ciphertext)
mac, err := r.MAC(cipher, encoder.Bytes())
if err != nil {
return nil, err
}
ciphertextWithMAC := append(buf.Bytes(), mac[:countMACBytesGroupMessage]...)
ciphertextWithMAC := append(encoder.Bytes(), mac[:countMACBytesGroupMessage]...)
signature, err := signKey.Sign(ciphertextWithMAC)
return append(ciphertextWithMAC, signature...), err
}
Expand Down
74 changes: 35 additions & 39 deletions crypto/goolm/message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package message

import (
"bytes"
"io"

"maunium.net/go/mautrix/crypto/aessha2"
"maunium.net/go/mautrix/crypto/goolm/crypto"
Expand All @@ -24,7 +25,7 @@ type Message struct {
}

// Decodes decodes the input and populates the corresponding fileds. MAC is ignored but has to be present.
func (r *Message) Decode(input []byte) error {
func (r *Message) Decode(input []byte) (err error) {
r.Version = 0
r.HasCounter = false
r.Counter = 0
Expand All @@ -33,60 +34,55 @@ func (r *Message) Decode(input []byte) error {
if len(input) == 0 {
return nil
}
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input)-countMACBytesMessage {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {

decoder := NewDecoder(input[:len(input)-countMACBytesMessage])
r.Version, err = decoder.ReadByte() // first byte is always version
if err != nil {
return
}

for {
// Read Key
if curKey, err := decoder.ReadVarInt(); err != nil {
if err == io.EOF {
// No more keys to read
return nil
}
return err
}
curPos += readBytes
if (curKey & 0b111) == 0 {
//The value is of type varint
value, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
} else if (curKey & 0b111) == 0 {
// The value is of type varint
if value, err := decoder.ReadVarInt(); err != nil {
return err
}
curPos += readBytes
switch curKey {
case counterTag:
} else if curKey == counterTag {
r.Counter = uint32(value)
r.HasCounter = true
r.Counter = value
}
} else if (curKey & 0b111) == 2 {
//The value is of type string
value, readBytes := decodeVarString(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
// The value is of type string
if value, err := decoder.ReadVarBytes(); err != nil {
return err
}
curPos += readBytes
switch curKey {
case ratchetKeyTag:
} else if curKey == ratchetKeyTag {
r.RatchetKey = value
case cipherTextKeyTag:
} else if curKey == cipherTextKeyTag {
r.Ciphertext = value
}
}
}

return nil
}

// EncodeAndMAC encodes the message and creates the MAC with the key and the cipher.
// If key or cipher is nil, no MAC is appended.
func (r *Message) EncodeAndMAC(cipher aessha2.AESSHA2) ([]byte, error) {
var buf bytes.Buffer
buf.WriteByte(r.Version)
buf.Write(encodeVarInt(ratchetKeyTag))
buf.Write(encodeVarString(r.RatchetKey))
buf.Write(encodeVarInt(counterTag))
buf.Write(encodeVarInt(r.Counter))
buf.Write(encodeVarInt(cipherTextKeyTag))
buf.Write(encodeVarString(r.Ciphertext))
mac, err := cipher.MAC(buf.Bytes())
return append(buf.Bytes(), mac[:countMACBytesMessage]...), err
var encoder Encoder
encoder.PutByte(r.Version)
encoder.PutVarInt(ratchetKeyTag)
encoder.PutVarBytes(r.RatchetKey)
encoder.PutVarInt(counterTag)
encoder.PutVarInt(uint64(r.Counter))
encoder.PutVarInt(cipherTextKeyTag)
encoder.PutVarBytes(r.Ciphertext)
mac, err := cipher.MAC(encoder.Bytes())
return append(encoder.Bytes(), mac[:countMACBytesMessage]...), err
}

// VerifyMAC verifies the givenMAC to the calculated MAC of the message.
Expand Down
Loading

0 comments on commit 0b814b1

Please sign in to comment.