diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_compile_options.proto b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_compile_options.proto index e005e1f4d083ac..ce13c3716f49c4 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_compile_options.proto +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_compile_options.proto @@ -23,4 +23,5 @@ import "xla/pjrt/compile_options.proto"; message IfrtIrCompileOptionsProto { repeated int32 device_ids = 1; map compile_option_overrides = 2; + bool propagate_shardings = 3; } diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.cc index babc7ad137dcd2..038b34a7136c3b 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.cc +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.cc @@ -69,7 +69,7 @@ IfrtIRCompileOptions::FromProto(const IfrtIrCompileOptionsProto& proto) { return std::make_unique( std::move(device_ids), absl::flat_hash_map>(), - std::move(compile_options_overrides)); + std::move(compile_options_overrides), proto.propagate_shardings()); } absl::StatusOr IfrtIRCompileOptions::ToProto() @@ -94,6 +94,8 @@ absl::StatusOr IfrtIRCompileOptions::ToProto() proto.mutable_compile_option_overrides()->insert( {id, compile_options_proto}); } + + proto.set_propagate_shardings(propagate_shardings); return proto; } diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h index d29daf21d52a52..d4f0da2aac58b6 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h @@ -65,10 +65,12 @@ struct IfrtIRCompileOptions loaded_exec_binding = {}, std::shared_ptr>> - compile_options_overrides = {}) + compile_options_overrides = {}, + bool propagate_shardings = false) : device_assignments(std::move(device_assignments)), loaded_exec_binding(std::move(loaded_exec_binding)), - compile_options_overrides(std::move(compile_options_overrides)) {} + compile_options_overrides(std::move(compile_options_overrides)), + propagate_shardings(propagate_shardings) {} // Mapping from logical device ids in IFRT IR MLIR module to runtime device // ids obtained from IFRT client. @@ -87,6 +89,10 @@ struct IfrtIRCompileOptions std::string, std::unique_ptr>> compile_options_overrides; + // Whether to propagate shardings from atom program executables for + // unspecified shardings. + bool propagate_shardings; + // Constructs `IfrtIRCompileOptions` from `IfrtIrCompileOptionsProto`. static absl::StatusOr> FromProto( const IfrtIrCompileOptionsProto& proto);