-
Notifications
You must be signed in to change notification settings - Fork 546
/
Copy pathonnxErrorRecorder.cpp
121 lines (108 loc) · 2.5 KB
/
onnxErrorRecorder.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnxErrorRecorder.hpp"
#include <exception>
namespace onnx2trt
{
ONNXParserErrorRecorder* ONNXParserErrorRecorder::create(
nvinfer1::ILogger* logger, nvinfer1::IErrorRecorder* otherRecorder)
{
try
{
auto recorder = new ONNXParserErrorRecorder(logger, otherRecorder);
if (recorder)
{
recorder->incRefCount();
}
return recorder;
}
catch (const std::exception& e)
{
logError(logger, e.what());
return nullptr;
}
}
void ONNXParserErrorRecorder::destroy(ONNXParserErrorRecorder*& recorder)
{
if (recorder)
{
recorder->decRefCount();
recorder = nullptr;
}
}
void ONNXParserErrorRecorder::logError(nvinfer1::ILogger* logger, const char* str)
{
if (logger)
{
logger->log(ILogger::Severity::kERROR, str);
}
}
ONNXParserErrorRecorder::ONNXParserErrorRecorder(
nvinfer1::ILogger* logger, nvinfer1::IErrorRecorder* otherRecorder)
: mUserRecorder(otherRecorder)
, mLogger(logger)
{
if (mUserRecorder)
{
mUserRecorder->incRefCount();
}
}
ONNXParserErrorRecorder::~ONNXParserErrorRecorder() noexcept
{
if (mUserRecorder)
{
mUserRecorder->decRefCount();
}
}
void ONNXParserErrorRecorder::clear() noexcept
{
try
{
// grab a lock so that there is no addition while clearing.
std::lock_guard<std::mutex> guard(mStackLock);
mErrorStack.clear();
}
catch (const std::exception& e)
{
logError(mLogger, e.what());
}
};
bool ONNXParserErrorRecorder::reportError(
nvinfer1::ErrorCode val, nvinfer1::IErrorRecorder::ErrorDesc desc) noexcept
{
try
{
std::lock_guard<std::mutex> guard(mStackLock);
mErrorStack.push_back(errorPair(val, desc));
if (mUserRecorder)
{
mUserRecorder->reportError(val, desc);
}
else
{
logError(mLogger, desc);
}
}
catch (const std::exception& e)
{
logError(mLogger, e.what());
}
// All errors are considered fatal.
return true;
}
nvinfer1::IErrorRecorder::RefCount ONNXParserErrorRecorder::incRefCount() noexcept
{
// Atomically increment or decrement the ref counter.
return ++mRefCount;
}
nvinfer1::IErrorRecorder::RefCount ONNXParserErrorRecorder::decRefCount() noexcept
{
auto newVal = --mRefCount;
if (newVal == 0)
{
delete this;
}
return newVal;
}
} // namespace onnx2trt