From ba97e30f0c69424b6634b209997903788a5716d8 Mon Sep 17 00:00:00 2001 From: siyul-park Date: Sat, 28 Dec 2024 09:06:32 +0900 Subject: [PATCH] fix: more minimal lock --- ext/pkg/network/listener.go | 3 --- pkg/node/manytoone.go | 25 ++++++++++++++----------- pkg/node/onetomany.go | 24 ++++++++++-------------- pkg/node/onetoone.go | 15 ++++++++------- pkg/port/inport.go | 9 ++++++++- pkg/port/outport.go | 9 ++++++++- pkg/symbol/cluster.go | 2 +- 7 files changed, 49 insertions(+), 38 deletions(-) diff --git a/ext/pkg/network/listener.go b/ext/pkg/network/listener.go index 61f6f483..9a8a4684 100644 --- a/ext/pkg/network/listener.go +++ b/ext/pkg/network/listener.go @@ -181,9 +181,6 @@ func (n *HTTPListenNode) Shutdown() error { // ServeHTTP handles HTTP requests. func (n *HTTPListenNode) ServeHTTP(w http.ResponseWriter, r *http.Request) { - n.mu.RLock() - defer n.mu.RUnlock() - proc := process.New() proc.Store(KeyHTTPResponseWriter, w) diff --git a/pkg/node/manytoone.go b/pkg/node/manytoone.go index 2b5bbd33..7ef8ab53 100644 --- a/pkg/node/manytoone.go +++ b/pkg/node/manytoone.go @@ -91,15 +91,17 @@ func (n *ManyToOneNode) Close() error { } func (n *ManyToOneNode) forward(index int) port.Listener { - return port.ListenFunc(func(proc *process.Process) { - n.mu.RLock() - defer n.mu.RUnlock() + inPort := n.inPorts[index] - inReader := n.inPorts[index].Open(proc) + return port.ListenFunc(func(proc *process.Process) { + inReader := inPort.Open(proc) var outWriter *packet.Writer var errWriter *packet.Writer readGroup, _ := n.readGroups.LoadOrStore(proc, func() (*packet.ReadGroup, error) { + n.mu.RLock() + defer n.mu.RUnlock() + inReaders := make([]*packet.Reader, len(n.inPorts)) for i, inPort := range n.inPorts { inReaders[i] = inPort.Open(proc) @@ -110,19 +112,20 @@ func (n *ManyToOneNode) forward(index int) port.Listener { for inPck := range inReader.Read() { n.tracer.Read(inReader, inPck) - if outWriter == nil { - outWriter = n.outPort.Open(proc) - } - if errWriter == nil { - errWriter = n.errPort.Open(proc) - } - if inPcks := readGroup.Read(inReader, inPck); len(inPcks) < len(n.inPorts) { n.tracer.Reduce(inPck) } else if outPck, errPck := n.action(proc, inPcks); errPck != nil { + if errWriter == nil { + errWriter = n.errPort.Open(proc) + } + n.tracer.Transform(inPck, errPck) n.tracer.Write(errWriter, errPck) } else if outPck != nil { + if outWriter == nil { + outWriter = n.outPort.Open(proc) + } + n.tracer.Transform(inPck, outPck) n.tracer.Write(outWriter, outPck) } else { diff --git a/pkg/node/onetomany.go b/pkg/node/onetomany.go index 8e4d1cd5..b8a7b27f 100644 --- a/pkg/node/onetomany.go +++ b/pkg/node/onetomany.go @@ -94,22 +94,17 @@ func (n *OneToManyNode) forward(proc *process.Process) { defer n.mu.RUnlock() inReader := n.inPort.Open(proc) - outWriters := make([]*packet.Writer, 0, len(n.outPorts)) + outWriters := make([]*packet.Writer, len(n.outPorts)) var errWriter *packet.Writer for inPck := range inReader.Read() { n.tracer.Read(inReader, inPck) - if len(outWriters) == 0 { - for _, outPort := range n.outPorts { - outWriters = append(outWriters, outPort.Open(proc)) + if outPcks, errPck := n.action(proc, inPck); errPck != nil { + if errWriter == nil { + errWriter = n.errPort.Open(proc) } - } - if errWriter == nil { - errWriter = n.errPort.Open(proc) - } - if outPcks, errPck := n.action(proc, inPck); errPck != nil { n.tracer.Transform(inPck, errPck) n.tracer.Write(errWriter, errPck) } else { @@ -122,6 +117,10 @@ func (n *OneToManyNode) forward(proc *process.Process) { count := 0 for i, outPck := range outPcks { if i < len(outWriters) && outPck != nil { + if outWriters[i] == nil { + outWriters[i] = n.outPorts[i].Open(proc) + } + n.tracer.Write(outWriters[i], outPck) count++ } @@ -135,12 +134,9 @@ func (n *OneToManyNode) forward(proc *process.Process) { } func (n *OneToManyNode) backward(index int) port.Listener { - return port.ListenFunc(func(proc *process.Process) { - n.mu.RLock() - defer n.mu.RUnlock() - - outPort := n.outPorts[index] + outPort := n.outPorts[index] + return port.ListenFunc(func(proc *process.Process) { outWriter := outPort.Open(proc) for backPck := range outWriter.Receive() { diff --git a/pkg/node/onetoone.go b/pkg/node/onetoone.go index a78f933b..5736e74c 100644 --- a/pkg/node/onetoone.go +++ b/pkg/node/onetoone.go @@ -75,17 +75,18 @@ func (n *OneToOneNode) forward(proc *process.Process) { for inPck := range inReader.Read() { n.tracer.Read(inReader, inPck) - if outWriter == nil { - outWriter = n.outPort.Open(proc) - } - if errWriter == nil { - errWriter = n.errPort.Open(proc) - } - if outPck, errPck := n.action(proc, inPck); errPck != nil { + if errWriter == nil { + errWriter = n.errPort.Open(proc) + } + n.tracer.Transform(inPck, errPck) n.tracer.Write(errWriter, errPck) } else { + if outWriter == nil { + outWriter = n.outPort.Open(proc) + } + n.tracer.Transform(inPck, outPck) n.tracer.Write(outWriter, outPck) } diff --git a/pkg/port/inport.go b/pkg/port/inport.go index 4cd6a4ff..86e4fa47 100644 --- a/pkg/port/inport.go +++ b/pkg/port/inport.go @@ -95,9 +95,16 @@ func (p *InPort) AddListener(listener Listener) bool { // Open prepares the input port for a given process and returns a reader. func (p *InPort) Open(proc *process.Process) *packet.Reader { + p.mu.RLock() + reader, ok := p.readers[proc] + p.mu.RUnlock() + if ok { + return reader + } + p.mu.Lock() - reader, ok := p.readers[proc] + reader, ok = p.readers[proc] if ok { p.mu.Unlock() return reader diff --git a/pkg/port/outport.go b/pkg/port/outport.go index b195079c..f7d814a0 100644 --- a/pkg/port/outport.go +++ b/pkg/port/outport.go @@ -129,9 +129,16 @@ func (p *OutPort) Unlink(in *InPort) { // Open opens the output port for the given process and returns a writer. func (p *OutPort) Open(proc *process.Process) *packet.Writer { + p.mu.RLock() + writer, ok := p.writers[proc] + p.mu.RUnlock() + if ok { + return writer + } + p.mu.Lock() - writer, ok := p.writers[proc] + writer, ok = p.writers[proc] if ok { p.mu.Unlock() return writer diff --git a/pkg/symbol/cluster.go b/pkg/symbol/cluster.go index b0fb9c8c..5c35b871 100644 --- a/pkg/symbol/cluster.go +++ b/pkg/symbol/cluster.go @@ -1,10 +1,10 @@ package symbol import ( - "github.com/siyul-park/uniflow/pkg/packet" "sync" "github.com/siyul-park/uniflow/pkg/node" + "github.com/siyul-park/uniflow/pkg/packet" "github.com/siyul-park/uniflow/pkg/port" "github.com/siyul-park/uniflow/pkg/process" "github.com/siyul-park/uniflow/pkg/spec"