From b15e3481faa1c5e19608368d67625d491f431a42 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 3 Aug 2023 07:54:59 -0700 Subject: [PATCH] Fix handling of unsigned index operands for Scatter. PiperOrigin-RevId: 553476353 --- .../xla/service/gpu/ir_emitter_unnested.cc | 6 ++-- tensorflow/compiler/xla/tests/scatter_test.cc | 32 +++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 8fbf0c5b87acca..aaa82dada374ac 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2701,9 +2701,9 @@ Status IrEmitterUnnested::EmitScatter( desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_))); // And add the index to our window index. This yields the output index. - llvm::Value* casted_scatter_index = - IntCast(loaded_scatter_index, index.GetType(), - /*isSigned=*/true); + llvm::Value* casted_scatter_index = IntCast( + loaded_scatter_index, index.GetType(), + /*isSigned=*/ShapeUtil::ElementIsSigned(desc.scatter_indices_shape)); llvm::Value* dim_offset = Add(input_window_multidim[operand_dim], casted_scatter_index); input_window_multidim[operand_dim] = dim_offset; diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 4af119cbb1a9bc..49f9018fc754d3 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -639,6 +639,38 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, U8Index) { + const std::string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[129,3]{1,0} parameter(0) + indices = u8[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[129,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateRandomLiteral(ShapeUtil::MakeShape(S32, {129, 3}), + /*mean=*/500, /*stddev=*/100) + .value(); + Literal scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {0x80, 1}, {1, 2}}); + Literal updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + XLA_TEST_F(ScatterTest, NegativeIndex) { const std::string hlo_text = R"( HloModule BatchDynamicSlice