Skip to content

Commit

Permalink
Update ShapeTensor methods
Browse files Browse the repository at this point in the history
  • Loading branch information
pskiran1 committed Jun 12, 2024
1 parent c3829b4 commit 91f1b29
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 54 deletions.
17 changes: 1 addition & 16 deletions src/instance_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1357,24 +1357,9 @@ ModelInstanceState::GetRequestShapeValues(
element_cnt /= shape[0];
}

const size_t datatype_size = TRITONSERVER_DataTypeByteSize(datatype);
const size_t expected_byte_size = element_cnt * datatype_size;

if ((expected_byte_size != data_byte_size) &&
(expected_byte_size != (data_byte_size - datatype_size))) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string("shape tensor for input '") + input_name +
"' expected byte size is " + std::to_string(expected_byte_size) +
" [ or " + std::to_string(expected_byte_size + datatype_size) +
" if input includes batch shape value] " + ", got " +
std::to_string(data_byte_size))
.c_str());
}

auto it = request_shape_values->emplace(io_index, ShapeTensor()).first;
RETURN_IF_ERROR(it->second.SetDataFromBuffer(
data_buffer, datatype, element_cnt, support_batching_,
data_buffer, data_byte_size, datatype, element_cnt, input_name, support_batching_,
total_batch_size));
}
}
Expand Down
73 changes: 48 additions & 25 deletions src/shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,75 +30,79 @@ namespace triton { namespace backend { namespace tensorrt {

TRITONSERVER_Error*
ShapeTensor::SetDataFromBuffer(
const char* data, const TRITONSERVER_DataType datatype,
const size_t element_cnt, const bool support_batching,
const size_t total_batch_size)
const char* data_buffer, size_t data_byte_size,
TRITONSERVER_DataType datatype, size_t nb_shape_values,
const char* input_name, bool support_batching, size_t total_batch_size)
{
if (data == nullptr) {
if (data_buffer == nullptr) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"Null data pointer received for Shape tensor");
}

element_cnt_ = element_cnt;
size_t datatype_size;

if (datatype == TRITONSERVER_DataType::TRITONSERVER_TYPE_INT32) {
datatype_size = sizeof(int32_t);
datatype_ = ShapeTensorDataType::INT32;
} else if (datatype == TRITONSERVER_DataType::TRITONSERVER_TYPE_INT64) {
datatype_size = sizeof(int64_t);
datatype_ = ShapeTensorDataType::INT64;
} else {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"Unsupported data type received for Shape tensor");
}

nb_shape_values_ = nb_shape_values;
if (support_batching) {
element_cnt_++; // Account for batch size
size_ = element_cnt_ * datatype_size;
data_.reset(new char[size_]);
nb_shape_values_++; // Account for batch size
}
const size_t datatype_size = TRITONSERVER_DataTypeByteSize(datatype);
size_ = nb_shape_values_ * datatype_size;

TRITONSERVER_Error* err =
ValidateDataByteSize(data_byte_size, input_name, datatype_size);
if (err != nullptr) {
return err;
}

data_ = std::make_unique<char[]>(size_);
if (support_batching) {
if (datatype_ == ShapeTensorDataType::INT32) {
*reinterpret_cast<int32_t*>(data_.get()) =
static_cast<int32_t>(total_batch_size);
} else if (datatype_ == ShapeTensorDataType::INT64) {
*reinterpret_cast<int64_t*>(data_.get()) =
static_cast<int64_t>(total_batch_size);
}
std::memcpy(data_.get() + datatype_size, data, (size_ - datatype_size));
std::memcpy(
data_.get() + datatype_size, data_buffer, size_ - datatype_size);
} else {
size_ = element_cnt_ * datatype_size;
data_.reset(new char[size_]);
std::memcpy(data_.get(), data, size_);
std::memcpy(data_.get(), data_buffer, size_);
}

return nullptr;
}

TRITONSERVER_Error*
ShapeTensor::SetDataFromShapeValues(
const int32_t* shape_values, const TRITONSERVER_DataType datatype,
const size_t element_cnt)
const int32_t* shape_values, TRITONSERVER_DataType datatype,
size_t nb_shape_values)
{
element_cnt_ = element_cnt;
nb_shape_values_ = nb_shape_values;
size_t datatype_size;

if (datatype == TRITONSERVER_DataType::TRITONSERVER_TYPE_INT32) {
datatype_size = sizeof(int32_t);
datatype_ = ShapeTensorDataType::INT32;
size_ = element_cnt_ * datatype_size;
size_ = nb_shape_values_ * datatype_size;
data_.reset(new char[size_]);
int32_t* data_ptr = reinterpret_cast<int32_t*>(data_.get());
std::memcpy(data_ptr, shape_values, size_);
} else if (datatype == TRITONSERVER_DataType::TRITONSERVER_TYPE_INT64) {
datatype_size = sizeof(int64_t);
datatype_ = ShapeTensorDataType::INT64;
size_ = element_cnt_ * datatype_size;
size_ = nb_shape_values_ * datatype_size;
data_.reset(new char[size_]);
int64_t* data_ptr = reinterpret_cast<int64_t*>(data_.get());
for (size_t i = 0; i < element_cnt_; ++i) {
for (size_t i = 0; i < nb_shape_values_; ++i) {
data_ptr[i] = static_cast<int64_t>(shape_values[i]);
}
} else {
Expand All @@ -112,23 +116,23 @@ ShapeTensor::SetDataFromShapeValues(

int64_t
ShapeTensor::GetDistance(
const ShapeTensor& other, const int64_t total_batch_size) const
const ShapeTensor& other, int64_t total_batch_size) const
{
int64_t distance = 0;
if (datatype_ == ShapeTensorDataType::INT32) {
const auto* shape_values = reinterpret_cast<const int32_t*>(data_.get());
const auto* opt_shape_values =
reinterpret_cast<const int32_t*>(other.GetData());
distance += std::abs(*opt_shape_values - total_batch_size);
for (size_t idx = 1; idx < other.GetElementCount(); idx++) {
for (size_t idx = 1; idx < other.GetNbShapeValues(); idx++) {
distance += std::abs(*(opt_shape_values + idx) - shape_values[idx - 1]);
}
} else {
const auto* shape_values = reinterpret_cast<const int64_t*>(data_.get());
const auto* opt_shape_values =
reinterpret_cast<const int64_t*>(other.GetData());
distance += std::abs(*opt_shape_values - total_batch_size);
for (size_t idx = 1; idx < other.GetElementCount(); idx++) {
for (size_t idx = 1; idx < other.GetNbShapeValues(); idx++) {
distance += std::abs(*(opt_shape_values + idx) - shape_values[idx - 1]);
}
}
Expand All @@ -149,4 +153,23 @@ ShapeTensor::GetDataTypeString() const
return nullptr;
}

TRITONSERVER_Error*
ShapeTensor::ValidateDataByteSize(
size_t expected_byte_size, const char* input_name,
size_t datatype_size) const
{
if (expected_byte_size != (size_ - datatype_size) &&
(expected_byte_size != size_)) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string("shape tensor for input '") + input_name +
"' expected byte size is " + std::to_string(expected_byte_size) +
" [ or " + std::to_string(size_) +
" if input includes batch shape value] " + ", got " +
std::to_string(expected_byte_size))
.c_str());
}
return nullptr;
}

}}} // namespace triton::backend::tensorrt
25 changes: 14 additions & 11 deletions src/shape_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,34 +41,37 @@ enum class ShapeTensorDataType { INT32, INT64 };
class ShapeTensor {
public:
ShapeTensor()
: size_(0), element_cnt_(0), datatype_(ShapeTensorDataType::INT32),
data_(nullptr)
: size_(0), nb_shape_values_(0), datatype_(ShapeTensorDataType::INT32)
{
}

TRITONSERVER_Error* SetDataFromBuffer(
const char* data, const TRITONSERVER_DataType datatype,
const size_t element_cnt, const bool support_batching,
const size_t total_batch_size);
const char* data_buffer, size_t data_byte_size,
TRITONSERVER_DataType datatype, size_t nb_shape_values,
const char* input_name, bool support_batching, size_t total_batch_size);

TRITONSERVER_Error* SetDataFromShapeValues(
const int32_t* shape_values, const TRITONSERVER_DataType datatype,
const size_t element_cnt);
const int32_t* shape_values, TRITONSERVER_DataType datatype,
size_t nb_shape_values);

int64_t GetDistance(const ShapeTensor& other, int64_t total_batch_size) const;

int64_t GetDistance(
const ShapeTensor& other, const int64_t total_batch_size) const;
const char* GetDataTypeString() const;

size_t GetSize() const { return size_; }
size_t GetElementCount() const { return element_cnt_; }
size_t GetNbShapeValues() const { return nb_shape_values_; }
ShapeTensorDataType GetDataType() const { return datatype_; }
const void* GetData() const { return static_cast<const void*>(data_.get()); }

private:
size_t size_;
size_t element_cnt_;
size_t nb_shape_values_;
ShapeTensorDataType datatype_;
std::unique_ptr<char[]> data_;

TRITONSERVER_Error* ValidateDataByteSize(
size_t expected_byte_size, const char* input_name,
size_t datatype_size) const;
};

}}} // namespace triton::backend::tensorrt
4 changes: 2 additions & 2 deletions src/tensorrt_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,13 @@ ValidateShapeValues(
const ShapeTensor& min_shape_values, const ShapeTensor& max_shape_values,
size_t nb_shape_values)
{
if (request_shape_values.GetElementCount() != nb_shape_values) {
if (request_shape_values.GetNbShapeValues() != nb_shape_values) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string(
"mismatch between the number of shape values. Expecting ") +
std::to_string(nb_shape_values) + ". Got " +
std::to_string(request_shape_values.GetElementCount()))
std::to_string(request_shape_values.GetNbShapeValues()))
.c_str());
}

Expand Down

0 comments on commit 91f1b29

Please sign in to comment.