Skip to content

Commit

Permalink
Interprets negative values for solver_timeout_in_seconds as disabli…
Browse files Browse the repository at this point in the history
…ng `solve_nd_sharding_iteratively`.

PiperOrigin-RevId: 692002226
  • Loading branch information
tensorflower-gardener committed Nov 1, 2024
1 parent 6ac2f0e commit 4caa2d2
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4124,6 +4124,12 @@ absl::StatusOr<bool> AutoSharding::Run(
}
}

// A negative solver timeout means we want to disable iterative ND sharding.
if (option_.solver_timeout_in_seconds < 0) {
option_.solve_nd_sharding_iteratively = false;
option_.solver_timeout_in_seconds *= -1;
}

bool module_is_changed = false;
VLOG(1) << "Original mesh shape "
<< spmd::ToString(option_.device_mesh_shape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,43 @@ ENTRY %entry {
op::Sharding("{devices=[2,2]<=[4]}")));
}

TEST_F(AutoShardingTest, NegativeTimeoutDisablesNDIterativeSolve) {
constexpr absl::string_view kHloString = R"(
HloModule module
ENTRY %entry {
%param0 = f32[8192,23]{1,0} parameter(0), sharding={devices=[4,1]0,1,2,3}
%param1 = f32[23,23]{1,0} parameter(1)
%dot = f32[8192,23]{1,0} dot(%param0, %param1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
ROOT %copy = f32[8192,23]{1,0} copy(%dot)
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(kHloString));
AutoShardingOption option;
option.enable = true;
option.device_mesh_shape = {2, 2};
option.device_mesh_ids = {0, 1, 2, 3};
option.device_mesh_alpha = {1.0, 1.0};
option.device_mesh_beta = {0.01, 1.0};
option.solve_nd_sharding_iteratively = true;
option.solver_timeout_in_seconds = -300;
option.preserve_shardings =
AutoShardingOption::PreserveShardingsType::kKeepAllShardings;
TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
VLOG(2) << module->ToString();
EXPECT_TRUE(changed);
const HloInstruction* param0 = FindInstruction(module.get(), "param0");
const HloInstruction* param1 = FindInstruction(module.get(), "param1");
const HloInstruction* dot = FindInstruction(module.get(), "dot");
ASSERT_NE(param0, nullptr);
ASSERT_NE(param1, nullptr);
ASSERT_NE(dot, nullptr);
EXPECT_THAT(param0, op::Sharding("{devices=[4,1]0,1,2,3}"));
EXPECT_THAT(param1, op::Sharding("{replicated}"));
EXPECT_THAT(dot, AnyOf(op::Sharding("{devices=[4,1]0,1,2,3}"),
op::Sharding("{devices=[2,2]<=[4]}")));
}

TEST_F(AutoShardingTest, DotInsertReshardingReshapes) {
constexpr absl::string_view kHloString = R"(
HloModule module
Expand Down

0 comments on commit 4caa2d2

Please sign in to comment.