Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provider package Edge: Edge worker supports queue handling #43115

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions providers/src/airflow/providers/edge/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
Changelog
---------

0.2.0pre0
.........

jscheffl marked this conversation as resolved.
Show resolved Hide resolved
Misc
~~~~

* ``Edge Worker can add or remove queues in the queue field in the DB (#43115)``

0.1.0pre0
.........

Expand Down
24 changes: 12 additions & 12 deletions providers/src/airflow/providers/edge/cli/edge_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,6 @@ def _hostname() -> str:
return os.uname()[1]


def _get_sysinfo() -> dict:
"""Produce the sysinfo from worker to post to central site."""
return {
"airflow_version": airflow_version,
"edge_provider_version": edge_provider_version,
}


def _pid_file_path(pid_file: str | None) -> str:
return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME, pid=pid_file)[0]

Expand Down Expand Up @@ -145,11 +137,19 @@ def signal_handler(sig, frame):
logger.info("Request to show down Edge Worker received, waiting for jobs to complete.")
_EdgeWorkerCli.drain = True

def _get_sysinfo(self) -> dict:
"""Produce the sysinfo from worker to post to central site."""
return {
"airflow_version": airflow_version,
"edge_provider_version": edge_provider_version,
"concurrency": self.concurrency,
}

def start(self):
"""Start the execution in a loop until terminated."""
try:
self.last_hb = EdgeWorker.register_worker(
self.hostname, EdgeWorkerState.STARTING, self.queues, _get_sysinfo()
self.hostname, EdgeWorkerState.STARTING, self.queues, self._get_sysinfo()
).last_update
except AirflowException as e:
if "404:NOT FOUND" in str(e):
Expand All @@ -162,7 +162,7 @@ def start(self):
self.loop()

logger.info("Quitting worker, signal being offline.")
EdgeWorker.set_state(self.hostname, EdgeWorkerState.OFFLINE, 0, _get_sysinfo())
EdgeWorker.set_state(self.hostname, EdgeWorkerState.OFFLINE, 0, self._get_sysinfo())
finally:
remove_existing_pidfile(self.pid_file_path)

Expand Down Expand Up @@ -230,8 +230,8 @@ def heartbeat(self) -> None:
if self.jobs
else EdgeWorkerState.IDLE
)
sysinfo = _get_sysinfo()
EdgeWorker.set_state(self.hostname, state, len(self.jobs), sysinfo)
sysinfo = self._get_sysinfo()
self.queues = EdgeWorker.set_state(self.hostname, state, len(self.jobs), sysinfo)

def interruptible_sleep(self):
"""Sleeps but stops sleeping if drain is made."""
Expand Down
55 changes: 51 additions & 4 deletions providers/src/airflow/providers/edge/models/edge_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import ast
import json
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -71,7 +72,7 @@ class EdgeWorkerModel(Base, LoggingMixin):
__tablename__ = "edge_worker"
worker_name = Column(String(64), primary_key=True, nullable=False)
state = Column(String(20))
queues = Column(String(256))
_queues = Column("queues", String(256))
first_online = Column(UtcDateTime)
last_update = Column(UtcDateTime)
jobs_active = Column(Integer, default=0)
Expand All @@ -90,7 +91,7 @@ def __init__(
):
self.worker_name = worker_name
self.state = state
self.queues = ", ".join(queues) if queues else None
self.queues = queues
self.first_online = first_online or timezone.utcnow()
self.last_update = last_update
super().__init__()
Expand All @@ -99,6 +100,33 @@ def __init__(
def sysinfo_json(self) -> dict:
return json.loads(self.sysinfo) if self.sysinfo else None

@property
def queues(self) -> list[str] | None:
"""Return list of queues which are stored in queues field."""
if self._queues:
return ast.literal_eval(self._queues)
return None

@queues.setter
def queues(self, queues: list[str] | None) -> None:
"""Set all queues of list into queues field."""
self._queues = str(queues) if queues else None

def add_queues(self, new_queues: list[str]) -> None:
"""Add new queue to the queues field."""
queues = self.queues if self.queues else []
queues.extend(new_queues)
# remove duplicated items
self.queues = list(set(queues))

def remove_queues(self, remove_queues: list[str]) -> None:
"""Remove queue from queues field."""
queues = self.queues if self.queues else []
for queue_name in remove_queues:
if queue_name in queues:
queues.remove(queue_name)
self.queues = queues


class EdgeWorker(BaseModel, LoggingMixin):
"""Accessor for Edge Worker instances as logical model."""
Expand Down Expand Up @@ -168,7 +196,7 @@ def register_worker(
return EdgeWorker(
worker_name=worker_name,
state=state,
queues=worker.queues,
queues=queues,
first_online=worker.first_online,
last_update=worker.last_update,
jobs_active=worker.jobs_active or 0,
Expand All @@ -187,7 +215,8 @@ def set_state(
jobs_active: int,
sysinfo: dict[str, str],
session: Session = NEW_SESSION,
):
) -> list[str] | None:
"""Set state of worker and returns the current assigned queues."""
EdgeWorker.assert_version(sysinfo)
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
Expand All @@ -196,6 +225,24 @@ def set_state(
worker.sysinfo = json.dumps(sysinfo)
worker.last_update = timezone.utcnow()
session.commit()
return worker.queues

@staticmethod
@provide_session
def add_and_remove_queues(
worker_name: str,
new_queues: list[str] | None = None,
remove_queues: list[str] | None = None,
session: Session = NEW_SESSION,
) -> None:
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
if new_queues:
worker.add_queues(new_queues)
if remove_queues:
worker.remove_queues(remove_queues)
session.add(worker)
session.commit()


EdgeWorker.model_rebuild()
Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ state: not-ready
source-date-epoch: 1720863625
# note that those versions are maintained by release manager - do not update them manually
versions:
- 0.1.0pre0
- 0.2.0pre0

dependencies:
- apache-airflow>=2.10.0
Expand Down
21 changes: 14 additions & 7 deletions providers/tests/edge/cli/test_edge_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from airflow.exceptions import AirflowException
from airflow.providers.edge.cli.edge_command import (
_EdgeWorkerCli,
_get_sysinfo,
_Job,
)
from airflow.providers.edge.models.edge_job import EdgeJob
Expand All @@ -42,12 +41,6 @@
# mypy: disable-error-code="attr-defined"


def test_get_sysinfo():
sysinfo = _get_sysinfo()
assert "airflow_version" in sysinfo
assert "edge_provider_version" in sysinfo


class TestEdgeWorkerCli:
@pytest.fixture
def dummy_joblist(self, tmp_path: Path) -> list[_Job]:
Expand Down Expand Up @@ -208,9 +201,14 @@ def test_heartbeat(self, mock_set_state, drain, jobs, expected_state, worker_wit
if not jobs:
worker_with_job.jobs = []
_EdgeWorkerCli.drain = drain
mock_set_state.return_value = ["queue1", "queue2"]
with conf_vars({("edge", "api_url"): "https://mock.server"}):
worker_with_job.heartbeat()
assert mock_set_state.call_args.args[1] == expected_state
queue_list = worker_with_job.queues or []
assert len(queue_list) == 2
assert "queue1" in (queue_list)
assert "queue2" in (queue_list)

@patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker")
def test_start_missing_apiserver(self, mock_register_worker, worker_with_job: _EdgeWorkerCli):
Expand Down Expand Up @@ -258,3 +256,12 @@ def stop_running():
mock_register_worker.assert_called_once()
mock_loop.assert_called_once()
mock_set_state.assert_called_once()

def test_get_sysinfo(self, worker_with_job: _EdgeWorkerCli):
concurrency = 8
worker_with_job.concurrency = concurrency
sysinfo = worker_with_job._get_sysinfo()
assert "airflow_version" in sysinfo
assert "edge_provider_version" in sysinfo
assert "concurrency" in sysinfo
assert sysinfo["concurrency"] == concurrency
70 changes: 64 additions & 6 deletions providers/tests/edge/models/test_edge_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# under the License.
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import pytest

from airflow.providers.edge.cli.edge_command import _get_sysinfo
from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli
from airflow.providers.edge.models.edge_worker import (
EdgeWorker,
EdgeWorkerModel,
Expand All @@ -36,6 +37,11 @@


class TestEdgeWorker:
@pytest.fixture
def cli_worker(self, tmp_path: Path) -> _EdgeWorkerCli:
test_worker = _EdgeWorkerCli(tmp_path / "dummy.pid", "dummy", None, 8, 5, 5)
return test_worker

@pytest.fixture(autouse=True)
def setup_test_cases(self, session: Session):
session.query(EdgeWorkerModel).delete()
Expand Down Expand Up @@ -67,28 +73,80 @@ def test_assert_version(self):
{"airflow_version": airflow_version, "edge_provider_version": edge_provider_version}
)

def test_register_worker(self, session: Session):
@pytest.mark.parametrize(
"input_queues",
[
pytest.param(None, id="empty-queues"),
pytest.param(["default", "default2"], id="with-queues"),
],
)
def test_register_worker(
self, session: Session, input_queues: list[str] | None, cli_worker: _EdgeWorkerCli
):
EdgeWorker.register_worker(
"test_worker", EdgeWorkerState.STARTING, queues=None, sysinfo=_get_sysinfo()
"test_worker", EdgeWorkerState.STARTING, queues=input_queues, sysinfo=cli_worker._get_sysinfo()
)

worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
assert len(worker) == 1
assert worker[0].worker_name == "test_worker"
if input_queues:
assert worker[0].queues == input_queues
else:
assert worker[0].queues is None

def test_set_state(self, session: Session):
def test_set_state(self, session: Session, cli_worker: _EdgeWorkerCli):
queues = ["default", "default2"]
rwm = EdgeWorkerModel(
worker_name="test2_worker",
state=EdgeWorkerState.IDLE,
queues=["default"],
queues=queues,
first_online=timezone.utcnow(),
)
session.add(rwm)
session.commit()

EdgeWorker.set_state("test2_worker", EdgeWorkerState.RUNNING, 1, _get_sysinfo())
return_queues = EdgeWorker.set_state(
"test2_worker", EdgeWorkerState.RUNNING, 1, cli_worker._get_sysinfo()
)

worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
assert len(worker) == 1
assert worker[0].worker_name == "test2_worker"
assert worker[0].state == EdgeWorkerState.RUNNING
assert worker[0].queues == queues
assert return_queues == ["default", "default2"]

@pytest.mark.parametrize(
"add_queues, remove_queues, expected_queues",
[
pytest.param(None, None, ["init"], id="no-changes"),
pytest.param(
["queue1", "queue2"], ["queue1", "queue_not_existing"], ["init", "queue2"], id="add-remove"
),
pytest.param(["init"], None, ["init"], id="check-duplicated"),
],
)
def test_add_and_remove_queues(
self,
session: Session,
add_queues: list[str] | None,
remove_queues: list[str] | None,
expected_queues: list[str],
cli_worker: _EdgeWorkerCli,
):
rwm = EdgeWorkerModel(
worker_name="test2_worker",
state=EdgeWorkerState.IDLE,
queues=["init"],
first_online=timezone.utcnow(),
)
session.add(rwm)
session.commit()
EdgeWorker.add_and_remove_queues("test2_worker", add_queues, remove_queues, session)
worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
assert len(worker) == 1
assert worker[0].worker_name == "test2_worker"
assert len(expected_queues) == len(worker[0].queues or [])
for expected_queue in expected_queues:
assert expected_queue in (worker[0].queues or [])