Skip to content

Commit

Permalink
[MIGraphX EP] Add support for MIGraphX Exhaustive tune flag (#46) (#2…
Browse files Browse the repository at this point in the history
…1599)

### Description
<!-- Describe your changes. -->
Set the exhaustive tune flag through the MIGraphX API and make this a
Session option in Onnxruntime

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Allow users to use MIGraphX Exhaustive tuning with Onnxruntime
inferences
This goers hand in hand with save/load after a model and been compiled
and tuning has found.

---------

Co-authored-by: Ted Themistokleous <tedthemistokleous@amd.com>
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
  • Loading branch information
3 people authored Aug 21, 2024
1 parent 26a4993 commit 0e827c2
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 2 deletions.
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ typedef struct OrtMIGraphXProviderOptions {
const char* migraphx_save_model_path; // migraphx model path name
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
const char* migraphx_load_model_path; // migraphx model path name
bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false
} OrtMIGraphXProviderOptions;

/** \brief OpenVINO Provider Options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true);
}

// Allow for exhaustive tune during compile
const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune);
if (!exhaustive_tune_env.empty()) {
exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true);
}

metadef_id_generator_ = ModelMetadefIdGenerator::Create();

LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: "
Expand All @@ -190,6 +196,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
<< ", migraphx_int8_enable: " << int8_enable_
<< ", migraphx_int8_enable: " << int8_enable_
<< ", dump_model_ops: " << dump_model_ops_
<< ", exhaustive_tune: " << exhaustive_tune_
<< ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_
<< ", int8_calibration_cache_available: " << int8_calibration_cache_available_
<< ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_
Expand Down Expand Up @@ -1181,6 +1188,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&

migraphx::compile_options co;
co.set_fast_math(false);
co.set_exhaustive_tune_flag(exhaustive_tune_);
LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl;
prog.compile(t_, co);
LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl;
Expand Down Expand Up @@ -1345,6 +1353,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl;
migraphx::compile_options co;
co.set_fast_math(false);
co.set_exhaustive_tune_flag(exhaustive_tune_);
prog.compile(t, co);

save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";
static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE";

}; // namespace migraphx_env_vars

Expand All @@ -50,6 +51,7 @@ struct MIGraphXFuncState {
bool load_compiled_mode = false;
std::string load_compiled_path;
bool dump_model_ops = false;
bool exhaustive_tune = false;
};

// Logical device representation.
Expand Down Expand Up @@ -101,6 +103,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
migraphx::target t_;
OrtMutex mgx_mu_;
hipStream_t stream_ = nullptr;
bool exhaustive_tune_ = false;
mutable std::filesystem::path model_path_;

std::unordered_map<std::string, migraphx::program> map_progs_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ constexpr const char* kSaveCompiledModel = "migx_save_compiled_model";
constexpr const char* kSaveModelPath = "migx_save_model_name";
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
constexpr const char* kLoadModelPath = "migx_load_model_name";
constexpr const char* kExhaustiveTune = "migx_exhaustive_tune";

} // namespace provider_option_names
} // namespace migraphx
Expand All @@ -45,6 +46,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune)
.Parse(options));

return info;
Expand All @@ -57,6 +59,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)},
};
return options;
}
Expand All @@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)},
};
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct MIGraphXExecutionProviderInfo {
std::string save_model_file{"./compiled_model.mxr"};
bool load_compiled_model{true};
std::string load_model_file{"./compiled_model.mxr"};
bool exhaustive_tune{false};

static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider {
info.device_id = static_cast<OrtDevice::DeviceId>(options.device_id);
info.target_device = "gpu";
info.fp16_enable = options.migraphx_fp16_enable;
info.exhaustive_tune = options.migraphx_exhaustive_tune;
info.int8_enable = options.migraphx_int8_enable;
info.int8_calibration_table_name = "";
if (options.migraphx_int8_calibration_table_name != nullptr) {
Expand All @@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider {
migx_options.device_id = internal_options.device_id;
migx_options.migraphx_fp16_enable = internal_options.fp16_enable;
migx_options.migraphx_int8_enable = internal_options.int8_enable;
migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune;

char* dest = nullptr;
auto str_size = internal_options.int8_calibration_table_name.size();
Expand Down
13 changes: 12 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,8 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
1,
"./compiled_model.mxr",
1,
"./compiled_model.mxr"};
"./compiled_model.mxr",
1};
for (auto option : it->second) {
if (option.first == "device_id") {
if (!option.second.empty()) {
Expand Down Expand Up @@ -929,6 +930,16 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_exhaustive_tune") {
if (option.second == "True" || option.second == "true") {
params.migraphx_exhaustive_tune = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_exhaustive_tune = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else {
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/test/util/default_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
1,
"./compiled_model.mxr",
1,
"./compiled_model.mxr"};
"./compiled_model.mxr",
1};
return MIGraphXProviderFactoryCreator::Create(&params)->CreateProvider();
#else
return nullptr;
Expand Down

0 comments on commit 0e827c2

Please sign in to comment.