Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PL-133119] Topological sort before rewriting aggregate jobs #1

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 217 additions & 118 deletions src/hydra-eval-jobs/hydra-eval-jobs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,212 @@ static void worker(
writeLine(to.get(), "restart");
}

struct DependencyCycle : public std::exception {
std::string a;
std::string b;
std::set<std::string> remainingAggregates;

DependencyCycle(const std::string & a, const std::string & b, const std::set<std::string> & remainingAggregates) : a(a), b(b), remainingAggregates(remainingAggregates) {}

std::string what() {
return fmt("Dependency cycle: %s <-> %s", a, b);
}
};

struct AggregateJob
{
std::string name;
std::set<std::string> dependencies;
std::unordered_map<std::string, std::string> brokenJobs;

bool operator<(const AggregateJob & b) const { return name < b.name; }
};

// This is copied from `libutil/topo-sort.hh` in CppNix and slightly modified.
// However, I needed a way to use strings as identifiers to sort, but still be able
// to put AggregateJob objects into this function since I'd rather not
// have to transform back and forth between a list of strings and AggregateJobs
// in resolveNamedConstituents.
std::vector<AggregateJob> topoSort(std::set<AggregateJob> items)
{
std::vector<AggregateJob> sorted;
std::set<std::string> visited, parents;

std::map<std::string, AggregateJob> dictIdentToObject;
for (auto & it : items) {
dictIdentToObject.insert({it.name, it});
}

std::function<void(const std::string & path, const std::string * parent)> dfsVisit;

dfsVisit = [&](const std::string & path, const std::string * parent) {
if (parents.count(path)) {
dictIdentToObject.erase(path);
dictIdentToObject.erase(*parent);
std::set<std::string> remaining;
for (auto & [k, _] : dictIdentToObject) {
remaining.insert(k);
}
throw DependencyCycle(path, *parent, remaining);
}

if (!visited.insert(path).second) return;
parents.insert(path);

std::set<std::string> references = dictIdentToObject[path].dependencies;

for (auto & i : references)
/* Don't traverse into items that don't exist in our starting set. */
if (i != path && dictIdentToObject.find(i) != dictIdentToObject.end())
dfsVisit(i, &path);

sorted.push_back(dictIdentToObject[path]);
parents.erase(path);
};

for (auto & [i, _] : dictIdentToObject)
dfsVisit(i, nullptr);

return sorted;
}

static bool insertMatchingConstituents(const std::string & childJobName,
const std::string & jobName,
std::function<bool(const std::string &, nlohmann::json &)> isBroken,
nlohmann::json & jobs,
std::set<std::string> & results)
{
bool expansionFound = false;
for (auto job = jobs.begin(); job != jobs.end(); job++) {
// If all jobs are selected by an aggregate job, select all
// jobs except itself.
if (childJobName == "*" && job.key() == jobName) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can drop the "*" comparison and just always filter out the parent job.

continue;
}
auto jobName = job.key();
if (fnmatch(childJobName.c_str(), jobName.c_str(), 0) == 0 && !isBroken(jobName, *job)) {
results.insert(jobName);
expansionFound = true;
}
}

return expansionFound;
}

static std::vector<AggregateJob> resolveNamedConstituents(nlohmann::json & jobs)
{
std::set<AggregateJob> aggregateJobs;
for (auto i = jobs.begin(); i != jobs.end(); ++i) {
auto jobName = i.key();
auto & job = i.value();

auto named = job.find("namedConstituents");
if (named != job.end()) {
bool globConstituents = job.value<bool>("globConstituents", false);
std::unordered_map<std::string, std::string> brokenJobs;
std::set<std::string> results;

auto isBroken = [&brokenJobs, &jobName](
const std::string & childJobName, nlohmann::json & job) -> bool {
if (job.find("error") != job.end()) {
std::string error = job["error"];
printError("aggregate job '%s' references broken job '%s': %s", jobName, childJobName, error);
brokenJobs[childJobName] = error;
return true;
} else {
return false;
}
};

for (const std::string & childJobName : *named) {
auto childJob = jobs.find(childJobName);
if (childJob == jobs.end()) {
if (!globConstituents) {
printError("aggregate job '%s' references non-existent job '%s'", jobName, childJobName);
brokenJobs[childJobName] = "does not exist";
} else if (!insertMatchingConstituents(childJobName, jobName, isBroken, jobs, results)) {
warn("aggregate job '%s' references constituent glob pattern '%s' with no matches", jobName, childJobName);
brokenJobs[childJobName] = "constituent glob pattern had no matches";
}
} else if (!isBroken(childJobName, *childJob)) {
results.insert(childJobName);
}
}

aggregateJobs.insert(AggregateJob(jobName, results, brokenJobs));
}
}

return topoSort(aggregateJobs);
}

static void rewriteAggregates(nlohmann::json & jobs,
std::vector<AggregateJob> aggregateJobs,
bool dryRun,
ref<Store> store)
{
for (auto & aggregateJob : aggregateJobs) {
auto & job = jobs.find(aggregateJob.name).value();
if (dryRun) {
for (auto & childJobName : aggregateJob.dependencies) {
std::string constituentDrvPath = jobs[childJobName]["drvPath"];
job["constituents"].push_back(constituentDrvPath);
}
} else {
auto drvPath = store->parseStorePath((std::string) job["drvPath"]);
auto drv = store->readDerivation(drvPath);

for (auto & childJobName : aggregateJob.dependencies) {
auto childDrvPath = store->parseStorePath((std::string) jobs[childJobName]["drvPath"]);
auto childDrv = store->readDerivation(childDrvPath);
job["constituents"].push_back(store->printStorePath(childDrvPath));
drv.inputDrvs.map[childDrvPath].value = {childDrv.outputs.begin()->first};
}

if (aggregateJob.brokenJobs.empty()) {
std::string drvName(drvPath.name());
assert(hasSuffix(drvName, drvExtension));
drvName.resize(drvName.size() - drvExtension.size());

auto hashModulo = hashDerivationModulo(*store, drv, true);
if (hashModulo.kind != DrvHash::Kind::Regular) continue;
auto h = hashModulo.hashes.find("out");
if (h == hashModulo.hashes.end()) continue;
auto outPath = store->makeOutputPath("out", h->second, drvName);
drv.env["out"] = store->printStorePath(outPath);
drv.outputs.insert_or_assign("out", DerivationOutput::InputAddressed { .path = outPath });
auto newDrvPath = store->printStorePath(writeDerivation(*store, drv));

debug("rewrote aggregate derivation %s -> %s", store->printStorePath(drvPath), newDrvPath);

job["drvPath"] = newDrvPath;
job["outputs"]["out"] = store->printStorePath(outPath);
}
}

job.erase("namedConstituents");

/* Register the derivation as a GC root. !!! This
registers roots for jobs that we may have already
done. */
auto localStore = store.dynamic_pointer_cast<LocalFSStore>();
if (gcRootsDir != "" && localStore) {
auto drvPath = job["drvPath"].get<std::string>();
Path root = gcRootsDir + "/" + std::string(baseNameOf(drvPath));
if (!pathExists(root))
localStore->addPermRoot(localStore->parseStorePath(drvPath), root);
}

if (!aggregateJob.brokenJobs.empty()) {
std::stringstream ss;
for (const auto& [jobName, error] : aggregateJob.brokenJobs) {
ss << jobName << ": " << error << "\n";
}
job["error"] = ss.str();
}
}
}

int main(int argc, char * * argv)
{
/* Prevent undeclared dependencies in the evaluation via
Expand Down Expand Up @@ -494,129 +700,22 @@ int main(int argc, char * * argv)
if (state->exc)
std::rethrow_exception(state->exc);

/* For aggregate jobs that have named consistuents
/* For aggregate jobs that have named constituents
(i.e. constituents that are a job name rather than a
derivation), look up the referenced job and add it to the
dependencies of the aggregate derivation. */
auto store = openStore();

for (auto i = state->jobs.begin(); i != state->jobs.end(); ++i) {
auto jobName = i.key();
auto & job = i.value();

auto named = job.find("namedConstituents");
if (named == job.end()) continue;

bool globConstituents = job.value<bool>("globConstituents", false);

std::unordered_map<std::string, std::string> brokenJobs;
auto isBroken = [&brokenJobs, &jobName](
const std::string & childJobName, nlohmann::json & job) -> bool {
if (job.find("error") != job.end()) {
std::string error = job["error"];
printError("aggregate job '%s' references broken job '%s': %s", jobName, childJobName, error);
brokenJobs[childJobName] = error;
return true;
} else {
return false;
}
};
auto getNonBrokenJobsOrRecordError = [&state, &isBroken, &jobName, &brokenJobs, &globConstituents](
const std::string & childJobName) -> std::vector<nlohmann::json> {
auto childJob = state->jobs.find(childJobName);
std::vector<nlohmann::json> results;
if (childJob == state->jobs.end()) {
if (!globConstituents) {
printError("aggregate job '%s' references non-existent job '%s'", jobName, childJobName);
brokenJobs[childJobName] = "does not exist";
} else {
for (auto job = state->jobs.begin(); job != state->jobs.end(); job++) {
auto jobName = job.key();
if (fnmatch(childJobName.c_str(), jobName.c_str(), 0) == 0
&& !isBroken(jobName, *job)
) {
results.push_back(*job);
}
}
if (results.empty()) {
warn("aggregate job '%s' references constituent glob pattern '%s' with no matches", jobName, childJobName);
brokenJobs[childJobName] = "constituent glob pattern had no matches";
}
}
} else if (!isBroken(childJobName, *childJob)) {
results.push_back(*childJob);
}
return results;
};

if (myArgs.dryRun) {
for (std::string jobName2 : *named) {
auto foundJobs = getNonBrokenJobsOrRecordError(jobName2);
if (foundJobs.empty()) {
continue;
}
for (auto & childJob : foundJobs) {
std::string constituentDrvPath = childJob["drvPath"];
job["constituents"].push_back(constituentDrvPath);
}
}
} else {
auto drvPath = store->parseStorePath((std::string) job["drvPath"]);
auto drv = store->readDerivation(drvPath);

for (std::string jobName2 : *named) {
auto foundJobs = getNonBrokenJobsOrRecordError(jobName2);
if (foundJobs.empty()) {
continue;
}
for (auto & childJob : foundJobs) {
auto childDrvPath = store->parseStorePath((std::string) childJob["drvPath"]);
auto childDrv = store->readDerivation(childDrvPath);
job["constituents"].push_back(store->printStorePath(childDrvPath));
drv.inputDrvs.map[childDrvPath].value = {childDrv.outputs.begin()->first};
}
}

if (brokenJobs.empty()) {
std::string drvName(drvPath.name());
assert(hasSuffix(drvName, drvExtension));
drvName.resize(drvName.size() - drvExtension.size());

auto hashModulo = hashDerivationModulo(*store, drv, true);
if (hashModulo.kind != DrvHash::Kind::Regular) continue;
auto h = hashModulo.hashes.find("out");
if (h == hashModulo.hashes.end()) continue;
auto outPath = store->makeOutputPath("out", h->second, drvName);
drv.env["out"] = store->printStorePath(outPath);
drv.outputs.insert_or_assign("out", DerivationOutput::InputAddressed { .path = outPath });
auto newDrvPath = store->printStorePath(writeDerivation(*store, drv));

debug("rewrote aggregate derivation %s -> %s", store->printStorePath(drvPath), newDrvPath);

job["drvPath"] = newDrvPath;
job["outputs"]["out"] = store->printStorePath(outPath);
}
}

job.erase("namedConstituents");

/* Register the derivation as a GC root. !!! This
registers roots for jobs that we may have already
done. */
auto localStore = store.dynamic_pointer_cast<LocalFSStore>();
if (gcRootsDir != "" && localStore) {
auto drvPath = job["drvPath"].get<std::string>();
Path root = gcRootsDir + "/" + std::string(baseNameOf(drvPath));
if (!pathExists(root))
localStore->addPermRoot(localStore->parseStorePath(drvPath), root);
}

if (!brokenJobs.empty()) {
std::stringstream ss;
for (const auto& [jobName, error] : brokenJobs) {
ss << jobName << ": " << error << "\n";
}
job["error"] = ss.str();
try {
auto namedConstituents = resolveNamedConstituents(state->jobs);
rewriteAggregates(state->jobs, namedConstituents, myArgs.dryRun, store);
} catch (DependencyCycle & e) {
printError("Found dependency cycle between jobs '%s' and '%s'", e.a, e.b);
state->jobs[e.a]["error"] = e.what();
state->jobs[e.b]["error"] = e.what();

for (auto & jobName : e.remainingAggregates) {
state->jobs[jobName]["error"] = "Skipping aggregate because of a dependency cycle";
}
}

Expand Down
Loading
Loading