Skip to content

Commit

Permalink
add threading to updateParameters (#332)
Browse files Browse the repository at this point in the history
Adds threading support to the updateParameters method used in
bindArguments. This will create threads for the patch point expression
conversion from string to a calculated value.
  • Loading branch information
bcdonovan committed May 28, 2024
1 parent 250c634 commit cb4fe84
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 31 deletions.
8 changes: 7 additions & 1 deletion include/API/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,20 @@ inline int asMainReturnCode(llvm::Error err) {
/// @param treatWarningsAsErrors return errors in place of warnings
/// @param diagnosticCb an optional callback that will receive emitted
/// diagnostics
/// @param numberOfThreads number of threads to use in binding
/// -1 : number of cpus
/// 0
/// > 0: limit
/// defaults to -1
/// @return 0 on success
int bindArguments(std::string_view target, qssc::config::EmitAction action,
std::string_view configPath, std::string_view moduleInput,
std::string_view payloadOutputPath,
std::unordered_map<std::string, double> const &arguments,
bool treatWarningsAsErrors, bool enableInMemoryInput,
std::string *inMemoryOutput,
const OptDiagnosticCallback &onDiagnostic);
const OptDiagnosticCallback &onDiagnostic,
int numberOfThreads = -1);

} // namespace qssc
#endif // QSS_COMPILER_LIB_H
13 changes: 6 additions & 7 deletions include/Arguments/Arguments.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,12 @@ class BindArgumentsImplementationFactory {
};

// TODO generalize type of arguments
llvm::Error bindArguments(llvm::StringRef moduleInput,
llvm::StringRef payloadOutputPath,
ArgumentSource const &arguments,
bool treatWarningsAsErrors, bool enableInMemoryInput,
std::string *inMemoryOutput,
BindArgumentsImplementationFactory &factory,
const OptDiagnosticCallback &onDiagnostic);
llvm::Error
bindArguments(llvm::StringRef moduleInput, llvm::StringRef payloadOutputPath,
ArgumentSource const &arguments, bool treatWarningsAsErrors,
bool enableInMemoryInput, std::string *inMemoryOutput,
BindArgumentsImplementationFactory &factory,
const OptDiagnosticCallback &onDiagnostic, int numberOfThreads);

} // namespace qssc::arguments

Expand Down
16 changes: 9 additions & 7 deletions lib/API/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,8 @@ bindArguments_(std::string_view target, qssc::config::EmitAction action,
std::unordered_map<std::string, double> const &arguments,
bool treatWarningsAsErrors, bool enableInMemoryInput,
std::string *inMemoryOutput,
const qssc::OptDiagnosticCallback &onDiagnostic) {
const qssc::OptDiagnosticCallback &onDiagnostic,
int numberOfThreads) {

MLIRContext context{};

Expand Down Expand Up @@ -850,7 +851,8 @@ bindArguments_(std::string_view target, qssc::config::EmitAction action,
*factory.value();
return qssc::arguments::bindArguments(
moduleInput, payloadOutputPath, source, treatWarningsAsErrors,
enableInMemoryInput, inMemoryOutput, factoryRef, onDiagnostic);
enableInMemoryInput, inMemoryOutput, factoryRef, onDiagnostic,
numberOfThreads);
}

} // anonymous namespace
Expand All @@ -862,12 +864,12 @@ int qssc::bindArguments(
std::unordered_map<std::string, double> const &arguments,
bool treatWarningsAsErrors, bool enableInMemoryInput,
std::string *inMemoryOutput,
const qssc::OptDiagnosticCallback &onDiagnostic) {
const qssc::OptDiagnosticCallback &onDiagnostic, int numberOfThreads) {

if (auto err =
bindArguments_(target, action, configPath, moduleInput,
payloadOutputPath, arguments, treatWarningsAsErrors,
enableInMemoryInput, inMemoryOutput, onDiagnostic)) {
if (auto err = bindArguments_(
target, action, configPath, moduleInput, payloadOutputPath, arguments,
treatWarningsAsErrors, enableInMemoryInput, inMemoryOutput,
onDiagnostic, numberOfThreads)) {
llvm::logAllUnhandledErrors(std::move(err), llvm::errs());
return 1;
}
Expand Down
123 changes: 109 additions & 14 deletions lib/Arguments/Arguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,20 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"

#include <deque>
#include <fstream>
#include <functional>
#include <iostream>
#include <memory>
#include <mutex>
#include <optional>
#include <sstream>
#include <string>
#include <sys/types.h>
#include <system_error>
#include <thread>
#include <utility>
#include <vector>

namespace qssc::arguments {

Expand All @@ -44,9 +51,43 @@ llvm::Error updateParameters(qssc::payload::PatchablePayload *payload,
Signature &sig, ArgumentSource const &arguments,
bool treatWarningsAsErrors,
BindArgumentsImplementationFactory &factory,
const OptDiagnosticCallback &onDiagnostic) {
const OptDiagnosticCallback &onDiagnostic,
int numberOfThreads) {

for (const auto &[binaryName, patchPoints] : sig.patchPointsByBinary) {
std::deque<std::thread> threads;
std::vector<std::shared_ptr<BindArgumentsImplementation>> binaries;

bool const enableThreads = (numberOfThreads != 0);
uint MAX_NUM_THREADS = (numberOfThreads > 0)
? numberOfThreads
: std::thread::hardware_concurrency();

// if failed to detect number of CPUs default to 10
if (MAX_NUM_THREADS == 0)
MAX_NUM_THREADS = 10;

std::mutex errorMutex;
bool errorSet = false;
llvm::Error firstError = llvm::Error::success();

// the onDiagnastic method used to emit diagnostics to python
// is not thread safe
// setup of local callback to capture the highest level diagnostic
// and re-emit from the main thread if threading is being used
std::optional<qssc::Diagnostic> localDiagValue = std::nullopt;
std::optional<DiagnosticCallback> const localCallback =
std::optional(std::function([&](const Diagnostic &diag) {
if (!localDiagValue.has_value() ||
localDiagValue.value().severity < diag.severity) {
localDiagValue = diag;
}
}));

uint numThreads = 0;
for (const auto &entry : sig.patchPointsByBinary) {

const auto &binaryName = entry.first;
const auto &patchPoints = entry.second;

if (patchPoints.size() == 0) // no patch points
continue;
Expand All @@ -63,25 +104,78 @@ llvm::Error updateParameters(qssc::payload::PatchablePayload *payload,

auto &binaryData = binaryDataOrErr.get();

// onDiagnostic callback is not thread safe
auto localDiagnostic = (enableThreads) ? localCallback : onDiagnostic;

auto binary = std::shared_ptr<BindArgumentsImplementation>(
factory.create(binaryData, onDiagnostic));
factory.create(binaryData, localDiagnostic));
binary->setTreatWarningsAsErrors(treatWarningsAsErrors);

for (auto const &patchPoint : patchPoints)
if (auto err = binary->patch(patchPoint, arguments))
return err;
if (enableThreads) {
// save shared point in vector to ensure lifetime exceeds the thread
binaries.emplace_back(binary);

numThreads++;
if (numThreads > MAX_NUM_THREADS) {
// wait for a thread to finish before starting another
auto &t = threads[0];
t.join();
threads.pop_front();
}
threads.emplace_back([&, binary] {
if (errorSet)
return;

for (auto const &patchPoint : patchPoints) {
auto err = binary->patch(patchPoint, arguments);
if (err && !errorSet) {
const std::lock_guard<std::mutex> lock(errorMutex);
firstError = std::move(err);
errorSet = true;
return;
}
}
});
} else {
// processing patch points on main thread
for (auto const &patchPoint : patchPoints)
if (auto err = binary->patch(patchPoint, arguments))
return err;
}
}

if (enableThreads) {
for (auto &t : threads)
t.join();

binaries.clear();

if (errorSet || localDiagValue.has_value()) {
// emit error or warning via onDiagnostic if
// one was set
auto *diagnosticCallback =
onDiagnostic.has_value() ? &onDiagnostic.value() : nullptr;
if (diagnosticCallback && localDiagValue.has_value())
(*diagnosticCallback)(localDiagValue.value());
// possibly return the error
auto minLevel =
(treatWarningsAsErrors) ? Severity::Info : Severity::Warning;
if (localDiagValue.has_value() &&
(localDiagValue.value().severity > minLevel)) {
return firstError;
}
}
}

return llvm::Error::success();
}

llvm::Error bindArguments(llvm::StringRef moduleInput,
llvm::StringRef payloadOutputPath,
ArgumentSource const &arguments,
bool treatWarningsAsErrors, bool enableInMemoryInput,
std::string *inMemoryOutput,
BindArgumentsImplementationFactory &factory,
const OptDiagnosticCallback &onDiagnostic) {
llvm::Error
bindArguments(llvm::StringRef moduleInput, llvm::StringRef payloadOutputPath,
ArgumentSource const &arguments, bool treatWarningsAsErrors,
bool enableInMemoryInput, std::string *inMemoryOutput,
BindArgumentsImplementationFactory &factory,
const OptDiagnosticCallback &onDiagnostic, int numberOfThreads) {

bool const enableInMemoryOutput = payloadOutputPath == "";

Expand Down Expand Up @@ -134,7 +228,8 @@ llvm::Error bindArguments(llvm::StringRef moduleInput,
return err;

if (auto err = updateParameters(payload.get(), sigOrError.get(), arguments,
treatWarningsAsErrors, factory, onDiagnostic))
treatWarningsAsErrors, factory, onDiagnostic,
numberOfThreads))
return err;

// setup linked payload I/O
Expand Down
5 changes: 3 additions & 2 deletions python_lib/qss_compiler/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,15 @@ py::tuple py_link_file(const std::string &input, const bool enableInMemoryInput,
const std::string &configPath,
const std::unordered_map<std::string, double> &arguments,
bool treatWarningsAsErrors,
qssc::DiagnosticCallback onDiagnostic) {
qssc::DiagnosticCallback onDiagnostic,
int numberOfThreads = -1) {

std::string inMemoryOutput("");

int const status = qssc::bindArguments(
target, qssc::config::EmitAction::QEM, configPath, input, outputPath,
arguments, treatWarningsAsErrors, enableInMemoryInput, &inMemoryOutput,
std::move(onDiagnostic));
std::move(onDiagnostic), numberOfThreads);

bool const success = status == 0;
#ifndef NDEBUG
Expand Down
8 changes: 8 additions & 0 deletions python_lib/qss_compiler/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ class LinkOptions:
"""Treat link warnings as errors"""
on_diagnostic: Optional[Callable[[Diagnostic], Any]] = None
"""Optional callback for processing diagnostic messages from the linker."""
number_of_threads: int = -1
"""Number of threads to use in linking
-1 = number of cpus reported,
0 disabled,
> 1 = limit,
defaults = -1
"""


def _prepare_link_options(link_options: Optional[LinkOptions] = None, **kwargs) -> LinkOptions:
Expand Down Expand Up @@ -133,6 +140,7 @@ def on_diagnostic(diag):
link_options.arguments,
link_options.treat_warnings_as_errors,
link_options.on_diagnostic,
link_options.number_of_threads,
)
if not success:
exception_mapping = {
Expand Down
5 changes: 5 additions & 0 deletions releasenotes/notes/link-threading-d295374d01595205.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Adds `number_of_threads` parameter to qss_compiler.link_file
interface to control number of threads used during linking.

0 comments on commit cb4fe84

Please sign in to comment.