From d6ad19e7012b46b3d9c6b14d7b4ba3191b6364ba Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Tue, 30 May 2023 18:46:05 -0700 Subject: [PATCH] Fixing performance regression when user enables running on caller thread --- tensorflow/core/util/mkl_threadpool.h | 56 +++++++++++++++------------ 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/tensorflow/core/util/mkl_threadpool.h b/tensorflow/core/util/mkl_threadpool.h index cedc828b98b9d3..488c9f6ad29b45 100644 --- a/tensorflow/core/util/mkl_threadpool.h +++ b/tensorflow/core/util/mkl_threadpool.h @@ -25,8 +25,8 @@ limitations under the License. #include #include -#include "dnnl_threadpool.hpp" #include "dnnl.hpp" +#include "dnnl_threadpool.hpp" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/cpu_info.h" @@ -116,33 +116,39 @@ struct MklDnnThreadPool : public threadpool_iface { const bool use_caller_thread = ThreadPoolUseCallerThread() && nthr == port::NumSchedulableCPUs(); const int njobs_to_schedule = use_caller_thread ? njobs - 1 : njobs; - - BlockingCounter counter(njobs_to_schedule); - std::function handle_range = [=, &handle_range, &counter]( - int first, int last) { - while (last - first > 1) { - const auto mid = first + (last - first) / 2; - // Find something near the midpoint which is a multiple of block size. - eigen_interface_->ScheduleWithHint([=]() { handle_range(mid, last); }, - mid, mid + 1); - last = mid; - } - counter.DecrementCount(); - run_jobs(balance, first, n, njobs, fn); - }; - - // Eigen avoids a thread hop by running the root of the tree on the main - // thread. We have disabled this because it actually slows things down - // relative to base because base cheats and uses n threads while letting - // main continue doing other work - eigen_interface_->ScheduleWithHint( - [=]() { handle_range(0, njobs_to_schedule); }, 0, 1); - if (use_caller_thread) { + for (int i = 0; i < njobs_to_schedule; i++) { + eigen_interface_->ScheduleWithHint( + [balance, i, n, njobs, fn]() { + run_jobs(balance, i, n, njobs, fn); + }, + i, i + 1); + } run_jobs(balance, njobs - 1, n, njobs, fn); + } else { + BlockingCounter counter(njobs); + std::function handle_range = [=, &handle_range, &counter]( + int first, int last) { + while (last - first > 1) { + const auto mid = first + (last - first) / 2; + // Find something near the midpoint which is a multiple of block size. + eigen_interface_->ScheduleWithHint([=]() { handle_range(mid, last); }, + mid, mid + 1); + last = mid; + } + counter.DecrementCount(); + run_jobs(balance, first, n, njobs, fn); + }; + + // Eigen avoids a thread hop by running the root of the tree on the main + // thread. We have disabled this because it actually slows things down + // relative to base because base cheats and uses n threads while letting + // main continue doing other work + eigen_interface_->ScheduleWithHint([=]() { handle_range(0, njobs); }, 0, + 1); + + counter.Wait(); } - - counter.Wait(); } ~MklDnnThreadPool() {}