Skip to content

Commit

Permalink
Fix handling of unsigned index operands for Scatter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553476353
  • Loading branch information
akuegel authored and tensorflower-gardener committed Aug 3, 2023
1 parent d624575 commit b15e348
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2701,9 +2701,9 @@ Status IrEmitterUnnested::EmitScatter(
desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_)));
// And add the index to our window index. This yields the output index.
llvm::Value* casted_scatter_index =
IntCast(loaded_scatter_index, index.GetType(),
/*isSigned=*/true);
llvm::Value* casted_scatter_index = IntCast(
loaded_scatter_index, index.GetType(),
/*isSigned=*/ShapeUtil::ElementIsSigned(desc.scatter_indices_shape));
llvm::Value* dim_offset =
Add(input_window_multidim[operand_dim], casted_scatter_index);
input_window_multidim[operand_dim] = dim_offset;
Expand Down
32 changes: 32 additions & 0 deletions tensorflow/compiler/xla/tests/scatter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,38 @@ ENTRY main {
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}

XLA_TEST_F(ScatterTest, U8Index) {
const std::string hlo_text = R"(
HloModule BatchDynamicSlice
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[129,3]{1,0} parameter(0)
indices = u8[6,2]{1,0} parameter(1)
updates = s32[6,1,1]{2,1,0} parameter(2)
ROOT scatter = s32[129,3]{1,0} scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1,2},
inserted_window_dims={},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=1
}
)";
Literal operand =
LiteralUtil::CreateRandomLiteral<S32>(ShapeUtil::MakeShape(S32, {129, 3}),
/*mean=*/500, /*stddev=*/100)
.value();
Literal scatter_indices = LiteralUtil::CreateR2<uint8_t>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {0x80, 1}, {1, 2}});
Literal updates = LiteralUtil::CreateR3<int32_t>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}

XLA_TEST_F(ScatterTest, NegativeIndex) {
const std::string hlo_text = R"(
HloModule BatchDynamicSlice
Expand Down

0 comments on commit b15e348

Please sign in to comment.