Skip to content

Commit

Permalink
Log dtype names on input dtype mismatch (#7537)
Browse files Browse the repository at this point in the history
Summary:
Update the error message when input tensor scalar type is incorrect. We've seen this get hit a few times and it should be easier to debug than it is.

New Message:
```
[method.cpp:834] Input 0 has unexpected scalar type: expected Float but was Byte.
```
Old Message:
```
[method.cpp:826] The 0-th input tensor's scalartype does not meet requirement: found 0 but expected 6
```

Test Plan:
Built executorch bento kernel locally and tested with an incorrect scalar type to view the new error message.
```
[method.cpp:834] Input 0 has unexpected scalar type: expected Float but was Byte.
```
I also locally patched and built the bento kernel with ET_ENABLE_ENUM_STRINGS=0.
```
[method.cpp:834] Input 0 has unexpected scalar type: expected 6 but was 0.
```

Reviewed By: digantdesai, SS-JIA

Differential Revision: D67887770

Pulled By: GregoryComer
  • Loading branch information
GregoryComer committed Jan 8, 2025
1 parent 39e8538 commit 4563528
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 5 deletions.
46 changes: 46 additions & 0 deletions runtime/core/exec_aten/util/scalar_type_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/platform/log.h>

namespace executorch {
namespace runtime {

/**
* Convert a scalar type value to a string representation. If
* ET_ENABLE_ENUM_STRINGS is set (it is on bby default), this will return a
* string name (for example, "Float"). Otherwise, it will return a string
* representation of the index value ("6").
*
* If the user buffer is not large enough to hold the string representation, the
* string will be truncated.
*
* The return value is the number of characters written, or in the case of
* truncation, the number of characters that would be written if the buffer was
* large enough.
*/
size_t scalar_type_to_string(
::executorch::aten::ScalarType t,
char* buffer,
size_t buffer_size) {
#if ET_ENABLE_ENUM_STRINGS
const char* name_str;
#define DEFINE_CASE(unused, name) \
case ::executorch::aten::ScalarType::name: \
name_str = #name; \
break;

switch (t) {
ET_FORALL_SCALAR_TYPES(DEFINE_CASE)
default:
name_str = "Unknown";
break;
}

return snprintf(buffer, buffer_size, "%s", name_str);
#undef DEFINE_CASE
#else
return snprintf(buffer, buffer_size, "%d", static_cast<int>(t));
#endif // ET_ENABLE_ENUM_TO_STRING
}

} // namespace runtime
} // namespace executorch
18 changes: 18 additions & 0 deletions runtime/core/exec_aten/util/scalar_type_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,24 @@ struct promote_types {
CTYPE_ALIAS, \
__VA_ARGS__))

/**
* Convert a scalar type value to a string representation. If
* ET_ENABLE_ENUM_STRINGS is set (it is on bby default), this will return a
* string name (for example, "Float"). Otherwise, it will return a string
* representation of the index value ("6").
*
* If the user buffer is not large enough to hold the string representation, the
* string will be truncated.
*
* The return value is the number of characters written, or in the case of
* truncation, the number of characters that would be written if the buffer was
* large enough.
*/
size_t scalar_type_to_string(
::executorch::aten::ScalarType t,
char* buffer,
size_t buffer_size);

} // namespace runtime
} // namespace executorch

Expand Down
5 changes: 4 additions & 1 deletion runtime/core/exec_aten/util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@ def define_common_targets():

runtime.cxx_library(
name = "scalar_type_util" + aten_suffix,
srcs = [],
srcs = ["scalar_type_util.cpp"],
exported_headers = [
"scalar_type_util.h",
],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//executorch/runtime/platform:platform",
],
exported_preprocessor_flags = exported_preprocessor_flags_,
exported_deps = exported_deps_,
exported_external_deps = ["libtorch"] if aten_mode else [],
Expand Down
1 change: 1 addition & 0 deletions runtime/core/portable_type/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def define_common_targets():
"scalar_type.h",
"qint_types.h",
"bits_types.h",
"string_view.h",
],
visibility = [
"//executorch/extension/...",
Expand Down
18 changes: 14 additions & 4 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -816,14 +816,24 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {
if (e.isTensor()) {
const auto& t_dst = e.toTensor();
const auto& t_src = input_evalue.toTensor();

#if ET_LOG_ENABLED
char dst_type_name[16];
char src_type_name[16];

scalar_type_to_string(
t_dst.scalar_type(), dst_type_name, sizeof(dst_type_name));
scalar_type_to_string(
t_src.scalar_type(), src_type_name, sizeof(src_type_name));
#endif

ET_CHECK_OR_RETURN_ERROR(
t_dst.scalar_type() == t_src.scalar_type(),
InvalidArgument,
"The %zu-th input tensor's scalartype does not meet requirement: found %" PRId8
" but expected %" PRId8,
"Input %zu has unexpected scalar type: expected %s but was %s.",
input_idx,
static_cast<int8_t>(t_src.scalar_type()),
static_cast<int8_t>(t_dst.scalar_type()));
dst_type_name,
src_type_name);
// Reset the shape for the Method's input as the size of forwarded input
// tensor for shape dynamism. Also is a safety check if need memcpy.
Error err = resize_tensor(t_dst, t_src.sizes());
Expand Down
1 change: 1 addition & 0 deletions runtime/executor/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def define_common_targets():
"//executorch/runtime/core:evalue" + aten_suffix,
"//executorch/runtime/core:event_tracer" + aten_suffix,
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
"//executorch/runtime/core/exec_aten/util:scalar_type_util" + aten_suffix,
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
"//executorch/runtime/kernel:kernel_runtime_context" + aten_suffix,
"//executorch/runtime/kernel:operator_registry",
Expand Down
6 changes: 6 additions & 0 deletions runtime/platform/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
#define ET_LOG_ENABLED 1
#endif // !defined(ET_LOG_ENABLED)

// Enable ET_ENABLE_ENUM_STRINGS by default. This option gates inclusion of
// enum string names and can be disabled by explicitly setting it to 0.
// #ifndef ET_ENABLE_ENUM_STRINGS
// #define ET_ENABLE_ENUM_STRINGS 1
#// endif

namespace executorch {
namespace runtime {

Expand Down

0 comments on commit 4563528

Please sign in to comment.