Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Feature/inherit config #36

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion example/run_child.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
""" This script requires to be defined by user. """
import logging
from typing import List

import os
import fire
import gokart
import luigi
Expand All @@ -9,9 +11,15 @@
logging.basicConfig(level=logging.INFO)


def main(task_pkl_path: str) -> None:
def main(task_pkl_path: str, remote_config_paths: List[str]) -> None:
# Load luigi config
luigi.configuration.LuigiConfigParser.add_config_path("./conf/base.ini")
for remote_config_path in remote_config_paths:
conf = make_target(remote_config_path).load()
# copy to pod local
local_path = os.path.join("./conf", os.path.basename(remote_config_path))
make_target(local_path).dump(conf)
luigi.configuration.LuigiConfigParser.add_config_path(local_path)

# Parse a serialized gokart.TaskOnKart
task: gokart.TaskOnKart = make_target(task_pkl_path).load()
Expand Down
50 changes: 50 additions & 0 deletions kannon/config_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any, Dict, Optional, Type

import gokart
import luigi


class inherits_config_params:

def __init__(self, config_class: luigi.Config, parameter_alias: Optional[Dict[str, str]] = None):
"""
Decorates task to inherit parameter value of `config_class`.

* config_class: Inherit parameter value of this task to decorated task. Only parameter values exist in both tasks are inherited.
* parameter_alias: Dictionary to map paramter names between config_class task and decorated task.
key: config_class's parameter name. value: decorated task's parameter name.
"""

self._config_class: luigi.Config = config_class
self._parameter_alias: Dict[str, str] = parameter_alias if parameter_alias is not None else {}

def __call__(self, task_class: Type[gokart.TaskOnKart]): # type: ignore
# wrap task to prevent task name from being changed
@luigi.task._task_wraps(task_class)
class Wrapped(task_class):
is_decorated_inherits_config_params = True
__config_params: Dict[str, Any] = dict()
__do_injection = True

@classmethod
def inject_config_params(cls) -> None:
cls.__config_params.clear() # clear config param cache
for param_key, param_value in self._config_class().param_kwargs.items():
task_param_key = self._parameter_alias.get(param_key, param_key)
if hasattr(cls, task_param_key):
cls.__config_params[task_param_key] = param_value

@classmethod
def get_param_values(cls, params, args, kwargs): # type: ignore
if cls.__do_injection:
cls.inject_config_params()
for param_key, param_value in cls.__config_params.items():
if hasattr(cls, param_key) and param_key not in kwargs:
kwargs[param_key] = param_value
return super(Wrapped, cls).get_param_values(params, args, kwargs)

@classmethod
def set_injection_flag(cls, flag: bool): # type: ignore
cls.__do_injection = flag

return Wrapped
77 changes: 68 additions & 9 deletions kannon/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Deque, Dict, List, Optional, Set

import gokart
import luigi
from gokart.target import make_target
from kubernetes import client
from luigi.task import flatten
Expand All @@ -27,6 +28,9 @@ def __init__(
job_prefix: str,
path_child_script: str = "./run_child.py",
env_to_inherit: Optional[List[str]] = None,
master_pod_name: Optional[str] = None,
master_pod_uid: Optional[str] = None,
dynamic_config_paths: Optional[List[str]] = None,
) -> None:
# validation
if not os.path.exists(path_child_script):
Expand All @@ -40,25 +44,59 @@ def __init__(
if env_to_inherit is None:
env_to_inherit = ["TASK_WORKSPACE_DIRECTORY"]
self.env_to_inherit = env_to_inherit
self.master_pod_name = master_pod_name
self.master_pod_uid = master_pod_uid
self.dynamic_config_paths = dynamic_config_paths

self.task_id_to_job_name: Dict[str, str] = dict()

def build(self, root_task: gokart.TaskOnKart) -> None:
# check all config file paths exists
# luigi_parser_instance = luigi.configuration.get_config()
# config_paths = luigi_parser_instance._config_paths
if self.dynamic_config_paths:
logger.info("Handling dynamic config files...")
for config_path in self.dynamic_config_paths:
logger.info(f"Config file {config_path} is registered")
# save configs to remote cache
remote_config_dir = os.path.join(os.environ.get("TASK_WORKSPACE_DIRECTORY"), "kannon", "conf")
added_remote_config_paths: List[str] = []
for local_config_path in self.dynamic_config_paths:
if not local_config_path.endswith(".ini"):
logger.warning(f"Format {local_config_path} is not supported, so skipped")
continue
# load local config and save it to remote cache
local_conf_content = make_target(local_config_path).load()
remote_config_path = os.path.join(remote_config_dir, os.path.basename(local_config_path))
make_target(remote_config_path).dump(local_conf_content)
logger.info(f"local config file {local_config_path} is saved at remote {remote_config_path}.")
added_remote_config_paths.append(remote_config_path)
else:
logger.info("No dynamic config files are given.")

# push tasks into queue
logger.info("Creating task queue...")
task_queue = self._create_task_queue(root_task)

# consume task queue
launched_task_ids: Set[str] = set()
logger.info("Consuming task queue...")
while task_queue:
task = task_queue.popleft()
if task.complete():
logger.info(f"Task {self._gen_task_info(task)} is already completed.")
continue
if task.make_unique_id() in launched_task_ids:
if task.make_unique_id() in self.task_id_to_job_name:
# check if task is still running on child job
job_name = self.task_id_to_job_name[task.make_unique_id()]
job_status = get_job_status(
self.api_instance,
job_name,
self.namespace,
)
if job_status == JobStatus.FAILED:
raise RuntimeError(f"Task {self._gen_task_info(task)} on job {job_name} has failed.")
logger.info(f"Task {self._gen_task_info(task)} is still running on child job.")
task_queue.append(task)
task_queue.append(task) # re-enqueue task to check if it is done
continue

# TODO: enable user to specify duration to sleep for each task
Expand All @@ -71,8 +109,7 @@ def build(self, root_task: gokart.TaskOnKart) -> None:
# execute task
if isinstance(task, TaskOnBullet):
logger.info(f"Trying to run task {self._gen_task_info(task)} on child job...")
self._exec_bullet_task(task)
launched_task_ids.add(task.make_unique_id()) # mark as already launched task
self._exec_bullet_task(task, added_remote_config_paths)
task_queue.append(task) # re-enqueue task to check if it is done
elif isinstance(task, gokart.TaskOnKart):
logger.info(f"Executing task {self._gen_task_info(task)} on master job...")
Expand Down Expand Up @@ -113,7 +150,8 @@ def _exec_gokart_task(self, task: gokart.TaskOnKart) -> None:
except Exception:
raise RuntimeError(f"Task {self._gen_task_info(task)} on job master has failed.")

def _exec_bullet_task(self, task: TaskOnBullet) -> None:
def _exec_bullet_task(self, task: TaskOnBullet, remote_config_paths: list) -> None:
logger.info(f"Task on bullet type = {type(task)}")
# Save task instance as pickle object
pkl_path = self._gen_pkl_path(task)
make_target(pkl_path).dump(task)
Expand All @@ -122,19 +160,26 @@ def _exec_bullet_task(self, task: TaskOnBullet) -> None:
job = self._create_child_job_object(
job_name=job_name,
task_pkl_path=pkl_path,
remote_config_paths=remote_config_paths,
)
create_job(self.api_instance, job, self.namespace)
logger.info(f"Created child job {job_name} with task {self._gen_task_info(task)}")
task_unique_id = task.make_unique_id()
self.task_id_to_job_name[task_unique_id] = job_name
self.task_id_to_job_name[task.make_unique_id()] = job_name

def _create_child_job_object(self, job_name: str, task_pkl_path: str) -> client.V1Job:
def _create_child_job_object(
self,
job_name: str,
task_pkl_path: str,
remote_config_paths: str,
) -> client.V1Job:
# TODO: use python -c to avoid dependency to execute_task.py
cmd = [
"python",
self.path_child_script,
"--task-pkl-path",
f"'{task_pkl_path}'",
"--remote-config-path",
remote_config_paths[0], # TODO: support multiple paths
]
job = deepcopy(self.template_job)
# replace command
Expand All @@ -152,6 +197,20 @@ def _create_child_job_object(self, job_name: str, task_pkl_path: str) -> client.
job.spec.template.spec.containers[0].env = child_envs
# replace job name
job.metadata.name = job_name
# add owner reference from child to parent if master pod info is available
if self.master_pod_name and self.master_pod_uid:
owner_reference = client.V1OwnerReference(
api_version="batch/v1",
kind="Pod",
name=self.master_pod_name, # owner pod name
uid=self.master_pod_uid, # owner pod uid
)
if job.metadata.owner_references:
job.metadata.owner_references.append(owner_reference)
else:
job.metadata.owner_references = [owner_reference]
else:
logger.warning("Owner reference is not set because master pod info is not provided.")

return job

Expand Down
Loading
Loading