diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 7e321dd153..60a69187bf 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -632,6 +632,93 @@ def list_signals( s = resp.signals return s + def approve(self, signal_id: str, execution_name: str, project: str = None, domain: str = None): + """ + :param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call. + :param execution_name: The name of the execution. This is the tail-end of the URL when looking + at the workflow execution. + :param project: The execution project, will default to the Remote's default project. + :param domain: The execution domain, will default to the Remote's default domain. + """ + + wf_exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) + + lt = TypeEngine.to_literal_type(bool) + true_literal = TypeEngine.to_literal(self.context, True, bool, lt) + + req = SignalSetRequest( + id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=true_literal.to_flyte_idl() + ) + + # Response is empty currently, nothing to give back to the user. + self.client.set_signal(req) + + def reject(self, signal_id: str, execution_name: str, project: str = None, domain: str = None): + """ + :param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call. + :param execution_name: The name of the execution. This is the tail-end of the URL when looking + at the workflow execution. + :param project: The execution project, will default to the Remote's default project. + :param domain: The execution domain, will default to the Remote's default domain. + """ + + wf_exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) + + lt = TypeEngine.to_literal_type(bool) + false_literal = TypeEngine.to_literal(self.context, False, bool, lt) + + req = SignalSetRequest( + id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=false_literal.to_flyte_idl() + ) + + # Response is empty currently, nothing to give back to the user. + self.client.set_signal(req) + + def set_input( + self, + signal_id: str, + execution_name: str, + value: typing.Union[literal_models.Literal, typing.Any], + project=None, + domain=None, + python_type=None, + literal_type=None, + ): + """ + :param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call. + :param execution_name: The name of the execution. This is the tail-end of the URL when looking + at the workflow execution. + :param value: This is either a Literal or a Python value which FlyteRemote will invoke the TypeEngine to + convert into a Literal. This argument is only value for wait_for_input type signals. + :param project: The execution project, will default to the Remote's default project. + :param domain: The execution domain, will default to the Remote's default domain. + :param python_type: Provide a python type to help with conversion if the value you provided is not a Literal. + :param literal_type: Provide a Flyte literal type to help with conversion if the value you provided + is not a Literal + """ + + wf_exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) + if isinstance(value, Literal): + logger.debug(f"Using provided {value} as existing Literal value") + lit = value + else: + lt = literal_type or ( + TypeEngine.to_literal_type(python_type) if python_type else TypeEngine.to_literal_type(type(value)) + ) + lit = TypeEngine.to_literal(self.context, value, python_type or type(value), lt) + logger.debug(f"Converted {value} to literal {lit} using literal type {lt}") + + req = SignalSetRequest(id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=lit.to_flyte_idl()) + + # Response is empty currently, nothing to give back to the user. + self.client.set_signal(req) + def set_signal( self, signal_id: str, diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index a559ffe09f..82c18b3c50 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -866,3 +866,36 @@ def test_attr_access_sd(): url = urlparse(remote_file_path) bucket, key = url.netloc, url.path.lstrip("/") file_transfer.delete_file(bucket=bucket, key=key) + +def test_signal_approve_reject(register): + from flytekit.models.types import LiteralType, SimpleType + from time import sleep + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + conditional_wf = remote.fetch_workflow(name="basic.signal_test.signal_test_wf", version=VERSION) + + execution = remote.execute(conditional_wf, inputs={"data": [1.0, 2.0, 3.0, 4.0, 5.0]}) + + def retry_operation(operation): + max_retries = 10 + for _ in range(max_retries): + try: + operation() + break + except Exception: + sleep(1) + + retry_operation(lambda: remote.set_input("title-input", execution.id.name, value="my report", project=PROJECT, domain=DOMAIN, python_type=str, literal_type=LiteralType(simple=SimpleType.STRING))) + retry_operation(lambda: remote.approve("review-passes", execution.id.name, project=PROJECT, domain=DOMAIN)) + + remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + assert execution.outputs["o0"] == {"title": "my report", "data": [1.0, 2.0, 3.0, 4.0, 5.0]} + + with pytest.raises(FlyteAssertion, match="Outputs could not be found because the execution ended in failure"): + execution = remote.execute(conditional_wf, inputs={"data": [1.0, 2.0, 3.0, 4.0, 5.0]}) + + retry_operation(lambda: remote.set_input("title-input", execution.id.name, value="my report", project=PROJECT, domain=DOMAIN, python_type=str, literal_type=LiteralType(simple=SimpleType.STRING))) + retry_operation(lambda: remote.reject("review-passes", execution.id.name, project=PROJECT, domain=DOMAIN)) + + remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + assert execution.outputs["o0"] == {"title": "my report", "data": [1.0, 2.0, 3.0, 4.0, 5.0]} diff --git a/tests/flytekit/integration/remote/workflows/basic/signal_test.py b/tests/flytekit/integration/remote/workflows/basic/signal_test.py new file mode 100644 index 0000000000..c2771ffdfd --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/signal_test.py @@ -0,0 +1,17 @@ +from datetime import timedelta +from flytekit import task, workflow, wait_for_input, approve, conditional +import typing + +@task +def reporting_wf(title_input: str, data: typing.List[float]) -> dict: + return {"title": title_input, "data": data} + +@workflow +def signal_test_wf(data: typing.List[float]) -> dict: + title_input = wait_for_input(name="title-input", timeout=timedelta(hours=1), expected_type=str) + + # Define a "review-passes" approve node so that a human can review + # the title before finalizing it. + approve(upstream_item=title_input, name="review-passes", timeout=timedelta(hours=1)) + + return reporting_wf(title_input, data)