Skip to content

Commit

Permalink
Enable reuse of state container for repeated parallel_for invocations
Browse files Browse the repository at this point in the history
Summary: As title.

Reviewed By: EscapeZero

Differential Revision: D50068872

fbshipit-source-id: c413bc756395e8e94f6d3050e3347d8b8bfacc7f
  • Loading branch information
graphicsMan authored and facebook-github-bot committed Oct 11, 2023
1 parent d94708d commit 1fa2e2f
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 5 deletions.
54 changes: 49 additions & 5 deletions dispenso/parallel_for.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ struct ParForOptions {
* size is provided to ChunkedRange.
**/
uint32_t minItemsPerChunk = 1;

/**
* When set to false, and StateContainers are supplied to parallel_for, re-create container from
* scratch each call to parallel_for. When true, reuse existing state as much as possible (only
* create new state if we require more than is already available in the container).
**/
bool reuseExistingState = false;
};

/**
Expand Down Expand Up @@ -203,6 +210,16 @@ struct NoOpIter {
};

struct NoOpContainer {
size_t size() const {
return 0;
}

bool empty() const {
return true;
}

void clear() {}

NoOpIter begin() {
return {};
}
Expand Down Expand Up @@ -234,14 +251,22 @@ void parallel_for_staticImpl(
const ChunkedRange<IntegerT>& range,
F&& f,
ssize_t maxThreads,
bool wait) {
bool wait,
bool reuseExistingState) {
using size_type = typename ChunkedRange<IntegerT>::size_type;

size_type numThreads = std::min<size_type>(taskSet.numPoolThreads(), maxThreads);
// Reduce threads used if they exceed work to be done.
numThreads = std::min(numThreads, range.size()) + wait;

for (size_type i = 0; i < numThreads; ++i) {
if (!reuseExistingState) {
states.clear();
}

size_t numToEmplace =
states.size() < static_cast<size_t>(numThreads) ? numThreads - states.size() : 0;

for (; numToEmplace--;) {
states.emplace_back(defaultState());
}

Expand Down Expand Up @@ -338,7 +363,12 @@ void parallel_for(
const size_type N = taskSet.numPoolThreads();
if (N == 0 || !options.maxThreads || range.size() <= minItemsPerChunk ||
detail::PerPoolPerThreadInfo::isParForRecursive(&taskSet.pool())) {
states.emplace_back(defaultState());
if (!options.reuseExistingState) {
states.clear();
}
if (states.empty()) {
states.emplace_back(defaultState());
}
f(*states.begin(), range.start, range.end);
if (options.wait) {
taskSet.wait();
Expand All @@ -365,13 +395,27 @@ void parallel_for(

if (isStatic) {
detail::parallel_for_staticImpl(
taskSet, states, defaultState, range, std::forward<F>(f), maxThreads, options.wait);
taskSet,
states,
defaultState,
range,
std::forward<F>(f),
maxThreads,
options.wait,
options.reuseExistingState);
return;
}

const size_type numToLaunch = std::min<size_type>(maxThreads, N);

for (size_type i = 0; i < numToLaunch + options.wait; ++i) {
if (!options.reuseExistingState) {
states.clear();
}

size_t numToEmplace = static_cast<size_type>(states.size()) < (numToLaunch + options.wait)
? (numToLaunch + options.wait) - states.size()
: 0;
for (; numToEmplace--;) {
states.emplace_back(defaultState());
}

Expand Down
48 changes: 48 additions & 0 deletions tests/chunked_for_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,51 @@ TEST(ChunkedFor, MinChunkSizeLoopStatic) {
minChunkSize(dispenso::ParForChunking::kStatic, 1000000, 10000000, 20000);
minChunkSize(dispenso::ParForChunking::kStatic, -10000000, -1000000, 20000);
}

template <typename StateContainer>
void loopWithStateImplReuseState() {
int w = 1024;
int h = 1024;
std::vector<int> image(static_cast<size_t>(w * h), 7);

StateContainer state;

dispenso::ParForOptions options;
options.reuseExistingState = true;

for (size_t i = 0; i < 3; ++i) {
dispenso::parallel_for(
state,
[]() { return int64_t{0}; },
dispenso::makeChunkedRange(0, h, 16),
[w, &image](int64_t& sum, int ystart, int yend) {
EXPECT_EQ(yend - ystart, 16);
int64_t s = 0;
for (int y = ystart; y < yend; ++y) {
int* row = image.data() + y * w;
for (int i = 0; i < w; ++i) {
s += row[i];
}
}
sum += s;
},
options);
}

int64_t sum = 0;
for (int64_t s : state) {
sum += s;
}

EXPECT_EQ(sum, 3 * w * h * 7);
}

TEST(ChunkedFor, LoopWithDequeStateReuse) {
loopWithStateImplReuseState<std::deque<int64_t>>();
}
TEST(ChunkedFor, LoopWithVectorStateReuse) {
loopWithStateImplReuseState<std::vector<int64_t>>();
}
TEST(ChunkedFor, LoopWithListStateReuse) {
loopWithStateImplReuseState<std::list<int64_t>>();
}

0 comments on commit 1fa2e2f

Please sign in to comment.