Skip to content

Commit

Permalink
fix: FP16 clipping applied to INT types (quic#79)
Browse files Browse the repository at this point in the history
* fix: FP16 clipping applied to int types

- Create FP16Clip transform which applies correct logic or FP16 clipping
- Transforms now return flag indicating transformation
- FP16 clipping should be applied only for FLOAT types
- Added tests for FP16Clip transform

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Update comment

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Fix gather index while running in ORT

- Simplify test_causal_lm_models

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Re-enable asserts in test_causal_lm_models

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* No need of checking the onnx path

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Fix onnx path expected with suffix "_clipped_fp16"

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Pass external_data dir while clipping FP16

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Fix export API tests

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Fix compile API tests

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* compile API test catch no onnx file found case

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* fix: Use correct suffix for unclipped onnx

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Remove unneeded config.json

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* ONNXTransform take `onnx_base_dir` param

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Fix output type and typos

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Added KVCache & CustomOps transforms

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* Removed "transforms.py"

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>

* removed assert statement from infer.py as it's already there in tests

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

---------

Signed-off-by: Ilango Rajagopal <quic_irajagop@quicinc.com>
Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>
Co-authored-by: Onkar Chougule <quic_ochougul@quicinc.com>
  • Loading branch information
irajagop and ochougul authored Jul 31, 2024
1 parent bb7920f commit 0b6bbed
Show file tree
Hide file tree
Showing 15 changed files with 191 additions and 166 deletions.
28 changes: 25 additions & 3 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#
# ----------------------------------------------------------------------------

from onnx import ModelProto
from typing import Optional, Tuple

import numpy as np
from onnx import ModelProto, external_data_helper, numpy_helper


class OnnxTransform:
Expand All @@ -17,18 +20,37 @@ def __init__(self):
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.")

@classmethod
def apply(cls, model: ModelProto) -> ModelProto:
def apply(cls, model: ModelProto, onnx_base_dir: Optional[str] = None) -> Tuple[ModelProto, bool]:
"""
Override this class to apply a transformation.
:param model: The model's ONNX graph to transform
:param onnx_base_dir: Directory where the model and external files are present
:returns: ONNX graph after applying the transform
:returns: Boolean indicating whether transform was applied
"""
raise NotImplementedError("Use subclasses for ONNX transform")


class FP16Clip(OnnxTransform):
pass
"""
Clips the tensor values to be in FP16 range.
"""

@classmethod
def apply(cls, model: ModelProto, onnx_base_dir: Optional[str] = None) -> Tuple[ModelProto, bool]:
finfo = np.finfo(np.float16)
fp16_max = finfo.max
fp16_min = finfo.min
transformed = False
for tensor in external_data_helper._get_all_tensors(model):
nptensor = numpy_helper.to_array(tensor, onnx_base_dir)
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)):
nptensor = np.clip(nptensor, fp16_min, fp16_max)
new_tensor = numpy_helper.from_array(nptensor, tensor.name)
tensor.CopyFrom(new_tensor)
transformed = True
return model, transformed


class SplitWeights(OnnxTransform):
Expand Down
17 changes: 10 additions & 7 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# ----------------------------------------------------------------------------

from typing import Dict, Type
from typing import Dict, Tuple, Type

from torch import nn

Expand All @@ -19,12 +19,13 @@ def __init__(self):
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.")

@classmethod
def apply(cls, model: nn.Module) -> nn.Module:
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
"""
Override this class method to apply a transformation.
:param model: The torch module to transform, this module may be tranformed in-place
:param model: The torch module to transform, this module may be transformed in-place
:returns: Torch module after applying the tranform
:returns: Torch module after applying the transform
:returns: Boolean indicating whether transform was applied
"""
raise NotImplementedError("Use subclasses for Pytorch transform")

Expand All @@ -37,14 +38,16 @@ class ModuleMapping(PytorchTransform):
_module_mapping: Dict[Type[nn.Module], Type[nn.Module]]

@classmethod
def apply(cls, model: nn.Module) -> nn.Module:
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
for module in model.modules():
if repl_module := cls._module_mapping.get(type(module)):
module.__class__ = repl_module
return model
transformed = True
return model, transformed

@classmethod
def register(cls, from_module: type, to_module: type):
def register(cls, from_module: Type[nn.Module], to_module: Type[nn.Module]):
"""
Add a new module type in the module mapping for this transform. ::
FlashAttention.register(LLamaAttention, LlamaFlashAttention)
Expand Down
9 changes: 2 additions & 7 deletions QEfficient/cloud/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_onnx_model_path(
####################
# Export to the Onnx
logger.info(f"Exporting Pytorch {model_name} model to ONNX...")
_, generated_onnx_model_path = qualcomm_efficient_converter(
_, onnx_model_path = qualcomm_efficient_converter(
model_name=model_name,
local_model_dir=local_model_dir,
tokenizer=tokenizer,
Expand All @@ -55,12 +55,7 @@ def get_onnx_model_path(
hf_token=hf_token,
cache_dir=cache_dir,
) # type: ignore
logger.info(
f"Generated Onnx_path {generated_onnx_model_path} \nOnnx_model_path {onnx_model_path} \nand Onnx_dir_path is {onnx_dir_path}"
)
assert (
generated_onnx_model_path == onnx_model_path
), f"ONNX files were generated at an unusual location, expected {onnx_model_path}, got {generated_onnx_model_path}"
logger.info(f"Generated onnx_path: {onnx_model_path}, onnx_dir_path: {onnx_dir_path}")
return onnx_model_path


Expand Down
5 changes: 1 addition & 4 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def main(
#########
# Compile
#########
generated_qpc_path = QEfficient.compile(
_ = QEfficient.compile(
onnx_path=onnx_model_path,
qpc_path=os.path.dirname(
qpc_dir_path
Expand All @@ -96,9 +96,6 @@ def main(
mos=mos,
device_group=device_group,
)
assert (
generated_qpc_path == qpc_dir_path
), f"QPC files were generated at an unusual location, expected {qpc_dir_path}; got {generated_qpc_path}"

#########
# Execute
Expand Down
52 changes: 7 additions & 45 deletions QEfficient/exporter/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import onnx
import onnxruntime
import torch
from onnx import external_data_helper, numpy_helper
from onnx import external_data_helper

from QEfficient.base.onnx_transforms import FP16Clip


def export_onnx(
Expand Down Expand Up @@ -171,43 +173,10 @@ def fix_onnx_fp16(
model_base_name: str,
pt_outputs: Dict[str, torch.Tensor],
) -> str:
finfo = np.finfo(np.float16)
fp16_max = finfo.max
fp16_min = finfo.min
abs_max_val = np.finfo(np.float16).max

model = onnx.load(os.path.join(gen_models_path, f"{model_base_name}.onnx"))

fp16_fix = False
for tensor in external_data_helper._get_all_tensors(model):
nptensor = numpy_helper.to_array(tensor, gen_models_path)
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)):
nptensor = np.clip(nptensor, fp16_min, fp16_max)
new_tensor = numpy_helper.from_array(nptensor, tensor.name)
tensor.CopyFrom(new_tensor)
fp16_fix = True

for idx, node in enumerate(model.graph.node):
if node.op_type == "Constant":
curr_data = numpy_helper.to_array(node.attribute[0].t)
if np.any(np.abs(curr_data) > abs_max_val):
updated_data = np.clip(curr_data, -abs_max_val, abs_max_val).astype(curr_data.dtype)
updated_tensor = numpy_helper.from_array(updated_data)
node.attribute[0].t.CopyFrom(updated_tensor)
fp16_fix = True

for node in model.graph.initializer:
curr_data = numpy_helper.to_array(node)
if np.any(np.abs(curr_data) > abs_max_val):
updated_data = np.clip(curr_data, -abs_max_val, abs_max_val).astype(curr_data.dtype)
updated_tensor = numpy_helper.from_array(updated_data)
updated_tensor.name = node.name
node.CopyFrom(updated_tensor)
fp16_fix = True

# TODO: Check this, variable "fp16_same_as_fp32" is not being used.
# fp16_same_as_fp32 = True
ort_outputs_fixed = []
# TODO: Remove this `fix_onnx_fp16` function and replace with this transform
# as we're not utilizing the validations done in this function
model, fp16_fix = FP16Clip.apply(model, gen_models_path)

if fp16_fix:
# Save FP16 model
Expand Down Expand Up @@ -246,15 +215,8 @@ def fix_onnx_fp16(
# TODO: need to the debug this
# info(oname, fix_diff)
close_outputs.append(fix_diff < 1e-5)
else:
print(
f"Failed to run the onnx {model_base_name} model in onnx runtime:%s",
ort_outputs,
)
# Commenting the below line
# fp16_same_as_fp32 = all(close_outputs)
else:
info("No constants out of FP16 range .. ")
info("No constants out of FP16 range")

return model_base_name

Expand Down
15 changes: 11 additions & 4 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,19 @@ def onnx_exists(model_name: str) -> Tuple[bool, str, str]:
os.makedirs(model_card_dir, exist_ok=True)

onnx_dir_path = os.path.join(model_card_dir, "onnx")
onnx_model_path = os.path.join(onnx_dir_path, model_name.replace("/", "_") + "_kv_clipped_fp16.onnx")
clipped_onnx_model_path = os.path.join(onnx_dir_path, model_name.replace("/", "_") + "_kv_clipped_fp16.onnx")
unclipped_onnx_model_path = clipped_onnx_model_path.replace("_clipped_fp16.onnx", ".onnx")

# Compute the boolean indicating if the ONNX model exists
onnx_exists_bool = os.path.isfile(onnx_model_path) and os.path.isfile(
os.path.join(os.path.dirname(onnx_model_path), "custom_io_fp16.yaml")
)
onnx_exists_bool = False
onnx_model_path = None
if os.path.isfile(os.path.join(onnx_dir_path, "custom_io_fp16.yaml")):
if os.path.isfile(clipped_onnx_model_path):
onnx_exists_bool = True
onnx_model_path = clipped_onnx_model_path
elif os.path.isfile(unclipped_onnx_model_path):
onnx_exists_bool = True
onnx_model_path = unclipped_onnx_model_path

# Return the boolean, onnx_dir_path, and onnx_model_path
return onnx_exists_bool, onnx_dir_path, onnx_model_path
Expand Down
22 changes: 13 additions & 9 deletions QEfficient/utils/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from QEfficient.utils.generate_inputs import InputHandler


# TODO: Deprecate this class and encourage the use of `QeffAutoModel...` classes
class ApiRunner:
"""
ApiRunner class is responsible for running:
Expand Down Expand Up @@ -133,19 +134,22 @@ def run_kv_model_on_ort(self, model_path):
:return generated_ids: numpy.ndarray. Generated output tokens
"""

# todo:vbaddi; find a better version to do this changes
# Currently the gathernd invalid index is set to INT MAX(FP16) and hence fails in OnnxRuntime
# Changing the constant value from INT MAX to -1. Fixes the issue.
# Replace invalid index value for INT32 max to 0 using add_initializer
m = onnx.load(model_path, load_external_data=False)
# NOTE: OrtValue objects should be kept around until the session is run, hence this dict is required
added_initializers = {}
for node in m.graph.node:
if node.op_type == "Constant":
np_tensor = onnx.numpy_helper.to_array(node.attribute[0].t)
if len(np_tensor.shape) == 0 and np_tensor.item() == 65504:
node.attribute[0].t.raw_data = np.array(-1).tobytes()

onnxruntime_model = model_path[:-5] + "_ort.onnx"
onnx.save(m, onnxruntime_model)
session = onnxruntime.InferenceSession(onnxruntime_model)
if len(np_tensor.shape) == 0 and np_tensor.item() == 2147483647:
added_initializers[node.output[0]] = onnxruntime.OrtValue.ortvalue_from_numpy(
np.array(0, np_tensor.dtype)
)

session_options = onnxruntime.SessionOptions()
for name, value in added_initializers.items():
session_options.add_initializer(name, value)
session = onnxruntime.InferenceSession(model_path, session_options)

generated_ids = []
inputs = self.input_handler.prepare_ort_inputs()
Expand Down
82 changes: 82 additions & 0 deletions tests/base/test_onnx_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

import os
import shutil

import numpy as np
import onnx
import pytest

from QEfficient.base.onnx_transforms import FP16Clip


@pytest.fixture
def external_path():
external_dir = "tmp_external_data"
os.makedirs(external_dir, exist_ok=True)
yield external_dir
shutil.rmtree(external_dir)


def test_fp16clip_transform():
test_onnx = onnx.parser.parse_model("""
<
ir_version: 8,
opset_import: ["" : 17]
>
test_fp16clip (float [n, 32] x) => (float [n, 32] y)
<
float val1 = {65505.0},
int64[1] slice_ends = {2147483647},
float zero = {0.0}
>
{
mask = Greater(x, zero)
val2 = Constant<value = float {-1e7}>()
masked = Where(mask, val1, val2)
slice_starts = Constant<value = int64[1] {0}>()
y = Slice(masked, slice_starts, slice_ends)
}
""")
onnx.checker.check_model(test_onnx, True, True, True)
transformed_onnx, transformed = FP16Clip.apply(test_onnx)
assert transformed
assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[0]) == 65504.0
assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[1]) == 2147483647
assert onnx.numpy_helper.to_array(transformed_onnx.graph.node[1].attribute[0].t) == -65504.0


def test_fp16clip_transform_external(external_path):
external_weight_file = "fp32_min.weight"
test_onnx = onnx.parser.parse_model(
"""
<
ir_version: 8,
opset_import: ["" : 17]
>
test_fp16clip (float [n, 32] x) => (float [n, 32] y)
<
float min_val = [ "location": "<external_weight_file>" ],
float zero = {0.0}
>
{
mask = Greater(x, zero)
y = Where(mask, x, min_val)
}
""".replace("<external_weight_file>", str(external_weight_file))
)

# Write onnx and external_data
onnx_path = os.path.join(external_path, "test_fp16_clip_external.onnx")
onnx.save(test_onnx, onnx_path)
np.array(-1e10, dtype="float32").tofile(os.path.join(external_path, external_weight_file))

onnx.checker.check_model(onnx_path, True, True, True)
transformed_onnx, transformed = FP16Clip.apply(test_onnx, external_path)
assert transformed
assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[0]) == -65504.0
3 changes: 2 additions & 1 deletion tests/base/test_pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def forward(self, x):
y1 = model(x)
assert torch.any(y1 != x)

model = TestTransform.apply(model)
model, transformed = TestTransform.apply(model)
assert transformed
y2 = model(x)
assert torch.all(y2 == x)
5 changes: 4 additions & 1 deletion tests/cloud/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def onnx_dir_path(self):
return str(os.path.join(self.model_card_dir(), "onnx"))

def onnx_model_path(self):
return str(os.path.join(self.onnx_dir_path(), self.model_name.replace("/", "_") + "_kv_clipped_fp16.onnx"))
return [
str(os.path.join(self.onnx_dir_path(), self.model_name.replace("/", "_") + "_kv_clipped_fp16.onnx")),
str(os.path.join(self.onnx_dir_path(), self.model_name.replace("/", "_") + "_kv.onnx")),
]

def model_hf_path(self):
return str(os.path.join(self.cache_dir, self.model_name))
Expand Down
7 changes: 6 additions & 1 deletion tests/cloud/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ def test_compile(setup, mocker):
mocker: mocker is itself a pytest fixture, uses to mock or spy internal functions.
"""
ms = setup
for onnx_model_path in ms.onnx_model_path():
if os.path.isfile(onnx_model_path):
break
else:
raise RuntimeError(f"onnx file not found: {ms.onnx_model_path()}")
QEfficient.compile(
onnx_path=ms.onnx_model_path(),
onnx_path=onnx_model_path,
qpc_path=os.path.dirname(ms.qpc_dir_path()),
num_cores=ms.num_cores,
device_group=ms.device_group,
Expand Down
Loading

0 comments on commit 0b6bbed

Please sign in to comment.