Skip to content

Commit

Permalink
When checking for the shapes, we should take care of the dynamic shapes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678026956
  • Loading branch information
tensorflower-gardener committed Sep 24, 2024
1 parent e1b495d commit 8dc568e
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -957,3 +957,47 @@ func.func @func(%arg0: tensor<f32>) -> tensor<4xf32> {
func.return %1 : tensor<4xf32>
}


// -----
// CHECK-LABEL: func @check_AddV2_variant_shape_with_input_sharding_propagation
func.func @check_AddV2_variant_shape_with_input_sharding_propagation(%arg0: tensor<?x12x384xbf16>, %arg1: tensor<12x384xbf16>) {
// CHECK: tf_device.cluster_func
// CHECK-SAME: input_sharding_configuration = ["sharding_info_1", "sharding_info_1"]
// CHECK-SAME: output_sharding_configuration = ["sharding_info_1"]
"tf_device.cluster_func"(%arg0, %arg1) {
func = @func,
use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64
} : (tensor<?x12x384xbf16>, tensor<12x384xbf16>) -> tensor<?x12x384xbf16>
func.return
}

// CHECK-LABEL: func @func
// CHECK: {{.*}}mhlo.sharding = "sharding_info_1"{{.*}}mhlo.sharding = "sharding_info_1"{{.*}}->{{.*}}mhlo.sharding = "sharding_info_1"
func.func @func(%arg0: tensor<?x12x384xbf16>, %arg1: tensor<12x384xbf16>) -> tensor<?x12x384xbf16> {
%add = "tf.AddV2"(%arg0, %arg1) : (tensor<?x12x384xbf16>, tensor<12x384xbf16>) -> tensor<?x12x384xbf16>
%0 = "tf.XlaSharding"(%add) { _XlaSharding = "sharding_info_1"} : (tensor<?x12x384xbf16>) -> tensor<?x12x384xbf16>
func.return %0 : tensor<?x12x384xbf16>
}



// -----
// CHECK-LABEL: func @check_BatchMatMul_variant_shape_without_input_sharding_propagation
func.func @check_BatchMatMul_variant_shape_without_input_sharding_propagation(%arg0: tensor<?x12x256xbf16>, %arg1: tensor<256x384xbf16>) {
// CHECK: tf_device.cluster_func
// CHECK-SAME: input_sharding_configuration = ["", ""]
// CHECK-SAME: output_sharding_configuration = ["sharding_info_1"]
"tf_device.cluster_func"(%arg0, %arg1) {
func = @func,
use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64
} : (tensor<?x12x256xbf16>, tensor<256x384xbf16>) -> tensor<?x12x384xbf16>
func.return
}

// CHECK-LABEL: func @func
// CHECK: {{.*}}mhlo.sharding = ""{{.*}}mhlo.sharding = ""{{.*}}->{{.*}}mhlo.sharding = "sharding_info_1"
func.func @func(%arg0: tensor<?x12x256xbf16>, %arg1: tensor<256x384xbf16>) -> tensor<?x12x384xbf16> {
%mul = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<?x12x256xbf16>, tensor<256x384xbf16>) -> tensor<?x12x384xbf16>
%0 = "tf.XlaSharding"(%mul) { _XlaSharding = "sharding_info_1"} : (tensor<?x12x384xbf16>) -> tensor<?x12x384xbf16>
func.return %0 : tensor<?x12x384xbf16>
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,33 @@ bool BinaryOpHasTraitsForSharding(Operation* op) {
return false;
}

bool DoTypesHaveSameShape(Value value_0, Value value_1) {
bool DoTypesHavePartialSameShape(Value value_0, Value value_1) {
auto shape_0 =
mlir::dyn_cast_or_null<mlir::RankedTensorType>(value_0.getType());
auto shape_1 =
mlir::dyn_cast_or_null<mlir::RankedTensorType>(value_1.getType());
if (shape_0 && shape_1) {
return shape_0.getShape() == shape_1.getShape();
if (shape_0.hasStaticShape() && shape_1.hasStaticShape())
return shape_0.getShape() == shape_1.getShape();
int i = 0, j = 0;
while (i < shape_0.getShape().size() && j < shape_1.getShape().size()) {
if (shape_0.getShape()[i] != shape_1.getShape()[j] &&
!shape_0.isDynamicDim(i) && !shape_1.isDynamicDim(j)) {
return false;
}
if (shape_0.getShape()[i] == shape_1.getShape()[j]) {
i++;
j++;
} else {
if (shape_0.isDynamicDim(i)) {
i++;
}
if (shape_1.isDynamicDim(j)) {
j++;
}
}
}
return i == shape_0.getShape().size() && j == shape_1.getShape().size();
}
return false;
}
Expand Down Expand Up @@ -337,7 +357,8 @@ std::optional<llvm::StringRef> GetXlaShardingFromArg(
}

if (BinaryOpHasTraitsForSharding(owner)) {
if (DoTypesHaveSameShape(value_to_visit, owner->getResult(0))) {
if (DoTypesHavePartialSameShape(value_to_visit,
owner->getResult(0))) {
next_values_to_visit.push_back(use.getOwner()->getResult(0));
continue;
}
Expand Down

0 comments on commit 8dc568e

Please sign in to comment.