Skip to content

Commit

Permalink
feat: support context in language
Browse files Browse the repository at this point in the history
  • Loading branch information
siyul-park committed Oct 12, 2024
1 parent 59212d0 commit 240ea61
Show file tree
Hide file tree
Showing 21 changed files with 159 additions and 77 deletions.
23 changes: 16 additions & 7 deletions ext/pkg/control/if.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package control

import (
"context"
"reflect"
"time"

"github.com/siyul-park/uniflow/ext/pkg/language"
"github.com/siyul-park/uniflow/pkg/node"
Expand All @@ -15,13 +17,14 @@ import (
// IfNodeSpec defines the specifications for creating an IfNode.
type IfNodeSpec struct {
spec.Meta `map:",inline"`
When string `map:"when"`
When string `map:"when"`
Timeout time.Duration `map:"timeout,omitempty"`
}

// IfNode evaluates a condition and routes packets based on the result.
type IfNode struct {
*node.OneToManyNode
condition func(any) (bool, error)
condition func(context.Context, any) (bool, error)
}

const KindIf = "if"
Expand All @@ -33,8 +36,14 @@ func NewIfNodeCodec(compiler language.Compiler) scheme.Codec {
if err != nil {
return nil, err
}
return NewIfNode(func(env any) (bool, error) {
res, err := program.Run(env)
return NewIfNode(func(ctx context.Context, env any) (bool, error) {
if spec.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, spec.Timeout)
defer cancel()
}

res, err := program.Run(ctx, env)
if err != nil {
return false, err
}
Expand All @@ -44,17 +53,17 @@ func NewIfNodeCodec(compiler language.Compiler) scheme.Codec {
}

// NewIfNode creates a new IfNode instance.
func NewIfNode(condition func(any) (bool, error)) *IfNode {
func NewIfNode(condition func(context.Context, any) (bool, error)) *IfNode {
n := &IfNode{condition: condition}
n.OneToManyNode = node.NewOneToManyNode(n.action)
return n
}

func (n *IfNode) action(_ *process.Process, inPck *packet.Packet) ([]*packet.Packet, *packet.Packet) {
func (n *IfNode) action(proc *process.Process, inPck *packet.Packet) ([]*packet.Packet, *packet.Packet) {
inPayload := inPck.Payload()
input := types.InterfaceOf(inPayload)

ok, err := n.condition(input)
ok, err := n.condition(proc.Context(), input)
if err != nil {
return nil, packet.New(types.NewError(err))
}
Expand Down
8 changes: 4 additions & 4 deletions ext/pkg/control/if_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestIfNode_SendAndReceive(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
defer cancel()

n := NewIfNode(func(_ any) (bool, error) {
n := NewIfNode(func(_ context.Context, _ any) (bool, error) {
return true, nil
})
defer n.Close()
Expand Down Expand Up @@ -82,7 +82,7 @@ func TestIfNode_SendAndReceive(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
defer cancel()

n := NewIfNode(func(_ any) (bool, error) {
n := NewIfNode(func(_ context.Context, _ any) (bool, error) {
return false, nil
})
defer n.Close()
Expand Down Expand Up @@ -124,7 +124,7 @@ func TestIfNode_SendAndReceive(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
defer cancel()

n := NewIfNode(func(_ any) (bool, error) {
n := NewIfNode(func(_ context.Context, _ any) (bool, error) {
return false, errors.New(faker.Sentence())
})
defer n.Close()
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestIfNode_SendAndReceive(t *testing.T) {
}

func BenchmarkIfNode_SendAndReceive(b *testing.B) {
n := NewIfNode(func(_ any) (bool, error) {
n := NewIfNode(func(_ context.Context, _ any) (bool, error) {
return true, nil
})
defer n.Close()
Expand Down
24 changes: 17 additions & 7 deletions ext/pkg/control/reduce.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package control

import (
"context"
"time"

"github.com/siyul-park/uniflow/ext/pkg/language"
"github.com/siyul-park/uniflow/pkg/node"
"github.com/siyul-park/uniflow/pkg/packet"
Expand All @@ -14,13 +17,14 @@ import (
// ReduceNodeSpec defines the specifications for creating a ReduceNode.
type ReduceNodeSpec struct {
spec.Meta `map:",inline"`
Action string `map:"action"`
Init any `map:"init,omitempty"`
Action string `map:"action"`
Init any `map:"init,omitempty"`
Timeout time.Duration `map:"timeout,omitempty"`
}

// ReduceNode performs a reduction operation using the provided action.
type ReduceNode struct {
action func(any, any, int) (any, error)
action func(context.Context, any, any, int) (any, error)
init any
tracer *packet.Tracer
inPort *port.InPort
Expand All @@ -38,14 +42,20 @@ func NewReduceNodeCodec(compiler language.Compiler) scheme.Codec {
return nil, err
}

return NewReduceNode(func(acc, cur any, index int) (any, error) {
return program.Run(acc, cur, index)
return NewReduceNode(func(ctx context.Context, acc, cur any, index int) (any, error) {
if spec.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, spec.Timeout)
defer cancel()
}

return program.Run(ctx, acc, cur, index)
}, spec.Init), nil
})
}

// NewReduceNode creates a new ReduceNode with the provided action and initial value.
func NewReduceNode(action func(any, any, int) (any, error), init any) *ReduceNode {
func NewReduceNode(action func(context.Context, any, any, int) (any, error), init any) *ReduceNode {
n := &ReduceNode{
action: action,
init: init,
Expand Down Expand Up @@ -107,7 +117,7 @@ func (n *ReduceNode) forward(proc *process.Process) {
n.tracer.Read(inReader, inPck)
cur := types.InterfaceOf(inPck.Payload())

if v, err := n.action(acc, cur, i); err != nil {
if v, err := n.action(proc.Context(), acc, cur, i); err != nil {
errPck := packet.New(types.NewError(err))
n.tracer.Transform(inPck, errPck)
n.tracer.Write(errWriter, errPck)
Expand Down
8 changes: 4 additions & 4 deletions ext/pkg/control/reduce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ func TestReduceNodeCodec_Decode(t *testing.T) {
}

func TestNewReduceNode(t *testing.T) {
n := NewReduceNode(func(a1, a2 any, i int) (any, error) {
n := NewReduceNode(func(_ context.Context, a1, a2 any, i int) (any, error) {
return a1, nil
}, nil)
assert.NotNil(t, n)
assert.NoError(t, n.Close())
}

func TestReduceNode_Port(t *testing.T) {
n := NewReduceNode(func(a1, a2 any, i int) (any, error) {
n := NewReduceNode(func(_ context.Context, a1, a2 any, i int) (any, error) {
return a1, nil
}, nil)
defer n.Close()
Expand All @@ -48,7 +48,7 @@ func TestReduceNode_SendAndReceive(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
defer cancel()

n := NewReduceNode(func(a1, a2 any, i int) (any, error) {
n := NewReduceNode(func(_ context.Context, a1, a2 any, i int) (any, error) {
return a2, nil
}, nil)
defer n.Close()
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestReduceNode_SendAndReceive(t *testing.T) {
}

func BenchmarkReduceNode_SendAndReceive(b *testing.B) {
n := NewReduceNode(func(a1, a2 any, i int) (any, error) {
n := NewReduceNode(func(_ context.Context, a1, a2 any, i int) (any, error) {
return a2, nil
}, nil)
defer n.Close()
Expand Down
26 changes: 18 additions & 8 deletions ext/pkg/control/snippet.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package control

import (
"context"
"time"

"github.com/siyul-park/uniflow/ext/pkg/language"
"github.com/siyul-park/uniflow/pkg/node"
"github.com/siyul-park/uniflow/pkg/packet"
Expand All @@ -13,14 +16,15 @@ import (
// SnippetNodeSpec defines the specifications for creating a SnippetNode.
type SnippetNodeSpec struct {
spec.Meta `map:",inline"`
Language string `map:"language,omitempty"`
Code string `map:"code"`
Language string `map:"language,omitempty"`
Code string `map:"code"`
Timeout time.Duration `map:"timeout,omitempty"`
}

// SnippetNode represents a node that executes code snippets in various languages.
type SnippetNode struct {
*node.OneToOneNode
fn func(any) (any, error)
fn func(context.Context, any) (any, error)
}

const KindSnippet = "snippet"
Expand All @@ -38,24 +42,30 @@ func NewSnippetNodeCodec(module *language.Module) scheme.Codec {
return nil, err
}

return NewSnippetNode(func(arg any) (any, error) {
return program.Run(arg)
return NewSnippetNode(func(ctx context.Context, arg any) (any, error) {
if spec.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, spec.Timeout)
defer cancel()
}

return program.Run(ctx, arg)
}), nil
})
}

// NewSnippetNode creates a new SnippetNode with the specified language.Language and code.
func NewSnippetNode(fn func(any) (any, error)) *SnippetNode {
func NewSnippetNode(fn func(context.Context, any) (any, error)) *SnippetNode {
n := &SnippetNode{fn: fn}
n.OneToOneNode = node.NewOneToOneNode(n.action)
return n
}

func (n *SnippetNode) action(_ *process.Process, inPck *packet.Packet) (*packet.Packet, *packet.Packet) {
func (n *SnippetNode) action(proc *process.Process, inPck *packet.Packet) (*packet.Packet, *packet.Packet) {
inPayload := inPck.Payload()
input := types.InterfaceOf(inPayload)

if output, err := n.fn(input); err != nil {
if output, err := n.fn(proc.Context(), input); err != nil {
return nil, packet.New(types.NewError(err))
} else if outPayload, err := types.Marshal(output); err != nil {
return nil, packet.New(types.NewError(err))
Expand Down
4 changes: 2 additions & 2 deletions ext/pkg/control/snippet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestSnippetNode_SendAndReceive(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
defer cancel()

n := NewSnippetNode(func(input any) (any, error) {
n := NewSnippetNode(func(_ context.Context, input any) (any, error) {
return input, nil
})
defer n.Close()
Expand All @@ -70,7 +70,7 @@ func TestSnippetNode_SendAndReceive(t *testing.T) {
}

func BenchmarkSnippetNode_SendAndReceive(b *testing.B) {
n := NewSnippetNode(func(input any) (any, error) {
n := NewSnippetNode(func(_ context.Context, input any) (any, error) {
return input, nil
})
defer n.Close()
Expand Down
27 changes: 18 additions & 9 deletions ext/pkg/control/switch.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package control

import (
"context"
"reflect"
"sync"
"time"

"github.com/siyul-park/uniflow/ext/pkg/language"
"github.com/siyul-park/uniflow/pkg/node"
Expand All @@ -16,7 +18,8 @@ import (
// SwitchNodeSpec holds specifications for creating a SwitchNode.
type SwitchNodeSpec struct {
spec.Meta `map:",inline"`
Matches []Condition `map:"matches"`
Matches []Condition `map:"matches"`
Timeout time.Duration `map:"timeout,omitempty"`
}

// Condition represents a condition for directing packets to specific ports.
Expand All @@ -28,7 +31,7 @@ type Condition struct {
// SwitchNode directs packets to different ports based on specified conditions.
type SwitchNode struct {
*node.OneToManyNode
conditions []func(any) (bool, error)
conditions []func(context.Context, any) (bool, error)
ports []int
mu sync.RWMutex
}
Expand All @@ -38,15 +41,21 @@ const KindSwitch = "switch"
// NewSwitchNodeCodec creates a new codec for SwitchNodeSpec.
func NewSwitchNodeCodec(compiler language.Compiler) scheme.Codec {
return scheme.CodecWithType(func(spec *SwitchNodeSpec) (node.Node, error) {
conditions := make([]func(any) (bool, error), len(spec.Matches))
conditions := make([]func(context.Context, any) (bool, error), len(spec.Matches))
for i, condition := range spec.Matches {
program, err := compiler.Compile(condition.When)
if err != nil {
return nil, err
}

conditions[i] = func(env any) (bool, error) {
res, err := program.Run(env)
conditions[i] = func(ctx context.Context, env any) (bool, error) {
if spec.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, spec.Timeout)
defer cancel()
}

res, err := program.Run(ctx, env)
if err != nil {
return false, err
}
Expand All @@ -56,7 +65,7 @@ func NewSwitchNodeCodec(compiler language.Compiler) scheme.Codec {

n := NewSwitchNode()
for i, condition := range spec.Matches {
n.Match(conditions[i], condition.Port)
n.Match(condition.Port, conditions[i])
}
return n, nil
})
Expand All @@ -70,7 +79,7 @@ func NewSwitchNode() *SwitchNode {
}

// Match associates a condition with a specific output port in the SwitchNode.
func (n *SwitchNode) Match(condition func(any) (bool, error), port string) {
func (n *SwitchNode) Match(port string, condition func(context.Context, any) (bool, error)) {
n.mu.Lock()
defer n.mu.Unlock()

Expand All @@ -83,7 +92,7 @@ func (n *SwitchNode) Match(condition func(any) (bool, error), port string) {
n.ports = append(n.ports, index)
}

func (n *SwitchNode) action(_ *process.Process, inPck *packet.Packet) ([]*packet.Packet, *packet.Packet) {
func (n *SwitchNode) action(proc *process.Process, inPck *packet.Packet) ([]*packet.Packet, *packet.Packet) {
n.mu.RLock()
defer n.mu.RUnlock()

Expand All @@ -93,7 +102,7 @@ func (n *SwitchNode) action(_ *process.Process, inPck *packet.Packet) ([]*packet
outPcks := make([]*packet.Packet, len(n.conditions))
for i, condition := range n.conditions {
port := n.ports[i]
if ok, err := condition(input); err != nil {
if ok, err := condition(proc.Context(), input); err != nil {
return nil, packet.New(types.NewError(err))
} else if ok {
outPcks[port] = inPck
Expand Down
4 changes: 2 additions & 2 deletions ext/pkg/control/switch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestSwitchNode_SendAndReceive(t *testing.T) {
n := NewSwitchNode()
defer n.Close()

n.Match(func(_ any) (bool, error) { return true, nil }, node.PortWithIndex(node.PortOut, 0))
n.Match(node.PortWithIndex(node.PortOut, 0), func(_ context.Context, _ any) (bool, error) { return true, nil })

in := port.NewOut()
in.Link(n.In(node.PortIn))
Expand Down Expand Up @@ -84,7 +84,7 @@ func BenchmarkSwitchNode_SendAndReceive(b *testing.B) {
n := NewSwitchNode()
defer n.Close()

n.Match(func(_ any) (bool, error) { return true, nil }, node.PortWithIndex(node.PortOut, 0))
n.Match(node.PortWithIndex(node.PortOut, 0), func(_ context.Context, _ any) (bool, error) { return true, nil })

in := port.NewOut()
in.Link(n.In(node.PortIn))
Expand Down
Loading

0 comments on commit 240ea61

Please sign in to comment.