Skip to content

Commit

Permalink
multi: Change all ocurrences to use NewMessage
Browse files Browse the repository at this point in the history
This changes all ocurrences of message instantiation to use
NewMessage().

This unifies all code for message init under a single code path.

In the future, it may be possible to make all message fields unexported
in order to better enforce message invariants.
  • Loading branch information
matheusd committed Sep 12, 2024
1 parent 25dbad9 commit dcb3395
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 175 deletions.
15 changes: 13 additions & 2 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ func (d *Decoder) Decode() (*Message, error) {
if err != nil {
return nil, exc.WrapError("decode", err)
}

// Special case an empty message to return a new MultiSegment message
// ready for writing. This maintains compatibility to tests and older
// implementation of message and arenas.
if hdr.maxSegment() == 0 && total == 0 {
msg, _ := NewMultiSegmentMessage(nil)
return msg, nil
}

// TODO(someday): if total size is greater than can fit in one buffer,
// attempt to allocate buffer per segment.
if total > maxSize-uint64(len(hdr)) || total > uint64(maxInt) {
Expand All @@ -77,7 +86,8 @@ func (d *Decoder) Decode() (*Message, error) {
return nil, exc.WrapError("decode", err)
}

return &Message{Arena: arena}, nil
msg, _, err := NewMessage(arena)
return msg, err
}

func (d *Decoder) readHeader(maxSize uint64) (streamHeader, error) {
Expand Down Expand Up @@ -167,7 +177,8 @@ func Unmarshal(data []byte) (*Message, error) {
return nil, exc.WrapError("unmarshal", err)
}

return &Message{Arena: arena}, nil
msg, _, err := NewMessage(arena)
return msg, err
}

// UnmarshalPacked reads a packed serialized stream into a message.
Expand Down
25 changes: 13 additions & 12 deletions codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ func TestEncoder(t *testing.T) {
t.Parallel()

for i, test := range serializeTests {
if test.decodeFails {
if test.decodeFails || test.newMessageFails {
continue
}
msg := &Message{Arena: test.arena()}
msg, _, err := NewMessage(test.arena())
require.NoError(t, err)
var buf bytes.Buffer
enc := NewEncoder(&buf)
err := enc.Encode(msg)
err = enc.Encode(msg)
out := buf.Bytes()
if err != nil {
if !test.encodeFails {
Expand Down Expand Up @@ -198,26 +199,26 @@ func TestDecoder_MaxMessageSize(t *testing.T) {
func TestStreamHeaderPadding(t *testing.T) {
t.Parallel()

msg := &Message{
Arena: MultiSegment([][]byte{
msg, _, err := NewMessage(
MultiSegment([][]byte{
incrementingData(8),
incrementingData(8),
incrementingData(8),
}),
}
}))
require.NoError(t, err)
var buf bytes.Buffer
enc := NewEncoder(&buf)
err := enc.Encode(msg)
err = enc.Encode(msg)
buf.Reset()
if err != nil {
t.Fatalf("Encode error: %v", err)
}
msg = &Message{
Arena: MultiSegment([][]byte{
msg, _, err = NewMessage(
MultiSegment([][]byte{
incrementingData(8),
incrementingData(8),
}),
}
}))
require.NoError(t, err)
err = enc.Encode(msg)
out := buf.Bytes()
if err != nil {
Expand Down
47 changes: 20 additions & 27 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1832,13 +1832,11 @@ func BenchmarkDecode(b *testing.B) {
func TestPointerTraverseDefense(t *testing.T) {
t.Parallel()
const limit = 128
msg := &capnp.Message{
Arena: capnp.SingleSegment([]byte{
0, 0, 0, 0, 1, 0, 0, 0, // root 1-word struct pointer to next word
0, 0, 0, 0, 0, 0, 0, 0, // struct's data
}),
TraverseLimit: limit * 8,
}
msg, _ := capnp.NewSingleSegmentMessage([]byte{
0, 0, 0, 0, 1, 0, 0, 0, // root 1-word struct pointer to next word
0, 0, 0, 0, 0, 0, 0, 0, // struct's data
})
msg.TraverseLimit = limit * 8

for i := 0; i < limit; i++ {
_, err := msg.Root()
Expand All @@ -1855,13 +1853,11 @@ func TestPointerTraverseDefense(t *testing.T) {
func TestPointerDepthDefense(t *testing.T) {
t.Parallel()
const limit = 64
msg := &capnp.Message{
Arena: capnp.SingleSegment([]byte{
0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word
0xfc, 0xff, 0xff, 0xff, 0, 0, 1, 0, // root struct pointer that points back to itself
}),
DepthLimit: limit,
}
msg, _ := capnp.NewSingleSegmentMessage([]byte{
0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word
0xfc, 0xff, 0xff, 0xff, 0, 0, 1, 0, // root struct pointer that points back to itself
})
msg.DepthLimit = limit
root, err := msg.Root()
if err != nil {
t.Fatal("Root:", err)
Expand Down Expand Up @@ -1894,14 +1890,12 @@ func TestPointerDepthDefense(t *testing.T) {
func TestPointerDepthDefenseAcrossStructsAndLists(t *testing.T) {
t.Parallel()
const limit = 63
msg := &capnp.Message{
Arena: capnp.SingleSegment([]byte{
0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word
0x01, 0, 0, 0, 0x0e, 0, 0, 0, // list pointer to 1-element list of pointer (next word)
0xf8, 0xff, 0xff, 0xff, 0, 0, 1, 0, // struct pointer to previous word
}),
DepthLimit: limit,
}
msg, _ := capnp.NewSingleSegmentMessage([]byte{
0, 0, 0, 0, 0, 0, 1, 0, // root 1-pointer struct pointer to next word
0x01, 0, 0, 0, 0x0e, 0, 0, 0, // list pointer to 1-element list of pointer (next word)
0xf8, 0xff, 0xff, 0xff, 0, 0, 1, 0, // struct pointer to previous word
})
msg.DepthLimit = limit

toStruct := func(p capnp.Ptr, err error) (capnp.Struct, error) {
if err != nil {
Expand Down Expand Up @@ -2083,11 +2077,10 @@ func TestSetEmptyTextWithDefault(t *testing.T) {

func TestFuzzedListOutOfBounds(t *testing.T) {
t.Parallel()
msg := &capnp.Message{
Arena: capnp.SingleSegment([]byte(
"\x00\x00\x00\x00\x03\x00\x01\x00\x0f\x000000000000" +
"000000000000\x01\x00\x00\x00\x13\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00")),
}
msg, _ := capnp.NewSingleSegmentMessage([]byte(
"\x00\x00\x00\x00\x03\x00\x01\x00\x0f\x000000000000" +
"000000000000\x01\x00\x00\x00\x13\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00"))

z, err := air.ReadRootZ(msg)
if err != nil {
t.Fatal("ReadRootZ:", err)
Expand Down
2 changes: 1 addition & 1 deletion internal/fuzztest/fuzztest.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func Fuzz(data []byte) int {
data = append(data, 0)
}
}
msg := &capnp.Message{Arena: capnp.SingleSegment(data)}
msg, _ := capnp.NewSingleSegmentMessage(data)
z, err := air.ReadRootZ(msg)
if err != nil {
return 0
Expand Down
16 changes: 4 additions & 12 deletions list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@ import (
)

func TestToListDefault(t *testing.T) {
msg := &Message{Arena: SingleSegment([]byte{
_, seg := NewSingleSegmentMessage([]byte{
0, 0, 0, 0, 0, 0, 0, 0,
42, 0, 0, 0, 0, 0, 0, 0,
})}
seg, err := msg.Segment(0)
if err != nil {
t.Fatal(err)
}
})
tests := []struct {
ptr Ptr
def []byte
Expand Down Expand Up @@ -56,15 +52,11 @@ func TestToListDefault(t *testing.T) {
}

func TestTextListBytesAt(t *testing.T) {
msg := &Message{Arena: SingleSegment([]byte{
_, seg := NewSingleSegmentMessage([]byte{
0, 0, 0, 0, 0, 0, 0, 0,
0x01, 0, 0, 0, 0x22, 0, 0, 0,
'f', 'o', 'o', 0, 0, 0, 0, 0,
})}
seg, err := msg.Segment(0)
if err != nil {
t.Fatal(err)
}
})
list := TextList{
seg: seg,
off: 8,
Expand Down
40 changes: 17 additions & 23 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,11 @@ func TestAlloc(t *testing.T) {
})
}
{
msg := &Message{Arena: MultiSegment([][]byte{
_, seg := NewMultiSegmentMessage([][]byte{
incrementingData(24)[:8:8],
incrementingData(24)[:8],
incrementingData(24)[:8],
})}
seg, err := msg.Segment(1)
if err != nil {
t.Fatal(err)
}
})
tests = append(tests, allocTest{
name: "prefers given segment",
seg: seg,
Expand All @@ -109,14 +105,10 @@ func TestAlloc(t *testing.T) {
})
}
{
msg := &Message{Arena: MultiSegment([][]byte{
_, seg := NewMultiSegmentMessage([][]byte{
incrementingData(24)[:8],
incrementingData(24),
})}
seg, err := msg.Segment(1)
if err != nil {
t.Fatal(err)
}
})
tests = append(tests, allocTest{
name: "given segment full with another available",
seg: seg,
Expand All @@ -126,14 +118,10 @@ func TestAlloc(t *testing.T) {
})
}
{
msg := &Message{Arena: MultiSegment([][]byte{
msg, seg := NewMultiSegmentMessage([][]byte{
incrementingData(24),
incrementingData(24),
})}
seg, err := msg.Segment(1)
if err != nil {
t.Fatal(err)
}
})

// Make arena not read-only again.
msg.Arena.(*MultiSegmentArena).bp = &bufferpool.Default
Expand Down Expand Up @@ -308,7 +296,14 @@ func TestMarshal(t *testing.T) {
if test.decodeFails {
continue
}
msg := &Message{Arena: test.arena()}
msg, _, err := NewMessage(test.arena())
if err != nil != test.newMessageFails {
t.Errorf("serializeTests[%d] %s: NewMessage unexpected error: %v", i, test.name, err)
continue
}
if err != nil {
continue
}
out, err := msg.Marshal()
if err != nil {
if !test.encodeFails {
Expand Down Expand Up @@ -373,7 +368,8 @@ func TestWriteTo(t *testing.T) {
continue
}

msg := &Message{Arena: test.arena()}
msg, _, err := NewMessage(test.arena())
require.NoError(t, err)
n, err := msg.WriteTo(&buf)
if test.encodeFails {
require.Error(t, err, test.name)
Expand Down Expand Up @@ -566,9 +562,7 @@ func TestTotalSize(t *testing.T) {
}
}

msg := &Message{
Arena: MultiSegment(segs),
}
msg, _ := NewMultiSegmentMessage(segs)

size, err := msg.TotalSize()
assert.Nil(t, err, "TotalSize() returned an error")
Expand Down
Loading

0 comments on commit dcb3395

Please sign in to comment.