From 4257b0f72e14c6cd1dba43a871f6ca8ccae98b5d Mon Sep 17 00:00:00 2001 From: siyual-park Date: Tue, 5 Nov 2024 13:33:42 +0900 Subject: [PATCH] fix: escape race condition --- ext/pkg/control/split.go | 4 ++-- ext/pkg/io/print.go | 12 ++++++++---- ext/pkg/mime/encoding.go | 10 ++++------ ext/pkg/network/listener_test.go | 3 --- pkg/scheme/scheme.go | 3 +++ pkg/types/map.go | 26 ++++++++++++++++++-------- pkg/types/map_test.go | 12 ++++++++++++ pkg/types/slice.go | 16 ++++++++++++++-- pkg/types/slice_test.go | 10 ++++++++++ 9 files changed, 71 insertions(+), 25 deletions(-) diff --git a/ext/pkg/control/split.go b/ext/pkg/control/split.go index d3bcd9e9..e2ad4a1d 100644 --- a/ext/pkg/control/split.go +++ b/ext/pkg/control/split.go @@ -39,8 +39,8 @@ func (n *SplitNode) action(_ *process.Process, inPck *packet.Packet) ([]*packet. switch inPayload := inPck.Payload().(type) { case types.Slice: outPcks := make([]*packet.Packet, 0, inPayload.Len()) - for i := 0; i < inPayload.Len(); i++ { - outPck := packet.New(inPayload.Get(i)) + for _, v := range inPayload.Range() { + outPck := packet.New(v) outPcks = append(outPcks, outPck) } return outPcks, nil diff --git a/ext/pkg/io/print.go b/ext/pkg/io/print.go index 97e37503..5a6ad558 100644 --- a/ext/pkg/io/print.go +++ b/ext/pkg/io/print.go @@ -92,8 +92,10 @@ func (n *PrintNode) action(_ *process.Process, inPck *packet.Packet) (*packet.Pa if !ok { return nil, packet.New(types.NewError(encoding.ErrUnsupportedType)) } - for i := 1; i < payload.Len(); i++ { - args = append(args, types.InterfaceOf(payload.Get(i))) + for i, v := range payload.Range() { + if i > 0 { + args = append(args, types.InterfaceOf(v)) + } } } @@ -124,8 +126,10 @@ func (n *DynPrintNode) action(_ *process.Process, inPck *packet.Packet) (*packet } var args []any - for i := 2; i < payload.Len(); i++ { - args = append(args, types.InterfaceOf(payload.Get(i))) + for i, v := range payload.Range() { + if i > 1 { + args = append(args, types.InterfaceOf(v)) + } } writer, err := n.fs.Open(filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE) diff --git a/ext/pkg/mime/encoding.go b/ext/pkg/mime/encoding.go index 834ac4fe..2c598672 100644 --- a/ext/pkg/mime/encoding.go +++ b/ext/pkg/mime/encoding.go @@ -105,7 +105,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er elements = types.NewSlice(value) } - for _, element := range elements.Values() { + for _, element := range elements.Range() { h := textproto.MIMEHeader{} h.Set(HeaderContentDisposition, fmt.Sprintf(`form-data; name="%s"`, quoteEscaper.Replace(key.String()))) @@ -121,7 +121,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er writeFields := func(value types.Value) error { if value, ok := value.(types.Map); ok { - for _, key := range value.Keys() { + for key := range value.Range() { if err := writeField(value, key); err != nil { return err } @@ -132,7 +132,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er writeFiles := func(value types.Value) error { if value, ok := value.(types.Map); ok { - for _, key := range value.Keys() { + for key := range value.Range() { if key, ok := key.(types.String); ok { value := value.GetOr(key, nil) @@ -195,9 +195,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er } if v, ok := value.(types.Map); ok { - for _, key := range v.Keys() { - value := v.GetOr(key, nil) - + for key, value := range v.Range() { if key.Equal(keyValues) { if err := writeFields(value); err != nil { return err diff --git a/ext/pkg/network/listener_test.go b/ext/pkg/network/listener_test.go index c038203e..93b2d774 100644 --- a/ext/pkg/network/listener_test.go +++ b/ext/pkg/network/listener_test.go @@ -339,9 +339,6 @@ func BenchmarkHTTPListenNode_ServeHTTP(b *testing.B) { n := NewHTTPListenNode("") defer n.Close() - in := port.NewOut() - in.Link(n.In(node.PortIn)) - out := port.NewIn() n.Out(node.PortOut).Link(out) diff --git a/pkg/scheme/scheme.go b/pkg/scheme/scheme.go index 75012b5f..a9b8100a 100644 --- a/pkg/scheme/scheme.go +++ b/pkg/scheme/scheme.go @@ -92,6 +92,9 @@ func (s *Scheme) AddCodec(kind string, codec Codec) bool { // RemoveCodec removes the codec associated with a kind. func (s *Scheme) RemoveCodec(kind string) bool { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.codecs[kind]; ok { delete(s.codecs, kind) return true diff --git a/pkg/types/map.go b/pkg/types/map.go index b1519083..0627e262 100644 --- a/pkg/types/map.go +++ b/pkg/types/map.go @@ -103,6 +103,18 @@ func (m Map) Pairs() []Value { return pairs } +// Range returns a function that iterates over all key-value pairs in the map. +func (m Map) Range() func(func(key, value Value) bool) { + return func(yield func(key Value, value Value) bool) { + for itr := m.value.Iterator(); !itr.Done(); { + k, v, _ := itr.Next() + if !yield(k, v) { + return + } + } + } +} + // Len returns the number of key-value pairs in the map. func (m Map) Len() int { return m.value.Len() @@ -234,6 +246,10 @@ func (m *mapProxy) Delete(key Value) { m.Map = m.Map.Delete(key) } +func (m *mapProxy) Close() { + m.Map = NewMap() +} + func (*comparer) Compare(x, y Value) int { return Compare(x, y) } @@ -375,14 +391,7 @@ func newMapDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Decod t.Set(reflect.MakeMapWithSize(t.Type(), proxy.Len())) } - for _, key := range proxy.Keys() { - value, ok := proxy.Get(key) - if !ok { - continue - } - - proxy.Delete(key) - + for key, value := range proxy.Range() { k := reflect.New(keyType) v := reflect.New(valueType) @@ -394,6 +403,7 @@ func newMapDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Decod t.SetMapIndex(k.Elem(), v.Elem()) } } + proxy.Close() return nil }), nil } else if typ.Elem().Kind() == reflect.Struct { diff --git a/pkg/types/map_test.go b/pkg/types/map_test.go index 02406cb6..14e457f3 100644 --- a/pkg/types/map_test.go +++ b/pkg/types/map_test.go @@ -72,6 +72,18 @@ func TestMap_Pairs(t *testing.T) { assert.Contains(t, pairs, v1) } +func TestMap_Range(t *testing.T) { + k1 := NewString(faker.UUIDHyphenated()) + v1 := NewString(faker.UUIDHyphenated()) + + o := NewMap(k1, v1) + + for k, v := range o.Range() { + assert.Equal(t, k1, k) + assert.Equal(t, v1, v) + } +} + func TestMap_Len(t *testing.T) { k1 := NewString(faker.UUIDHyphenated()) v1 := NewString(faker.UUIDHyphenated()) diff --git a/pkg/types/slice.go b/pkg/types/slice.go index 5816effd..dd20ac7e 100644 --- a/pkg/types/slice.go +++ b/pkg/types/slice.go @@ -66,6 +66,18 @@ func (s Slice) Values() []Value { return elements } +// Range returns a function that iterates over all key-value pairs of the slice. +func (s Slice) Range() func(func(key int, value Value) bool) { + return func(yield func(key int, value Value) bool) { + for itr := s.value.Iterator(); !itr.Done(); { + i, v := itr.Next() + if !yield(i, v) { + return + } + } + } +} + // Len returns the length of the slice. func (s Slice) Len() int { return s.value.Len() @@ -214,8 +226,8 @@ func newSliceDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Dec return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { t := reflect.NewAt(typ.Elem(), target).Elem() if s, ok := source.(Slice); ok { - for i := 0; i < s.Len(); i++ { - if err := setElement(s.Get(i), t, i); err != nil { + for i, v := range s.Range() { + if err := setElement(v, t, i); err != nil { return err } } diff --git a/pkg/types/slice_test.go b/pkg/types/slice_test.go index 7e794b8a..64acbad3 100644 --- a/pkg/types/slice_test.go +++ b/pkg/types/slice_test.go @@ -75,6 +75,16 @@ func TestSlice_Values(t *testing.T) { assert.Equal(t, []Value{v1, v2}, o.Values()) } +func TestSlice_Range(t *testing.T) { + v1 := NewString(faker.UUIDHyphenated()) + + o := NewSlice(v1) + + for _, v := range o.Range() { + assert.Equal(t, v1, v) + } +} + func TestSlice_Len(t *testing.T) { v1 := NewString(faker.UUIDHyphenated()) v2 := NewString(faker.UUIDHyphenated())