Skip to content

Commit

Permalink
AWQ+GPTQ (quic#101)
Browse files Browse the repository at this point in the history
* Awq feature (quic#100)

* added preprocess layer before loading quantized awq weights

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* added onnx export

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* added ScaledActivation class

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* refactoring the code to right places and added one single test for now

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* cleaned code

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* added proper tests, added decorator for updating quantizers, cleaned code

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* fixed CLI

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* added auto file for decorator

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

---------

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* bugfix for tests

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* fixed tests for AWQ model

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* Adding support for GPTQ models (quic#103)

* Adding support for gptq models

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* Code cleaning and formating

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* ruff format and fixed some bug

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* Added tests for gptq

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* Bug-fix-1

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* fixed bugs-2

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* fixed bug-3

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* Added docstring

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* Addressed comments

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* Addressed comments

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* fixed bugs-3

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* ruff check and format

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

* Addressed comments-3

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>

---------

Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>
Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* added liscence at top for missing file

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* added export_and_compile and fixed bugs

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* removed GPTQ test

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

* removed threading from pytest

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>

---------

Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>
Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>
Co-authored-by: Amit Raj <168538872+quic-amitraj@users.noreply.github.com>
  • Loading branch information
ochougul and quic-amitraj authored Sep 13, 2024
1 parent 0ef6829 commit afb4645
Show file tree
Hide file tree
Showing 19 changed files with 1,384 additions and 60 deletions.
11 changes: 1 addition & 10 deletions QEfficient/base/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class QEFF_MODEL_TYPE(Enum):

CAUSALLM = "LLM"
DIFFUSION = "DIFFUSION"
AWQ = "AWQ"


MODEL_TYPE_TO_QEFF_AUTO_MODEL_MAP: Dict[QEFF_MODEL_TYPE, Type[QEFFBaseModel]] = {
Expand All @@ -56,15 +55,7 @@ def get_hf_model_type(hf_model_path: str) -> QEFF_MODEL_TYPE:
)

if config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING:
# FIXME: Add logic to handle if quantization config is stored in separate quant_config.json outside of config, also create a separate function for this and below lines
quant_config = getattr(config, "quantization_config", getattr(config, "quant_config", None))
if quant_config is not None:
if quant_config.get("quant_method", None) == "awq":
return QEFF_MODEL_TYPE.AWQ
else:
raise NotImplementedError(f"current model type is not yet supported {type(config)}")
else:
return QEFF_MODEL_TYPE.CAUSALLM
return QEFF_MODEL_TYPE.CAUSALLM
else:
raise NotImplementedError(f"model type {type(config)} is not yet supported")

Expand Down
32 changes: 32 additions & 0 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,35 @@ def register(cls, from_module: Type[nn.Module], to_module: Type[nn.Module]):
FlashAttention.register(LLamaAttention, LlamaFlashAttention)
"""
cls._module_mapping[from_module] = to_module


class ModuleMutatorTransform(PytorchTransform):
"""Serves as base class for any transform that mutates pytorch module in any way.
Mutate here mean, we initialize a new pytorch module object using info from original module and
replace original module with new module.
Raises:
NotImplementedError: Not supposed to use directly, Create a subclass and implement mutate method and assign a valid nn.Module class to _match_class variable.
"""

_match_class: nn.Module

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
for name, module in model.named_children():
if isinstance(module, cls._match_class):
setattr(model, name, cls.mutate(module, model))
transformed = True
else:
cls.apply(module)

if isinstance(model, cls._match_class):
model = cls.mutate(model, None)
transformed = True

return model, transformed

@classmethod
def mutate(cls, original_module: nn.Module, parent_module: nn.Module):
raise NotImplementedError("Please implement your own method by inheriting this class")
182 changes: 182 additions & 0 deletions QEfficient/customop/matmulnbits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import math

import torch
from torch import nn


class QuantLinearTorchFunction(torch.autograd.Function):
@staticmethod
def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features):
input_tuple = (x, qself_qweight, qself_scales, qself_qzeros)
input_tuple += (g_idx,) if g_idx is not None else ()
return g.op(
"com.microsoft::MatMulNBits",
*input_tuple,
outputs=1,
K_i=in_features,
N_i=out_features,
bits_i=bits,
block_size_i=group_size,
)

@staticmethod
def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features):
if torch.onnx.is_in_onnx_export():
return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype).float()
fp_weight = dequantize_blockwise_bits(
qself_qweight, qself_scales, qself_qzeros, bits, group_size, g_idx, in_features, out_features
)[0].float()

return torch.matmul(x.float(), fp_weight.T.float())


def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, group_size, g_idx, rows, cols):
if bits != 4:
raise ValueError("Only bits=4 is supported for executing quantized model")
if group_size != 128:
raise ValueError("Only group_size=128 is supported for executing quantized model")
expand_quant_value = (quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F
expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1)
aligned_scale = scale.reshape(*quant_values.shape[:-1], 1)
if zero_point.dtype == scale.dtype:
expand_zero_point = zero_point.reshape(*quant_values.shape[:-1], -1)
else:
expand_zero_point = (zero_point.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F
try:
expand_zero_point = expand_zero_point.reshape(*quant_values.shape[:-1], -1)
# FIXME: remove try-except
except RuntimeError:
expand_zero_point = expand_zero_point.reshape(quant_values.shape[0], -1, 1)
expand_zero_point = expand_zero_point[:, : quant_values.shape[1]]
if g_idx is not None and g_idx[:32].sum().item() != 0:
float_values = (
(expand_quant_value.reshape(expand_quant_value.shape[0], -1) - expand_zero_point[:, g_idx, 0])
* aligned_scale[:, g_idx, 0]
).to(scale.dtype)
else:
float_values = ((expand_quant_value - expand_zero_point) * aligned_scale).to(scale.dtype)
float_values = float_values.reshape(cols, -1)
if rows != float_values.shape[-1]:
float_values = float_values[:, :rows]
expand_zero_point = expand_zero_point[:, :rows]
if expand_zero_point.ndim == 3:
expand_zero_point = expand_zero_point.squeeze(-1)
if aligned_scale.ndim == 3:
aligned_scale = aligned_scale.squeeze(-1)

return float_values, expand_zero_point, aligned_scale


class QuantLinearORT(nn.Module):
def __init__(self, bits, group_size, in_features, out_features, bias):
super().__init__()
if bits not in [2, 3, 4, 5, 6, 7, 8]:
raise NotImplementedError("Only 2,4,5,6,7,8 bits are supported.")
self.in_features = in_features
self.out_features = out_features
self.bits = bits
self.group_size = group_size if group_size != -1 else in_features
self.act_order = None

q_rows = in_features // self.group_size
self.register_buffer(
"qweight",
torch.zeros((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8),
)
self.register_buffer(
"qzeros",
torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8),
)
self.register_buffer(
"scales", torch.zeros((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16)
)
self.register_buffer(
"g_idx", torch.tensor([i // self.group_size for i in range(in_features)], dtype=torch.int32)
)
if bias:
self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16))
else:
self.bias = None

def quant_weight(self, weight, scales, zeros, g_idx):
scale_zeros = zeros * scales
scale_mat = scales[g_idx]
scale_zeros_mat = scale_zeros[g_idx]
int_weight_T = torch.round(((weight + scale_zeros_mat) / scale_mat).float()).to(torch.int)
return int_weight_T

def pack_on_device(self, int_weight, int_zeros):
if self.bits != 4:
raise ValueError("only 4bit is supported by ONNXRUNTIME for now.")

# Order of groups
self.act_order = self.g_idx[: self.group_size // self.bits].sum().item() != 0

intzeros_pt = int_zeros.T if int_zeros.dtype == self.scales.dtype else int_zeros.T.byte()
scales_pt = self.scales.T.to(int_weight.device)
intweight_pt = int_weight.byte()

block_size = self.group_size
rows, cols = intweight_pt.shape
blob_size = block_size // 2
k_blocks = (rows + block_size - 1) // block_size
padded_rows = k_blocks * block_size
pad_len = padded_rows - rows
if pad_len > 0:
intweight_pt = torch.nn.functional.pad(intweight_pt, (0, 0, 0, pad_len), "constant", 0)
intzeros_pt = torch.nn.functional.pad(intzeros_pt, (0, intzeros_pt.shape[-1] & 1, 0, 0), "constant", 0)

# Pack zeros if they are not float
if int_zeros.dtype != self.scales.dtype:
intzeros_pt = (intzeros_pt[:, 0::2]) | (intzeros_pt[:, 1::2] << 4)
intzeros_pt = intzeros_pt.reshape(-1)

# Pack weights
intweight_pt_T = int_weight.T
intweight_pt_T = (intweight_pt_T[:, 0::2]) | (intweight_pt_T[:, 1::2] << 4)
intweight_pt_T = intweight_pt_T.reshape(cols, k_blocks, blob_size)

scales_pt = scales_pt.reshape(-1)

# Validation checks
if (self.qweight.shape != intweight_pt_T.shape) and (
self.qzeros.shape == intzeros_pt.shape or self.qzeros.dtype != intzeros_pt.dtype
):
raise RuntimeError("Something went wrong while packing the weights in QuantLinearORT module")

# Assign buffers
self.scales = scales_pt.float()
self.qweight = intweight_pt_T.byte() # Convert to uint8
if int_zeros.dtype != self.scales.dtype:
self.qzeros = intzeros_pt.byte() # Convert to uint8
else:
self.qzeros = intzeros_pt

def pack(self, linear, scales, zeros, g_idx=None):
layer_weight = linear.weight.data
self.scales = scales.T
self.g_idx = g_idx.clone()
int_weight = self.quant_weight(layer_weight.T, scales.T, zeros.T, g_idx)
return self.pack_on_device(int_weight, zeros.T)

def forward(self, inputs):
out = QuantLinearTorchFunction().apply(
inputs,
self.qweight,
self.scales,
self.qzeros,
self.g_idx if self.act_order else None,
self.bits,
self.group_size,
self.in_features,
self.out_features,
)
out = out + self.bias if self.bias is not None else out
return out
2 changes: 1 addition & 1 deletion QEfficient/exporter/export_hf_to_cloud_ai_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def qualcomm_efficient_converter(
model_kv = model_kv if model_kv.is_transformed else QEfficient.transform(model_kv) if kv else model_kv

if onnx_dir_path is None:
model_card_dir = os.path.join(QEFF_MODELS_DIR, str(model_name))
model_card_dir = os.path.join(QEFF_MODELS_DIR, str(model_kv.model_card_name))
onnx_dir_path = os.path.join(model_card_dir, "onnx")
os.makedirs(onnx_dir_path, exist_ok=True)

Expand Down
Loading

0 comments on commit afb4645

Please sign in to comment.