Skip to content

Commit

Permalink
add use_graphlib param with TopologicalSorter
Browse files Browse the repository at this point in the history
  • Loading branch information
minwook-shin committed Aug 4, 2024
1 parent b139e6d commit c2e5504
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
4 changes: 2 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def example_return_func(text):
return text


dag = DAG()
dag = DAG(use_graphlib=True)

dag.add_task(DefaultFunctionOperator(function=print, param=(['hello']), task_id='hello_task'))
dag.add_task(DefaultFunctionOperator(function=example_return_func, param=(['bye']), task_id='bye_task'))
Expand All @@ -20,7 +20,7 @@ def example_return_func(text):

dag.update_task(task_order[2], ['where are you from?'])

converter.convert_list_to_dag(task_order).run(task_order[0])
converter.convert_list_to_dag(task_order).run()

# print return value of iter_task
print(dag.get_return_value('iter_task'))
Expand Down
36 changes: 28 additions & 8 deletions f_scheduler/modules/dag.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
from graphlib import TopologicalSorter


class DAG:
def __init__(self):
def __init__(self, use_graphlib=False):
self.tasks = {}
self.graph = TopologicalSorter()
self.use_graphlib = use_graphlib

def add_task(self, task):
self.tasks[task.task_id] = task
if self.use_graphlib:
self.graph.add(task.task_id)

def set_downstream(self, task_id, next_task_id):
task = self.tasks[task_id]
next_task = self.tasks[next_task_id]
task.next(next_task)
if not self.use_graphlib:
task = self.tasks[task_id]
next_task = self.tasks[next_task_id]
task.next(next_task)
else:
self.graph.add(task_id, next_task_id)

def run(self, start_task_id):
start_task = self.tasks[start_task_id]
start_task.run()
def run(self, start_task_id=None):
if not self.use_graphlib:
if start_task_id is None:
start_task_id = list(self.tasks.keys())[0]
start_task = self.tasks[start_task_id]
start_task.run()
else:
order = list(self.graph.static_order())
for task_id in order:
self.tasks[task_id].run()

def clear(self):
self.tasks.clear()
Expand All @@ -24,7 +41,10 @@ def get_all_tasks(self):
return self.tasks

def update_task(self, task_id, new_param):
task = self.tasks[task_id]
if not self.use_graphlib:
task = self.tasks[task_id]
else:
task = self.tasks.get(task_id)
if task:
task.param = new_param
else:
Expand Down

0 comments on commit c2e5504

Please sign in to comment.