Skip to content

Commit

Permalink
[OV]: load and convert llms in original precision (#778)
Browse files Browse the repository at this point in the history
* [OV]: load and convert llm in original precision

* unpatch for onnx

* add torch_dtype option for loading model

* fix rotary emb initialization

* fix patching order

* force precision using --weight-format

* fix quantization tests

* fix test

* move torch import
  • Loading branch information
eaidova authored Aug 19, 2024
1 parent ad1fe8b commit e9800ce
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 9 deletions.
42 changes: 38 additions & 4 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
from transformers.utils import is_torch_available

from optimum.exporters import TasksManager
from optimum.exporters.onnx.base import OnnxConfig
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
from optimum.exporters.openvino.convert import export_from_model
from optimum.intel.utils.import_utils import is_openvino_tokenizers_available, is_transformers_version
from optimum.intel.utils.import_utils import (
is_openvino_tokenizers_available,
is_openvino_version,
is_transformers_version,
)
from optimum.utils.save_utils import maybe_load_preprocessors

from .utils import clear_class_registry
Expand All @@ -35,6 +40,11 @@
if TYPE_CHECKING:
from optimum.intel.openvino.configuration import OVConfig


if is_torch_available():
import torch


_COMPRESSION_OPTIONS = {
"int8": {"bits": 8},
"int4_sym_g128": {"bits": 4, "sym": True, "group_size": 128},
Expand Down Expand Up @@ -100,6 +110,7 @@ def main_export(
stateful: bool = True,
convert_tokenizer: bool = False,
library_name: Optional[str] = None,
model_loading_kwargs: Optional[Dict[str, Any]] = None,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -230,7 +241,8 @@ def main_export(

do_gptq_patching = False
custom_architecture = False
loading_kwargs = {}
patch_16bit = False
loading_kwargs = model_loading_kwargs or {}
if library_name == "transformers":
config = AutoConfig.from_pretrained(
model_name_or_path,
Expand Down Expand Up @@ -281,11 +293,32 @@ def main_export(
"Please provide custom export config if you want load model with remote code."
)
trust_remote_code = False
dtype = loading_kwargs.get("torch_dtype")
if isinstance(dtype, str):
dtype = config.torch_dtype if dtype == "auto" else getattr(torch, dtype)

if (
dtype is None
and framework == "pt"
and not do_gptq_patching
and task.startswith("text-generation")
and getattr(config, "torch_dtype", torch.float32) in [torch.float16, torch.bfloat16]
):
if ov_config is not None and ov_config.dtype in {"fp16", "fp32"}:
dtype = torch.float16 if ov_config.dtype == "fp16" else torch.float32
elif is_openvino_version(">=", "2024.2") and config.torch_dtype == torch.float16:
dtype = torch.float16
elif is_openvino_version(">=", "2024.3") and config.torch_dtype == torch.bfloat16:
dtype = torch.bfloat16

if dtype is not None:
if dtype in [torch.float16, torch.bfloat16]:
patch_16bit = True
loading_kwargs["torch_dtype"] = dtype

logger.warning(loading_kwargs)
# Patch the modules to export of GPTQ models w/o GPU
if do_gptq_patching:
import torch

torch.set_default_dtype(torch.float32)
orig_cuda_check = torch.cuda.is_available
torch.cuda.is_available = lambda: True
Expand Down Expand Up @@ -383,6 +416,7 @@ class StoreAttr(object):
preprocessors=preprocessors,
device=device,
trust_remote_code=trust_remote_code,
patch_16bit_model=patch_16bit,
**kwargs_shapes,
)

Expand Down
18 changes: 18 additions & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def export(
model_kwargs: Optional[Dict[str, Any]] = None,
ov_config: Optional["OVConfig"] = None,
stateful: bool = True,
patch_16bit_model: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -156,6 +157,7 @@ def export(
ov_config=ov_config,
model_kwargs=model_kwargs,
stateful=stateful,
patch_16bit_model=patch_16bit_model,
)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
Expand Down Expand Up @@ -289,6 +291,7 @@ def export_pytorch(
model_kwargs: Optional[Dict[str, Any]] = None,
ov_config: Optional["OVConfig"] = None,
stateful: bool = False,
patch_16bit_model: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -381,6 +384,10 @@ def ts_patched_forward(*args, **kwargs):
patcher.patched_forward = ts_patched_forward

with patcher:
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable

__make_16bit_traceable(model)
check_dummy_inputs_are_allowed(model, dummy_inputs)
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
inputs = config.ordered_inputs(model)
Expand All @@ -401,6 +408,13 @@ def ts_patched_forward(*args, **kwargs):
"A stateless model will be exported instead. It may result in sub-optimal inference performance."
"Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
)

if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import unpatch_model

unpatch_model(model, "_openvino_module_extension_patch_orig_forward")
model.to(torch.float32)

return export_pytorch_via_onnx(
model,
config,
Expand Down Expand Up @@ -467,6 +481,7 @@ def export_models(
model_kwargs: Optional[Dict[str, Any]] = None,
ov_config: Optional["OVConfig"] = None,
stateful: bool = True,
patch_16bit_model: bool = False,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Export the models to OpenVINO IR format
Expand Down Expand Up @@ -518,6 +533,7 @@ def export_models(
model_kwargs=model_kwargs,
ov_config=ov_config,
stateful=stateful,
patch_16bit_model=patch_16bit_model,
)
)

Expand All @@ -538,6 +554,7 @@ def export_from_model(
preprocessors: List = None,
device: str = "cpu",
trust_remote_code: bool = False,
patch_16bit_model: bool = False,
**kwargs_shapes,
):
model_kwargs = model_kwargs or {}
Expand Down Expand Up @@ -700,6 +717,7 @@ def export_from_model(
stateful=stateful,
opset=opset,
model_kwargs=model_kwargs,
patch_16bit_model=patch_16bit_model,
)


Expand Down
41 changes: 41 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

from packaging import version
from transformers import PreTrainedModel, TFPreTrainedModel
from transformers.utils import is_tf_available

from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.onnx.model_configs import (
CodeGenOnnxConfig,
FalconOnnxConfig,
GemmaOnnxConfig,
GPTNeoXOnnxConfig,
LlamaOnnxConfig,
MistralOnnxConfig,
MPTOnnxConfig,
Expand All @@ -31,6 +33,7 @@
VaeDecoderOnnxConfig,
VaeEncoderOnnxConfig,
)
from optimum.exporters.onnx.model_patcher import ModelPatcher
from optimum.exporters.tasks import TasksManager
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.input_generators import (
Expand All @@ -50,6 +53,9 @@
ChatGLMModelPatcher,
CodeGenModelPatcher,
DBRXModelPatcher,
FalconModelPatcher,
GptNeoxJapaneseModelPatcher,
GptNeoxModelPatcher,
InternLM2Patcher,
InternLMModelPatcher,
JaisModelPatcher,
Expand All @@ -60,6 +66,7 @@
PersimmonModelPatcher,
Phi3ModelPatcher,
QwenModelPatcher,
RotaryEmbPatcher,
UpdateCausalMaskModelPatcher,
XverseModelPatcher,
)
Expand Down Expand Up @@ -505,6 +512,12 @@ def patch_model_for_export(
return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs)


def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return RotaryEmbPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("internlm2", *["text-generation", "text-generation-with-past"], library_name="transformers")
class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14
Expand Down Expand Up @@ -632,6 +645,11 @@ class FalconOpenVINOConfig(FalconOnnxConfig):
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return FalconModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
class UNetOpenVINOConfig(UNetOnnxConfig):
Expand Down Expand Up @@ -725,6 +743,11 @@ class GPTNeoxJapaneseOpenVINOConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return GptNeoxJapaneseModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"cohere",
Expand Down Expand Up @@ -913,3 +936,21 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MistralModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"gpt-neox",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class GPTNeoxOpenVINOConfig(GPTNeoXOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return GptNeoxModelPatcher(self, model, model_kwargs=model_kwargs)
46 changes: 46 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def patch_update_causal_mask(model, transformers_version):
inner_model._update_causal_mask = types.MethodType(_llama_gemma_update_causal_mask, inner_model)


# initialization of sin/cos cached in bf16/fp16 leads to accuracy loss
# reinitialize them to save in float32 before export
def _reinitialize_cos_sin_cached_fp32(rotary_emb):
if rotary_emb.cos_cached.dtype != torch.float32:
rotary_emb._set_cos_sin_cache(
seq_len=rotary_emb.max_position_embeddings, device=rotary_emb.inv_freq.device, dtype=torch.float32
)


def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
Expand Down Expand Up @@ -158,6 +167,7 @@ def __enter__(self):
layer.block_sparse_moe.forward = types.MethodType(
_mixtral_sparse_moe_block_forward, layer.block_sparse_moe
)
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
Expand Down Expand Up @@ -689,6 +699,10 @@ def __enter__(self):
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
self._model.model._update_causal_mask = types.MethodType(_mistral_update_causal_mask, self._model.model)

else:
for layer in self._model.model.layers:
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)

Expand Down Expand Up @@ -2224,6 +2238,7 @@ def __enter__(self):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_persimmon_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
Expand Down Expand Up @@ -2359,8 +2374,39 @@ class UpdateCausalMaskModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.42.0")
if hasattr(self._model.model.layers[0].self_attn.rotary_emb, "_set_cos_sin_cache"):
for layer in self._model.model.layers:
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_update_causal_mask"):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask


class RotaryEmbPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)


class FalconModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.transformer.h:
_reinitialize_cos_sin_cached_fp32(layer.self_attention.rotary_emb)


class GptNeoxModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.gpt_neox.layers:
_reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb)


class GptNeoxJapaneseModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
for layer in self._model.gpt_neox_japanese.layers:
_reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb)
10 changes: 9 additions & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,17 @@ def _from_transformers(
if load_in_8bit is None and not quantization_config:
ov_export_config = None
else:
ov_export_config = OVConfig(dtype="fp32")
ov_export_config = OVConfig(dtype="auto")

stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)

torch_dtype = kwargs.pop("torch_dtype", None)

model_loading_kwargs = {}

if torch_dtype is not None:
model_loading_kwargs["torch_dtype"] = torch_dtype

main_export(
model_name_or_path=model_id,
output=save_dir_path,
Expand All @@ -298,6 +305,7 @@ def _from_transformers(
trust_remote_code=trust_remote_code,
ov_config=ov_export_config,
stateful=stateful,
model_loading_kwargs=model_loading_kwargs,
)

config.is_decoder = True
Expand Down
Loading

0 comments on commit e9800ce

Please sign in to comment.