Skip to content

Commit

Permalink
[docs] Fix fp16 mixed precision example to reflect correct input vari…
Browse files Browse the repository at this point in the history
…able names (#22250)

### Description
Correct variable name from `test_data` to `feed_dict` to fix example
code in mixed precision example docs.



### Motivation and Context
Fixes #21822
  • Loading branch information
shubhambhokare1 authored Oct 1, 2024
1 parent 7faa4b1 commit d5908da
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion docs/performance/model-optimizations/float16.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ from onnxconverter_common import auto_mixed_precision
import onnx

model = onnx.load("path/to/model.onnx")
model_fp16 = auto_convert_mixed_precision(model, test_data, rtol=0.01, atol=0.001, keep_io_types=True)
# Assuming x is the input to the model
feed_dict = {'input': x.numpy()}
model_fp16 = auto_convert_mixed_precision(model, feed_dict, rtol=0.01, atol=0.001, keep_io_types=True)
onnx.save(model_fp16, "path/to/model_fp16.onnx")
```

Expand All @@ -73,6 +75,7 @@ auto_convert_mixed_precision(model, feed_dict, validate_fn=None, rtol=None, atol
```

- `model`: The ONNX model to convert.
- `feed_dict`: Test data used to measure the accuracy of the model during conversion. Format is similar to InferenceSession.run (map of input names to values)
- `validate_fn`: A function accepting two lists of numpy arrays (the outputs of the float32 model and the mixed-precision model, respectively) that returns `True` if the results are sufficiently close and `False` otherwise. Can be used instead of or in addition to `rtol` and `atol`.
- `rtol`, `atol`: Absolute and relative tolerances used for validation. See [numpy.allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html) for more information.
- `keep_io_types`: Whether model inputs/outputs should be left as float32.
Expand Down

0 comments on commit d5908da

Please sign in to comment.