diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 549614de496a..c500af98edb3 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -53,6 +53,8 @@ def generate_artifacts( 3. Checkpoint (directory): Contains the model parameters. 4. Optimizer model (onnx.ModelProto): Model containing the optimizer graph. + All generated ModelProto are using the same opsets defined by *model*. + Args: model: The base model to be used for gradient graph generation. requires_grad: List of names of model parameters that require gradient computation @@ -207,11 +209,17 @@ def _export_to_ort_format(model_path, output_dir, extra_options): logging.info("Optimizer enum provided: %s", optimizer.name) + opset_version = None + for domain in model.opset_import: + if domain.domain == "": + opset_version = domain.version + break + optim_model = None optim_blocks = {OptimType.AdamW: onnxblock.optim.AdamW, OptimType.SGD: onnxblock.optim.SGD} optim_block = optim_blocks[optimizer]() - with onnxblock.empty_base(): + with onnxblock.empty_base(opset_version=opset_version): _ = optim_block(model_params) optim_model = optim_block.to_model_proto()