-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6a09923
commit a087ea8
Showing
10 changed files
with
152 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
cmake_minimum_required(VERSION 3.13) | ||
|
||
project(output_processing_lib CXX) | ||
|
||
set(CMAKE_CXX_STANDARD 17) | ||
set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
set(CMAKE_POSITION_INDEPENDENT_CODE ON) | ||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin") | ||
|
||
file(GLOB_RECURSE HEADERS "${CMAKE_CURRENT_SOURCE_DIR}/include/*") | ||
file(GLOB_RECURSE SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/*") | ||
|
||
file(GLOB_RECURSE PY_SOURCES ${HEADERS} ${SOURCES} output_processing.cpp wrappers.hpp) | ||
|
||
add_subdirectory(thirdparty/pybind11) | ||
|
||
pybind11_add_module(${PROJECT_NAME}_py MODULE ${PY_SOURCES}) | ||
add_library(${PROJECT_NAME} ${HEADERS} ${SOURCES}) | ||
|
||
target_include_directories(${PROJECT_NAME}_py PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") | ||
target_include_directories(${PROJECT_NAME} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") |
5 changes: 5 additions & 0 deletions
5
src/output_processing/include/output_processing/exception_handler.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#pragma once | ||
|
||
#define DLB_ASSERT(cond, msg) \ | ||
if (!(cond)) \ | ||
return false; |
9 changes: 9 additions & 0 deletions
9
src/output_processing/include/output_processing/output_handlers.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#pragma once | ||
|
||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
|
||
|
||
bool ClassificationTask(const std::map<std::string, std::vector<std::vector<float>>>& output_tensors, | ||
const size_t number_top, const std::string& labels); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#pragma once | ||
|
||
#include <string> | ||
#include <vector> | ||
#include <algorithm> | ||
#include <numeric> | ||
|
||
|
||
std::vector<std::string> read_labels(const std::string& label_path); | ||
|
||
template<typename T> | ||
std::vector<size_t> argsort(const std::vector<T> &array) { | ||
std::vector<size_t> indices(array.size()); | ||
std::iota(indices.begin(), indices.end(), 0); | ||
std::sort(indices.begin(), indices.end(), | ||
[&array](int left, int right) -> bool { | ||
return array[left] < array[right]; | ||
}); | ||
|
||
return indices; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#include <pybind11/pybind11.h> | ||
|
||
#include "wrappers.hpp" | ||
|
||
PYBIND11_MODULE(output_processing, m) { | ||
m.def("ClassificationTask", &ClassificationTaskPy); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#include "output_processing/output_handlers.hpp" | ||
|
||
#include "output_processing/exception_handler.hpp" | ||
#include "output_processing/utils.hpp" | ||
|
||
#include <iostream> | ||
#include <iomanip> | ||
|
||
|
||
bool ClassificationTask(const std::map<std::string, std::vector<std::vector<float>>>& output_tensors, | ||
const size_t number_top, const std::string& label_file) { | ||
DLB_ASSERT(output_tensors.size() == 1) | ||
|
||
const auto result_tensor = *output_tensors.cbegin(); | ||
const auto layer_name = result_tensor.first; | ||
const auto data = result_tensor.second; | ||
const auto labels = read_labels(label_file); | ||
const auto batch = data.size(); | ||
|
||
std::cout << "[ INFO ] Top " + std::to_string(number_top) + " results:\n"; | ||
|
||
for (size_t i = 0; i < batch; ++i) { | ||
const auto batch_result = data[i]; | ||
const auto top_idxs = argsort(batch_result); | ||
|
||
std::cout << "[ INFO ] Result for image " + std::to_string(i + 1) + ":\n"; | ||
for (size_t idx = 0; idx < number_top; ++idx) { | ||
std::cout.precision(2); | ||
const auto sorted_index = top_idxs[idx]; | ||
std::cout << std::setw(10) << batch_result[sorted_index] << std::setw(5) << labels[sorted_index] << "\n"; | ||
} | ||
} | ||
return true; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#include "output_processing/utils.hpp" | ||
|
||
#include <fstream> | ||
#include <filesystem> | ||
|
||
|
||
std::vector<std::string> read_labels(const std::string& label_path) { | ||
std::vector<std::string> labels; | ||
|
||
std::filesystem::path path = label_path; | ||
std::ifstream in(path); | ||
if (in.is_open()) { | ||
std::string label; | ||
while (std::getline(in, label)) { | ||
labels.push_back(label); | ||
} | ||
} | ||
in.close(); | ||
if (path.extension() == ".json") { | ||
// Remove parenthless | ||
labels.pop_back(); | ||
labels.erase(labels.begin()); | ||
} | ||
return labels; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#pragma once | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
#include <pybind11/numpy.h> | ||
|
||
#include "output_processing/output_handlers.hpp" | ||
#include "output_processing/exception_handler.hpp" | ||
|
||
#include <string> | ||
#include <vector> | ||
#include <iostream> | ||
|
||
|
||
namespace py = pybind11; | ||
|
||
|
||
void ClassificationTaskPy(const py::dict& map, const size_t number_top, | ||
const std::string& label_file) { | ||
std::map<std::string, std::vector<std::vector<float>>> tensors; | ||
for (const auto& it : map) { | ||
const std::string& layer_name = py::cast<const std::string>(it.first); | ||
const py::array& py_tensor = py::cast<const py::array>(it.second); | ||
} | ||
} |