diff --git a/src/Init/System/IO.lean b/src/Init/System/IO.lean index 45eba794fb7d..0c6805cd7b97 100644 --- a/src/Init/System/IO.lean +++ b/src/Init/System/IO.lean @@ -210,8 +210,44 @@ def sleep (ms : UInt32) : BaseIO Unit := /-- Request cooperative cancellation of the task. The task must explicitly call `IO.checkCanceled` to react to the cancellation. -/ @[extern "lean_io_cancel"] opaque cancel : @& Task α → BaseIO Unit +/-- The current state of a `Task` in the Lean runtime's task manager. -/ +inductive TaskState + /-- + The `Task` is waiting to be run. + It can be waiting for dependencies to complete or + sitting in the task manager queue waiting for a thread to run on. + -/ + | waiting + /-- + The `Task` is actively running on a thread or, + in the case of a `Promise`, waiting for a call to `IO.Promise.resolve`. + -/ + | running + /-- + The `Task` has finished running and its result is available. + Calling `Task.get` or `IO.wait` on the task will not block. + -/ + | finished + deriving Inhabited, Repr, DecidableEq, Ord + +instance : LT TaskState := ltOfOrd +instance : LE TaskState := leOfOrd +instance : Min TaskState := minOfLe +instance : Max TaskState := maxOfLe + +protected def TaskState.toString : TaskState → String + | .waiting => "waiting" + | .running => "running" + | .finished => "finished" + +instance : ToString TaskState := ⟨TaskState.toString⟩ + +/-- Returns current state of the `Task` in the Lean runtime's task manager. -/ +@[extern "lean_io_get_task_state"] opaque getTaskState : @& Task α → BaseIO TaskState + /-- Check if the task has finished execution, at which point calling `Task.get` will return immediately. -/ -@[extern "lean_io_has_finished"] opaque hasFinished : @& Task α → BaseIO Bool +@[inline] def hasFinished (task : Task α) : BaseIO Bool := do + return (← getTaskState task) matches .finished /-- Wait for the task to finish, then return its result. -/ @[extern "lean_io_wait"] opaque wait (t : Task α) : BaseIO α := diff --git a/src/include/lean/lean.h b/src/include/lean/lean.h index 80120a38b89d..c189c6235e2d 100644 --- a/src/include/lean/lean.h +++ b/src/include/lean/lean.h @@ -1110,8 +1110,8 @@ static inline lean_obj_res lean_task_get_own(lean_obj_arg t) { LEAN_EXPORT bool lean_io_check_canceled_core(void); /* primitive for implementing `IO.cancel : Task a -> IO Unit` */ LEAN_EXPORT void lean_io_cancel_core(b_lean_obj_arg t); -/* primitive for implementing `IO.hasFinished : Task a -> IO Unit` */ -LEAN_EXPORT bool lean_io_has_finished_core(b_lean_obj_arg t); +/* primitive for implementing `IO.getTaskState : Task a -> IO TaskState` */ +LEAN_EXPORT uint8_t lean_io_get_task_state_core(b_lean_obj_arg t); /* primitive for implementing `IO.waitAny : List (Task a) -> IO (Task a)` */ LEAN_EXPORT b_lean_obj_res lean_io_wait_any_core(b_lean_obj_arg task_list); diff --git a/src/runtime/io.cpp b/src/runtime/io.cpp index 4dddd7afc0b5..fe4a232d9c71 100644 --- a/src/runtime/io.cpp +++ b/src/runtime/io.cpp @@ -1066,8 +1066,8 @@ extern "C" LEAN_EXPORT obj_res lean_io_cancel(b_obj_arg t, obj_arg) { return io_result_mk_ok(box(0)); } -extern "C" LEAN_EXPORT obj_res lean_io_has_finished(b_obj_arg t, obj_arg) { - return io_result_mk_ok(box(lean_io_has_finished_core(t))); +extern "C" LEAN_EXPORT obj_res lean_io_get_task_state(b_obj_arg t, obj_arg) { + return io_result_mk_ok(box(lean_io_get_task_state_core(t))); } extern "C" LEAN_EXPORT obj_res lean_io_wait(obj_arg t, obj_arg) { diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index 89af9b2604fa..262227eb5e8f 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -1030,8 +1030,17 @@ extern "C" LEAN_EXPORT void lean_io_cancel_core(b_obj_arg t) { g_task_manager->cancel(lean_to_task(t)); } -extern "C" LEAN_EXPORT bool lean_io_has_finished_core(b_obj_arg t) { - return lean_to_task(t)->m_value != nullptr; +extern "C" LEAN_EXPORT uint8_t lean_io_get_task_state_core(b_obj_arg t) { + lean_task_object * o = lean_to_task(t); + if (o->m_imp) { + if (o->m_imp->m_closure) { + return 0; // waiting (waiting/queued) + } else { + return 1; // running (running/promised) + } + } else { + return 2; // finished + } } extern "C" LEAN_EXPORT b_obj_res lean_io_wait_any_core(b_obj_arg task_list) { diff --git a/src/runtime/object.h b/src/runtime/object.h index 04f13fe1b642..f9bf12d6c2da 100644 --- a/src/runtime/object.h +++ b/src/runtime/object.h @@ -288,7 +288,7 @@ inline b_obj_res task_get(b_obj_arg t) { return lean_task_get(t); } inline bool io_check_canceled_core() { return lean_io_check_canceled_core(); } inline void io_cancel_core(b_obj_arg t) { return lean_io_cancel_core(t); } -inline bool io_has_finished_core(b_obj_arg t) { return lean_io_has_finished_core(t); } +inline bool io_get_task_state_core(b_obj_arg t) { return lean_io_get_task_state_core(t); } inline b_obj_res io_wait_any_core(b_obj_arg task_list) { return lean_io_wait_any_core(task_list); } // ======================================= diff --git a/tests/lean/run/taskState.lean b/tests/lean/run/taskState.lean new file mode 100644 index 000000000000..e5e045edec06 --- /dev/null +++ b/tests/lean/run/taskState.lean @@ -0,0 +1,25 @@ +def assertBEq [BEq α] [ToString α] (caption : String) (actual expected : α) : IO Unit := do + unless actual == expected do + throw <| IO.userError <| + s!"{caption}: expected '{expected}', got '{actual}'" + +def test : IO Unit := do + let p1 : IO.Promise Unit ← IO.Promise.new -- resolving queues the task + let p2 : IO.Promise Unit ← IO.Promise.new -- resolved once task is running + let p3 : IO.Promise Unit ← IO.Promise.new -- resolving finishes the task + let t ← BaseIO.mapTask (fun () => do p2.resolve (); IO.wait p3.result) p1.result + assertBEq "p1" (← IO.getTaskState p1.result) .running + assertBEq "p2" (← IO.getTaskState p2.result) .running + assertBEq "p3" (← IO.getTaskState p3.result) .running + assertBEq "t" (← IO.getTaskState t) .waiting + p1.resolve () + assertBEq "p1" (← IO.getTaskState p1.result) .finished + IO.wait p2.result + assertBEq "p2" (← IO.getTaskState p2.result) .finished + assertBEq "t" (← IO.getTaskState t) .running + p3.resolve () + assertBEq "p3" (← IO.getTaskState p3.result) .finished + IO.wait t + assertBEq "t" (← IO.getTaskState t) .finished + +#eval test