From c20488ced70488c9e95b6c11fdea309efe2fdc99 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 29 Nov 2023 18:27:04 -0800 Subject: [PATCH] skip_infer for SkipGroupNorm in SymbolicShapeInference (#18630) ### Description https://github.com/microsoft/onnxruntime/pull/18273 added `SkipGroupNorm` contrib op but it did not skip onnx shape inference for this op in `SymbolicShapeInference`. This leads to failed shape inference of the transformers optimized model with `enable_skip_group_norm=True`. Also results in an invalid float16 model for the SD CUDA example. This PR adds `SkipGroupNorm` to `skip_infer` so that it skips onnx shape inference for this op and instead uses the relevant dispatcher. ### Motivation and Context Fix shape inference failure for models with `SkipGroupNorm` nodes. --- onnxruntime/python/tools/symbolic_shape_infer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index a9cbef98d916..e90eea553c18 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -467,6 +467,7 @@ def _onnx_infer_single_node(self, node): "PythonOp", "MultiHeadAttention", "GroupNorm", + "SkipGroupNorm", "BiasSplitGelu", "BiasAdd", "NhwcConv",