From 8dc568e66c82d68ab9cd610895e27e522e2f2b6b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Sep 2024 19:00:07 -0700 Subject: [PATCH] When checking for the shapes, we should take care of the dynamic shapes. PiperOrigin-RevId: 678026956 --- .../tests/tpu_sharding_identification.mlir | 44 +++++++++++++++++++ .../tpu_sharding_identification_pass.cc | 27 ++++++++++-- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index abb68b92146ba7..1aa574d12bf1ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -957,3 +957,47 @@ func.func @func(%arg0: tensor) -> tensor<4xf32> { func.return %1 : tensor<4xf32> } + +// ----- +// CHECK-LABEL: func @check_AddV2_variant_shape_with_input_sharding_propagation +func.func @check_AddV2_variant_shape_with_input_sharding_propagation(%arg0: tensor, %arg1: tensor<12x384xbf16>) { + // CHECK: tf_device.cluster_func + // CHECK-SAME: input_sharding_configuration = ["sharding_info_1", "sharding_info_1"] + // CHECK-SAME: output_sharding_configuration = ["sharding_info_1"] + "tf_device.cluster_func"(%arg0, %arg1) { + func = @func, + use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64 + } : (tensor, tensor<12x384xbf16>) -> tensor + func.return +} + +// CHECK-LABEL: func @func +// CHECK: {{.*}}mhlo.sharding = "sharding_info_1"{{.*}}mhlo.sharding = "sharding_info_1"{{.*}}->{{.*}}mhlo.sharding = "sharding_info_1" +func.func @func(%arg0: tensor, %arg1: tensor<12x384xbf16>) -> tensor { + %add = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor<12x384xbf16>) -> tensor + %0 = "tf.XlaSharding"(%add) { _XlaSharding = "sharding_info_1"} : (tensor) -> tensor + func.return %0 : tensor +} + + + +// ----- +// CHECK-LABEL: func @check_BatchMatMul_variant_shape_without_input_sharding_propagation +func.func @check_BatchMatMul_variant_shape_without_input_sharding_propagation(%arg0: tensor, %arg1: tensor<256x384xbf16>) { + // CHECK: tf_device.cluster_func + // CHECK-SAME: input_sharding_configuration = ["", ""] + // CHECK-SAME: output_sharding_configuration = ["sharding_info_1"] + "tf_device.cluster_func"(%arg0, %arg1) { + func = @func, + use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64 + } : (tensor, tensor<256x384xbf16>) -> tensor + func.return +} + +// CHECK-LABEL: func @func +// CHECK: {{.*}}mhlo.sharding = ""{{.*}}mhlo.sharding = ""{{.*}}->{{.*}}mhlo.sharding = "sharding_info_1" +func.func @func(%arg0: tensor, %arg1: tensor<256x384xbf16>) -> tensor { + %mul = "tf.BatchMatMul"(%arg0, %arg1) : (tensor, tensor<256x384xbf16>) -> tensor + %0 = "tf.XlaSharding"(%mul) { _XlaSharding = "sharding_info_1"} : (tensor) -> tensor + func.return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc index 0d2475c5be5433..2d3bb7a5a3fc3b 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc @@ -129,13 +129,33 @@ bool BinaryOpHasTraitsForSharding(Operation* op) { return false; } -bool DoTypesHaveSameShape(Value value_0, Value value_1) { +bool DoTypesHavePartialSameShape(Value value_0, Value value_1) { auto shape_0 = mlir::dyn_cast_or_null(value_0.getType()); auto shape_1 = mlir::dyn_cast_or_null(value_1.getType()); if (shape_0 && shape_1) { - return shape_0.getShape() == shape_1.getShape(); + if (shape_0.hasStaticShape() && shape_1.hasStaticShape()) + return shape_0.getShape() == shape_1.getShape(); + int i = 0, j = 0; + while (i < shape_0.getShape().size() && j < shape_1.getShape().size()) { + if (shape_0.getShape()[i] != shape_1.getShape()[j] && + !shape_0.isDynamicDim(i) && !shape_1.isDynamicDim(j)) { + return false; + } + if (shape_0.getShape()[i] == shape_1.getShape()[j]) { + i++; + j++; + } else { + if (shape_0.isDynamicDim(i)) { + i++; + } + if (shape_1.isDynamicDim(j)) { + j++; + } + } + } + return i == shape_0.getShape().size() && j == shape_1.getShape().size(); } return false; } @@ -337,7 +357,8 @@ std::optional GetXlaShardingFromArg( } if (BinaryOpHasTraitsForSharding(owner)) { - if (DoTypesHaveSameShape(value_to_visit, owner->getResult(0))) { + if (DoTypesHavePartialSameShape(value_to_visit, + owner->getResult(0))) { next_values_to_visit.push_back(use.getOwner()->getResult(0)); continue; }