From afb4645d561ac8b33073e572cb5c9400f838fde0 Mon Sep 17 00:00:00 2001 From: Onkar Chougule <168134249+ochougul@users.noreply.github.com> Date: Sat, 14 Sep 2024 00:23:01 +0530 Subject: [PATCH] AWQ+GPTQ (#101) * Awq feature (#100) * added preprocess layer before loading quantized awq weights Signed-off-by: Onkar Chougule * added onnx export Signed-off-by: Onkar Chougule * added ScaledActivation class Signed-off-by: Onkar Chougule * refactoring the code to right places and added one single test for now Signed-off-by: Onkar Chougule * cleaned code Signed-off-by: Onkar Chougule * added proper tests, added decorator for updating quantizers, cleaned code Signed-off-by: Onkar Chougule * fixed CLI Signed-off-by: Onkar Chougule * added auto file for decorator Signed-off-by: Onkar Chougule --------- Signed-off-by: Onkar Chougule * bugfix for tests Signed-off-by: Onkar Chougule * fixed tests for AWQ model Signed-off-by: Onkar Chougule * Adding support for GPTQ models (#103) * Adding support for gptq models Signed-off-by: Amit Raj * Code cleaning and formating Signed-off-by: Amit Raj * ruff format and fixed some bug Signed-off-by: Amit Raj * Added tests for gptq Signed-off-by: Amit Raj * Bug-fix-1 Signed-off-by: Amit Raj * fixed bugs-2 Signed-off-by: Amit Raj * fixed bug-3 Signed-off-by: Amit Raj * Added docstring Signed-off-by: Amit Raj * Addressed comments Signed-off-by: Amit Raj * Addressed comments Signed-off-by: Amit Raj * fixed bugs-3 Signed-off-by: Amit Raj * ruff check and format Signed-off-by: Amit Raj * Addressed comments-3 Signed-off-by: Amit Raj --------- Signed-off-by: Amit Raj Signed-off-by: Onkar Chougule * added liscence at top for missing file Signed-off-by: Onkar Chougule * added export_and_compile and fixed bugs Signed-off-by: Onkar Chougule * removed GPTQ test Signed-off-by: Onkar Chougule * removed threading from pytest Signed-off-by: Onkar Chougule --------- Signed-off-by: Onkar Chougule Signed-off-by: Amit Raj Co-authored-by: Amit Raj <168538872+quic-amitraj@users.noreply.github.com> --- QEfficient/base/common.py | 11 +- QEfficient/base/pytorch_transforms.py | 32 ++ QEfficient/customop/matmulnbits.py | 182 +++++++++ .../exporter/export_hf_to_cloud_ai_100.py | 2 +- .../transformers/models/modeling_auto.py | 163 ++++++-- .../transformers/quantizers/__init__.py | 6 + QEfficient/transformers/quantizers/auto.py | 41 ++ QEfficient/transformers/quantizers/awq.py | 81 ++++ QEfficient/transformers/quantizers/gptq.py | 80 ++++ .../quantizers/quant_transforms.py | 100 +++++ .../transformers/quantizers/quantizer_awq.py | 84 ++++ .../transformers/quantizers/quantizer_gptq.py | 151 +++++++ .../quantizers/quantizer_utils.py | 380 ++++++++++++++++++ QEfficient/utils/_utils.py | 4 +- scripts/Jenkinsfile | 2 +- tests/base/test_pytorch_transforms.py | 46 ++- .../models/test_causal_lm_models.py | 6 +- .../test_transformer_pytorch_transforms.py | 57 +++ tests/utils.py | 16 +- 19 files changed, 1384 insertions(+), 60 deletions(-) create mode 100644 QEfficient/customop/matmulnbits.py create mode 100644 QEfficient/transformers/quantizers/__init__.py create mode 100644 QEfficient/transformers/quantizers/auto.py create mode 100644 QEfficient/transformers/quantizers/awq.py create mode 100644 QEfficient/transformers/quantizers/gptq.py create mode 100644 QEfficient/transformers/quantizers/quant_transforms.py create mode 100644 QEfficient/transformers/quantizers/quantizer_awq.py create mode 100644 QEfficient/transformers/quantizers/quantizer_gptq.py create mode 100644 QEfficient/transformers/quantizers/quantizer_utils.py diff --git a/QEfficient/base/common.py b/QEfficient/base/common.py index df7496e1..866c6218 100644 --- a/QEfficient/base/common.py +++ b/QEfficient/base/common.py @@ -31,7 +31,6 @@ class QEFF_MODEL_TYPE(Enum): CAUSALLM = "LLM" DIFFUSION = "DIFFUSION" - AWQ = "AWQ" MODEL_TYPE_TO_QEFF_AUTO_MODEL_MAP: Dict[QEFF_MODEL_TYPE, Type[QEFFBaseModel]] = { @@ -56,15 +55,7 @@ def get_hf_model_type(hf_model_path: str) -> QEFF_MODEL_TYPE: ) if config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING: - # FIXME: Add logic to handle if quantization config is stored in separate quant_config.json outside of config, also create a separate function for this and below lines - quant_config = getattr(config, "quantization_config", getattr(config, "quant_config", None)) - if quant_config is not None: - if quant_config.get("quant_method", None) == "awq": - return QEFF_MODEL_TYPE.AWQ - else: - raise NotImplementedError(f"current model type is not yet supported {type(config)}") - else: - return QEFF_MODEL_TYPE.CAUSALLM + return QEFF_MODEL_TYPE.CAUSALLM else: raise NotImplementedError(f"model type {type(config)} is not yet supported") diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index b787cebc..6e21d11b 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -55,3 +55,35 @@ def register(cls, from_module: Type[nn.Module], to_module: Type[nn.Module]): FlashAttention.register(LLamaAttention, LlamaFlashAttention) """ cls._module_mapping[from_module] = to_module + + +class ModuleMutatorTransform(PytorchTransform): + """Serves as base class for any transform that mutates pytorch module in any way. + Mutate here mean, we initialize a new pytorch module object using info from original module and + replace original module with new module. + + Raises: + NotImplementedError: Not supposed to use directly, Create a subclass and implement mutate method and assign a valid nn.Module class to _match_class variable. + """ + + _match_class: nn.Module + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + transformed = False + for name, module in model.named_children(): + if isinstance(module, cls._match_class): + setattr(model, name, cls.mutate(module, model)) + transformed = True + else: + cls.apply(module) + + if isinstance(model, cls._match_class): + model = cls.mutate(model, None) + transformed = True + + return model, transformed + + @classmethod + def mutate(cls, original_module: nn.Module, parent_module: nn.Module): + raise NotImplementedError("Please implement your own method by inheriting this class") diff --git a/QEfficient/customop/matmulnbits.py b/QEfficient/customop/matmulnbits.py new file mode 100644 index 00000000..4b813430 --- /dev/null +++ b/QEfficient/customop/matmulnbits.py @@ -0,0 +1,182 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import math + +import torch +from torch import nn + + +class QuantLinearTorchFunction(torch.autograd.Function): + @staticmethod + def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features): + input_tuple = (x, qself_qweight, qself_scales, qself_qzeros) + input_tuple += (g_idx,) if g_idx is not None else () + return g.op( + "com.microsoft::MatMulNBits", + *input_tuple, + outputs=1, + K_i=in_features, + N_i=out_features, + bits_i=bits, + block_size_i=group_size, + ) + + @staticmethod + def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features): + if torch.onnx.is_in_onnx_export(): + return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype).float() + fp_weight = dequantize_blockwise_bits( + qself_qweight, qself_scales, qself_qzeros, bits, group_size, g_idx, in_features, out_features + )[0].float() + + return torch.matmul(x.float(), fp_weight.T.float()) + + +def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, group_size, g_idx, rows, cols): + if bits != 4: + raise ValueError("Only bits=4 is supported for executing quantized model") + if group_size != 128: + raise ValueError("Only group_size=128 is supported for executing quantized model") + expand_quant_value = (quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F + expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1) + aligned_scale = scale.reshape(*quant_values.shape[:-1], 1) + if zero_point.dtype == scale.dtype: + expand_zero_point = zero_point.reshape(*quant_values.shape[:-1], -1) + else: + expand_zero_point = (zero_point.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F + try: + expand_zero_point = expand_zero_point.reshape(*quant_values.shape[:-1], -1) + # FIXME: remove try-except + except RuntimeError: + expand_zero_point = expand_zero_point.reshape(quant_values.shape[0], -1, 1) + expand_zero_point = expand_zero_point[:, : quant_values.shape[1]] + if g_idx is not None and g_idx[:32].sum().item() != 0: + float_values = ( + (expand_quant_value.reshape(expand_quant_value.shape[0], -1) - expand_zero_point[:, g_idx, 0]) + * aligned_scale[:, g_idx, 0] + ).to(scale.dtype) + else: + float_values = ((expand_quant_value - expand_zero_point) * aligned_scale).to(scale.dtype) + float_values = float_values.reshape(cols, -1) + if rows != float_values.shape[-1]: + float_values = float_values[:, :rows] + expand_zero_point = expand_zero_point[:, :rows] + if expand_zero_point.ndim == 3: + expand_zero_point = expand_zero_point.squeeze(-1) + if aligned_scale.ndim == 3: + aligned_scale = aligned_scale.squeeze(-1) + + return float_values, expand_zero_point, aligned_scale + + +class QuantLinearORT(nn.Module): + def __init__(self, bits, group_size, in_features, out_features, bias): + super().__init__() + if bits not in [2, 3, 4, 5, 6, 7, 8]: + raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.") + self.in_features = in_features + self.out_features = out_features + self.bits = bits + self.group_size = group_size if group_size != -1 else in_features + self.act_order = None + + q_rows = in_features // self.group_size + self.register_buffer( + "qweight", + torch.zeros((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8), + ) + self.register_buffer( + "qzeros", + torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8), + ) + self.register_buffer( + "scales", torch.zeros((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16) + ) + self.register_buffer( + "g_idx", torch.tensor([i // self.group_size for i in range(in_features)], dtype=torch.int32) + ) + if bias: + self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16)) + else: + self.bias = None + + def quant_weight(self, weight, scales, zeros, g_idx): + scale_zeros = zeros * scales + scale_mat = scales[g_idx] + scale_zeros_mat = scale_zeros[g_idx] + int_weight_T = torch.round(((weight + scale_zeros_mat) / scale_mat).float()).to(torch.int) + return int_weight_T + + def pack_on_device(self, int_weight, int_zeros): + if self.bits != 4: + raise ValueError("only 4bit is supported by ONNXRUNTIME for now.") + + # Order of groups + self.act_order = self.g_idx[: self.group_size // self.bits].sum().item() != 0 + + intzeros_pt = int_zeros.T if int_zeros.dtype == self.scales.dtype else int_zeros.T.byte() + scales_pt = self.scales.T.to(int_weight.device) + intweight_pt = int_weight.byte() + + block_size = self.group_size + rows, cols = intweight_pt.shape + blob_size = block_size // 2 + k_blocks = (rows + block_size - 1) // block_size + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + if pad_len > 0: + intweight_pt = torch.nn.functional.pad(intweight_pt, (0, 0, 0, pad_len), "constant", 0) + intzeros_pt = torch.nn.functional.pad(intzeros_pt, (0, intzeros_pt.shape[-1] & 1, 0, 0), "constant", 0) + + # Pack zeros if they are not float + if int_zeros.dtype != self.scales.dtype: + intzeros_pt = (intzeros_pt[:, 0::2]) | (intzeros_pt[:, 1::2] << 4) + intzeros_pt = intzeros_pt.reshape(-1) + + # Pack weights + intweight_pt_T = int_weight.T + intweight_pt_T = (intweight_pt_T[:, 0::2]) | (intweight_pt_T[:, 1::2] << 4) + intweight_pt_T = intweight_pt_T.reshape(cols, k_blocks, blob_size) + + scales_pt = scales_pt.reshape(-1) + + # Validation checks + if (self.qweight.shape != intweight_pt_T.shape) and ( + self.qzeros.shape == intzeros_pt.shape or self.qzeros.dtype != intzeros_pt.dtype + ): + raise RuntimeError("Something went wrong while packing the weights in QuantLinearORT module") + + # Assign buffers + self.scales = scales_pt.float() + self.qweight = intweight_pt_T.byte() # Convert to uint8 + if int_zeros.dtype != self.scales.dtype: + self.qzeros = intzeros_pt.byte() # Convert to uint8 + else: + self.qzeros = intzeros_pt + + def pack(self, linear, scales, zeros, g_idx=None): + layer_weight = linear.weight.data + self.scales = scales.T + self.g_idx = g_idx.clone() + int_weight = self.quant_weight(layer_weight.T, scales.T, zeros.T, g_idx) + return self.pack_on_device(int_weight, zeros.T) + + def forward(self, inputs): + out = QuantLinearTorchFunction().apply( + inputs, + self.qweight, + self.scales, + self.qzeros, + self.g_idx if self.act_order else None, + self.bits, + self.group_size, + self.in_features, + self.out_features, + ) + out = out + self.bias if self.bias is not None else out + return out diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py index 0649b1f7..706d1410 100644 --- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py +++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py @@ -443,7 +443,7 @@ def qualcomm_efficient_converter( model_kv = model_kv if model_kv.is_transformed else QEfficient.transform(model_kv) if kv else model_kv if onnx_dir_path is None: - model_card_dir = os.path.join(QEFF_MODELS_DIR, str(model_name)) + model_card_dir = os.path.join(QEFF_MODELS_DIR, str(model_kv.model_card_name)) onnx_dir_path = os.path.join(model_card_dir, "onnx") os.makedirs(onnx_dir_path, exist_ok=True) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index fa15fa8e..5cd058be 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -5,6 +5,7 @@ # # ---------------------------------------------------------------------------- +import hashlib import os from typing import Any, List, Optional, Union @@ -14,7 +15,12 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel, Runtime from QEfficient.transformers.pytorch_transforms import CBTransform, CustomOpsTransform, KVCacheTransform +from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers +from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform +from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig +from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig from QEfficient.utils import get_qpc_dir_path, load_hf_tokenizer +from QEfficient.utils.constants import QEFF_MODELS_DIR from QEfficient.utils.logging_utils import logger # Dictionary that defines the interface from transformers to be used underneath the QEFF interface @@ -30,6 +36,11 @@ class QEFFTransformersBase(QEFFBaseModel): """ def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwargs) -> None: + if hasattr(model.config, "quantization_config") and not isinstance( + model.config.quantization_config, tuple(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.values()) + ): + raise AssertionError("Please use `from_pretrained` method to load quantized models") + super().__init__(model) self.model.config.use_cache = ( True # Always pass use_cache = True, to get KV values as output during ONNX export @@ -37,12 +48,16 @@ def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwarg self.pretrained_model_name_or_path = pretrained_model_name_or_path # Set model card name, which is used to decide ONNX, QPC files path during export and compile resp. - model_card_name = kwargs.pop("model_card_name", None) - self.model_card_name = ( - model_card_name - if model_card_name - else (self.pretrained_model_name_or_path if not os.path.isdir(self.pretrained_model_name_or_path) else None) - ) + if model_card_name := kwargs.pop("model_card_name", None): + self.model_card_name = model_card_name + elif os.path.isdir(self.pretrained_model_name_or_path): + hash_object = hashlib.sha256() + hash_object.update(self.pretrained_model_name_or_path.encode("utf-8")) + self.model_card_name = hash_object.hexdigest() + else: + self.model_card_name = self.pretrained_model_name_or_path + + self.full_batch_size = kwargs.get("full_batch_size", None) self.kwargs = kwargs self._tokenizer = None self.is_transformed = False @@ -53,6 +68,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}\n" + self.model.__repr__() @classmethod + @with_replaced_quantizers def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): """ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. @@ -64,7 +80,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): ``Mandatory`` Args: :transform (bool): Whether to optimize model for KV retention; default is ``True``. Pass ``False`` to get BertStyle model. :model_card_name (str): ``HuggingFace`` model card name or name of the model if custom, used for deciding directory name while saving ``ONNX/qpc`` files. - + :full_batch_size (int): Pass this if you want to execute model with continuous batching. Example usage: .. code-block:: python @@ -92,6 +108,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): logger.warning(f"Updating attn_implementation to be 'eager', got {attn_implementation}") kwargs.update({"attn_implementation": "eager"}) + if low_cpu_mem_usage := kwargs.get("low_cpu_mem_usage", None): + logger.warning(f"Updating low_cpu_mem_usage to be 'False', got {low_cpu_mem_usage}") + kwargs.update({"low_cpu_mem_usage": False}) + model = QEFFAutoModelToTransformersAutoModelMap[cls.__name__].from_pretrained( pretrained_model_name_or_path, *args, **kwargs ) @@ -148,9 +168,26 @@ def transform(self, **kwargs): """ if self.is_transformed: return - if kwargs.get("full_batch_size", None): - self._pytorch_transforms.remove(KVCacheTransform) - self._pytorch_transforms.append(CBTransform) + + if self.full_batch_size is not None: + if KVCacheTransform in self._pytorch_transforms: + self._pytorch_transforms[self._pytorch_transforms.index(KVCacheTransform)] = CBTransform + if CBTransform not in self._pytorch_transforms: + raise RuntimeError("please don't update _pytorch_transforms variable") + else: + if CBTransform in self._pytorch_transforms: + self._pytorch_transforms[self._pytorch_transforms.index(CBTransform)] = KVCacheTransform + if KVCacheTransform not in self._pytorch_transforms: + raise RuntimeError("Please don't update _pytorch_transforms variable") + + # Update list of pytorch transforms if the model falls in AWQ/GPTQ category + if hasattr(self.model.config, "quantization_config"): + if isinstance(self.model.config.quantization_config, QEffAwqConfig): + self._pytorch_transforms.insert(0, AwqToMatmulNbitsTransform) + + if isinstance(self.model.config.quantization_config, QEffGPTQConfig): + self._pytorch_transforms.insert(0, GPTQToMatmulNbitsTransform) + for transform in self._pytorch_transforms: transform.apply(self.model) self.is_transformed = True @@ -158,7 +195,7 @@ def transform(self, **kwargs): def execute(self, *args, **kwargs): # type: ignore raise NotImplementedError("Reached too far!!") - def export(self, model_card_name: Optional[str] = None) -> str: + def export(self) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. The model should already be transformed i.e. ``self.is_transformed`` should be ``True``. @@ -166,7 +203,7 @@ def export(self, model_card_name: Optional[str] = None) -> str: We currently don't support exporting non-transformed models. Please refer to the ``convert_to_cloud_bertstyle`` function in the **Low-Level API** for a legacy function that supports this." ``Optional`` Args: - :model_card_name (Optional[str]): Name of the model card. Mandatory when model is initialized with path for ``pretrained_model_name_or_path`` argument during initialization. ``Defaults to None.`` + does not any arguments. Raises: :AttributeError: If ``pretrained_model_name_or_path`` is a path, this function needs model card name of the model so that it can distinguish between directories while saving the ``ONNX`` files generated. So, user needs to pass ``model_card_name`` as a valid ``string`` in that case, Otherwise this will raise the error. @@ -175,15 +212,13 @@ def export(self, model_card_name: Optional[str] = None) -> str: :str: Path of the generated ``ONNX`` graph. """ assert self.is_transformed, "Please first run transform on the QEFFAutoModelForCausalLM object" - - # Make sure model_card_name is available for export - if self.model_card_name is None and model_card_name is None: - raise AttributeError("Please pass model_card_name as valid string input") - elif model_card_name is not None: - self.model_card_name = model_card_name - # Export - _, onnx_model_path = QEfficient.export(model_name=self.model_card_name, model_kv=self, tokenizer=self.tokenizer) + _, onnx_model_path = QEfficient.export( + model_name=self.model_card_name, + model_kv=self, + tokenizer=self.tokenizer, + full_batch_size=self.full_batch_size, + ) self.onnx_path = onnx_model_path return self.onnx_path @@ -191,8 +226,7 @@ def export(self, model_card_name: Optional[str] = None) -> str: def compile( self, num_cores: int, - device_group: List[int], - model_card_name: Optional[str] = None, + device_group: List[int] = None, batch_size: int = 1, prompt_len: int = 32, ctx_len: int = 128, @@ -208,7 +242,7 @@ def compile( ``Mandatory`` Args: :num_cores (int): Number of cores used to compile the model. - :device_group (List[int]): If this is a list of more that one integers, tensor-slicing is invoked. + :device_group (List[int]): If this is a list of more that one integers, tensor-slicing is invoked, defaults to None, and automatically chooses suitable device. ``Optional`` Args: :model_card_name (Optional[str], optional): Name of the model, Mandatory if ``self.pretrained_model_name_or_path`` is a path. ``Defaults to None``. :batch_size (int, optional): Batch size. ``Defaults to 1``. @@ -225,7 +259,7 @@ def compile( # Export first if self.ort_runtime_args are not populated if self.onnx_path is None: logger.info(f"Exporting the {self.model.__class__.__name__} model to ONNX for compilation!") - self.export(model_card_name=model_card_name) + self.export() # Prepare qpc dir path qpc_dir_path = get_qpc_dir_path( @@ -238,7 +272,7 @@ def compile( mxfp6=mxfp6, mxint8=mxint8, device_group=device_group, - full_batch_size=None, + full_batch_size=self.full_batch_size, ) # Compile @@ -254,13 +288,73 @@ def compile( ctx_len=ctx_len, mxfp6=mxfp6, mxint8=mxint8, + full_batch_size=self.full_batch_size, ) self.qpc_path = qpc_dir_path - self.device_id = device_group + return self.qpc_path + + def export_and_compile( + self, + num_cores: int, + device_group: List[int], + batch_size: int = 1, + prompt_len: int = 32, + ctx_len: int = 128, + mxfp6: bool = True, + mxint8: bool = False, + mos: int = -1, + aic_enable_depth_first: bool = False, + qpc_dir_suffix: Optional[str] = None, + full_batch_size: Optional[int] = None, + ) -> str: + """ + This API is specific to Internal VLLM use-case and is not recommended to be used in your application unless your are using VLLM. + """ + _, transformed = CBTransform.apply(self.model) + if not transformed: + raise RuntimeError("Could not apply Continuous batch transform on the model") + if full_batch_size is not None: + self.full_batch_size = full_batch_size + self.export() + + qpc_base_dir_name = get_qpc_dir_path( + model_card_name=self.model_card_name, + num_cores=num_cores, + mos=mos, + batch_size=batch_size, + prompt_len=prompt_len, + ctx_len=ctx_len, + mxfp6=mxfp6, + mxint8=mxint8, + device_group=device_group, + full_batch_size=self.full_batch_size, + ) + qpc_base_dir_name = ( + os.path.dirname(qpc_base_dir_name) + "_" + qpc_dir_suffix if qpc_dir_suffix else qpc_base_dir_name + ) + model_card_dir = os.path.join(QEFF_MODELS_DIR, str(self.model_card_name)) + os.makedirs(model_card_dir, exist_ok=True) + qpc_dir_path = os.path.join(model_card_dir, qpc_base_dir_name) + + # Compile + self.qpc_path = QEfficient.compile( + onnx_path=self.onnx_path, + qpc_path=qpc_dir_path, + num_cores=num_cores, + device_group=device_group, + aic_enable_depth_first=aic_enable_depth_first, + mos=mos, + batch_size=batch_size, + prompt_len=prompt_len, + ctx_len=ctx_len, + mxfp6=mxfp6, + mxint8=mxint8, + full_batch_size=full_batch_size, + ) return self.qpc_path - def generate(self, prompts: List[str], runtime: str = "AI_100", **kwargs): + def generate(self, prompts: List[str], device_id: List[int] = None, runtime: str = "AI_100", **kwargs): """ This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. This is a sequential execution based on the ``batch_size`` of the compiled model and the number of prompts passed. @@ -268,20 +362,23 @@ def generate(self, prompts: List[str], runtime: str = "AI_100", **kwargs): ``Mandatory`` Args: :prompts (List[str]): List of prompts to run the execution. + :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model ``optional`` Args: :runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100". """ assert Runtime(runtime) == Runtime.AI_100, "Only AI_100 runtime is supported right now via generate API" - self.run_cloud_ai_100(prompts=prompts, **kwargs) + self.run_cloud_ai_100(prompts=prompts, device_id=device_id, **kwargs) - def run_cloud_ai_100(self, prompts: List[str], **kwargs): + def run_cloud_ai_100(self, prompts: List[str], device_id: List[int] = None, **kwargs): assert isinstance(self.qpc_path, str), "Please run compile API first!" - assert ( - self.device_id is not None - ), "please pass valid device_id as input argument" # FIXME: replace with isinstance generation_len = kwargs.pop("generation_len", None) return QEfficient.cloud_ai_100_exec_kv( - self.tokenizer, self.qpc_path, prompt=prompts, device_id=self.device_id, generation_len=generation_len + self.tokenizer, + self.qpc_path, + prompt=prompts, + device_id=device_id, + generation_len=generation_len, + full_batch_size=self.full_batch_size, ) diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py new file mode 100644 index 00000000..d259e435 --- /dev/null +++ b/QEfficient/transformers/quantizers/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/quantizers/auto.py b/QEfficient/transformers/quantizers/auto.py new file mode 100644 index 00000000..aa84f908 --- /dev/null +++ b/QEfficient/transformers/quantizers/auto.py @@ -0,0 +1,41 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING + +from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer +from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer + +QEFF_AUTO_QUANTIZER_MAPPING = {"awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer} +QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = {"awq": QEffAwqConfig, "gptq": QEffGPTQConfig} + + +def with_replaced_quantizers(func): + def wrapper(*args, **kwargs): + transformers_replaced_quantization_config_mapping = dict() + transformers_replaced_quantizer_mapping = dict() + + for k in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.keys(): + # Replace quantization config + transformers_replaced_quantization_config_mapping[k] = AUTO_QUANTIZATION_CONFIG_MAPPING[k] + AUTO_QUANTIZATION_CONFIG_MAPPING[k] = QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING[k] + + # Replace quantizer + transformers_replaced_quantizer_mapping[k] = AUTO_QUANTIZER_MAPPING[k] + AUTO_QUANTIZER_MAPPING[k] = QEFF_AUTO_QUANTIZER_MAPPING[k] + + # Call the function for loading quantized models here + out = func(*args, **kwargs) + + # Put back quantization config and quantizer + for k in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.keys(): + AUTO_QUANTIZATION_CONFIG_MAPPING[k] = transformers_replaced_quantization_config_mapping[k] + AUTO_QUANTIZER_MAPPING[k] = transformers_replaced_quantizer_mapping[k] + + return out + + return wrapper diff --git a/QEfficient/transformers/quantizers/awq.py b/QEfficient/transformers/quantizers/awq.py new file mode 100644 index 00000000..a36c4631 --- /dev/null +++ b/QEfficient/transformers/quantizers/awq.py @@ -0,0 +1,81 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn + +from QEfficient.transformers.quantizers.quantizer_utils import dequantize_gemm + + +class WQLinear_GEMM(nn.Module): + def __init__(self, bits, group_size, in_features, out_features, bias): + super().__init__() + + if bits != 4: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = in_features + self.out_features = out_features + self.bits = bits + self.group_size = group_size if group_size != -1 else in_features + + # quick sanity check (make sure alignment) + if self.in_features % self.group_size != 0: + raise ValueError( + f"in_features should be perfectly divisible by group_size, got in_features = {self.in_features}, group_size = {self.group_size} while initializing WQLinear_GEMM module" + ) + if out_features % (32 // self.bits) != 0: + raise ValueError( + f"out_features must be perfectly divisible by number of weights packed into int32 value i.e. 8, got out_features={self.out_features}" + ) + + # For compatibility with QuantLinearORT + self.g_idx = torch.tensor([i // group_size for i in range(in_features)], dtype=torch.int32) + self.register_buffer( + "qweight", + torch.zeros( + (in_features, out_features // (32 // self.bits)), + dtype=torch.int32, + ), + ) + self.register_buffer( + "qzeros", + torch.zeros( + (in_features // self.group_size, out_features // (32 // self.bits)), + dtype=torch.int32, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + (in_features // self.group_size, out_features), + dtype=torch.float16, + ), + ) + if bias: + self.register_buffer( + "bias", + torch.zeros( + (out_features), + dtype=torch.float16, + ), + ) + else: + self.bias = None + + def forward(self, x): + # Only Inference supported + with torch.no_grad(): + out_shape = x.shape[:-1] + (self.out_features,) + + out = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size) + out = torch.matmul(x.float(), out.float()) + + out = out + self.bias if self.bias is not None else out + out = out.reshape(out_shape) + + return out diff --git a/QEfficient/transformers/quantizers/gptq.py b/QEfficient/transformers/quantizers/gptq.py new file mode 100644 index 00000000..f0c4bedb --- /dev/null +++ b/QEfficient/transformers/quantizers/gptq.py @@ -0,0 +1,80 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import math + +import torch +from torch import nn + +from QEfficient.transformers.quantizers.quantizer_utils import dequantize_gptq + + +class QuantLinearGPTQ(nn.Module): + """ + A quantized linear layer using GPTQ (Generalized Post-Training Quantization). + This class supports only 4-bit quantization and is compatible with QuantLinearORT. + + Research paper link- GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers (https://arxiv.org/abs/2210.17323) + + Attributes: + in_features (int): The number of input features. + out_features (int): The number of output features. + bits (int): The number of bits used for quantization (must be 4). + act_order (None or bool): The activation order. + orig_fp_weight (None or torch.Tensor): The original floating-point weights. + maxq (int): The maximum quantization value. + group_size (int): The group size for quantization. + pack_mode (str): The packing mode, set to "GPTQ". + qweight (torch.Tensor): The quantized weight tensor. + qzeros (torch.Tensor): The quantized zeros tensor. + scales (torch.Tensor): The scales tensor. + g_idx (torch.Tensor): The group index tensor. + bias (torch.Tensor or None): The bias tensor, if applicable. + """ + + def __init__(self, bits, group_size, in_features, out_features, bias): + super().__init__() + if bits != 4: + raise NotImplementedError("Only 4 bits are supported.") + self.in_features = in_features + self.out_features = out_features + self.bits = bits + self.act_order = None + self.orig_fp_weight = None + self.maxq = 2**self.bits - 1 + self.group_size = group_size if group_size != -1 else in_features + self.pack_mode = "GPTQ" + + # For compatibility with QuantLinearORT + self.register_buffer( + "qweight", + torch.zeros((in_features // 32 * self.bits, out_features), dtype=torch.int32), + ) + self.register_buffer( + "qzeros", + torch.zeros((math.ceil(in_features / self.group_size), out_features // 32 * self.bits), dtype=torch.int32), + ) + self.register_buffer( + "scales", + torch.zeros((math.ceil(in_features / self.group_size), out_features), dtype=torch.float16), + ) + self.g_idx = torch.tensor([i // group_size for i in range(in_features)], dtype=torch.int32) + if bias: + self.register_buffer( + "bias", + torch.zeros((out_features), dtype=torch.float16), + ) + else: + self.bias = None + + def forward(self, x): + # Only Inference supported + out, _, _ = dequantize_gptq(self.qweight.T, self.qzeros, self.scales, self.bits, self.g_idx) + out = torch.matmul(x.float(), out.float()) + out = out + self.bias if self.bias is not None else out + + return out diff --git a/QEfficient/transformers/quantizers/quant_transforms.py b/QEfficient/transformers/quantizers/quant_transforms.py new file mode 100644 index 00000000..b20d8335 --- /dev/null +++ b/QEfficient/transformers/quantizers/quant_transforms.py @@ -0,0 +1,100 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +from torch import nn + +from QEfficient.base.pytorch_transforms import ModuleMutatorTransform +from QEfficient.customop.matmulnbits import QuantLinearORT +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM +from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ +from QEfficient.transformers.quantizers.quantizer_utils import dequantize_gptq, unpack_weights + + +class AwqToMatmulNbitsTransform(ModuleMutatorTransform): + _match_class = WQLinear_GEMM + + @staticmethod + def unpack_and_dequantize_awq(qweight, qzeros, scales, bits, group_size): + # Unpack the qweight and qzeros tensors + scales, int_weight, int_zeros = unpack_weights(qweight, qzeros, scales, bits, "awq") + + # fp16 weights + scales_expand = scales.repeat_interleave(group_size, dim=0) + int_zeros_expand = int_zeros.repeat_interleave(group_size, dim=0) + int_weight = (int_weight - int_zeros_expand) * scales_expand + + return int_weight.T, scales, int_zeros.to(torch.int32) + + @classmethod + def mutate(cls, original_module: nn.Module, parent_module: nn.Module): + fp16_weight, scales, zeros = cls.unpack_and_dequantize_awq( + original_module.qweight, + original_module.qzeros, + original_module.scales, + original_module.bits, + original_module.group_size, + ) + + original_module.weight = fp16_weight + new_module = QuantLinearORT( + original_module.bits, + original_module.group_size, + original_module.in_features, + original_module.out_features, + original_module.bias is not None, + ) + new_module.bias = original_module.bias if original_module.bias is not None else None + new_module.pack(original_module, scales.T, zeros.T, original_module.g_idx) + return new_module + + +class GPTQToMatmulNbitsTransform(ModuleMutatorTransform): + """ + A transformation class that mutates a ``QuantLinearGPTQ`` module to a ``QuantLinearORT`` + module by unpacking and dequantizing the quantized weights. + """ + + _match_class = QuantLinearGPTQ + + @staticmethod + def unpack_and_dequantize_gptq(qweight, qzeros, scales, bits, g_idx): + # Unpack the qweight and qzeros tensors + int_weight, scales, int_zeros = dequantize_gptq(qweight.T, qzeros, scales, bits, g_idx) + return int_weight, scales, int_zeros.to(torch.int32) + + @classmethod + def mutate(cls, original_module: nn.Module, parent_module: nn.Module): + """ + ``Mutates`` the original ``QuantLinearGPTQ`` module to a ``QuantLinearORT`` module. + + Args: + original_module (nn.Module): The original ``QuantLinearGPTQ`` module. + parent_module (nn.Module): The parent module containing the original module. + + Returns: + :nn.Module: The new ``QuantLinearORT`` module with unpacked and dequantized weights. + """ + + fp16_weight, scales, zeros = cls.unpack_and_dequantize_gptq( + original_module.qweight, + original_module.qzeros, + original_module.scales, + original_module.bits, + original_module.g_idx, + ) + original_module.weight = fp16_weight.T + new_module = QuantLinearORT( + original_module.bits, + original_module.group_size, + original_module.in_features, + original_module.out_features, + original_module.bias is not None, + ) + new_module.bias = original_module.bias if original_module.bias is not None else None + new_module.pack(original_module, scales.T, zeros.T, original_module.g_idx) + return new_module diff --git a/QEfficient/transformers/quantizers/quantizer_awq.py b/QEfficient/transformers/quantizers/quantizer_awq.py new file mode 100644 index 00000000..86c04547 --- /dev/null +++ b/QEfficient/transformers/quantizers/quantizer_awq.py @@ -0,0 +1,84 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +from transformers.quantizers.quantizer_awq import AwqQuantizer +from transformers.utils.quantization_config import AwqBackendPackingMethod, AwqConfig, AWQLinearVersion + +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM +from QEfficient.transformers.quantizers.quantizer_utils import ( + get_keys_to_not_convert, + replace_linear_layer_with_target_layer, + replace_quantization_scales, +) +from QEfficient.utils.logging_utils import logger + + +class QEffAwqConfig(AwqConfig): + def post_init(self): + """ + Safety checker that arguments are correct + """ + + if self.backend not in [AwqBackendPackingMethod.AUTOAWQ]: + raise ValueError( + f"Only quantization backend {AwqBackendPackingMethod.AUTOAWQ} is supported - not recognized backend {self.backend}" + ) + + self.version = AWQLinearVersion.from_str(self.version) + if self.version not in [AWQLinearVersion.GEMM]: + raise ValueError( + f"Only {AWQLinearVersion.GEMM} version in supported - not recognized version {self.version}" + ) + + if self.do_fuse or self.fuse_max_seq_len is not None: + raise ValueError( + f"fused modules are not supported, got do_fuse={self.do_fuse}, fuse_max_seq_len={self.fuse_max_seq_len}" + ) + + if self.bits != 4: + raise ValueError(f"Only 4-bit AWQ quantization is supported, got bits={self.bits}") + + +class QEffAwqQuantizer(AwqQuantizer): + target_cls = WQLinear_GEMM + + def __init__(self, quantization_config: QEffAwqConfig, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, device_map, **kwargs): + # No need to validate as we will always use pytorch CPU version. + return True + + @property + def is_trainable(self): + return False + + def update_torch_dtype(self, torch_dtype): + if torch_dtype not in [None, torch.float32]: + logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") + return None + + def _process_model_before_weight_loading(self, model, **kwargs): + self.modules_to_not_convert = get_keys_to_not_convert(model) + + if self.quantization_config.modules_to_not_convert is not None: + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + + model, has_been_replaced = replace_linear_layer_with_target_layer( + model, + target_cls=self.target_cls, + quantization_config=self.quantization_config, + modules_to_not_convert=self.modules_to_not_convert, + ) + + model = replace_quantization_scales(model, model.config.model_type) + if not has_been_replaced: + logger.warning( + "You are loading an AWQ model but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is a bug." + ) diff --git a/QEfficient/transformers/quantizers/quantizer_gptq.py b/QEfficient/transformers/quantizers/quantizer_gptq.py new file mode 100644 index 00000000..76dfe371 --- /dev/null +++ b/QEfficient/transformers/quantizers/quantizer_gptq.py @@ -0,0 +1,151 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +from transformers.quantizers.quantizer_gptq import HfQuantizer +from transformers.utils.quantization_config import GPTQConfig + +from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ +from QEfficient.transformers.quantizers.quantizer_utils import ( + get_keys_to_not_convert, + repack_zeros, + replace_linear_layer_with_target_layer, +) +from QEfficient.utils.logging_utils import logger + + +class QEffGPTQConfig(GPTQConfig): + """ + Configuration class for QEffGPTQ, extending GPTQConfig. + This class includes a post-initialization safety checker to ensure that the configuration arguments are correct. + """ + + def post_init(self): + r""" + Safety checker that arguments are correct. + """ + if self.bits != 4: + raise ValueError(f"Only 4-bit quantization is supported, got bits={self.bits}") + if self.desc_act: + raise ValueError("Only GPTQ model without descending activation size supported.") + if self.group_size != -1 and self.group_size <= 0: + raise ValueError("group_size must be greater than 0 or equal to -1") + if not (0 < self.damp_percent < 1): + raise ValueError("damp_percent must be between 0 and 1.") + + +class QEffGPTQQuantizer(HfQuantizer): + """ + Quantizer class for QEffGPTQ, extending HfQuantizer. + This class handles the initialization, environment validation, dtype updating, and model processing for quantization. + """ + + target_cls = QuantLinearGPTQ + + def __init__(self, quantization_config: QEffGPTQConfig, **kwargs): + """ + Initializes the QEffGPTQQuantizer with the given quantization configuration. + + Args: + quantization_config (QEffGPTQConfig): The quantization configuration. + **kwargs: Additional keyword arguments. + """ + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, device_map, **kwargs): + """ + Validates the environment for quantization. + + Args: + device_map (dict): The device map for the model. + **kwargs: Additional keyword arguments. + + Returns: + :bool: True if the environment is valid, False otherwise. + """ + return True + + def update_torch_dtype(self, torch_dtype): + """ + Updates the torch data type for quantization. + + Args: + torch_dtype (torch.dtype): The requested torch data type. + + Returns: + :torch.dtype: The updated torch data type. + """ + if torch_dtype not in [None, torch.float32]: + logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") + return None + + def _process_model_before_weight_loading(self, model, **kwargs): + """ + Processes the model before loading weights, ensuring it is suitable for quantization. + + Args: + model (torch.nn.Module): The model to process. + **kwargs: Additional keyword arguments. + + Returns: + :torch.nn.Module: The processed model. + """ + if model.__class__.main_input_name != "input_ids": + raise RuntimeError("We can only quantize pure text model.") + if not self.pre_quantized: + raise RuntimeError("Model is not quantized") + + self.modules_to_not_convert = get_keys_to_not_convert(model) + + model, has_been_replaced = replace_linear_layer_with_target_layer( + model, + target_cls=self.target_cls, + quantization_config=self.quantization_config, + modules_to_not_convert=self.modules_to_not_convert, + ) + if not has_been_replaced: + logger.warning( + "You are loading a GPTQ model but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on GitHub if you think this is a bug." + ) + return model + + def _process_model_after_weight_loading(self, model, **kwargs): + """ + Processes the model after loading weights, repacking quantization layers. + + Args: + model (torch.nn.Module): The model to process. + **kwargs: Additional keyword arguments. + + Returns: + :torch.nn.Module: The processed model. + """ + for name, module in model.named_modules(): + if isinstance(module, QuantLinearGPTQ): + izeros = repack_zeros(module.qzeros, module.bits) + module.qzeros = izeros + + @property + def is_trainable(self): + """ + Indicates if the quantizer is trainable. + + Returns: + :bool: False, indicating the quantizer is not trainable. + """ + return False + + @property + def is_serializable(self): + """ + Indicates if the quantizer is serializable. + + Returns: + :bool: True, indicating the quantizer is serializable. + """ + return True diff --git a/QEfficient/transformers/quantizers/quantizer_utils.py b/QEfficient/transformers/quantizers/quantizer_utils.py new file mode 100644 index 00000000..755898cf --- /dev/null +++ b/QEfficient/transformers/quantizers/quantizer_utils.py @@ -0,0 +1,380 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import copy + +import torch +from torch import nn +from transformers.integrations.awq import AWQ_SCALES_MAPPINGS + + +class ScaledActivation(nn.Module): + """ + A wrapper class for activation modules that scales the output by a specified factor. + + Args: + module (nn.Module): The activation module to wrap. + scales (torch.Tensor): The scaling factors. + + Attributes: + act (nn.Module): The activation module. + scales (nn.Parameter): The scaling factors. + """ + + def __init__(self, module, scales): + super().__init__() + self.act = module + self.scales = nn.Parameter(scales.data) + + def forward(self, x): + return self.act(x) / self.scales.view(1, 1, -1).to(x.device) + + +def get_keys_to_not_convert(model): + """ + Identifies and returns the names of parameters that should not be converted to a different precision. + + Args: + model (nn.Module): The model to analyze. + + Returns: + :list: A list of parameter names that should remain in full precision. + """ + # Create a copy of the model and tie the weights, then + # check if it contains tied weights + tied_model = copy.deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + tied_model.tie_weights() + + tied_params = find_tied_parameters(tied_model) + # For compatibility with Accelerate < 0.18 + if isinstance(tied_params, dict): + tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys()) + else: + tied_keys = sum(tied_params, []) + has_tied_params = len(tied_keys) > 0 + + # If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision + if not has_tied_params: + output_emb = model.get_output_embeddings() + if output_emb is not None: + list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)] + return list_last_module + + # otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision + list_modules = list(model.named_parameters()) + list_last_module = [list_modules[-1][0]] + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = list(set(tied_keys)) + list(intersection) + + # remove ".weight" from the keys + names_to_remove = [".weight", ".bias"] + filtered_module_names = [] + for name in list_untouched: + for name_to_remove in names_to_remove: + if name_to_remove in name: + name = name.replace(name_to_remove, "") + filtered_module_names.append(name) + + return filtered_module_names + + +def find_tied_parameters(model: nn.Module, **kwargs): + """ + Recursively finds and returns tied parameters within a given model. + + Args: + model (nn.Module): The model to search within. + **kwargs: Additional keyword arguments for internal use. + + Returns: + :list: A list of lists, where each sublist contains the names of tied parameters. + """ + # Initialize result and named_parameters before recursing. + named_parameters = kwargs.get("named_parameters", None) + prefix = kwargs.get("prefix", "") + result = kwargs.get("result", {}) + + if named_parameters is None: + named_parameters = {n: p for n, p in model.named_parameters()} + else: + # A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters` + # of the submodule it belongs to. So while recursing we track the names that are not in the initial + # `named_parameters`. + for name, parameter in model.named_parameters(): + full_name = name if prefix == "" else f"{prefix}.{name}" + if full_name not in named_parameters: + # When we find one, it has to be one of the existing parameters. + for new_name, new_param in named_parameters.items(): + if new_param is parameter: + if new_name not in result: + result[new_name] = [] + result[new_name].append(full_name) + + # Once we have treated direct parameters, we move to the child modules. + for name, child in model.named_children(): + child_name = name if prefix == "" else f"{prefix}.{name}" + find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result) + + return [sorted([weight] + list(set(tied))) for weight, tied in result.items()] + + +def replace_linear_layer_with_target_layer( + model: torch.nn.Module, + target_cls, + quantization_config=None, + modules_to_not_convert=None, + current_key_name=None, + has_been_replaced=False, +): + """ + Replaces all nn.Linear layers in the model with a specified target class, except for specified modules. + + Args: + model (torch.nn.Module): The model containing the layers to be replaced. + target_cls (type): The target class to replace nn.Linear layers with. + quantization_config (object, optional): Configuration object for quantization. + modules_to_not_convert (list, optional): List of module names to exclude from replacement. + current_key_name (list, optional): List of current key names for recursion. + has_been_replaced (bool, optional): Flag indicating if any layer has been replaced. + + Returns: + :tuple: The modified model and a flag indicating if any layer has been replaced. + """ + if modules_to_not_convert is None: + modules_to_not_convert = [] + + # target_cls = WQLinear_GEMM + + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + in_features = module.in_features + out_features = module.out_features + + model._modules[name] = target_cls( + bits=quantization_config.bits, + group_size=quantization_config.group_size, + in_features=in_features, + out_features=out_features, + bias=module.bias is not None, + # dev=module.weight.device, + ) + has_been_replaced = True + + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children())) > 0: + _, has_been_replaced = replace_linear_layer_with_target_layer( + module, + target_cls, + modules_to_not_convert=modules_to_not_convert, + current_key_name=current_key_name, + quantization_config=quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def replace_quantization_scales(model, model_type): + """ + Replaces the quantization scales in the model based on the specified model type. + + Args: + model (torch.nn.Module): The model containing the layers to be modified. + model_type (str): The type of the model to determine the scale mappings. + + Returns: + :torch.nn.Module: The modified model with updated quantization scales. + """ + if model_type not in AWQ_SCALES_MAPPINGS: + return model + for name, module in model.named_children(): + act_name = AWQ_SCALES_MAPPINGS[model_type]["act"] + layer_before_act_name = AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"] + if name == act_name and hasattr(model, layer_before_act_name): + layer_before_act = getattr(model, AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"]) + size = layer_before_act.out_features + scale_like = torch.ones(size) + model._modules[name] = ScaledActivation(module, scale_like) + replace_quantization_scales(module, model_type) + return model + + +def reverse_awq_order(int_weights: torch.Tensor, int_zeros: torch.Tensor, bits: int): + """ + Reverses the order of the AWQ (Adaptive Weight Quantization) tensors. + + Args: + int_weights (torch.Tensor): The integer weight tensor. + int_zeros (torch.Tensor): The integer zeros tensor. + bits (int): The number of bits used for quantization. + + Returns: + :tuple: The reversed integer weight and zeros tensors. + """ + reverse_order_tensor = torch.arange( + int_weights.shape[-1], + dtype=torch.int32, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, [0, 4, 1, 5, 2, 6, 3, 7]] + reverse_order_tensor = reverse_order_tensor.view(-1) + + int_zeros = int_zeros[:, reverse_order_tensor] + int_weights = int_weights[:, reverse_order_tensor] + + return int_weights, int_zeros + + +def unpack_weights_and_zeros(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int, quant: str): + """ + Unpacks the quantized weights and zeros tensors based on the specified bit width and quantization type. + + Args: + qweight (torch.Tensor): The quantized weight tensor. + qzeros (torch.Tensor): The quantized zeros tensor. + bits (int): The number of bits used for quantization. + quant (str): The quantization type ("awq" or other). + + Returns: + :tuple: A tuple containing the unpacked integer weight and zeros tensors. + """ + + shifts = torch.arange(0, 32, bits) + + # unpacking weights column-wise + int_weights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 # smallest dtype available + ) + int_weights = int_weights.reshape(int_weights.shape[0], -1) + + # unpacking zeros column-wise + int_zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( + torch.int8 # smallest dtype available + ) + int_zeros = int_zeros.reshape(int_zeros.shape[0], -1) + + if quant == "awq": + return reverse_awq_order(int_weights, int_zeros, bits) + + return int_weights, int_zeros + + +def dequantize_gemm(qweight, qzeros, scales, bits, group_size): + """ + Dequantizes the GEMM (General Matrix Multiply) quantized weights and zeros. + + Args: + qweight (torch.Tensor): The quantized weight tensor. + qzeros (torch.Tensor): The quantized zeros tensor. + scales (torch.Tensor): The scales tensor. + bits (int): The number of bits used for quantization. + group_size (int): The group size for quantization. + + Returns: + :torch.Tensor: The dequantized weight tensor. + """ + # Unpack the qweight and qzeros tensors + scales, int_weight, int_zeros = unpack_weights(qweight, qzeros, scales, bits, "awq") + + # fp16 weights + scales = scales.repeat_interleave(group_size, dim=0) + int_zeros = int_zeros.repeat_interleave(group_size, dim=0) + + int_weight = (int_weight - int_zeros) * scales + + return int_weight + + +def dequantize_gptq(qweight, qzeros, scales, bits, g_idx): + """ + Dequantizes the ```GPTQ (Generalized Post-Training Quantization)``` quantized weights and zeros. + + Args: + qweight (torch.Tensor): The quantized weight tensor. + qzeros (torch.Tensor): The quantized zeros tensor. + scales (torch.Tensor): The scales tensor. + bits (int): The number of bits used for quantization. + g_idx (torch.Tensor): The group index tensor. + + Returns: + :tuple: A tuple containing the dequantized weight tensor, scales tensor, and zeros tensor. + """ + scales, int_weight, int_zeros = unpack_weights(qweight, qzeros, scales, bits, "gptq") + scales = scales.view(-1, 1, scales.size(-1)) + scales = scales.view(scales.shape[0], -1) + scale_zeros = int_zeros * scales + scale_mat = scales[g_idx] + scale_zeros_mat = scale_zeros[g_idx] + int_weight = int_weight.T * scale_mat - scale_zeros_mat.float() + + return int_weight, scales, int_zeros + + +def unpack_weights(qweight, qzeros, scales, bits, quant): + """ + Unpacks the quantized weights and zeros tensors and performs overflow checks. + + Args: + qweight (torch.Tensor): The quantized weight tensor. + qzeros (torch.Tensor): The quantized zeros tensor. + scales (torch.Tensor): The scales tensor. + bits (int): The number of bits used for quantization. + quant (str): The quantization type ("awq" or "gptq"). + + Returns: + :tuple: A tuple containing the scales tensor, unpacked integer weight tensor, and unpacked integer zeros tensor. + """ + int_weight, int_zeros = unpack_weights_and_zeros(qweight, qzeros, bits, quant) + + # overflow checks + int_weight = torch.bitwise_and(int_weight, (2**bits) - 1) + int_zeros = torch.bitwise_and(int_zeros, (2**bits) - 1) + + return scales, int_weight, int_zeros + + +def repack_zeros(qzeros, bits): + """ + Unpacks the quantized zeros tensor. + + Args: + qzeros (torch.Tensor): The quantized zeros tensor. + bits (int): The number of bits used for quantization. + + Returns: + :torch.Tensor: The unpacked integer zeros tensor. + """ + + shifts = torch.arange(0, 32, bits, dtype=torch.int32, device=qzeros.device).unsqueeze(0) + izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( + torch.int32 # smallest dtype available + ) + izeros = torch.bitwise_and(izeros[0], (2**bits) - 1).view(-1, 1, izeros[0].size(1) * izeros[0].size(2)) + izeros = izeros.view(izeros.shape[0], -1) + izeros += 1 + qzeros.mul_(0) + if qzeros.shape[0] == izeros.shape[0]: + qzeros = qzeros.T + izeros = izeros.T + compress_ratio = 32 // bits + i = 0 + row = 0 + while row < qzeros.shape[0]: + for j in range(i, i + compress_ratio): + qzeros[row:] |= izeros[j::compress_ratio] << (bits * (j - i)) + break + qzeros = qzeros.T + return qzeros diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index d20a4beb..8a9e3d1c 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -199,7 +199,7 @@ def check_and_assign_cache_dir(local_model_dir, cache_dir): def padding_check_and_fix(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]) -> None: """ - Checks and fixes tokenizer paddding side and pad_token_id viability. + Checks and fixes tokenizer padding side and pad_token_id viability. -------- tokenizer: `Union[PreTrainedTokenizer, PreTrainedTokenizerFast]` - Pass model tokenizer to check and fix. @@ -251,7 +251,7 @@ def get_padding_shape_from_config(config, batch_size, seq_len): n_heads = config.num_attention_heads d_head = config.hidden_size // config.num_attention_heads else: - raise ValueError("Invalid model configuration: n_head/n_heads or num_key_value_heads not found.") + raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.") padding_shape = [batch_size, n_heads, seq_len, d_head] return padding_shape diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 79ccc292..3facc515 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -36,7 +36,7 @@ pipeline sh ''' . preflight_qeff/bin/activate export TOKENIZERS_PARALLELISM=false - pytest tests --ignore tests/cloud -n 4 --junitxml=tests/tests_log1.xml + pytest tests --ignore tests/cloud --junitxml=tests/tests_log1.xml pytest tests/cloud --junitxml=tests/tests_log2.xml junitparser merge tests/tests_log1.xml tests/tests_log2.xml tests/tests_log.xml deactivate diff --git a/tests/base/test_pytorch_transforms.py b/tests/base/test_pytorch_transforms.py index 981977c7..764bb887 100644 --- a/tests/base/test_pytorch_transforms.py +++ b/tests/base/test_pytorch_transforms.py @@ -9,7 +9,20 @@ import torch from torch import nn -from QEfficient.base.pytorch_transforms import ModuleMappingTransform +from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMutatorTransform + + +class TestModel(nn.Module): + def __init__(self): + super().__init__() + + self.a = nn.Linear(32, 64) + self.b = nn.Linear(64, 32) + + def forward(self, x): + x = self.a(x) + x = self.b(x) + return x def test_module_mapping_transform(): @@ -19,24 +32,35 @@ def test_module_mapping_transform(): class TestTransform(ModuleMappingTransform): _module_mapping = {nn.Linear: nn.Identity} - class TestModel(nn.Module): - def __init__(self): - super().__init__() + model = TestModel() + x = torch.rand(1, 32) + y1 = model(x) + assert torch.any(y1 != x) + + model, transformed = TestTransform.apply(model) + assert transformed + y2 = model(x) + assert torch.all(y2 == x) - self.a = nn.Linear(32, 64) - self.b = nn.Linear(64, 32) - def forward(self, x): - x = self.a(x) - x = self.b(x) - return x +def test_module_mutator_transform(): + with pytest.raises(TypeError): + ModuleMutatorTransform() + + class TestTransform(ModuleMutatorTransform): + _match_class = nn.Linear + + @classmethod + def mutate(cls, original_module: nn.Module, parent_module: nn.Module): + return nn.Identity() model = TestModel() + prev_ids = [id(model.a), id(model.b)] x = torch.rand(1, 32) y1 = model(x) assert torch.any(y1 != x) - model, transformed = TestTransform.apply(model) assert transformed + assert not ([id(model.a), id(model.b)] == prev_ids) y2 = model(x) assert torch.all(y2 == x) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 9b630d0c..46612ab4 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -15,7 +15,7 @@ from QEfficient.utils.constants import Constants from QEfficient.utils.device_utils import get_available_device_id from QEfficient.utils.run_utils import ApiRunner -from tests.utils import load_pytorch_model +from tests.utils import load_pytorch_model, replace_transformers_quantizers test_models = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", @@ -30,6 +30,9 @@ "wtang06/mpt-125m-c4", "hakurei/gpt-j-random-tinier", "mistralai/Mixtral-8x7B-Instruct-v0.1", + "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", # AWQ model + # "TheBloke/Llama-2-7B-Chat-GPTQ", # GPTQ model -> Enable once GPTQ+ROPE + # issue is resolved ] @@ -40,6 +43,7 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): Test function to validate the model before and after KV changes on Pytorch :param model_name: Name of model. """ + replace_transformers_quantizers() if model_name == "microsoft/Phi-3-mini-4k-instruct": n_layer = 2 # test only 2 layer models else: diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index a458ebeb..1775871a 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -9,7 +9,11 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM +from QEfficient.customop.matmulnbits import QuantLinearORT from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform +from QEfficient.transformers.quantizers.awq import WQLinear_GEMM +from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ +from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform from QEfficient.utils._utils import get_padding_shape_from_config from QEfficient.utils.logging_utils import logger @@ -189,3 +193,56 @@ def test_kv_cache_transform( input_len=8, logits_tolerance=logits_tolerance, ) + + +@pytest.mark.parametrize("in_features", [2048, 4096]) +@pytest.mark.parametrize("out_features", [2048, 4096]) +def test_awq_to_matmulnbits_transform(in_features, out_features): + wqlinear = WQLinear_GEMM(bits=4, group_size=128, in_features=in_features, out_features=out_features, bias=False) + + wqlinear.qweight = torch.randint( + low=-(2**31), high=2**31 - 1, size=(in_features, out_features // 8), dtype=torch.int32 + ) + wqlinear.qzeros = torch.randint( + low=-(2**31), high=2**31 - 1, size=(in_features // wqlinear.group_size, out_features // 8), dtype=torch.int32 + ) + wqlinear.scales = torch.rand(in_features // wqlinear.group_size, out_features, dtype=torch.float32) + + rand_data = torch.rand(4, in_features) + old_out = wqlinear(rand_data) + new_module, transformed = AwqToMatmulNbitsTransform.apply(wqlinear) + assert transformed + new_out = new_module(rand_data) + assert isinstance(new_module, QuantLinearORT) + assert compare_original_vs_kv_model_pt_outputs( + old_out, new_out, tolerance=1e-8 + ), "Test failed because MAE is greater than tolerance" + + +@pytest.mark.parametrize("in_features", [4096, 4096]) +@pytest.mark.parametrize("out_features", [4096, 4096]) +def test_gptq_to_matmulnbits_transform(in_features, out_features): + quant_linear_gptq = QuantLinearGPTQ( + bits=4, group_size=128, in_features=in_features, out_features=out_features, bias=False + ) + quant_linear_gptq.qweight = torch.randint( + low=-(2**31), high=2**31 - 1, size=(in_features // 8, out_features), dtype=torch.int32 + ) + quant_linear_gptq.qzeros = torch.randint( + low=-(2**31), + high=2**31 - 1, + size=(in_features // quant_linear_gptq.group_size, out_features // 8), + dtype=torch.int32, + ) + quant_linear_gptq.scales = torch.rand( + in_features // quant_linear_gptq.group_size, out_features, dtype=torch.float32 + ) + rand_data = torch.rand(4, in_features) + old_out = quant_linear_gptq(rand_data) + new_module, transformed = GPTQToMatmulNbitsTransform.apply(quant_linear_gptq) + assert transformed + new_out = new_module(rand_data) + assert isinstance(new_module, QuantLinearORT) + assert compare_original_vs_kv_model_pt_outputs( + old_out, new_out, tolerance=1e-4 + ), "Test failed because MAE is greater than tolerance" diff --git a/tests/utils.py b/tests/utils.py index 5f743396..c8d830c5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,7 +9,10 @@ import unittest from transformers import AutoModelForCausalLM +from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING, AUTO_QUANTIZER_MAPPING +from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer +from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer from QEfficient.utils import hf_download from QEfficient.utils.device_utils import is_multi_qranium_setup_available @@ -43,8 +46,19 @@ def load_pytorch_model(model_config): ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], ) model_hf = AutoModelForCausalLM.from_pretrained( - model_path, use_cache=True, num_hidden_layers=model_config["n_layer"], attn_implementation="eager" + model_path, + use_cache=True, + num_hidden_layers=model_config["n_layer"], + attn_implementation="eager", + low_cpu_mem_usage=False, ) # Run models for single layers only params = sum(p.numel() for p in model_hf.parameters()) model_hf.eval() return model_hf, params + + +def replace_transformers_quantizers(): + AUTO_QUANTIZER_MAPPING.update({"awq": QEffAwqQuantizer}) + AUTO_QUANTIZATION_CONFIG_MAPPING.update({"awq": QEffAwqConfig}) + AUTO_QUANTIZER_MAPPING.update({"gptq": QEffGPTQQuantizer}) + AUTO_QUANTIZATION_CONFIG_MAPPING.update({"gptq": QEffGPTQConfig})