Skip to content

Commit

Permalink
fix: more minimal lock
Browse files Browse the repository at this point in the history
  • Loading branch information
siyul-park committed Dec 28, 2024
1 parent e4e4320 commit 8c94611
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 41 deletions.
3 changes: 0 additions & 3 deletions ext/pkg/network/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,6 @@ func (n *HTTPListenNode) Shutdown() error {

// ServeHTTP handles HTTP requests.
func (n *HTTPListenNode) ServeHTTP(w http.ResponseWriter, r *http.Request) {
n.mu.RLock()
defer n.mu.RUnlock()

proc := process.New()

proc.Store(KeyHTTPResponseWriter, w)
Expand Down
25 changes: 14 additions & 11 deletions pkg/node/manytoone.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,17 @@ func (n *ManyToOneNode) Close() error {
}

func (n *ManyToOneNode) forward(index int) port.Listener {
return port.ListenFunc(func(proc *process.Process) {
n.mu.RLock()
defer n.mu.RUnlock()
inPort := n.inPorts[index]

inReader := n.inPorts[index].Open(proc)
return port.ListenFunc(func(proc *process.Process) {
inReader := inPort.Open(proc)
var outWriter *packet.Writer
var errWriter *packet.Writer

readGroup, _ := n.readGroups.LoadOrStore(proc, func() (*packet.ReadGroup, error) {
n.mu.RLock()
defer n.mu.RUnlock()

inReaders := make([]*packet.Reader, len(n.inPorts))
for i, inPort := range n.inPorts {
inReaders[i] = inPort.Open(proc)
Expand All @@ -110,19 +112,20 @@ func (n *ManyToOneNode) forward(index int) port.Listener {
for inPck := range inReader.Read() {
n.tracer.Read(inReader, inPck)

if outWriter == nil {
outWriter = n.outPort.Open(proc)
}
if errWriter == nil {
errWriter = n.errPort.Open(proc)
}

if inPcks := readGroup.Read(inReader, inPck); len(inPcks) < len(n.inPorts) {
n.tracer.Reduce(inPck)
} else if outPck, errPck := n.action(proc, inPcks); errPck != nil {
if errWriter == nil {
errWriter = n.errPort.Open(proc)
}

n.tracer.Transform(inPck, errPck)
n.tracer.Write(errWriter, errPck)
} else if outPck != nil {
if outWriter == nil {
outWriter = n.outPort.Open(proc)
}

n.tracer.Transform(inPck, outPck)
n.tracer.Write(outWriter, outPck)
} else {
Expand Down
24 changes: 10 additions & 14 deletions pkg/node/onetomany.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,17 @@ func (n *OneToManyNode) forward(proc *process.Process) {
defer n.mu.RUnlock()

inReader := n.inPort.Open(proc)
outWriters := make([]*packet.Writer, 0, len(n.outPorts))
outWriters := make([]*packet.Writer, len(n.outPorts))
var errWriter *packet.Writer

for inPck := range inReader.Read() {
n.tracer.Read(inReader, inPck)

if len(outWriters) == 0 {
for _, outPort := range n.outPorts {
outWriters = append(outWriters, outPort.Open(proc))
if outPcks, errPck := n.action(proc, inPck); errPck != nil {
if errWriter == nil {
errWriter = n.errPort.Open(proc)
}
}
if errWriter == nil {
errWriter = n.errPort.Open(proc)
}

if outPcks, errPck := n.action(proc, inPck); errPck != nil {
n.tracer.Transform(inPck, errPck)
n.tracer.Write(errWriter, errPck)
} else {
Expand All @@ -122,6 +117,10 @@ func (n *OneToManyNode) forward(proc *process.Process) {
count := 0
for i, outPck := range outPcks {
if i < len(outWriters) && outPck != nil {
if outWriters[i] == nil {
outWriters[i] = n.outPorts[i].Open(proc)
}

n.tracer.Write(outWriters[i], outPck)
count++
}
Expand All @@ -135,12 +134,9 @@ func (n *OneToManyNode) forward(proc *process.Process) {
}

func (n *OneToManyNode) backward(index int) port.Listener {
return port.ListenFunc(func(proc *process.Process) {
n.mu.RLock()
defer n.mu.RUnlock()

outPort := n.outPorts[index]
outPort := n.outPorts[index]

return port.ListenFunc(func(proc *process.Process) {
outWriter := outPort.Open(proc)

for backPck := range outWriter.Receive() {
Expand Down
15 changes: 8 additions & 7 deletions pkg/node/onetoone.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,18 @@ func (n *OneToOneNode) forward(proc *process.Process) {
for inPck := range inReader.Read() {
n.tracer.Read(inReader, inPck)

if outWriter == nil {
outWriter = n.outPort.Open(proc)
}
if errWriter == nil {
errWriter = n.errPort.Open(proc)
}

if outPck, errPck := n.action(proc, inPck); errPck != nil {
if errWriter == nil {
errWriter = n.errPort.Open(proc)
}

n.tracer.Transform(inPck, errPck)
n.tracer.Write(errWriter, errPck)
} else {
if outWriter == nil {
outWriter = n.outPort.Open(proc)
}

n.tracer.Transform(inPck, outPck)
n.tracer.Write(outWriter, outPck)
}
Expand Down
9 changes: 8 additions & 1 deletion pkg/port/inport.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,16 @@ func (p *InPort) AddListener(listener Listener) bool {

// Open prepares the input port for a given process and returns a reader.
func (p *InPort) Open(proc *process.Process) *packet.Reader {
p.mu.RLock()
reader, ok := p.readers[proc]
p.mu.RUnlock()
if ok {
return reader
}

p.mu.Lock()

reader, ok := p.readers[proc]
reader, ok = p.readers[proc]
if ok {
p.mu.Unlock()
return reader
Expand Down
9 changes: 8 additions & 1 deletion pkg/port/outport.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,16 @@ func (p *OutPort) Unlink(in *InPort) {

// Open opens the output port for the given process and returns a writer.
func (p *OutPort) Open(proc *process.Process) *packet.Writer {
p.mu.RLock()
writer, ok := p.writers[proc]
p.mu.RUnlock()
if ok {
return writer
}

p.mu.Lock()

writer, ok := p.writers[proc]
writer, ok = p.writers[proc]
if ok {
p.mu.Unlock()
return writer
Expand Down
12 changes: 9 additions & 3 deletions pkg/process/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,23 @@ func (l *Local[T]) Delete(proc *Process) bool {

// LoadOrStore retrieves or stores a value for the given process.
func (l *Local[T]) LoadOrStore(proc *Process, val func() (T, error)) (T, error) {
l.mu.RLock()
v, ok := l.data[proc]
l.mu.RUnlock()
if ok {
return v, nil
}

l.mu.Lock()
defer l.mu.Unlock()

if v, ok := l.data[proc]; ok {
l.mu.Unlock()
return v, nil
}

v, err := val()
if err != nil {
l.mu.Unlock()
return v, err
}

Expand All @@ -139,8 +147,6 @@ func (l *Local[T]) LoadOrStore(proc *Process, val func() (T, error)) (T, error)

storeHooks.Store(v)

l.mu.Lock()

return v, nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/symbol/cluster.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package symbol

import (
"github.com/siyul-park/uniflow/pkg/packet"
"sync"

"github.com/siyul-park/uniflow/pkg/node"
"github.com/siyul-park/uniflow/pkg/packet"
"github.com/siyul-park/uniflow/pkg/port"
"github.com/siyul-park/uniflow/pkg/process"
"github.com/siyul-park/uniflow/pkg/spec"
Expand Down

0 comments on commit 8c94611

Please sign in to comment.