Skip to content

Commit

Permalink
Fix opset version of the optimizer in function generate_artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Nov 6, 2023
1 parent d652b1f commit fc458aa
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def generate_artifacts(
3. Checkpoint (directory): Contains the model parameters.
4. Optimizer model (onnx.ModelProto): Model containing the optimizer graph.
All generated ModelProto are using the same opsets defined by *model*.
Args:
model: The base model to be used for gradient graph generation.
requires_grad: List of names of model parameters that require gradient computation
Expand Down Expand Up @@ -207,11 +209,17 @@ def _export_to_ort_format(model_path, output_dir, extra_options):

logging.info("Optimizer enum provided: %s", optimizer.name)

opset_version = None
for domain in model.opset_import:
if domain.domain == "":
opset_version = domain.version
break

optim_model = None
optim_blocks = {OptimType.AdamW: onnxblock.optim.AdamW, OptimType.SGD: onnxblock.optim.SGD}

optim_block = optim_blocks[optimizer]()
with onnxblock.empty_base():
with onnxblock.empty_base(opset_version=opset_version):
_ = optim_block(model_params)
optim_model = optim_block.to_model_proto()

Expand Down

0 comments on commit fc458aa

Please sign in to comment.