Skip to content

Commit

Permalink
multi: Use NewXXXMessage(nil) where possible
Browse files Browse the repository at this point in the history
This commit changes all occurrences of
NewMessage(NewSingleSegmentArena(nil)) to NewSingleSegmentArena(nil) and
all occurrences of NewMessage(NewMultiSegmentArena(nil)) to
NewMultiSegmentArena(nil).

Also, occurrences of Message{Arena: XXX} (where XXX is either a single
or multi segment arena) are changed when possible as well.

In the future, this will allow protecting Message values from wrong
usage by enforcing the use of NewMessage to initialize message objects
(instead of use of zero valued messages).
  • Loading branch information
matheusd committed Aug 30, 2024
1 parent 46ccd63 commit 153d699
Show file tree
Hide file tree
Showing 22 changed files with 125 additions and 411 deletions.
6 changes: 3 additions & 3 deletions answer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestPromiseFulfill(t *testing.T) {
t.Run("Done", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
done := p.Answer().Done()
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
Expand All @@ -75,7 +75,7 @@ func TestPromiseFulfill(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
ans := p.Answer()
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
Expand All @@ -99,7 +99,7 @@ func TestPromiseFulfill(t *testing.T) {
h := new(dummyHook)
c := NewClient(h)
defer c.Release()
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{PointerCount: 3})
Expand Down
2 changes: 1 addition & 1 deletion canonical.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
// for equivalent structs, even as the schema evolves. The blob is
// suitable for hashing or signing.
func Canonicalize(s Struct) ([]byte, error) {
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
if !s.IsValid() {
return seg.Data(), nil
}
Expand Down
20 changes: 10 additions & 10 deletions canonical_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "empty struct",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{})
return s
},
want: []byte{0xfc, 0xff, 0xff, 0xff, 0, 0, 0, 0},
}, {
name: "zero data, zero pointer struct",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{DataSize: 8, PointerCount: 1})
return s
},
want: []byte{0xfc, 0xff, 0xff, 0xff, 0, 0, 0, 0},
}, {
name: "one word data struct",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{DataSize: 8, PointerCount: 1})
s.SetUint16(0, 0xbeef)
return s
Expand All @@ -47,7 +47,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "two pointers to zero structs",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 2})
e1, _ := NewStruct(seg, ObjectSize{DataSize: 8})
e2, _ := NewStruct(seg, ObjectSize{DataSize: 8})
Expand All @@ -63,7 +63,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "pointer to interface",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 2})
iface := NewInterface(seg, 1)
s.SetPtr(0, iface.ToPtr())
Expand All @@ -76,7 +76,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "int list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewInt8List(seg, 5)
s.SetPtr(0, l.ToPtr())
Expand All @@ -95,7 +95,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "zero int list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewInt8List(seg, 5)
s.SetPtr(0, l.ToPtr())
Expand All @@ -110,7 +110,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "struct list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewCompositeList(seg, ObjectSize{DataSize: 8, PointerCount: 1}, 2)
s.SetPtr(0, l.ToPtr())
Expand All @@ -133,7 +133,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "zero struct list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewCompositeList(seg, ObjectSize{DataSize: 16, PointerCount: 2}, 3)
s.SetPtr(0, l.ToPtr())
Expand All @@ -148,7 +148,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "zero-length struct list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewCompositeList(seg, ObjectSize{DataSize: 16, PointerCount: 2}, 0)
s.SetPtr(0, l.ToPtr())
Expand Down
32 changes: 9 additions & 23 deletions capability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,8 @@ func (dr *dummyReturner) AllocResults(sz ObjectSize) (Struct, error) {
if dr.s.IsValid() {
return Struct{}, errors.New("AllocResults called multiple times")
}
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
return Struct{}, err
}
_, seg := NewSingleSegmentMessage(nil)
var err error
dr.s, err = NewRootStruct(seg, sz)
return dr.s, err
}
Expand All @@ -377,10 +375,7 @@ func (dr *dummyReturner) ReleaseResults() {
}

func TestToInterface(t *testing.T) {
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
_, seg := NewSingleSegmentMessage(nil)
tests := []struct {
ptr Ptr
in Interface
Expand All @@ -399,10 +394,7 @@ func TestToInterface(t *testing.T) {
}

func TestInterface_value(t *testing.T) {
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
_, seg := NewSingleSegmentMessage(nil)
tests := []struct {
in Interface
val rawPointer
Expand All @@ -421,10 +413,7 @@ func TestInterface_value(t *testing.T) {
}

func TestTransform(t *testing.T) {
_, s, err := NewMessage(SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
_, s := NewSingleSegmentMessage(nil)
root, err := NewStruct(s, ObjectSize{PointerCount: 2})
if err != nil {
t.Fatal(err)
Expand All @@ -442,7 +431,7 @@ func TestTransform(t *testing.T) {
b.SetUint64(0, 2)
a.SetPtr(0, b.ToPtr())

dmsg, d, err := NewMessage(SingleSegment(nil))
dmsg, d := NewSingleSegmentMessage(nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -675,20 +664,17 @@ func deepPointerEqual(a, b Ptr) bool {
if !a.IsValid() || !b.IsValid() {
return false
}
msgA, _, _ := NewMessage(SingleSegment(nil))
msgA, _ := NewSingleSegmentMessage(nil)
msgA.SetRoot(a)
abytes, _ := msgA.Marshal()
msgB, _, _ := NewMessage(SingleSegment(nil))
msgB, _ := NewSingleSegmentMessage(nil)
msgB.SetRoot(b)
bbytes, _ := msgB.Marshal()
return bytes.Equal(abytes, bbytes)
}

func newEmptyStruct() Struct {
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
panic(err)
}
_, seg := NewSingleSegmentMessage(nil)
s, err := NewRootStruct(seg, ObjectSize{})
if err != nil {
panic(err)
Expand Down
2 changes: 1 addition & 1 deletion capnpc-go/capnpc-go.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (g *generator) defineSchemaVar() error {
}
sort.Sort(uint64Slice(ids))

msg, seg, _ := capnp.NewMessage(capnp.SingleSegment(nil))
msg, seg := capnp.NewSingleSegmentMessage(nil)
req, _ := schema.NewRootCodeGeneratorRequest(seg)
// TODO(light): find largest object size and use that to allocate list
nodes, _ := req.NewNodes(int32(len(g.nodes)))
Expand Down
7 changes: 2 additions & 5 deletions capnpc-go/fileparts.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@ func (sd *staticData) init(fileID uint64) {
}

func (sd *staticData) copyData(obj capnp.Ptr) (staticDataRef, error) {
m, _, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
return staticDataRef{}, err
}
err = m.SetRoot(obj)
m, _ := capnp.NewSingleSegmentMessage(nil)
err := m.SetRoot(obj)
if err != nil {
return staticDataRef{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion encoding/text/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (enc *Encoder) Encode(typeID uint64, s capnp.Struct) error {

// EncodeList writes the text representation of struct list l to the stream.
func (enc *Encoder) EncodeList(typeID uint64, l capnp.List) error {
_, seg, _ := capnp.NewMessage(capnp.SingleSegment(nil))
_, seg := capnp.NewSingleSegmentMessage(nil)
typ, _ := schema.NewRootType(seg)
typ.SetStructType()
typ.StructType().SetTypeId(typeID)
Expand Down
5 changes: 1 addition & 4 deletions example/books/ex1/books1.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ import (

func main() {
// Make a brand new empty message. A Message allocates Cap'n Proto structs.
msg, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
panic(err)
}
msg, seg := capnp.NewSingleSegmentMessage(nil)

// Create a new Book struct. Every message must have a root struct.
book, err := books.NewRootBook(seg)
Expand Down
11 changes: 2 additions & 9 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ import (

func Example() {
// Make a brand new empty message.
msg, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
panic(err)
}
msg, seg := capnp.NewSingleSegmentMessage(nil)

// If you want runtime-type identification, this is easily obtained. Just
// wrap everything in a struct that contains a single anoymous union (e.g. struct Z).
Expand Down Expand Up @@ -87,11 +84,7 @@ func Example() {
}

func ExampleUnmarshal() {
msg, s, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
fmt.Printf("allocation error %v\n", err)
return
}
msg, s := capnp.NewSingleSegmentMessage(nil)
d, err := air.NewRootZdate(s)
if err != nil {
fmt.Printf("root error %v\n", err)
Expand Down
Loading

0 comments on commit 153d699

Please sign in to comment.