Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce TaskExecutor overhead #13861

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 99 additions & 113 deletions lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
Expand Down Expand Up @@ -73,144 +72,131 @@ public TaskExecutor(Executor executor) {
/**
* Execute all the callables provided as an argument, wait for them to complete and return the
* obtained results. If an exception is thrown by more than one callable, the subsequent ones will
* be added as suppressed exceptions to the first one that was caught.
* be added as suppressed exceptions to the first one that was caught. Additionally, if one task
* throws an exception, all other tasks from the same group are cancelled, to avoid needless
* computation as their results would not be exposed anyways.
*
* @param callables the callables to execute
* @return a list containing the results from the tasks execution
* @param <T> the return type of the task execution
*/
public <T> List<T> invokeAll(Collection<Callable<T>> callables) throws IOException {
TaskGroup<T> taskGroup = new TaskGroup<>(callables);
return taskGroup.invokeAll(executor);
List<RunnableFuture<T>> futures = new ArrayList<>(callables.size());
for (Callable<T> callable : callables) {
futures.add(new Task<>(callable, futures));
}
final int count = futures.size();
// taskId provides the first index of an un-executed task in #futures
final AtomicInteger taskId = new AtomicInteger(0);
// we fork execution count - 1 tasks to execute at least one task on the current thread to
// minimize needless forking and blocking of the current thread
if (count > 1) {
final Runnable work =
() -> {
int id = taskId.getAndIncrement();
if (id < count) {
futures.get(id).run();
}
};
for (int j = 0; j < count - 1; j++) {
executor.execute(work);
}
}
// try to execute as many tasks as possible on the current thread to minimize context
// switching in case of long running concurrent
// tasks as well as dead-locking if the current thread is part of #executor for executors that
// have limited or no parallelism
int id;
while ((id = taskId.getAndIncrement()) < count) {
futures.get(id).run();
if (id >= count - 1) {
// save redundant CAS in case this was the last task
break;
}
}
return collectResults(futures);
}

private static <T> List<T> collectResults(List<RunnableFuture<T>> futures) throws IOException {
Throwable exc = null;
List<T> results = new ArrayList<>(futures.size());
for (Future<T> future : futures) {
try {
results.add(future.get());
} catch (InterruptedException e) {
exc = IOUtils.useOrSuppress(exc, new ThreadInterruptedException(e));
} catch (ExecutionException e) {
exc = IOUtils.useOrSuppress(exc, e.getCause());
}
}
assert assertAllFuturesCompleted(futures) : "Some tasks are still running?";
if (exc != null) {
throw IOUtils.rethrowAlways(exc);
}
return results;
}

@Override
public String toString() {
return "TaskExecutor(" + "executor=" + executor + ')';
}

/**
* Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and
* exposes the ability to invoke such tasks and wait for them all to complete their execution and
* provide their results. Additionally, if one task throws an exception, all other tasks from the
* same group are cancelled, to avoid needless computation as their results would not be exposed
* anyways. Creates one {@link FutureTask} for each {@link Callable} provided
*
* @param <T> the return type of all the callables
*/
private static final class TaskGroup<T> {
private final List<RunnableFuture<T>> futures;

TaskGroup(Collection<Callable<T>> callables) {
List<RunnableFuture<T>> tasks = new ArrayList<>(callables.size());
for (Callable<T> callable : callables) {
tasks.add(createTask(callable));
private static boolean assertAllFuturesCompleted(Collection<? extends Future<?>> futures) {
for (Future<?> future : futures) {
if (future.isDone() == false) {
return false;
}
this.futures = Collections.unmodifiableList(tasks);
}
return true;
}

RunnableFuture<T> createTask(Callable<T> callable) {
return new FutureTask<>(callable) {
private static <T> void cancelAll(Collection<? extends Future<T>> futures) {
for (Future<?> future : futures) {
future.cancel(false);
}
}

private final AtomicBoolean startedOrCancelled = new AtomicBoolean(false);
private static class Task<T> extends FutureTask<T> {

@Override
public void run() {
if (startedOrCancelled.compareAndSet(false, true)) {
super.run();
}
}

@Override
protected void setException(Throwable t) {
super.setException(t);
cancelAll();
}

@Override
public boolean cancel(boolean mayInterruptIfRunning) {
assert mayInterruptIfRunning == false
: "cancelling tasks that are running is not supported";
/*
Future#get (called in invokeAll) throws CancellationException when invoked against a running task that has been cancelled but
leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns.
Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will
wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and
make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op.
*/
if (startedOrCancelled.compareAndSet(false, true)) {
// task is cancelled hence it has no results to return. That's fine: they would be
// ignored anyway.
set(null);
return true;
}
return false;
}
};
private final AtomicBoolean startedOrCancelled = new AtomicBoolean(false);

private final Collection<? extends Future<T>> futures;

public Task(Callable<T> callable, Collection<? extends Future<T>> futures) {
super(callable);
this.futures = futures;
}

List<T> invokeAll(Executor executor) throws IOException {
final int count = futures.size();
// taskId provides the first index of an un-executed task in #futures
final AtomicInteger taskId = new AtomicInteger(0);
// we fork execution count - 1 tasks to execute at least one task on the current thread to
// minimize needless forking and blocking of the current thread
if (count > 1) {
final Runnable work =
() -> {
int id = taskId.getAndIncrement();
if (id < count) {
futures.get(id).run();
}
};
for (int j = 0; j < count - 1; j++) {
executor.execute(work);
}
}
// try to execute as many tasks as possible on the current thread to minimize context
// switching in case of long running concurrent
// tasks as well as dead-locking if the current thread is part of #executor for executors that
// have limited or no parallelism
int id;
while ((id = taskId.getAndIncrement()) < count) {
futures.get(id).run();
if (id >= count - 1) {
// save redundant CAS in case this was the last task
break;
}
@Override
public void run() {
if (startedOrCancelled.compareAndSet(false, true)) {
super.run();
}
Throwable exc = null;
List<T> results = new ArrayList<>(count);
for (int i = 0; i < count; i++) {
Future<T> future = futures.get(i);
try {
results.add(future.get());
} catch (InterruptedException e) {
exc = IOUtils.useOrSuppress(exc, new ThreadInterruptedException(e));
} catch (ExecutionException e) {
exc = IOUtils.useOrSuppress(exc, e.getCause());
}
}
assert assertAllFuturesCompleted() : "Some tasks are still running?";
if (exc != null) {
throw IOUtils.rethrowAlways(exc);
}
return results;
}

private boolean assertAllFuturesCompleted() {
for (RunnableFuture<T> future : futures) {
if (future.isDone() == false) {
return false;
}
}
return true;
@Override
protected void setException(Throwable t) {
super.setException(t);
cancelAll(futures);
}

private void cancelAll() {
for (Future<T> future : futures) {
future.cancel(false);
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
assert mayInterruptIfRunning == false : "cancelling tasks that are running is not supported";
/*
Future#get (called in #collectResults) throws CancellationException when invoked against a running task that has been cancelled but
leaves the task running. We rather want to make sure that invokeAll does not leave any running tasks behind when it returns.
Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will
wait for them to finish instead of throwing CancellationException. A cleaner way would have been to override FutureTask#get and
make it wait for cancelled tasks, but FutureTask#awaitDone is private. Tasks that are cancelled before they are started will be no-op.
*/
if (startedOrCancelled.compareAndSet(false, true)) {
// task is cancelled hence it has no results to return. That's fine: they would be
// ignored anyway.
set(null);
return true;
}
return false;
}
}
}