Skip to content

Commit

Permalink
Merge pull request #20 from ibois-epfl/tool_classification_label_check
Browse files Browse the repository at this point in the history
Defensive checks for ml labels in config
  • Loading branch information
9and3 authored Jan 22, 2024
2 parents dcc8f3a + 3fb4404 commit 3ae0673
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
80 changes: 79 additions & 1 deletion include/config.hh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
#include <variant>
#include <functional>
#include <unordered_map>
#include <filesystem>
#include <algorithm>
#include <unordered_set>
#include <fstream>

namespace ttool
{
Expand Down Expand Up @@ -203,6 +207,77 @@ namespace ttool
LoadConfigFile();
}

/**
* @brief Check if the acit names match the folder names
*
*/
void CheckAcitFiles(const std::string& TToolRootPath)
{

std::filesystem::path rootPath = std::filesystem::current_path() / TToolRootPath;

std::string line;
std::string toolheadNameTagStart = "<toolhead name=\"";
std::string toolheadNameTagEnd = "\"";

for (const auto& acitFile : m_ConfigData.AcitFiles) {
std::string acitFileR = acitFile.substr(1);
std::filesystem::path acitFilePath = rootPath / acitFileR;
std::string toolheadName = "";

std::ifstream fs(acitFilePath);
if (!fs.is_open()) {
throw std::runtime_error("Could not open file: " + acitFilePath.string());
}

while (std::getline(fs, line)) {
size_t start = line.find(toolheadNameTagStart);
if (start != std::string::npos) {
start += toolheadNameTagStart.length();
size_t end = line.find(toolheadNameTagEnd, start);
if (end != std::string::npos) {
toolheadName = line.substr(start, end - start);
break;
}
}
}
fs.close();

if (acitFile.find(toolheadName) == std::string::npos) {
throw std::runtime_error("Toolhead name mismatch error: Toolhead name \"" + toolheadName +
"\" does not match the folder name \"" + acitFile + "\"");
}
}
}

/**
* @brief Check if the labels in the config file match the file paths of model files and acit files
*
*/
void CheckClassifierLabelsConfig()
{
std::unordered_set<std::string> filePaths;

for (const auto& modelFile : m_ConfigData.ModelFiles) {
filePaths.insert( modelFile);
}
for (const auto& acitFile : m_ConfigData.AcitFiles) {
filePaths.insert(acitFile);
}

for (const auto& label : m_ConfigData.ClassifierLabels) {
bool labelMatches = std::any_of(filePaths.begin(), filePaths.end(),
[&label](const std::string& filePath) {
return filePath.find(label) != std::string::npos;
});

if (!labelMatches) {
throw std::runtime_error("Label mismatch error: Label \"" + label + "\" does not match any file paths");
}
}

}

/**
* @brief Read the config file and set the values to the ConfigData object
*
Expand Down Expand Up @@ -248,7 +323,7 @@ namespace ttool
return fs.release();
}

/**
/**
* @brief Print the config file to the console
*
*/
Expand Down Expand Up @@ -341,17 +416,20 @@ namespace ttool
*/
ConfigData GetConfigData()
{
std::vector<std::string> fileNames;
// Create a copy of the ConfigData object
ConfigData configData = this->m_ConfigData;
// Prefix the model files with the m_TToolRootPath
for (auto& modelFile : configData.ModelFiles)
{
modelFile = std::string(m_TToolRootPath) + "/" + modelFile;
fileNames.push_back(modelFile);
}
// Prefix the acit files with the m_TToolRootPath
for (auto& acitFile : configData.AcitFiles)
{
acitFile = std::string(m_TToolRootPath) + "/" + acitFile;
fileNames.push_back(acitFile);
}
// Prefix the classifier model path with the m_TToolRootPath
configData.ClassifierModelPath = std::string(m_TToolRootPath) + "/" + configData.ClassifierModelPath;
Expand Down
2 changes: 2 additions & 0 deletions include/ttool.hh
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ namespace ttool
m_ConfigFile = configFile;
m_ConfigPtr = std::make_shared<ttool::Config>(configFile);
m_ConfigPtr->SetTToolRootPath(ttoolRootPath);
m_ConfigPtr->CheckAcitFiles(ttoolRootPath);
m_ConfigPtr->CheckClassifierLabelsConfig();
}

/**
Expand Down

0 comments on commit 3ae0673

Please sign in to comment.