diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 21f49fc8498e63..a0510f6dff8246 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -4124,6 +4124,12 @@ absl::StatusOr AutoSharding::Run( } } + // A negative solver timeout means we want to disable iterative ND sharding. + if (option_.solver_timeout_in_seconds < 0) { + option_.solve_nd_sharding_iteratively = false; + option_.solver_timeout_in_seconds *= -1; + } + bool module_is_changed = false; VLOG(1) << "Original mesh shape " << spmd::ToString(option_.device_mesh_shape); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index ad18259d411fa5..d0a121ea6276bf 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -1214,6 +1214,43 @@ ENTRY %entry { op::Sharding("{devices=[2,2]<=[4]}"))); } +TEST_F(AutoShardingTest, NegativeTimeoutDisablesNDIterativeSolve) { + constexpr absl::string_view kHloString = R"( +HloModule module +ENTRY %entry { + %param0 = f32[8192,23]{1,0} parameter(0), sharding={devices=[4,1]0,1,2,3} + %param1 = f32[23,23]{1,0} parameter(1) + %dot = f32[8192,23]{1,0} dot(%param0, %param1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + ROOT %copy = f32[8192,23]{1,0} copy(%dot) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.enable = true; + option.device_mesh_shape = {2, 2}; + option.device_mesh_ids = {0, 1, 2, 3}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + option.solve_nd_sharding_iteratively = true; + option.solver_timeout_in_seconds = -300; + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings; + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + VLOG(2) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* param0 = FindInstruction(module.get(), "param0"); + const HloInstruction* param1 = FindInstruction(module.get(), "param1"); + const HloInstruction* dot = FindInstruction(module.get(), "dot"); + ASSERT_NE(param0, nullptr); + ASSERT_NE(param1, nullptr); + ASSERT_NE(dot, nullptr); + EXPECT_THAT(param0, op::Sharding("{devices=[4,1]0,1,2,3}")); + EXPECT_THAT(param1, op::Sharding("{replicated}")); + EXPECT_THAT(dot, AnyOf(op::Sharding("{devices=[4,1]0,1,2,3}"), + op::Sharding("{devices=[2,2]<=[4]}"))); +} + TEST_F(AutoShardingTest, DotInsertReshardingReshapes) { constexpr absl::string_view kHloString = R"( HloModule module