Skip to content

Commit

Permalink
add Catch support
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Wormald committed Sep 18, 2024
1 parent d56cd09 commit 74d4802
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 52 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,6 @@ This should be an object which can be used to override the default configuration
This library is currently a work-in-progress and does not support every feature of Step Functions.
Some functionality yet to be implemented:

* `Retry` and `Catch` fields
* `Retry` fields
* Some AWS resources in `Task` steps
* Some runtime error handling and data validation
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "sfn-sim",
"version": "0.7.0",
"version": "0.8.0",
"description": "AWS Step Functions simulator for unit testing state machines",
"keywords": [
"aws",
Expand Down
10 changes: 10 additions & 0 deletions src/errors.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ class RuntimeError extends Error {
super(message);
this.name = 'RuntimeError';
}

toErrorOutput() {
return {
Error: this.name,
Cause: this.message,
};
}
}

class FailError extends RuntimeError {
Expand Down Expand Up @@ -96,6 +103,8 @@ class ResultWriterFailedError extends RuntimeError {
}
}

const ERROR_WILDCARD = 'States.ALL';

export {
ValidationError,
SimulatorError,
Expand All @@ -104,4 +113,5 @@ export {
TaskFailedError,
NoChoiceMatchedError,
IntrinsicFailureError,
ERROR_WILDCARD,
};
140 changes: 91 additions & 49 deletions src/index.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { StateLint } from '@wmfs/statelint';
import { v4 as uuidV4 } from 'uuid';
import runChoice from './choice.js';
import { ValidationError, RuntimeError, FailError } from './errors.js';
import { ValidationError, RuntimeError, FailError, ERROR_WILDCARD } from './errors.js';
import { defaultOptions } from './options.js';
import runTask from './task.js';
import { getValue, applyPayloadTemplate, getStateResult } from './utils.js';
Expand All @@ -17,7 +17,7 @@ const load = (definition, resources = [], overrideOptions = {}) => {
if (options.validateDefinition) {
const stateLint = new StateLint();
const problems = stateLint.validate(definition);

if (problems.length) {
const message = problems.join('\n');
throw new ValidationError(message);
Expand Down Expand Up @@ -53,7 +53,7 @@ const load = (definition, resources = [], overrideOptions = {}) => {
const execute = async (definition, data) => {
let rawInput = data.context.Execution.Input || {};

while (true) {
main: while (true) {
data.context.State.EnteredTime = new Date().toISOString();
const state = definition.States[data.context.State.Name];

Expand Down Expand Up @@ -82,65 +82,93 @@ const execute = async (definition, data) => {
if (state.Type === 'Parallel') {
const effectiveInput = applyPayloadTemplate(stateInput, data, state.Parameters);

const branches = state.Branches.map((branch) => {
const branchData = {
...data,
context: {
...data.context,
Execution: {
...data.context.Execution,
Input: effectiveInput,
},
State: {
...data.context.State,
Name: branch.StartAt,
try {
const branches = state.Branches.map((branch) => {
const branchData = {
...data,
context: {
...data.context,
Execution: {
...data.context.Execution,
Input: effectiveInput,
},
State: {
...data.context.State,
Name: branch.StartAt,
},
},
},
};
};

return execute(branch, branchData);
});
const result = await Promise.all(branches);

return execute(branch, branchData);
});
const result = await Promise.all(branches);
const effectiveResult = applyPayloadTemplate(result, data, state.ResultSelector);

const effectiveResult = applyPayloadTemplate(result, data, state.ResultSelector);
stateResult = getStateResult(rawInput, effectiveResult, state.ResultPath);
} catch (error) {
for (const catcher of state?.Catch || []) {
if (catcher.ErrorEquals.includes(error.name) || catcher.ErrorEquals.includes(ERROR_WILDCARD)) {
rawInput = getStateResult(rawInput, error.toErrorOutput(), catcher.ResultPath);

stateResult = getStateResult(rawInput, effectiveResult, state.ResultPath);
data.context.State.Name = catcher.Next;

continue main;
}
}

throw error;
}
}

if (state.Type === 'Map') {
const effectiveInput = applyPayloadTemplate(stateInput, data, state.Parameters);

const items = getValue(effectiveInput, state.ItemsPath);

const executions = items.map((Value, Index) => {
const itemData = {
...data,
context: {
...data.context,
Execution: {
...data.context.Execution,
Input: Value,
},
State: {
...data.context.State,
Name: state.ItemProcessor.StartAt,
},
Map: {
Item: {
Index,
Value,
try {
const executions = items.map((Value, Index) => {
const itemData = {
...data,
context: {
...data.context,
Execution: {
...data.context.Execution,
Input: Value,
},
State: {
...data.context.State,
Name: state.ItemProcessor.StartAt,
},
Map: {
Item: {
Index,
Value,
},
},
},
},
};
};

return execute(state.ItemProcessor, itemData);
});
const result = await Promise.all(executions);

const effectiveResult = applyPayloadTemplate(result, data, state.ResultSelector);

stateResult = getStateResult(rawInput, effectiveResult, state.ResultPath);
} catch (error) {
for (const catcher of state?.Catch || []) {
if (catcher.ErrorEquals.includes(error.name) || catcher.ErrorEquals.includes(ERROR_WILDCARD)) {
rawInput = getStateResult(rawInput, error.toErrorOutput(), catcher.ResultPath);

return execute(state.ItemProcessor, itemData);
});
const result = await Promise.all(executions);
data.context.State.Name = catcher.Next;

const effectiveResult = applyPayloadTemplate(result, data, state.ResultSelector);
continue main;
}
}

stateResult = getStateResult(rawInput, effectiveResult, state.ResultPath);
throw error;
}
}

if (state.Type === 'Wait') {
Expand All @@ -161,11 +189,25 @@ const execute = async (definition, data) => {
if (state.Type === 'Task') {
const effectiveInput = applyPayloadTemplate(stateInput, data, state.Parameters);

const result = await runTask(state, data, effectiveInput);
try {
const result = await runTask(state, data, effectiveInput);

const effectiveResult = applyPayloadTemplate(result, data, state.ResultSelector);

stateResult = getStateResult(rawInput, effectiveResult, state.ResultPath);
} catch (error) {
for (const catcher of state?.Catch || []) {
if (catcher.ErrorEquals.includes(error.name) || catcher.ErrorEquals.includes(ERROR_WILDCARD)) {
rawInput = getStateResult(rawInput, error.toErrorOutput(), catcher.ResultPath);

const effectiveResult = applyPayloadTemplate(result, data, state.ResultSelector);
data.context.State.Name = catcher.Next;

stateResult = getStateResult(rawInput, effectiveResult, state.ResultPath);
continue main;
}
}

throw error;
}
}

if (state.Type === 'Pass') {
Expand Down
140 changes: 139 additions & 1 deletion tests/index.test.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { vi, describe, test, expect } from 'vitest';
import { FailError, ValidationError } from '../src/errors.js';
import { FailError, TaskFailedError, ValidationError } from '../src/errors.js';
import { load } from '../src/index.js';

describe('Pass', () => {
Expand Down Expand Up @@ -230,6 +230,144 @@ test('executes a Wait step', async () => {
expect(result).toEqual({ someString: 'hello' });
});

describe('Catch', () => {
test('catches a matching error', async () => {
const definition = {
StartAt: 'TaskStep',
States: {
TaskStep: {
Type: 'Task',
Resource: 'arn:aws:lambda:::function:my-function',
Catch: [
{
ErrorEquals: [
'States.SomeOtherError',
'States.TaskFailed',
],
ResultPath: '$.error',
Next: 'CaughtStep',
},
],
End: true,
},
CaughtStep: {
Type: 'Succeed',
},
},
};

const resources = [
{
service: 'lambda',
name: 'my-function',
function: () => {
throw new Error('Oh no!');
},
},
];

const stateMachine = load(definition, resources);
const result = await stateMachine.execute({ someKey: 'someValue' });

expect(result).toEqual({
someKey: 'someValue',
error: {
Error: 'States.TaskFailed',
Cause: 'Error: Oh no!'
},
});
});

test('catches any error with a wildcard', async () => {
const definition = {
StartAt: 'TaskStep',
States: {
TaskStep: {
Type: 'Task',
Resource: 'arn:aws:lambda:::function:my-function',
Catch: [
{
ErrorEquals: [
'States.SomeOtherError',
],
Next: 'CaughtOtherStep',
},
{
ErrorEquals: [
'States.ALL',
],
Next: 'CaughtStep',
},
],
End: true,
},
CaughtOtherStep: {
Type: 'Fail',
},
CaughtStep: {
Type: 'Succeed',
},
},
};

const resources = [
{
service: 'lambda',
name: 'my-function',
function: () => {
throw new Error('Oh no!');
},
},
];

const stateMachine = load(definition, resources);
const result = await stateMachine.execute({ someKey: 'someValue' });

expect(result).toEqual({
Error: 'States.TaskFailed',
Cause: 'Error: Oh no!'
});
});

test('throws again if no catchers match the error', async () => {
const definition = {
StartAt: 'TaskStep',
States: {
TaskStep: {
Type: 'Task',
Resource: 'arn:aws:lambda:::function:my-function',
Catch: [
{
ErrorEquals: [
'States.SomeOtherError',
],
Next: 'CaughtStep',
},
],
End: true,
},
CaughtStep: {
Type: 'Succeed',
},
},
};

const resources = [
{
service: 'lambda',
name: 'my-function',
function: () => {
throw new Error('Oh no!');
},
},
];

const stateMachine = load(definition, resources);

expect(() => stateMachine.execute({ someKey: 'someValue' })).rejects.toThrowError(TaskFailedError);
});
});

test('throws a ValidationError for an invalid definition', () => {
const invalidDefinition = {
StartAt: 'NonexistentState',
Expand Down

0 comments on commit 74d4802

Please sign in to comment.