Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Refactor: Moved responsibilities of JobRun to JobBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
knakazawa99 committed Mar 17, 2024
1 parent 32a43c0 commit f3c7d2c
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 84 deletions.
84 changes: 38 additions & 46 deletions prefect_aws/glue_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Optional

from prefect.blocks.abstract import JobBlock, JobRun
from prefect.utilities.asyncutils import run_sync_in_worker_thread
from pydantic import VERSION as PYDANTIC_VERSION
from pydantic import BaseModel

Expand All @@ -29,11 +28,12 @@ class GlueJobRun(JobRun, BaseModel):
description="The name of the job definition to use.",
)

arguments: Optional[dict] = Field(
default=None,
title="AWS Glue Job Arguments",
description="The job arguments associated with this run.",
job_id: str = Field(
...,
title="AWS Glue Job ID",
description="The ID of the job run.",
)

job_watch_poll_interval: float = Field(
default=60.0,
description=(
Expand All @@ -50,44 +50,14 @@ class GlueJobRun(JobRun, BaseModel):
description="The AWS credentials to use to connect to Glue.",
)

client: Any = Field(default=None, description="")
job_id: str = Field(
default="",
)

async def wait_for_completion(self) -> None:
"""run and wait for completion"""
await run_sync_in_worker_thread(self._get_client)
await run_sync_in_worker_thread(self._start_job)
await run_sync_in_worker_thread(self._watch_job)
client: _GlueJobClient = Field(default=None, description="")

async def fetch_result(self) -> str:
"""fetch glue job result"""
"""fetch glue job state"""
job = self._get_job_run()
return job["JobRun"]["JobRunState"]

def _start_job(self) -> str:
"""
Start the AWS Glue Job
[doc](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue/client/start_job_run.html)
"""
self.logger.info(
f"starting job {self.job_name} with arguments {self.arguments}"
)
try:
response = self.client.start_job_run(
JobName=self.job_name,
Arguments=self.arguments,
)
job_run_id = str(response["JobRunId"])
self.logger.info(f"job started with job run id: {job_run_id}")
self.job_id = job_run_id
return job_run_id
except Exception as e:
self.logger.error(f"failed to start job: {e}")
raise RuntimeError

def _watch_job(self) -> None:
def wait_for_completion(self) -> None:
"""
Wait for the job run to complete and get exit code
"""
Expand All @@ -105,13 +75,6 @@ def _watch_job(self) -> None:

time.sleep(self.job_watch_poll_interval)

def _get_client(self) -> None:
"""
Retrieve a Glue Job Client
"""
boto_session = self.aws_credentials.get_boto3_session()
self.client = boto_session.client("glue")

def _get_job_run(self):
"""get glue job"""
return self.client.get_job_run(JobName=self.job_name, RunId=self.job_id)
Expand Down Expand Up @@ -188,8 +151,37 @@ def example_run_glue_job():

async def trigger(self) -> GlueJobRun:
"""trigger for GlueJobRun"""
client = self._get_client()
job_run_id = self._start_job(client)
return GlueJobRun(
job_name=self.job_name,
arguments=self.arguments,
job_id=job_run_id,
job_watch_poll_interval=self.job_watch_poll_interval,
)

def _start_job(self, client: _GlueJobClient) -> str:
"""
Start the AWS Glue Job
[doc](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue/client/start_job_run.html)
"""
self.logger.info(
f"starting job {self.job_name} with arguments {self.arguments}"
)
try:
response = client.start_job_run(
JobName=self.job_name,
Arguments=self.arguments,
)
job_run_id = str(response["JobRunId"])
self.logger.info(f"job started with job run id: {job_run_id}")
return job_run_id
except Exception as e:
self.logger.error(f"failed to start job: {e}")
raise RuntimeError

def _get_client(self) -> _GlueJobClient:
"""
Retrieve a Glue Job Client
"""
boto_session = self.aws_credentials.get_boto3_session()
return boto_session.client("glue")
82 changes: 44 additions & 38 deletions tests/test_glue_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,7 @@ async def test_fetch_result(aws_credentials, glue_job_client):
assert result == "SUCCEEDED"


def test_start_job(aws_credentials, glue_job_client):
with mock_glue():
glue_job_client.create_job(
Name="test_job_name", Role="test-role", Command={}, DefaultArguments={}
)
glue_job_run = GlueJobRun(
job_name="test_job_name", arguments={"arg1": "value1"}
)
glue_job_run.client = glue_job_client
glue_job_run._start_job()
assert glue_job_run.job_id != ""


def test_start_job_fail_because_not_exist_job(aws_credentials, glue_job_client):
with mock_glue():
glue_job_run = GlueJobRun(
job_name="test_job_name", arguments={"arg1": "value1"}
)
glue_job_run.client = glue_job_client
with pytest.raises(RuntimeError):
glue_job_run._start_job()


def test_watch_job(aws_credentials, glue_job_client):
def test_wait_for_completion(aws_credentials, glue_job_client):
with mock_glue():
glue_job_client.create_job(
Name="test_job_name", Role="test-role", Command={}, DefaultArguments={}
Expand Down Expand Up @@ -84,10 +61,10 @@ def test_watch_job(aws_credentials, glue_job_client):
)
glue_job_run.client = glue_job_client
glue_job_run.job_id = job_run_id
glue_job_run._watch_job()
glue_job_run.wait_for_completion()


def test_watch_job_fail(aws_credentials, glue_job_client):
def test_wait_for_completion_fail(aws_credentials, glue_job_client):
with mock_glue():
glue_job_client.create_job(
Name="test_job_name", Role="test-role", Command={}, DefaultArguments={}
Expand All @@ -113,16 +90,7 @@ def test_watch_job_fail(aws_credentials, glue_job_client):
]
)
with pytest.raises(RuntimeError):
glue_job_run._watch_job()


def test_get_client(aws_credentials):
with mock_glue():
glue_job_run = GlueJobRun(
job_name="test_job_name", aws_credentials=aws_credentials
)
glue_job_run._get_client()
assert hasattr(glue_job_run.client, "get_job_run")
glue_job_run.wait_for_completion()


def test__get_job_run(aws_credentials, glue_job_client):
Expand All @@ -144,7 +112,45 @@ def test__get_job_run(aws_credentials, glue_job_client):
assert response["JobRun"]["JobRunState"] == "SUCCEEDED"


async def test_trigger():
glue_job = GlueJobBlock(job_name="test_job_name", arguments={"arg1": "value1"})
async def test_trigger(aws_credentials, glue_job_client):
glue_job_client.create_job(
Name="test_job_name", Role="test-role", Command={}, DefaultArguments={}
)
glue_job = GlueJobBlock(
job_name="test_job_name",
arguments={"arg1": "value1"},
aws_credential=aws_credentials,
)
glue_job._start_job = MagicMock(side_effect=["test_job_id"])
glue_job_run = await glue_job.trigger()
assert isinstance(glue_job_run, GlueJobRun)


def test_start_job(aws_credentials, glue_job_client):
with mock_glue():
glue_job_client.create_job(
Name="test_job_name", Role="test-role", Command={}, DefaultArguments={}
)
glue_job = GlueJobBlock(job_name="test_job_name", arguments={"arg1": "value1"})

glue_job_client.start_job_run = MagicMock(
side_effect=[{"JobRunId": "test_job_run_id"}]
)
job_run_id = glue_job._start_job(glue_job_client)
assert job_run_id == "test_job_run_id"


def test_start_job_fail_because_not_exist_job(aws_credentials, glue_job_client):
with mock_glue():
glue_job = GlueJobBlock(job_name="test_job_name", arguments={"arg1": "value1"})
with pytest.raises(RuntimeError):
glue_job._start_job(glue_job_client)


def test_get_client(aws_credentials):
with mock_glue():
glue_job_run = GlueJobBlock(
job_name="test_job_name", aws_credentials=aws_credentials
)
client = glue_job_run._get_client()
assert hasattr(client, "get_job_run")

0 comments on commit f3c7d2c

Please sign in to comment.