Skip to content

Commit

Permalink
Fixing linting / ci errors (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
Martynas Asipauskas authored and GitHub Enterprise committed Aug 21, 2024
1 parent ebe64c5 commit eb02b82
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 24 deletions.
12 changes: 7 additions & 5 deletions docs/python_airflow_operator.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ Args:


#### template_fields(_: Sequence[str_ _ = ('job_request', 'job_set_prefix'_ )

#### template_fields_renderers(_: Dict[str, str_ _ = {'job_request': 'py'_ )
Initializes a new ArmadaOperator.


Expand Down Expand Up @@ -225,7 +227,7 @@ Bases: `object`



### _class_ armada.model.RunningJobContext(armada_queue: 'str', job_id: 'str', job_set_id: 'str', cluster: 'Optional[str]' = None, start_time: 'DateTime' = DateTime(2024, 8, 21, 14, 36, 0, 499682, tzinfo=Timezone('UTC')), last_log_time: 'Optional[DateTime]' = None, job_state: 'str' = 'UNKNOWN')
### _class_ armada.model.RunningJobContext(armada_queue: 'str', job_id: 'str', job_set_id: 'str', submit_time: 'DateTime', cluster: 'Optional[str]' = None, last_log_time: 'Optional[DateTime]' = None, job_state: 'str' = 'UNKNOWN')
Bases: `object`


Expand All @@ -241,10 +243,10 @@ Bases: `object`
* **job_set_id** (*str*) –


* **cluster** (*str** | **None*) –
* **submit_time** (*DateTime*) –


* **start_time** (*DateTime*) –
* **cluster** (*str** | **None*) –


* **last_log_time** (*DateTime** | **None*) –
Expand All @@ -266,6 +268,6 @@ Bases: `object`

#### last_log_time(_: DateTime | Non_ _ = Non_ )

#### start_time(_: DateTim_ _ = DateTime(2024, 8, 21, 14, 36, 0, 499682, tzinfo=Timezone('UTC')_ )

#### _property_ state(_: JobStat_ )

#### submit_time(_: DateTim_ )
9 changes: 5 additions & 4 deletions third_party/airflow/armada/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from armada_client.armada.submit_pb2 import JobSubmitRequestItem
from armada_client.client import ArmadaClient
from armada_client.typings import JobState
from pendulum import DateTime

from .model import RunningJobContext

Expand All @@ -33,12 +34,12 @@ def client_for(args: GrpcChannelArgs) -> ArmadaClient:
with ArmadaClientFactory.CLIENTS_LOCK:
if channel_args_key not in ArmadaClientFactory.CLIENTS:
ArmadaClientFactory.CLIENTS[channel_args_key] = ArmadaClient(
channel=ArmadaClientFactory.create_channel(args)
channel=ArmadaClientFactory._create_channel(args)
)
return ArmadaClientFactory.CLIENTS[channel_args_key]

@staticmethod
def create_channel(args: GrpcChannelArgs) -> grpc.Channel:
def _create_channel(args: GrpcChannelArgs) -> grpc.Channel:
if args.auth is None:
return grpc.insecure_channel(
target=args.target, options=args.options, compression=args.compression
Expand Down Expand Up @@ -97,9 +98,9 @@ def submit_job(
if job.error:
raise AirflowException(f"Error submitting job to Armada: {job.error}")

return RunningJobContext(queue, job.job_id, job_set_id, None)
return RunningJobContext(queue, job.job_id, job_set_id, DateTime.utcnow())

def update_context(
def refresh_context(
self, job_context: RunningJobContext, tracking_url: str
) -> RunningJobContext:
response = self.client.get_job_status([job_context.job_id])
Expand Down
3 changes: 1 addition & 2 deletions third_party/airflow/armada/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(
options: Optional[Sequence[Tuple[str, Any]]] = None,
compression: Optional[grpc.Compression] = None,
auth: Optional[grpc.AuthMetadataPlugin] = None,
# auth_details: Optional[Dict[str, Any]] = None,
):
self.target = target
self.options = options
Expand Down Expand Up @@ -59,8 +58,8 @@ class RunningJobContext:
armada_queue: str
job_id: str
job_set_id: str
submit_time: DateTime
cluster: Optional[str] = None
start_time: DateTime = DateTime.utcnow()
last_log_time: Optional[DateTime] = None
job_state: str = JobState.UNKNOWN.name

Expand Down
5 changes: 3 additions & 2 deletions third_party/airflow/armada/operators/armada.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def _reattach_or_submit_job(
armada_queue=existing_run["armada_queue"],
job_id=existing_run["armada_job_id"],
job_set_id=existing_run["armada_job_set_id"],
submit_time=DateTime.utcnow(),
)

# We haven't got a running job, submit a new one and persist state to xcom.
Expand Down Expand Up @@ -276,15 +277,15 @@ def _running_job_terminated(self, context: RunningJobContext):
def _not_acknowledged_within_timeout(self) -> bool:
if self.job_context.state == JobState.UNKNOWN:
if (
DateTime.utcnow().diff(self.job_context.start_time).in_seconds()
DateTime.utcnow().diff(self.job_context.submit_time).in_seconds()
> self.job_acknowledgement_timeout
):
return True
return False

@log_exceptions
def _check_job_status_and_fetch_logs(self) -> None:
self.job_context = self.hook.update_context(
self.job_context = self.hook.refresh_context(
self.job_context, self._trigger_tracking_message(self.job_context.job_id)
)

Expand Down
18 changes: 9 additions & 9 deletions third_party/airflow/test/operators/test_armada.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def default_hook() -> MagicMock:
mock = MagicMock()
job_context = running_job_context()
mock.submit_job.return_value = job_context
mock.update_context.return_value = dataclasses.replace(
mock.refresh_context.return_value = dataclasses.replace(
job_context, job_state=JobState.SUCCEEDED.name, cluster=DEFAULT_CLUSTER
)
mock.cancel_job.return_value = dataclasses.replace(
Expand Down Expand Up @@ -88,15 +88,15 @@ def operator(

def running_job_context(
cluster: str = None,
start_time: DateTime = DateTime.now(),
submit_time: DateTime = DateTime.now(),
job_state: str = JobState.UNKNOWN.name,
) -> RunningJobContext:
return RunningJobContext(
DEFAULT_QUEUE,
DEFAULT_JOB_ID,
DEFAULT_JOB_SET,
submit_time,
cluster,
start_time,
job_state=job_state,
)

Expand All @@ -118,7 +118,7 @@ def running_job_context(
def test_execute(job_states, context):
op = operator(JobSubmitRequestItem())

op.hook.update_context.side_effect = [
op.hook.refresh_context.side_effect = [
running_job_context(cluster="cluster-1", job_state=s.name) for s in job_states
]

Expand All @@ -127,7 +127,7 @@ def test_execute(job_states, context):
op.hook.submit_job.assert_called_once_with(
DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request
)
assert op.hook.update_context.call_count == len(job_states)
assert op.hook.refresh_context.call_count == len(job_states)

# We're not polling for logs
op.pod_manager.fetch_container_logs.assert_not_called()
Expand All @@ -136,7 +136,7 @@ def test_execute(job_states, context):
@patch("pendulum.DateTime.utcnow", return_value=DEFAULT_CURRENT_TIME)
def test_execute_in_deferrable(_, context):
op = operator(JobSubmitRequestItem(), deferrable=True)
op.hook.update_context.side_effect = [
op.hook.refresh_context.side_effect = [
running_job_context(cluster="cluster-1", job_state=s.name)
for s in [JobState.QUEUED, JobState.QUEUED]
]
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_execute_in_deferrable(_, context):
def test_execute_fail(terminal_state, context):
op = operator(JobSubmitRequestItem())

op.hook.update_context.side_effect = [
op.hook.refresh_context.side_effect = [
running_job_context(cluster="cluster-1", job_state=s.name)
for s in [JobState.RUNNING, terminal_state]
]
Expand All @@ -179,7 +179,7 @@ def test_execute_fail(terminal_state, context):
op.hook.submit_job.assert_called_once_with(
DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request
)
assert op.hook.update_context.call_count == 2
assert op.hook.refresh_context.call_count == 2

# We're not polling for logs
op.pod_manager.fetch_container_logs.assert_not_called()
Expand All @@ -200,7 +200,7 @@ def test_on_kill_terminates_running_job():
def test_not_acknowledged_within_timeout_terminates_running_job(context):
job_context = running_job_context()
op = operator(JobSubmitRequestItem(), job_acknowledgement_timeout_s=-1)
op.hook.update_context.return_value = job_context
op.hook.refresh_context.return_value = job_context

with pytest.raises(AirflowException) as exec_info:
op.execute(context)
Expand Down
2 changes: 1 addition & 1 deletion third_party/airflow/test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def test_roundtrip_running_job_context():
"queue_123",
"job_id_123",
"job_set_id_123",
"cluster-1.armada.localhost",
DateTime.utcnow(),
"cluster-1.armada.localhost",
DateTime.utcnow().add(minutes=-2),
JobState.RUNNING.name,
)
Expand Down
2 changes: 1 addition & 1 deletion third_party/airflow/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ allowlist_externals =
find
xargs
commands =
coverage run -m unittest discover
coverage run -m pytest test/
coverage xml
# This executes the dag files in examples but really only checks for imports and python errors
bash -c "find examples/ -maxdepth 1 -type f -name *.py | xargs python3"
Expand Down

0 comments on commit eb02b82

Please sign in to comment.