Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations to JobRunnerMapper and related code #19115

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions lib/galaxy/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def __init__(self, app: MinimalManagerApp):
"""Parse the job configuration XML."""
self.app = app
self.runner_plugins = []
self.dynamic_params = None
self.dynamic_params: Optional[Dict[str, Any]] = None
self.handlers = {}
self.handler_runner_plugins = {}
self.default_handler_id = None
Expand Down Expand Up @@ -432,7 +432,7 @@ def _configure_from_dict(self, job_config_dict):
continue
self.runner_plugins.append(runner_info)
if "dynamic" in job_config_dict:
self.dynamic_params = job_config_dict.get("dynamic", None)
self.dynamic_params = job_config_dict["dynamic"]

# Parse handlers
handling_config_dict = job_config_dict.get("handling", {})
Expand Down Expand Up @@ -830,12 +830,12 @@ def get_destinations(self, id_or_tag) -> Iterable[JobDestination]:
"""
return self.destinations.get(id_or_tag, [])

def get_job_runner_plugins(self, handler_id):
def get_job_runner_plugins(self, handler_id: str):
"""Load all configured job runner plugins

:returns: list of job runner plugins
"""
rval = {}
rval: Dict[str, BaseJobRunner] = {}
if handler_id in self.handler_runner_plugins:
plugins_to_load = [rp for rp in self.runner_plugins if rp["id"] in self.handler_runner_plugins[handler_id]]
log.info(
Expand Down Expand Up @@ -871,11 +871,9 @@ def get_job_runner_plugins(self, handler_id):
# If the name included a '.' or loading from the static runners path failed, try the original name
module = __import__(load)
module_name = load
if module is None:
# Module couldn't be loaded, error should have already been displayed
continue
for comp in module_name.split(".")[1:]:
module = getattr(module, comp)
assert module # make mypy happy
if not class_names:
# If there's not a ':', we check <module>.__all__ for class names
try:
Expand Down
7 changes: 4 additions & 3 deletions lib/galaxy/jobs/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class StopSignalException(Exception):
class BaseJobHandlerQueue(Monitors):
STOP_SIGNAL = object()

def __init__(self, app: MinimalManagerApp, dispatcher):
def __init__(self, app: MinimalManagerApp, dispatcher: "DefaultJobDispatcher"):
"""
Initializes the Queue, creates (unstarted) monitoring thread.
"""
Expand Down Expand Up @@ -309,7 +309,8 @@ def __check_jobs_at_startup(self):
with transaction(session):
session.commit()

def _check_job_at_startup(self, job):
def _check_job_at_startup(self, job: model.Job):
assert job.tool_id is not None
if not self.app.toolbox.has_tool(job.tool_id, job.tool_version, exact=True):
log.warning(f"({job.id}) Tool '{job.tool_id}' removed from tool config, unable to recover job")
self.job_wrapper(job).fail(
Expand Down Expand Up @@ -1207,7 +1208,7 @@ def start(self):
for runner in self.job_runners.values():
runner.start()

def url_to_destination(self, url):
def url_to_destination(self, url: str):
"""This is used by the runner mapper (a.k.a. dynamic runner) and
recovery methods to have runners convert URLs to destinations.

Expand Down
20 changes: 18 additions & 2 deletions lib/galaxy/jobs/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@
import logging
from inspect import getfullargspec
from types import ModuleType
from typing import (
Callable,
TYPE_CHECKING,
)

import galaxy.jobs.rules
from galaxy.jobs import stock_rules
from galaxy.jobs.dynamic_tool_destination import map_tool_to_destination
from galaxy.util.submodules import import_submodules
from .rule_helper import RuleHelper

if TYPE_CHECKING:
from galaxy.jobs import (
JobConfiguration,
JobDestination,
JobWrapper,
)

log = logging.getLogger(__name__)

DYNAMIC_RUNNER_NAME = "dynamic"
Expand Down Expand Up @@ -52,7 +63,12 @@ class JobRunnerMapper:

rules_module: ModuleType

def __init__(self, job_wrapper, url_to_destination, job_config):
def __init__(
self,
job_wrapper: "JobWrapper",
url_to_destination: Callable[[str], "JobDestination"],
job_config: "JobConfiguration",
):
self.job_wrapper = job_wrapper
self.url_to_destination = url_to_destination
self.job_config = job_config
Expand Down Expand Up @@ -129,7 +145,7 @@ def __job_params(self, job):
param_values = job.get_param_values(app, ignore_errors=True)
return param_values

def __convert_url_to_destination(self, url):
def __convert_url_to_destination(self, url: str):
"""
Job runner URLs are deprecated, but dynamic mapper functions may still
be returning them. Runners are expected to be able to convert these to
Expand Down
9 changes: 6 additions & 3 deletions test/unit/app/jobs/test_mapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import uuid
from typing import cast

from galaxy.jobs import (
HasResourceParameters,
JobConfiguration,
JobDestination,
JobWrapper,
)
from galaxy.jobs.mapper import (
ERROR_MESSAGE_NO_RULE_FUNCTION,
Expand Down Expand Up @@ -134,10 +137,10 @@ def __assert_mapper_errors_with_message(mapper, message):


def __mapper(tool_job_destination=TOOL_JOB_DESTINATION):
job_wrapper = MockJobWrapper(tool_job_destination)
job_config = MockJobConfig()
job_wrapper = cast(JobWrapper, MockJobWrapper(tool_job_destination))
job_config = cast(JobConfiguration, MockJobConfig())

mapper = JobRunnerMapper(job_wrapper, {}, job_config)
mapper = JobRunnerMapper(job_wrapper, lambda url: JobDestination(), job_config)
mapper.rules_module = test_rules
return mapper

Expand Down
Loading