Skip to content

Commit

Permalink
PR tensorflow#17170: Code dedup in execution_trace_utils LiteralToValue
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#17170

Code dedup in execution_trace_utils LiteralToValue

### Issue descr:
To avoid having to modify this every time a new FP8 type is added, remove all these FP8 cases and check if IsF8Type(literal.shape().element_type() before the switch statement.

Copybara import of the project:

--
2b50deff921dd98b530df1994b3317073ac528f7 by Alexander Pivovarov <pivovaa@amazon.com>:

Code dedup in execution_trace_utils LiteralToValue

Merging this change closes tensorflow#17170

PiperOrigin-RevId: 675361983
  • Loading branch information
apivovarov authored and tensorflower-gardener committed Sep 17, 2024
1 parent 3ecd4ca commit 6ed0d07
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ cc_library(
":execution_trace_proto_cc",
":execution_trace_proto_cc_impl",
"//xla:literal",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/mlir/tools/mlir_interpreter/framework",
"@com_google_absl//absl/status",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ limitations under the License.
#include "xla/mlir/tools/mlir_interpreter/framework/interpreter_value.h"
#include "xla/mlir/tools/mlir_interpreter/framework/tensor_or_memref.h"
#include "xla/mlir/tools/mlir_replay/public/execution_trace.pb.h"
#include "xla/primitive_util.h"
#include "tsl/platform/statusor.h"

namespace mlir {
Expand Down Expand Up @@ -251,7 +252,13 @@ absl::StatusOr<InterpreterValue> LiteralToValue(const xla::Literal& literal) {
}

if (literal.shape().IsArray()) {
switch (literal.shape().element_type()) {
auto type = literal.shape().element_type();
if (xla::primitive_util::IsF8Type(type)) {
return absl::UnimplementedError(
absl::StrCat(xla::primitive_util::LowercasePrimitiveTypeName(type),
" not implemented"));
}
switch (type) {
case xla::PRED:
return {{ArrayLiteralToTensor<bool>(literal)}};
case xla::S8:
Expand All @@ -278,16 +285,6 @@ absl::StatusOr<InterpreterValue> LiteralToValue(const xla::Literal& literal) {
return absl::UnimplementedError("BF16 not implemented");
case xla::F64:
return {{ArrayLiteralToTensor<double>(literal)}};
case xla::F8E5M2:
return absl::UnimplementedError("F8E5M2 not implemented");
case xla::F8E4M3FN:
return absl::UnimplementedError("F8E4M3FN not implemented");
case xla::F8E4M3B11FNUZ:
return absl::UnimplementedError("F8E4M3B11FNUZ not implemented");
case xla::F8E5M2FNUZ:
return absl::UnimplementedError("F8E5M2FNUZ not implemented");
case xla::F8E4M3FNUZ:
return absl::UnimplementedError("F8E4M3FNUZ not implemented");
case xla::C64:
return {{ArrayLiteralToTensor<std::complex<float>>(literal)}};
case xla::C128:
Expand Down

0 comments on commit 6ed0d07

Please sign in to comment.