diff --git a/flytepropeller/pkg/controller/nodes/branch/handler.go b/flytepropeller/pkg/controller/nodes/branch/handler.go index 9789b65c22..a869569680 100644 --- a/flytepropeller/pkg/controller/nodes/branch/handler.go +++ b/flytepropeller/pkg/controller/nodes/branch/handler.go @@ -3,6 +3,7 @@ package branch import ( "context" "fmt" + "strconv" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" @@ -15,6 +16,7 @@ import ( stdErrors "github.com/flyteorg/flyte/flytestdlib/errors" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" + "github.com/flyteorg/flyte/flytestdlib/storage" ) type metrics struct { @@ -74,8 +76,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha childNodeStatus.SetParentNodeID(&i) logger.Debugf(ctx, "Recursively executing branchNode's chosen path") - nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID()) - return b.recurseDownstream(ctx, nCtx, nodeStatus, finalNode) + return b.recurseDownstream(ctx, nCtx, finalNode) } // If the branchNodestatus was already evaluated i.e, Node is in Running status @@ -99,8 +100,7 @@ func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha } // Recurse downstream - nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID()) - return b.recurseDownstream(ctx, nCtx, nodeStatus, branchTakenNode) + return b.recurseDownstream(ctx, nCtx, branchTakenNode) } func (b *branchHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecutionContext) (handler.Transition, error) { @@ -123,7 +123,7 @@ func (b *branchHandler) getExecutionContextForDownstream(nCtx interfaces.NodeExe return executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo), nil } -func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) { +func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.NodeExecutionContext, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) { // TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time // There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc // The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node. @@ -134,8 +134,16 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.N } childNodeStatus := nl.GetNodeExecutionStatus(ctx, branchTakenNode.GetID()) - childNodeStatus.SetDataDir(nodeStatus.GetDataDir()) - childNodeStatus.SetOutputDir(nodeStatus.GetOutputDir()) + childDataDir, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), branchTakenNode.GetID()) + if err != nil { + return handler.UnknownTransition, err + } + childOutputDir, err := nCtx.DataStore().ConstructReference(ctx, childDataDir, strconv.Itoa(int(childNodeStatus.GetAttempts()))) + if err != nil { + return handler.UnknownTransition, err + } + childNodeStatus.SetDataDir(childDataDir) + childNodeStatus.SetOutputDir(childOutputDir) upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID()) if err != nil { return handler.UnknownTransition, err @@ -151,9 +159,14 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx interfaces.N } if downstreamStatus.IsComplete() { - // For branch node we set the output node to be the same as the child nodes output + childOutputsPath := v1alpha1.GetOutputsFile(childOutputDir) + outputsPath := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + if err := nCtx.DataStore().CopyRaw(ctx, childOutputsPath, outputsPath, storage.Options{}); err != nil { + errMsg := fmt.Sprintf("Failed to copy child node outputs from [%v] to [%v]", childOutputsPath, outputsPath) + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.OutputsNotFoundError, errMsg, nil)), nil + } phase := handler.PhaseInfoSuccess(&handler.ExecutionInfo{ - OutputInfo: &handler.OutputInfo{OutputURI: v1alpha1.GetOutputsFile(childNodeStatus.GetOutputDir())}, + OutputInfo: &handler.OutputInfo{OutputURI: outputsPath}, }) return handler.DoTransition(handler.TransitionTypeEphemeral, phase), nil diff --git a/flytepropeller/pkg/controller/nodes/branch/handler_test.go b/flytepropeller/pkg/controller/nodes/branch/handler_test.go index a48344020d..96d20b1710 100644 --- a/flytepropeller/pkg/controller/nodes/branch/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/branch/handler_test.go @@ -1,6 +1,7 @@ package branch import ( + "bytes" "context" "fmt" "testing" @@ -110,6 +111,7 @@ func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.Nod ns := &mocks2.ExecutableNodeStatus{} ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) ns.OnGetPhase().Return(v1alpha1.NodePhaseNotYetStarted) ir := &mocks3.InputReader{} @@ -162,7 +164,6 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { name string ns interfaces.NodeStatus err error - nodeStatus *mocks2.ExecutableNodeStatus branchTakenNode v1alpha1.ExecutableNode isErr bool expectedPhase handler.EPhase @@ -170,17 +171,17 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { upstreamNodeID string }{ {"upstreamNodeExists", interfaces.NodeStatusPending, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"}, + bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"}, {"childNodeError", interfaces.NodeStatusUndefined, fmt.Errorf("err"), - &mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""}, + bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""}, {"childPending", interfaces.NodeStatusPending, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""}, + bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""}, {"childStillRunning", interfaces.NodeStatusRunning, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""}, + bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""}, {"childFailure", interfaces.NodeStatusFailed(expectedError), nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""}, + bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""}, {"childComplete", interfaces.NodeStatusComplete, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""}, + bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -222,16 +223,16 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { ).Return(test.ns, test.err) childNodeStatus := &mocks2.ExecutableNodeStatus{} - if mockNodeLookup != nil { - childNodeStatus.OnGetOutputDir().Return("parent-output-dir") - test.nodeStatus.OnGetDataDir().Return("parent-data-dir") - test.nodeStatus.OnGetOutputDir().Return("parent-output-dir") - mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus) - childNodeStatus.On("SetDataDir", storage.DataReference("parent-data-dir")).Once() - childNodeStatus.On("SetOutputDir", storage.DataReference("parent-output-dir")).Once() + childNodeStatus.OnGetAttempts().Return(0) + childNodeStatus.On("SetDataDir", storage.DataReference("/output-dir/child")).Once() + childNodeStatus.On("SetOutputDir", storage.DataReference("/output-dir/child/0")).Once() + mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus) + if test.childPhase == v1alpha1.NodePhaseSucceeded { + _ = nCtx.DataStore().WriteRaw(ctx, storage.DataReference("/output-dir/child/0/outputs.pb"), 0, storage.Options{}, bytes.NewReader([]byte{})) } + branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope()).(*branchHandler) - h, err := branch.recurseDownstream(ctx, nCtx, test.nodeStatus, test.branchTakenNode) + h, err := branch.recurseDownstream(ctx, nCtx, test.branchTakenNode) if test.isErr { assert.Error(t, err) } else {