Skip to content

Commit

Permalink
apply 16bit patching and refactor vlm support
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 3, 2024
1 parent a1c1737 commit a97f962
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 153 deletions.
93 changes: 48 additions & 45 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from optimum.utils.save_utils import maybe_load_preprocessors

from .utils import _MAX_UNCOMPRESSED_SIZE, clear_class_registry
from .utils import _MAX_UNCOMPRESSED_SIZE, MULTI_MODAL_TEXT_GENERATION_MODELS, clear_class_registry


if TYPE_CHECKING:
Expand Down Expand Up @@ -275,51 +275,54 @@ def main_export(
"Please provide custom export config if you want load model with remote code."
)
trust_remote_code = False
dtype = loading_kwargs.get("torch_dtype")
if isinstance(dtype, str):
dtype = config.torch_dtype if dtype == "auto" else getattr(torch, dtype)
dtype = loading_kwargs.get("torch_dtype")
if isinstance(dtype, str):
dtype = getattr(config, "torch_dtype") if dtype == "auto" else getattr(torch, dtype)

if (
dtype is None
and framework == "pt"
and not do_gptq_patching
and task.startswith("text-generation")
and getattr(config, "torch_dtype", torch.float32) in [torch.float16, torch.bfloat16]
):
if ov_config is not None and ov_config.dtype in {"fp16", "fp32"}:
dtype = torch.float16 if ov_config.dtype == "fp16" else torch.float32
elif is_openvino_version(">=", "2024.2") and config.torch_dtype == torch.float16:
dtype = torch.float16
elif is_openvino_version(">=", "2024.3") and config.torch_dtype == torch.bfloat16:
dtype = torch.bfloat16

if dtype is not None:
if dtype in [torch.float16, torch.bfloat16]:
patch_16bit = True
loading_kwargs["torch_dtype"] = dtype
# Patch the modules to export of GPTQ models w/o GPU
if do_gptq_patching:
torch.set_default_dtype(torch.float32)
orig_cuda_check = torch.cuda.is_available
torch.cuda.is_available = lambda: True

from optimum.gptq import GPTQQuantizer

orig_post_init_model = GPTQQuantizer.post_init_model

def post_init_model(self, model):
from auto_gptq import exllama_set_max_input_length

class StoreAttr(object):
pass

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
model = exllama_set_max_input_length(model, self.max_input_length)
return model

GPTQQuantizer.post_init_model = post_init_model
if (
dtype is None
and framework == "pt"
and not do_gptq_patching
and (
task.startswith("text-generation")
or getattr(config, "model_type", None) in MULTI_MODAL_TEXT_GENERATION_MODELS
)
and getattr(config, "torch_dtype", torch.float32) in [torch.float16, torch.bfloat16]
):
if ov_config is not None and ov_config.dtype in {"fp16", "fp32"}:
dtype = torch.float16 if ov_config.dtype == "fp16" else torch.float32
elif is_openvino_version(">=", "2024.2") and config.torch_dtype == torch.float16:
dtype = torch.float16
elif is_openvino_version(">=", "2024.3") and config.torch_dtype == torch.bfloat16:
dtype = torch.bfloat16

if dtype is not None:
if dtype in [torch.float16, torch.bfloat16]:
patch_16bit = True
loading_kwargs["torch_dtype"] = dtype
# Patch the modules to export of GPTQ models w/o GPU
if do_gptq_patching:
torch.set_default_dtype(torch.float32)
orig_cuda_check = torch.cuda.is_available
torch.cuda.is_available = lambda: True

from optimum.gptq import GPTQQuantizer

orig_post_init_model = GPTQQuantizer.post_init_model

def post_init_model(self, model):
from auto_gptq import exllama_set_max_input_length

class StoreAttr(object):
pass

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
model = exllama_set_max_input_length(model, self.max_input_length)
return model

GPTQQuantizer.post_init_model = post_init_model

if library_name == "open_clip":
model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
Expand Down
154 changes: 46 additions & 108 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,11 @@ def _compile_text_emb(self):
logger.info(f"Compiling the Text embeddings model to {self._device} ...")
self.text_emb_request = core.compile_model(self.text_emb_model, self._device, self.ov_config)

def to(self, device: str):
if self._compile_only:
raise ValueError(
"`to()` is not supported with `compile_only` mode, please intialize model without this option"
)

if isinstance(device, str):
self._device = device.upper()
self.clear_requests()

return self

def clear_requests(self):
if self._compile_only:
raise ValueError(
"`clear_requests()` is not supported with `compile_only` mode, please intialize model without this option"
)

del self.request
del self.text_emb_request
self.request = None
self.text_emb_request = None

Expand Down Expand Up @@ -379,10 +364,6 @@ def _from_pretrained(
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token

language_model_file_name = "openvino_language_model.xml"
text_embeddings_file_name = "openvino_text_embeddings_model.xml"
vision_embeddings_file_name = "openvino_vision_embeddings_model.xml"

model_cls = MODEL_TYPE_TO_CLS_MAPPING[config.model_type]

quantization_config = model_cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
Expand All @@ -391,58 +372,22 @@ def _from_pretrained(
# Load model from a local directory
if os.path.isdir(model_id):
model_save_dir = Path(model_id)
if not compile_only:
language_model = model_cls.load_model(
os.path.join(model_id, language_model_file_name), quantization_config
)
text_embeddings = model_cls.load_model(
os.path.join(model_id, text_embeddings_file_name), quantization_config
)
vision_embeddings = model_cls.load_model(
os.path.join(model_id, vision_embeddings_file_name), quantization_config
)

for part in model_cls.additional_parts:
part_file_name = f"openvino_{part}_model.xml"
part_model = model_cls.load_model(os.path.join(model_id, part_file_name), quantization_config)
kwargs[part] = part_model
else:
language_model = model_cls._compile_model(
os.path.join(model_id, language_model_file_name),
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
text_embeddings = model_cls._compile_model(
os.path.join(model_id, text_embeddings_file_name),
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
vision_embeddings = model_cls._compile_model(
os.path.join(model_id, vision_embeddings_file_name),
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
for part in model_cls.additional_parts:
part_file_name = f"openvino_{part}_model.xml"
part_model = model_cls._compile_model(
os.path.join(model_id, part_file_name), kwargs.get("device", "CPU"), kwargs.get("ov_config")
)
kwargs[part] = part_model

# Load model from hub
model_file_names = {
"language_model": "openvino_language_model.xml",
"text_embeddings": "openvino_text_embeddings_model.xml",
"vision_embeddings": "openvino_vision_embeddings_model.xml",
}

for part in model_cls.additional_parts:
model_file_names[part] = f"openvino_{part}_model.xml"
model_cls = MODEL_TYPE_TO_CLS_MAPPING[config.model_type]
quantization_config = model_cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
compile_only = kwargs.get("compile_only", False)
if os.path.isdir(model_id):
model_save_dir = Path(model_id)
file_names = {k: os.path.join(model_id, model_file_names[k]) for k in model_file_names}
else:
model_file_names = {
"language_model": language_model_file_name,
"text_embeddings": text_embeddings_file_name,
"vision_embeddings": vision_embeddings_file_name,
}
for part in model_cls.additional_parts:
model_file_names[part] = part_file_name = f"openvino_{part}_model.xml"

file_names = model_file_names.copy()
file_names = {}
for name, file_name in model_file_names.items():
model_cache_path = hf_hub_download(
repo_id=model_id,
Expand All @@ -454,40 +399,39 @@ def _from_pretrained(
local_files_only=local_files_only,
)
file_names[name] = model_cache_path

model_save_dir = Path(model_cache_path).parent
if not compile_only:
language_model = model_cls.load_model(file_names["language_model"], quantization_config)
text_embeddings = model_cls.load_model(file_names["text_embeddings"], quantization_config)
vision_embeddings = model_cls.load_model(file_names["vision_emnbeddings"], quantization_config)
for part in model_cls.additional_parts:
kwargs[part] = model_cls.load_model(file_names[part], quantization_config)
else:
language_model = model_cls._compile_model(
file_names["language_model"],
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
text_embeddings = model_cls._compile_model(
file_names["text_embeddings"],
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
vision_embeddings = model_cls._compile_model(
file_names["vision_embeddings"],
if not compile_only:
language_model = model_cls.load_model(file_names["language_model"], quantization_config)
text_embeddings = model_cls.load_model(file_names["text_embeddings"], quantization_config)
vision_embeddings = model_cls.load_model(file_names["vision_emnbeddings"], quantization_config)
for part in model_cls.additional_parts:
kwargs[part] = model_cls.load_model(file_names[part], quantization_config)
else:
language_model = model_cls._compile_model(
file_names["language_model"],
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
text_embeddings = model_cls._compile_model(
file_names["text_embeddings"],
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
vision_embeddings = model_cls._compile_model(
file_names["vision_embeddings"],
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
for part in model_cls.additional_parts:
kwargs[part] = model_cls._compile_model(
file_names[part],
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
for part in model_cls.additional_parts:
kwargs[part] = model_cls._compile_model(
file_names[part],
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
Expand Down Expand Up @@ -530,15 +474,6 @@ def _from_transformers(
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
FutureWarning,
)
if token is not None:
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token

compile_only = kwargs.pop("compile_only", False)
if compile_only:
logger.warning(
Expand Down Expand Up @@ -747,6 +682,7 @@ def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):

return image_features

# Adopted from https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/llava/modeling_llava.py#L297C9-L297C45
def merge_vision_text_embeddings(
self,
vision_embeds,
Expand Down Expand Up @@ -895,6 +831,7 @@ def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):


class _OVLlavaNextForCausalLM(_OVLlavaForCausalLM):
# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L655
def pack_image_features(self, image_features, image_sizes, image_newline=None):
from transformers.models.llava_next.modeling_llava_next import get_anyres_image_grid_shape, unpad_image

Expand Down Expand Up @@ -951,6 +888,7 @@ def pack_image_features(self, image_features, image_sizes, image_newline=None):
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens

# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L416
def get_multimodal_embeddings(
self,
input_ids,
Expand Down

0 comments on commit a97f962

Please sign in to comment.