Skip to content

Commit

Permalink
[IFRT] Add sharding propagation option to IFRT IR compile options.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 692002394
  • Loading branch information
ICGog authored and tensorflower-gardener committed Nov 1, 2024
1 parent 4caa2d2 commit 26b62a0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ import "xla/pjrt/compile_options.proto";
message IfrtIrCompileOptionsProto {
repeated int32 device_ids = 1;
map<string, xla.CompileOptionsProto> compile_option_overrides = 2;
bool propagate_shardings = 3;
}
4 changes: 3 additions & 1 deletion third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ IfrtIRCompileOptions::FromProto(const IfrtIrCompileOptionsProto& proto) {
return std::make_unique<IfrtIRCompileOptions>(
std::move(device_ids),
absl::flat_hash_map<std::string, std::shared_ptr<LoadedExecutable>>(),
std::move(compile_options_overrides));
std::move(compile_options_overrides), proto.propagate_shardings());
}

absl::StatusOr<IfrtIrCompileOptionsProto> IfrtIRCompileOptions::ToProto()
Expand All @@ -94,6 +94,8 @@ absl::StatusOr<IfrtIrCompileOptionsProto> IfrtIRCompileOptions::ToProto()
proto.mutable_compile_option_overrides()->insert(
{id, compile_options_proto});
}

proto.set_propagate_shardings(propagate_shardings);
return proto;
}

Expand Down
10 changes: 8 additions & 2 deletions third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ struct IfrtIRCompileOptions
loaded_exec_binding = {},
std::shared_ptr<absl::flat_hash_map<
std::string, std::unique_ptr<xla::ifrt::CompileOptions>>>
compile_options_overrides = {})
compile_options_overrides = {},
bool propagate_shardings = false)
: device_assignments(std::move(device_assignments)),
loaded_exec_binding(std::move(loaded_exec_binding)),
compile_options_overrides(std::move(compile_options_overrides)) {}
compile_options_overrides(std::move(compile_options_overrides)),
propagate_shardings(propagate_shardings) {}

// Mapping from logical device ids in IFRT IR MLIR module to runtime device
// ids obtained from IFRT client.
Expand All @@ -87,6 +89,10 @@ struct IfrtIRCompileOptions
std::string, std::unique_ptr<xla::ifrt::CompileOptions>>>
compile_options_overrides;

// Whether to propagate shardings from atom program executables for
// unspecified shardings.
bool propagate_shardings;

// Constructs `IfrtIRCompileOptions` from `IfrtIrCompileOptionsProto`.
static absl::StatusOr<std::unique_ptr<IfrtIRCompileOptions>> FromProto(
const IfrtIrCompileOptionsProto& proto);
Expand Down

0 comments on commit 26b62a0

Please sign in to comment.