diff --git a/include/API/api.h b/include/API/api.h index 1173b5ff4..745067beb 100644 --- a/include/API/api.h +++ b/include/API/api.h @@ -125,6 +125,11 @@ inline int asMainReturnCode(llvm::Error err) { /// @param treatWarningsAsErrors return errors in place of warnings /// @param diagnosticCb an optional callback that will receive emitted /// diagnostics +/// @param numberOfThreads number of threads to use in binding +/// -1 : number of cpus +/// 0 +/// > 0: limit +/// defaults to -1 /// @return 0 on success int bindArguments(std::string_view target, qssc::config::EmitAction action, std::string_view configPath, std::string_view moduleInput, @@ -132,7 +137,8 @@ int bindArguments(std::string_view target, qssc::config::EmitAction action, std::unordered_map const &arguments, bool treatWarningsAsErrors, bool enableInMemoryInput, std::string *inMemoryOutput, - const OptDiagnosticCallback &onDiagnostic); + const OptDiagnosticCallback &onDiagnostic, + int numberOfThreads = -1); } // namespace qssc #endif // QSS_COMPILER_LIB_H diff --git a/include/Arguments/Arguments.h b/include/Arguments/Arguments.h index 503111ebe..d674c9b59 100644 --- a/include/Arguments/Arguments.h +++ b/include/Arguments/Arguments.h @@ -79,13 +79,12 @@ class BindArgumentsImplementationFactory { }; // TODO generalize type of arguments -llvm::Error bindArguments(llvm::StringRef moduleInput, - llvm::StringRef payloadOutputPath, - ArgumentSource const &arguments, - bool treatWarningsAsErrors, bool enableInMemoryInput, - std::string *inMemoryOutput, - BindArgumentsImplementationFactory &factory, - const OptDiagnosticCallback &onDiagnostic); +llvm::Error +bindArguments(llvm::StringRef moduleInput, llvm::StringRef payloadOutputPath, + ArgumentSource const &arguments, bool treatWarningsAsErrors, + bool enableInMemoryInput, std::string *inMemoryOutput, + BindArgumentsImplementationFactory &factory, + const OptDiagnosticCallback &onDiagnostic, int numberOfThreads); } // namespace qssc::arguments diff --git a/lib/API/api.cpp b/lib/API/api.cpp index 281a5aaf2..2d30b31da 100644 --- a/lib/API/api.cpp +++ b/lib/API/api.cpp @@ -811,7 +811,8 @@ bindArguments_(std::string_view target, qssc::config::EmitAction action, std::unordered_map const &arguments, bool treatWarningsAsErrors, bool enableInMemoryInput, std::string *inMemoryOutput, - const qssc::OptDiagnosticCallback &onDiagnostic) { + const qssc::OptDiagnosticCallback &onDiagnostic, + int numberOfThreads) { MLIRContext context{}; @@ -850,7 +851,8 @@ bindArguments_(std::string_view target, qssc::config::EmitAction action, *factory.value(); return qssc::arguments::bindArguments( moduleInput, payloadOutputPath, source, treatWarningsAsErrors, - enableInMemoryInput, inMemoryOutput, factoryRef, onDiagnostic); + enableInMemoryInput, inMemoryOutput, factoryRef, onDiagnostic, + numberOfThreads); } } // anonymous namespace @@ -862,12 +864,12 @@ int qssc::bindArguments( std::unordered_map const &arguments, bool treatWarningsAsErrors, bool enableInMemoryInput, std::string *inMemoryOutput, - const qssc::OptDiagnosticCallback &onDiagnostic) { + const qssc::OptDiagnosticCallback &onDiagnostic, int numberOfThreads) { - if (auto err = - bindArguments_(target, action, configPath, moduleInput, - payloadOutputPath, arguments, treatWarningsAsErrors, - enableInMemoryInput, inMemoryOutput, onDiagnostic)) { + if (auto err = bindArguments_( + target, action, configPath, moduleInput, payloadOutputPath, arguments, + treatWarningsAsErrors, enableInMemoryInput, inMemoryOutput, + onDiagnostic, numberOfThreads)) { llvm::logAllUnhandledErrors(std::move(err), llvm::errs()); return 1; } diff --git a/lib/Arguments/Arguments.cpp b/lib/Arguments/Arguments.cpp index e42047e87..1175d5be2 100644 --- a/lib/Arguments/Arguments.cpp +++ b/lib/Arguments/Arguments.cpp @@ -28,13 +28,20 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" +#include #include +#include #include #include +#include +#include #include #include +#include #include +#include #include +#include namespace qssc::arguments { @@ -44,9 +51,43 @@ llvm::Error updateParameters(qssc::payload::PatchablePayload *payload, Signature &sig, ArgumentSource const &arguments, bool treatWarningsAsErrors, BindArgumentsImplementationFactory &factory, - const OptDiagnosticCallback &onDiagnostic) { + const OptDiagnosticCallback &onDiagnostic, + int numberOfThreads) { - for (const auto &[binaryName, patchPoints] : sig.patchPointsByBinary) { + std::deque threads; + std::vector> binaries; + + bool const enableThreads = (numberOfThreads != 0); + uint MAX_NUM_THREADS = (numberOfThreads > 0) + ? numberOfThreads + : std::thread::hardware_concurrency(); + + // if failed to detect number of CPUs default to 10 + if (MAX_NUM_THREADS == 0) + MAX_NUM_THREADS = 10; + + std::mutex errorMutex; + bool errorSet = false; + llvm::Error firstError = llvm::Error::success(); + + // the onDiagnastic method used to emit diagnostics to python + // is not thread safe + // setup of local callback to capture the highest level diagnostic + // and re-emit from the main thread if threading is being used + std::optional localDiagValue = std::nullopt; + std::optional const localCallback = + std::optional(std::function([&](const Diagnostic &diag) { + if (!localDiagValue.has_value() || + localDiagValue.value().severity < diag.severity) { + localDiagValue = diag; + } + })); + + uint numThreads = 0; + for (const auto &entry : sig.patchPointsByBinary) { + + const auto &binaryName = entry.first; + const auto &patchPoints = entry.second; if (patchPoints.size() == 0) // no patch points continue; @@ -63,25 +104,78 @@ llvm::Error updateParameters(qssc::payload::PatchablePayload *payload, auto &binaryData = binaryDataOrErr.get(); + // onDiagnostic callback is not thread safe + auto localDiagnostic = (enableThreads) ? localCallback : onDiagnostic; + auto binary = std::shared_ptr( - factory.create(binaryData, onDiagnostic)); + factory.create(binaryData, localDiagnostic)); binary->setTreatWarningsAsErrors(treatWarningsAsErrors); - for (auto const &patchPoint : patchPoints) - if (auto err = binary->patch(patchPoint, arguments)) - return err; + if (enableThreads) { + // save shared point in vector to ensure lifetime exceeds the thread + binaries.emplace_back(binary); + + numThreads++; + if (numThreads > MAX_NUM_THREADS) { + // wait for a thread to finish before starting another + auto &t = threads[0]; + t.join(); + threads.pop_front(); + } + threads.emplace_back([&, binary] { + if (errorSet) + return; + + for (auto const &patchPoint : patchPoints) { + auto err = binary->patch(patchPoint, arguments); + if (err && !errorSet) { + const std::lock_guard lock(errorMutex); + firstError = std::move(err); + errorSet = true; + return; + } + } + }); + } else { + // processing patch points on main thread + for (auto const &patchPoint : patchPoints) + if (auto err = binary->patch(patchPoint, arguments)) + return err; + } + } + + if (enableThreads) { + for (auto &t : threads) + t.join(); + + binaries.clear(); + + if (errorSet || localDiagValue.has_value()) { + // emit error or warning via onDiagnostic if + // one was set + auto *diagnosticCallback = + onDiagnostic.has_value() ? &onDiagnostic.value() : nullptr; + if (diagnosticCallback && localDiagValue.has_value()) + (*diagnosticCallback)(localDiagValue.value()); + // possibly return the error + auto minLevel = + (treatWarningsAsErrors) ? Severity::Info : Severity::Warning; + if (localDiagValue.has_value() && + (localDiagValue.value().severity > minLevel)) { + return firstError; + } + } } return llvm::Error::success(); } -llvm::Error bindArguments(llvm::StringRef moduleInput, - llvm::StringRef payloadOutputPath, - ArgumentSource const &arguments, - bool treatWarningsAsErrors, bool enableInMemoryInput, - std::string *inMemoryOutput, - BindArgumentsImplementationFactory &factory, - const OptDiagnosticCallback &onDiagnostic) { +llvm::Error +bindArguments(llvm::StringRef moduleInput, llvm::StringRef payloadOutputPath, + ArgumentSource const &arguments, bool treatWarningsAsErrors, + bool enableInMemoryInput, std::string *inMemoryOutput, + BindArgumentsImplementationFactory &factory, + const OptDiagnosticCallback &onDiagnostic, int numberOfThreads) { bool const enableInMemoryOutput = payloadOutputPath == ""; @@ -134,7 +228,8 @@ llvm::Error bindArguments(llvm::StringRef moduleInput, return err; if (auto err = updateParameters(payload.get(), sigOrError.get(), arguments, - treatWarningsAsErrors, factory, onDiagnostic)) + treatWarningsAsErrors, factory, onDiagnostic, + numberOfThreads)) return err; // setup linked payload I/O diff --git a/python_lib/qss_compiler/lib.cpp b/python_lib/qss_compiler/lib.cpp index 865a112ae..451608af0 100644 --- a/python_lib/qss_compiler/lib.cpp +++ b/python_lib/qss_compiler/lib.cpp @@ -227,14 +227,15 @@ py::tuple py_link_file(const std::string &input, const bool enableInMemoryInput, const std::string &configPath, const std::unordered_map &arguments, bool treatWarningsAsErrors, - qssc::DiagnosticCallback onDiagnostic) { + qssc::DiagnosticCallback onDiagnostic, + int numberOfThreads = -1) { std::string inMemoryOutput(""); int const status = qssc::bindArguments( target, qssc::config::EmitAction::QEM, configPath, input, outputPath, arguments, treatWarningsAsErrors, enableInMemoryInput, &inMemoryOutput, - std::move(onDiagnostic)); + std::move(onDiagnostic), numberOfThreads); bool const success = status == 0; #ifndef NDEBUG diff --git a/python_lib/qss_compiler/link.py b/python_lib/qss_compiler/link.py index 7244614e1..fc00ccbfd 100644 --- a/python_lib/qss_compiler/link.py +++ b/python_lib/qss_compiler/link.py @@ -53,6 +53,13 @@ class LinkOptions: """Treat link warnings as errors""" on_diagnostic: Optional[Callable[[Diagnostic], Any]] = None """Optional callback for processing diagnostic messages from the linker.""" + number_of_threads: int = -1 + """Number of threads to use in linking + -1 = number of cpus reported, + 0 disabled, + > 1 = limit, + defaults = -1 + """ def _prepare_link_options(link_options: Optional[LinkOptions] = None, **kwargs) -> LinkOptions: @@ -133,6 +140,7 @@ def on_diagnostic(diag): link_options.arguments, link_options.treat_warnings_as_errors, link_options.on_diagnostic, + link_options.number_of_threads, ) if not success: exception_mapping = { diff --git a/releasenotes/notes/link-threading-d295374d01595205.yaml b/releasenotes/notes/link-threading-d295374d01595205.yaml new file mode 100644 index 000000000..de2038826 --- /dev/null +++ b/releasenotes/notes/link-threading-d295374d01595205.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Adds `number_of_threads` parameter to qss_compiler.link_file + interface to control number of threads used during linking.