This repository has been archived by the owner on Apr 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
3 changed files
with
301 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
--- | ||
description: Worker integration with the AWS Glue Job | ||
notes: This documentation page is generated from source file docstrings. | ||
--- | ||
|
||
::: prefect_aws.workers.glue_job_worker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import logging | ||
import time | ||
from typing import Any, Optional | ||
|
||
import anyio | ||
from prefect.server.schemas.core import FlowRun | ||
from prefect.utilities.asyncutils import run_sync_in_worker_thread | ||
from prefect.workers.base import ( | ||
BaseJobConfiguration, | ||
BaseVariables, | ||
BaseWorker, | ||
BaseWorkerResult, | ||
) | ||
from pydantic import Field | ||
|
||
from prefect_aws import AwsCredentials | ||
|
||
_GlueJobClient = Any | ||
|
||
|
||
class GlueJobWorkerConfiguration(BaseJobConfiguration): | ||
""" | ||
Job configuration for a Glue Job. | ||
""" | ||
|
||
job_name: str = Field( | ||
..., | ||
title="AWS Glue Job Name", | ||
description="The name of the job definition to use.", | ||
) | ||
arguments: Optional[dict] = Field( | ||
default=None, | ||
title="AWS Glue Job Arguments", | ||
description="The job arguments associated with this run.", | ||
) | ||
job_watch_poll_interval: float = Field( | ||
default=60.0, | ||
description=( | ||
"The amount of time to wait between AWS API calls while monitoring the " | ||
"state of an Glue Job." | ||
), | ||
) | ||
aws_credentials: Optional[AwsCredentials] = Field(default_factory=AwsCredentials) | ||
error_states = ["FAILED", "STOPPED", "ERROR", "TIMEOUT"] | ||
|
||
|
||
class GlueJobWorkerResult(BaseWorkerResult): | ||
""" | ||
The result of Glue job. | ||
""" | ||
|
||
|
||
class GlueJobWorker(BaseWorker): | ||
type = "glue-job" | ||
job_configuration = GlueJobWorkerConfiguration | ||
job_configuration_variables = BaseVariables | ||
_description = "Execute flow runs Glue Job." | ||
_display_name = "AWS Glue Job" | ||
_documentation_url = "https://prefecthq.github.io/prefect-aws/glue_job/" | ||
_logo_url = "https://images.ctfassets.net/gm98wzqotmnx/1jbV4lceHOjGgunX15lUwT/db88e184d727f721575aeb054a37e277/aws.png?h=250" # noqa | ||
|
||
async def run( | ||
self, | ||
flow_run: FlowRun, | ||
configuration: GlueJobWorkerConfiguration, | ||
task_status: Optional[anyio.abc.TaskStatus] = None, | ||
) -> GlueJobWorkerResult: | ||
"""Run the Glue Job.""" | ||
glue_job_client = await run_sync_in_worker_thread( | ||
self._get_client, configuration | ||
) | ||
return await run_sync_in_worker_thread( | ||
self.run_with_client, glue_job_client, configuration | ||
) | ||
|
||
async def run_with_client( | ||
self, | ||
flow_run: FlowRun, | ||
glue_job_client: _GlueJobClient, | ||
configuration: GlueJobWorkerConfiguration, | ||
) -> GlueJobWorkerResult: | ||
"""Run the Glue Job with Glue Client.""" | ||
logger = self.get_flow_run_logger(flow_run) | ||
run_job_id = await run_sync_in_worker_thread( | ||
self._start_job, logger, glue_job_client, configuration | ||
) | ||
exit_code = await run_sync_in_worker_thread( | ||
self._watch_job_and_get_exit_code, | ||
logger, | ||
glue_job_client, | ||
run_job_id, | ||
configuration, | ||
) | ||
return GlueJobWorkerResult(identifier=run_job_id, status_code=exit_code) | ||
|
||
@staticmethod | ||
def _get_client(configuration: GlueJobWorkerConfiguration) -> _GlueJobClient: | ||
""" | ||
Retrieve a Glue Job Client | ||
""" | ||
boto_session = configuration.aws_credentials.get_boto3_session() | ||
return boto_session.client("glue") | ||
|
||
@staticmethod | ||
def _start_job( | ||
logger: logging.Logger, | ||
glue_job_client: _GlueJobClient, | ||
configuration: GlueJobWorkerConfiguration, | ||
) -> str: | ||
""" | ||
Start the AWS Glue Job | ||
[doc](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue/client/start_job_run.html) | ||
""" | ||
logger.info( | ||
f"starting job {configuration.job_name} with arguments" | ||
f" {configuration.arguments}" | ||
) | ||
try: | ||
response = glue_job_client.start_job_run( | ||
JobName=configuration.job_name, | ||
Arguments=configuration.arguments, | ||
) | ||
job_run_id = str(response["JobRunId"]) | ||
logger.info(f"job started with job run id: {job_run_id}") | ||
return job_run_id | ||
except Exception as e: | ||
logger.error(f"failed to start job: {e}") | ||
raise RuntimeError | ||
|
||
@staticmethod | ||
def _watch_job_and_get_exit_code( | ||
logger: logging.Logger, | ||
glue_job_client: _GlueJobClient, | ||
job_run_id: str, | ||
configuration: GlueJobWorkerConfiguration, | ||
) -> Optional[int]: | ||
""" | ||
Wait for the job run to complete and get exit code | ||
""" | ||
logger.info(f"watching job {configuration.job_name} with run id {job_run_id}") | ||
exit_code = 0 | ||
while True: | ||
job = glue_job_client.get_job_run( | ||
JobName=configuration.job_name, RunId=job_run_id | ||
) | ||
job_state = job["JobRun"]["JobRunState"] | ||
if job_state in configuration.error_states: | ||
# Generate a dynamic exception type from the AWS name | ||
logger.error(f"job failed: {job['JobRun']['ErrorMessage']}") | ||
raise RuntimeError(job["JobRun"]["ErrorMessage"]) | ||
elif job_state == "SUCCEEDED": | ||
logger.info(f"job succeeded: {job_run_id}") | ||
break | ||
|
||
time.sleep(configuration.job_watch_poll_interval) | ||
return exit_code |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
from unittest.mock import MagicMock | ||
from uuid import uuid4 | ||
|
||
import boto3 | ||
import pytest | ||
from moto import mock_glue | ||
from prefect.server.schemas.core import FlowRun | ||
|
||
from prefect_aws.workers.glue_job_worker import ( | ||
GlueJobWorker, | ||
GlueJobWorkerConfiguration, | ||
GlueJobWorkerResult, | ||
) | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def glue_job_client(aws_credentials): | ||
with mock_glue(): | ||
yield boto3.client("glue", region_name="us-east-1") | ||
|
||
|
||
@pytest.fixture | ||
def flow_run(): | ||
return FlowRun(flow_id=uuid4(), deployment_id=uuid4()) | ||
|
||
|
||
def test_get_client(aws_credentials): | ||
with mock_glue(): | ||
job_worker_configuration = GlueJobWorkerConfiguration( | ||
job_name="test_glue_job_name" | ||
) | ||
glue_job_worker = GlueJobWorker(work_pool_name="test") | ||
glue_client = glue_job_worker._get_client(job_worker_configuration) | ||
assert hasattr(glue_client, "get_job_run") | ||
|
||
|
||
async def test_start_job(aws_credentials, glue_job_client, flow_run): | ||
with mock_glue(): | ||
glue_job_client.create_job( | ||
Name="test_job_name", Role="test-role", Command={}, DefaultArguments={} | ||
) | ||
|
||
job_worker_configuration = GlueJobWorkerConfiguration( | ||
job_name="test_job_name", arguments={} | ||
) | ||
async with GlueJobWorker(work_pool_name="test") as worker: | ||
logger = worker.get_flow_run_logger(flow_run) | ||
res_job_id = worker._start_job( | ||
logger, glue_job_client, job_worker_configuration | ||
) | ||
assert res_job_id == "01" | ||
|
||
|
||
async def test_start_job_fail_because_not_exist_job( | ||
aws_credentials, glue_job_client, flow_run | ||
): | ||
with mock_glue(): | ||
job_worker_configuration = GlueJobWorkerConfiguration( | ||
job_name="test_job_name", arguments={} | ||
) | ||
async with GlueJobWorker(work_pool_name="test") as worker: | ||
logger = worker.get_flow_run_logger(flow_run) | ||
with pytest.raises(RuntimeError): | ||
worker._start_job(logger, glue_job_client, job_worker_configuration) | ||
|
||
|
||
async def test_watch_job_and_get_exit_code(aws_credentials, glue_job_client, flow_run): | ||
with mock_glue(): | ||
glue_job_client.create_job( | ||
Name="test_job_name", Role="test-role", Command={}, DefaultArguments={} | ||
) | ||
job_run_id = glue_job_client.start_job_run( | ||
JobName="test_job_name", | ||
Arguments={}, | ||
)["JobRunId"] | ||
|
||
job_worker_configuration = GlueJobWorkerConfiguration( | ||
job_name="test_job_name", arguments={}, job_watch_poll_interval=1.0 | ||
) | ||
async with GlueJobWorker(work_pool_name="test") as worker: | ||
glue_job_client.get_job_run = MagicMock( | ||
side_effect=[ | ||
{"JobRun": {"JobName": "test_job", "JobRunState": "RUNNING"}}, | ||
{"JobRun": {"JobName": "test_job", "JobRunState": "SUCCEEDED"}}, | ||
] | ||
) | ||
logger = worker.get_flow_run_logger(flow_run) | ||
exist_code = worker._watch_job_and_get_exit_code( | ||
logger, glue_job_client, job_run_id, job_worker_configuration | ||
) | ||
assert exist_code == 0 | ||
|
||
|
||
async def test_watch_job_and_get_exit_fail(aws_credentials, glue_job_client, flow_run): | ||
with mock_glue(): | ||
glue_job_client.create_job( | ||
Name="test_job_name", Role="test-role", Command={}, DefaultArguments={} | ||
) | ||
job_run_id = glue_job_client.start_job_run( | ||
JobName="test_job_name", | ||
Arguments={}, | ||
)["JobRunId"] | ||
|
||
job_worker_configuration = GlueJobWorkerConfiguration( | ||
job_name="test_job_name", arguments={}, job_watch_poll_interval=1.0 | ||
) | ||
async with GlueJobWorker(work_pool_name="test") as worker: | ||
glue_job_client.get_job_run = MagicMock( | ||
side_effect=[ | ||
{ | ||
"JobRun": { | ||
"JobName": "test_job_name", | ||
"JobRunState": "FAILED", | ||
"ErrorMessage": "err", | ||
} | ||
}, | ||
] | ||
) | ||
logger = worker.get_flow_run_logger(flow_run) | ||
with pytest.raises(RuntimeError): | ||
worker._watch_job_and_get_exit_code( | ||
logger, glue_job_client, job_run_id, job_worker_configuration | ||
) | ||
|
||
|
||
async def test_run_with_client(aws_credentials, glue_job_client, flow_run): | ||
with mock_glue(): | ||
async with GlueJobWorker(work_pool_name="test") as worker: | ||
glue_job_client.create_job( | ||
Name="test_job_name1", Role="test-role", Command={}, DefaultArguments={} | ||
) | ||
job_worker_configuration = GlueJobWorkerConfiguration( | ||
job_name="test_job_name1", arguments={}, job_watch_poll_interval=1.0 | ||
) | ||
res = await worker.run_with_client( | ||
flow_run, glue_job_client, job_worker_configuration | ||
) | ||
|
||
assert res == GlueJobWorkerResult(identifier="01", status_code=0) |