Skip to content

Commit

Permalink
refactor(rust): Allow non-scoped tasks to be spawned (#18163)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Aug 13, 2024
1 parent 41318c3 commit 9dc4106
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
41 changes: 34 additions & 7 deletions crates/polars-stream/src/async_executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,23 @@ pub enum TaskPriority {
}

/// Metadata associated with a task to help schedule it and clean it up.
struct ScopedTaskMetadata {
task_key: TaskKey,
completed_tasks: Weak<Mutex<Vec<TaskKey>>>,
}

struct TaskMetadata {
priority: TaskPriority,
freshly_spawned: AtomicBool,

task_key: TaskKey,
completed_tasks: Weak<Mutex<Vec<TaskKey>>>,
scoped: Option<ScopedTaskMetadata>,
}

impl Drop for TaskMetadata {
fn drop(&mut self) {
if let Some(completed_tasks) = self.completed_tasks.upgrade() {
completed_tasks.lock().push(self.task_key);
if let Some(scoped) = &self.scoped {
if let Some(completed_tasks) = scoped.completed_tasks.upgrade() {
completed_tasks.lock().push(scoped.task_key);
}
}
}
}
Expand Down Expand Up @@ -296,10 +301,12 @@ impl<'scope, 'env> TaskScope<'scope, 'env> {
fut,
on_wake,
TaskMetadata {
task_key,
priority,
freshly_spawned: AtomicBool::new(true),
completed_tasks: Arc::downgrade(&self.completed_tasks),
scoped: Some(ScopedTaskMetadata {
task_key,
completed_tasks: Arc::downgrade(&self.completed_tasks),
}),
},
)
};
Expand Down Expand Up @@ -338,6 +345,26 @@ where
}
}

#[allow(unused)]
pub fn spawn<F: Future + Send + 'static>(priority: TaskPriority, fut: F) -> JoinHandle<F::Output>
where
<F as Future>::Output: Send + 'static,
{
let executor = Executor::global();
let on_wake = move |task| executor.schedule_task(task);
let (runnable, join_handle) = task::spawn(
fut,
on_wake,
TaskMetadata {
priority,
freshly_spawned: AtomicBool::new(true),
scoped: None,
},
);
runnable.schedule();
join_handle
}

fn random_permutation<R: Rng>(len: u32, rng: &mut R) -> impl Iterator<Item = u32> {
let modulus = len.next_power_of_two();
let halfwidth = modulus.trailing_zeros() / 2;
Expand Down
5 changes: 2 additions & 3 deletions crates/polars-stream/src/async_executor/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,15 @@ impl CancelHandle {
}
}

#[allow(unused)]
pub fn spawn<F, S, M>(future: F, schedule: S, metadata: M) -> JoinHandle<F::Output>
pub fn spawn<F, S, M>(future: F, schedule: S, metadata: M) -> (Runnable<M>, JoinHandle<F::Output>)
where
F: Future + Send + 'static,
F::Output: Send + 'static,
S: Fn(Runnable<M>) + Send + Sync + Copy + 'static,
M: Send + Sync + 'static,
{
let task = unsafe { Task::spawn(future, schedule, metadata) };
JoinHandle(Some(task))
(task.clone().into_runnable(), task.into_join_handle())
}

/// Takes a future and turns it into a runnable task with associated metadata.
Expand Down

0 comments on commit 9dc4106

Please sign in to comment.