From 7c4aa31f7ae14c94dd85c5cc2eeee794c6a7ab76 Mon Sep 17 00:00:00 2001 From: Lorenzo Mammana Date: Thu, 30 May 2024 11:17:25 +0200 Subject: [PATCH] Fix: Onnx model export not always working properly in half precision (#119) * build: Add onnxconverter-common as dependency * feat: Allow exporting mixed precision onnx models if a broken model is detected * build: Upgrade version * docs: Update changelog --- CHANGELOG.md | 7 +++ poetry.lock | 80 ++++++++++++++++++------------- pyproject.toml | 5 +- quadra/__init__.py | 2 +- quadra/utils/export.py | 104 ++++++++++++++++++++++++++++++++++++++--- 5 files changed, 156 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9df57f20..e7014a36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,13 @@ # Changelog All notable changes to this project will be documented in this file. +### [2.1.8] + +#### Added + +- Add onnxconverter-common to the dependencies in order to allow exporting onnx models in mixed precision if issues +are encountered exporting the model entirely in half precision. + ### [2.1.7] #### Fixed diff --git a/poetry.lock b/poetry.lock index 38d27c2e..a519ba4d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2628,7 +2628,7 @@ files = [ cssselect = ["cssselect (>=0.7)"] html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] -source = ["Cython (>=3.0.7)"] +source = ["Cython (>=3.0.8)"] [[package]] name = "mako" @@ -3685,6 +3685,23 @@ protobuf = ">=3.20.2" [package.extras] reference = ["Pillow", "google-re2"] +[[package]] +name = "onnxconverter-common" +version = "1.14.0" +description = "ONNX Converter and Optimization Tools" +optional = true +python-versions = ">=3.8" +files = [ + {file = "onnxconverter-common-1.14.0.tar.gz", hash = "sha256:6e431429bd15325c5b2c3eab61bed0d5634c23ed58f8823961be448d629d014a"}, + {file = "onnxconverter_common-1.14.0-py2.py3-none-any.whl", hash = "sha256:9723e4a9b47f283e298605dce9f357d5ebd5e5e70172fca26e282a1b490916c4"}, +] + +[package.dependencies] +numpy = "*" +onnx = "*" +packaging = "*" +protobuf = "3.20.2" + [[package]] name = "onnxruntime-gpu" version = "1.17.0" @@ -4207,33 +4224,33 @@ test = ["coveralls", "futures", "mock", "pytest (>=2.7.3)", "pytest-benchmark", [[package]] name = "protobuf" -version = "3.20.3" +version = "3.20.2" description = "Protocol Buffers" optional = false python-versions = ">=3.7" files = [ - {file = "protobuf-3.20.3-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f4bd856d702e5b0d96a00ec6b307b0f51c1982c2bf9c0052cf9019e9a544ba99"}, - {file = "protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9aae4406ea63d825636cc11ffb34ad3379335803216ee3a856787bcf5ccc751e"}, - {file = "protobuf-3.20.3-cp310-cp310-win32.whl", hash = "sha256:28545383d61f55b57cf4df63eebd9827754fd2dc25f80c5253f9184235db242c"}, - {file = "protobuf-3.20.3-cp310-cp310-win_amd64.whl", hash = "sha256:67a3598f0a2dcbc58d02dd1928544e7d88f764b47d4a286202913f0b2801c2e7"}, - {file = "protobuf-3.20.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:899dc660cd599d7352d6f10d83c95df430a38b410c1b66b407a6b29265d66469"}, - {file = "protobuf-3.20.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e64857f395505ebf3d2569935506ae0dfc4a15cb80dc25261176c784662cdcc4"}, - {file = "protobuf-3.20.3-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:d9e4432ff660d67d775c66ac42a67cf2453c27cb4d738fc22cb53b5d84c135d4"}, - {file = "protobuf-3.20.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:74480f79a023f90dc6e18febbf7b8bac7508420f2006fabd512013c0c238f454"}, - {file = "protobuf-3.20.3-cp37-cp37m-win32.whl", hash = "sha256:b6cc7ba72a8850621bfec987cb72623e703b7fe2b9127a161ce61e61558ad905"}, - {file = "protobuf-3.20.3-cp37-cp37m-win_amd64.whl", hash = "sha256:8c0c984a1b8fef4086329ff8dd19ac77576b384079247c770f29cc8ce3afa06c"}, - {file = "protobuf-3.20.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:de78575669dddf6099a8a0f46a27e82a1783c557ccc38ee620ed8cc96d3be7d7"}, - {file = "protobuf-3.20.3-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:f4c42102bc82a51108e449cbb32b19b180022941c727bac0cfd50170341f16ee"}, - {file = "protobuf-3.20.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:44246bab5dd4b7fbd3c0c80b6f16686808fab0e4aca819ade6e8d294a29c7050"}, - {file = "protobuf-3.20.3-cp38-cp38-win32.whl", hash = "sha256:c02ce36ec760252242a33967d51c289fd0e1c0e6e5cc9397e2279177716add86"}, - {file = "protobuf-3.20.3-cp38-cp38-win_amd64.whl", hash = "sha256:447d43819997825d4e71bf5769d869b968ce96848b6479397e29fc24c4a5dfe9"}, - {file = "protobuf-3.20.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:398a9e0c3eaceb34ec1aee71894ca3299605fa8e761544934378bbc6c97de23b"}, - {file = "protobuf-3.20.3-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:bf01b5720be110540be4286e791db73f84a2b721072a3711efff6c324cdf074b"}, - {file = "protobuf-3.20.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:daa564862dd0d39c00f8086f88700fdbe8bc717e993a21e90711acfed02f2402"}, - {file = "protobuf-3.20.3-cp39-cp39-win32.whl", hash = "sha256:819559cafa1a373b7096a482b504ae8a857c89593cf3a25af743ac9ecbd23480"}, - {file = "protobuf-3.20.3-cp39-cp39-win_amd64.whl", hash = "sha256:03038ac1cfbc41aa21f6afcbcd357281d7521b4157926f30ebecc8d4ea59dcb7"}, - {file = "protobuf-3.20.3-py2.py3-none-any.whl", hash = "sha256:a7ca6d488aa8ff7f329d4c545b2dbad8ac31464f1d8b1c87ad1346717731e4db"}, - {file = "protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2"}, + {file = "protobuf-3.20.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09e25909c4297d71d97612f04f41cea8fa8510096864f2835ad2f3b3df5a5559"}, + {file = "protobuf-3.20.2-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e8fbc522303e09036c752a0afcc5c0603e917222d8bedc02813fd73b4b4ed804"}, + {file = "protobuf-3.20.2-cp310-cp310-win32.whl", hash = "sha256:84a1544252a933ef07bb0b5ef13afe7c36232a774affa673fc3636f7cee1db6c"}, + {file = "protobuf-3.20.2-cp310-cp310-win_amd64.whl", hash = "sha256:2c0b040d0b5d5d207936ca2d02f00f765906622c07d3fa19c23a16a8ca71873f"}, + {file = "protobuf-3.20.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:3cb608e5a0eb61b8e00fe641d9f0282cd0eedb603be372f91f163cbfbca0ded0"}, + {file = "protobuf-3.20.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:84fe5953b18a383fd4495d375fe16e1e55e0a3afe7b4f7b4d01a3a0649fcda9d"}, + {file = "protobuf-3.20.2-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:384164994727f274cc34b8abd41a9e7e0562801361ee77437099ff6dfedd024b"}, + {file = "protobuf-3.20.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e39cf61bb8582bda88cdfebc0db163b774e7e03364bbf9ce1ead13863e81e359"}, + {file = "protobuf-3.20.2-cp37-cp37m-win32.whl", hash = "sha256:18e34a10ae10d458b027d7638a599c964b030c1739ebd035a1dfc0e22baa3bfe"}, + {file = "protobuf-3.20.2-cp37-cp37m-win_amd64.whl", hash = "sha256:8228e56a865c27163d5d1d1771d94b98194aa6917bcfb6ce139cbfa8e3c27334"}, + {file = "protobuf-3.20.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:03d76b7bd42ac4a6e109742a4edf81ffe26ffd87c5993126d894fe48a120396a"}, + {file = "protobuf-3.20.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:f52dabc96ca99ebd2169dadbe018824ebda08a795c7684a0b7d203a290f3adb0"}, + {file = "protobuf-3.20.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:f34464ab1207114e73bba0794d1257c150a2b89b7a9faf504e00af7c9fd58978"}, + {file = "protobuf-3.20.2-cp38-cp38-win32.whl", hash = "sha256:5d9402bf27d11e37801d1743eada54372f986a372ec9679673bfcc5c60441151"}, + {file = "protobuf-3.20.2-cp38-cp38-win_amd64.whl", hash = "sha256:9c673c8bfdf52f903081816b9e0e612186684f4eb4c17eeb729133022d6032e3"}, + {file = "protobuf-3.20.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:291fb4307094bf5ccc29f424b42268640e00d5240bf0d9b86bf3079f7576474d"}, + {file = "protobuf-3.20.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b4fdb29c5a7406e3f7ef176b2a7079baa68b5b854f364c21abe327bbeec01cdb"}, + {file = "protobuf-3.20.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7a5037af4e76c975b88c3becdf53922b5ffa3f2cddf657574a4920a3b33b80f3"}, + {file = "protobuf-3.20.2-cp39-cp39-win32.whl", hash = "sha256:a9e5ae5a8e8985c67e8944c23035a0dff2c26b0f5070b2f55b217a1c33bbe8b1"}, + {file = "protobuf-3.20.2-cp39-cp39-win_amd64.whl", hash = "sha256:c184485e0dfba4dfd451c3bd348c2e685d6523543a0f91b9fd4ae90eb09e8422"}, + {file = "protobuf-3.20.2-py2.py3-none-any.whl", hash = "sha256:c9cdf251c582c16fd6a9f5e95836c90828d51b0069ad22f463761d27c6c19019"}, + {file = "protobuf-3.20.2.tar.gz", hash = "sha256:712dca319eee507a1e7df3591e639a2b112a2f4a62d40fe7832a16fd19151750"}, ] [[package]] @@ -4782,7 +4799,7 @@ files = [ [[package]] name = "pyyaml-env-tag" version = "0.1" -description = "A custom YAML tag for referencing environment variables in YAML files. " +description = "A custom YAML tag for referencing environment variables in YAML files." optional = false python-versions = ">=3.6" files = [ @@ -5276,7 +5293,6 @@ description = "Scikit-multilearn is a BSD-licensed library for multi-label class optional = false python-versions = "*" files = [ - {file = "scikit-multilearn-0.2.0.linux-x86_64.tar.gz", hash = "sha256:3179fed29b1492f6a69600696c23045b9f494d2b89d1796a8bdc43ccbb33712b"}, {file = "scikit_multilearn-0.2.0-py2-none-any.whl", hash = "sha256:0a389600a6797db6567f2f6ca1d0dca30bebfaaa73f75de62d7ae40f8f03d4fb"}, {file = "scikit_multilearn-0.2.0-py3-none-any.whl", hash = "sha256:068c652f22704a084ca252d05d21a655e7c9b248d0a4543847b74de5fca2b3f0"}, ] @@ -5420,7 +5436,7 @@ huey = ["huey (>=2)"] loguru = ["loguru (>=0.5)"] opentelemetry = ["opentelemetry-distro (>=0.35b0)"] opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] -pure-eval = ["asttokens", "executing", "pure_eval"] +pure-eval = ["asttokens", "executing", "pure-eval"] pymongo = ["pymongo (>=3.1)"] pyspark = ["pyspark (>=2.4.4)"] quart = ["blinker (>=1.1)", "quart (>=0.16.1)"] @@ -5713,7 +5729,7 @@ typing-extensions = ">=4.6.0" [package.extras] aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"] aioodbc = ["aioodbc", "greenlet (!=0.4.17)"] -aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing_extensions (!=3.10.0.1)"] +aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing-extensions (!=3.10.0.1)"] asyncio = ["greenlet (!=0.4.17)"] asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"] @@ -5723,7 +5739,7 @@ mssql-pyodbc = ["pyodbc"] mypy = ["mypy (>=0.910)"] mysql = ["mysqlclient (>=1.4.0)"] mysql-connector = ["mysql-connector-python"] -oracle = ["cx_oracle (>=8)"] +oracle = ["cx-oracle (>=8)"] oracle-oracledb = ["oracledb (>=1.0.1)"] postgresql = ["psycopg2 (>=2.7)"] postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"] @@ -5733,7 +5749,7 @@ postgresql-psycopg2binary = ["psycopg2-binary"] postgresql-psycopg2cffi = ["psycopg2cffi"] postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] -sqlcipher = ["sqlcipher3_binary"] +sqlcipher = ["sqlcipher3-binary"] [[package]] name = "sqlparse" @@ -6916,9 +6932,9 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] [extras] -onnx = ["onnx", "onnxruntime_gpu", "onnxsim"] +onnx = ["onnx", "onnxconverter-common", "onnxruntime_gpu", "onnxsim"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.11" -content-hash = "b6f01476ed98e11c0c68000b29266f86032e4893b8221914d4e8b4b55c2dcadb" +content-hash = "68361649693152e3d71b1b85bb326dfa4984c9b170c1bee846937588b023fd70" diff --git a/pyproject.toml b/pyproject.toml index 36b6cb43..c93758c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "quadra" -version = "2.1.7" +version = "2.1.8" description = "Deep Learning experiment orchestration library" authors = [ "Federico Belotti ", @@ -86,6 +86,7 @@ typing_extensions = { version = "4.11.0", python = "<3.10" } onnx = { version = "1.15.0", optional = true } onnxsim = { version = "0.4.28", optional = true } onnxruntime_gpu = { version = "1.17.0", optional = true, source = "onnx_cu12" } +onnxconverter-common = { version = "^1.14.0", optional = true } [[tool.poetry.source]] name = "torch_cu121" @@ -141,7 +142,7 @@ mike = "1.1.2" cairosvg = "2.7.0" [tool.poetry.extras] -onnx = ["onnx", "onnxsim", "onnxruntime_gpu"] +onnx = ["onnx", "onnxsim", "onnxruntime_gpu", "onnxconverter-common"] [tool.poetry_bumpversion.file."quadra/__init__.py"] search = '__version__ = "{current_version}"' diff --git a/quadra/__init__.py b/quadra/__init__.py index 9e951908..8d00c9d9 100644 --- a/quadra/__init__.py +++ b/quadra/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.1.7" +__version__ = "2.1.8" def get_version(): diff --git a/quadra/utils/export.py b/quadra/utils/export.py index 34807e55..0cf7f45f 100644 --- a/quadra/utils/export.py +++ b/quadra/utils/export.py @@ -7,6 +7,7 @@ import torch from anomalib.models.cflow import CflowLightning from omegaconf import DictConfig, ListConfig, OmegaConf +from onnxconverter_common import auto_convert_mixed_precision from torch import nn from quadra.models.base import ModelSignatureWrapper @@ -250,14 +251,14 @@ def export_onnx_model( for i, _ in enumerate(output_names): dynamic_axes[output_names[i]] = {0: "batch_size"} - onnx_config = cast(dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True)) + modified_onnx_config = cast(dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True)) - onnx_config["input_names"] = input_names - onnx_config["output_names"] = output_names - onnx_config["dynamic_axes"] = dynamic_axes + modified_onnx_config["input_names"] = input_names + modified_onnx_config["output_names"] = output_names + modified_onnx_config["dynamic_axes"] = dynamic_axes - simplify = onnx_config.pop("simplify", False) - _ = onnx_config.pop("fixed_batch_size", None) + simplify = modified_onnx_config.pop("simplify", False) + _ = modified_onnx_config.pop("fixed_batch_size", None) if len(inp) == 1: inp = inp[0] @@ -269,7 +270,7 @@ def export_onnx_model( raise ValueError("ONNX export does not support model with dict inputs") try: - torch.onnx.export(model=model, args=inp, f=model_path, **onnx_config) + torch.onnx.export(model=model, args=inp, f=model_path, **modified_onnx_config) onnx_model = onnx.load(model_path) # Check if ONNX model is valid @@ -280,6 +281,19 @@ def export_onnx_model( log.info("ONNX model saved to %s", os.path.join(os.getcwd(), model_path)) + if half_precision: + is_export_ok = _safe_export_half_precision_onnx( + model=model, + export_model_path=model_path, + inp=inp, + onnx_config=onnx_config, + input_shapes=input_shapes, + input_names=input_names, + ) + + if not is_export_ok: + return None + if simplify: log.info("Attempting to simplify ONNX model") onnx_model = onnx.load(model_path) @@ -302,6 +316,82 @@ def export_onnx_model( return os.path.join(os.getcwd(), model_path), input_shapes +def _safe_export_half_precision_onnx( + model: nn.Module, + export_model_path: str, + inp: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...], + onnx_config: DictConfig, + input_shapes: list[Any], + input_names: list[str], +): + """Check that the exported half precision ONNX model does not contain NaN values. If it does, attempt to export + the model with a more stable export and overwrite the original model. + + Args: + model: PyTorch model to be exported + export_model_path: Path to save the model + inp: Input tensors for the model + onnx_config: ONNX export configuration + input_shapes: Input shapes for the model + input_names: Input names for the model + + Returns: + True if the model is stable or it was possible to export a more stable model, False otherwise. + """ + test_fp_16_model: BaseEvaluationModel = import_deployment_model( + export_model_path, OmegaConf.create({"onnx": {}}), "cuda:0" + ) + if not isinstance(inp, Sequence): + inp = [inp] + + test_output = test_fp_16_model(*inp) + + if not isinstance(test_output, Sequence): + test_output = [test_output] + + # Check if there are nan values in any of the outputs + is_broken_model = any(torch.isnan(out).any() for out in test_output) + + if is_broken_model: + try: + log.warning( + "The exported half precision ONNX model contains NaN values, attempting with a more stable export..." + ) + # Cast back the fp16 model to fp32 to simulate the export with fp32 + model = model.float() + log.info("Starting to export model in full precision") + export_output = export_onnx_model( + model=model, + output_path=os.path.dirname(export_model_path), + onnx_config=onnx_config, + input_shapes=input_shapes, + half_precision=False, + model_name=os.path.basename(export_model_path), + ) + if export_output is not None: + export_model_path, _ = export_output + else: + log.warning("Failed to export model") + return False + + model_fp32 = onnx.load(export_model_path) + test_data = {input_names[i]: inp[i].float().cpu().numpy() for i in range(len(inp))} + log.warning("Attempting to convert model in mixed precision, this may take a while...") + model_fp16 = auto_convert_mixed_precision(model_fp32, test_data, rtol=0.01, atol=0.001, keep_io_types=False) + onnx.save(model_fp16, export_model_path) + + onnx_model = onnx.load(export_model_path) + # Check if ONNX model is valid + onnx.checker.check_model(onnx_model) + return True + except Exception as e: + log.debug("Failed to export model with mixed precision with error: %s", e) + return False + else: + log.info("Exported half precision ONNX model does not contain NaN values, model is stable") + return True + + def export_pytorch_model(model: nn.Module, output_path: str, model_name: str = "model.pth") -> str: """Export pytorch model's parameter dictionary using a deserialized state_dict.