From 47658da2de2886cf463b374e16ccaab88912e719 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Wed, 24 Jul 2024 01:27:49 -0400 Subject: [PATCH 1/4] porting sq to autoround Signed-off-by: n1ck-guo --- auto_round/low_cpu_mem/utils.py | 6 +- auto_round/quantizer.py | 4 +- auto_round/smooth_quant/__init__.py | 16 + auto_round/smooth_quant/auto_alpha.py | 657 ++++++++++++++++++++++++ auto_round/smooth_quant/calibration.py | 103 ++++ auto_round/smooth_quant/graph_trace.py | 222 ++++++++ auto_round/smooth_quant/smooth_quant.py | 581 +++++++++++++++++++++ auto_round/smooth_quant/utils.py | 483 +++++++++++++++++ 8 files changed, 2068 insertions(+), 4 deletions(-) create mode 100644 auto_round/smooth_quant/__init__.py create mode 100644 auto_round/smooth_quant/auto_alpha.py create mode 100644 auto_round/smooth_quant/calibration.py create mode 100644 auto_round/smooth_quant/graph_trace.py create mode 100644 auto_round/smooth_quant/smooth_quant.py create mode 100644 auto_round/smooth_quant/utils.py diff --git a/auto_round/low_cpu_mem/utils.py b/auto_round/low_cpu_mem/utils.py index 6b02c2a5..c1c51d3b 100644 --- a/auto_round/low_cpu_mem/utils.py +++ b/auto_round/low_cpu_mem/utils.py @@ -387,10 +387,12 @@ def _layer_wise_to(module, name, device_or_dtype): return module.ori_to(device_or_dtype) elif len(module._modules) == 0: # skip method type - if len(module._parameters) == 0 or module.weight.device.type != 'meta': + if len(module._parameters) == 0: return module.ori_to(device_or_dtype) else: - for n, _ in module.named_parameters(): + for n, p in module.named_parameters(): + if p.device.type != 'meta': + continue param_name = name + "." + n value = load_value(empty_model, param_name, empty_model.path) dtype = None diff --git a/auto_round/quantizer.py b/auto_round/quantizer.py index 34aa864a..8d3c0c9a 100644 --- a/auto_round/quantizer.py +++ b/auto_round/quantizer.py @@ -149,7 +149,7 @@ def __init__(self, orig_layer, enable_minmax_tuning=True, device='cpu'): weight_dtype = torch.float32 orig_layer_weight = self.orig_layer.weight if not hasattr(self.orig_layer, 'get_weight') \ - else self.orig_layer.get_weight() + else self.orig_layer.get_weight().to(device) self.value = torch.nn.Parameter( reshape_tensor( torch.zeros(self.orig_layer.weight.shape, device=self.device, dtype=weight_dtype), @@ -285,7 +285,7 @@ def __init__(self, orig_layer, enable_minmax_tuning=True, device='cpu'): weight_dtype = torch.float32 self.device = device if hasattr(self.orig_layer, 'get_weight'): - self.weight_t = self.orig_layer.get_weight().t() + self.weight_t = self.orig_layer.get_weight().t().to(self.device) else: self.weight_t = self.orig_layer.weight.t() self.weight_t = self.weight_t.to(self.device) diff --git a/auto_round/smooth_quant/__init__.py b/auto_round/smooth_quant/__init__.py new file mode 100644 index 00000000..db2d9048 --- /dev/null +++ b/auto_round/smooth_quant/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/auto_round/smooth_quant/auto_alpha.py b/auto_round/smooth_quant/auto_alpha.py new file mode 100644 index 00000000..ecba1b2e --- /dev/null +++ b/auto_round/smooth_quant/auto_alpha.py @@ -0,0 +1,657 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import json + + +import torch +from .utils import logger + + +import numpy +from tqdm import tqdm + +from .calibration import Calibration +from .utils import * + + +@register_autotune("version1") +class AutoAlpha: + def __init__( + self, + model, + dataloader, + absorb_to_layer, + op_types, + device, + q_func, + example_inputs, + weight_clip=True, + alpha_min=0.3, + alpha_max=0.7, + alpha_step=0.1, + shared_criterion="mean", + init_alpha=0.5, + folding=False, + do_blockwise=False, + n_samples=32, + ): + """Initialize the AutoAlpha tuner with necessary parameters and components.""" + + self.model = model.to("cpu") + self.model.eval() + self.dataloader = dataloader + self.alpha_min = alpha_min + self.alpha_max = alpha_max + self.alpha_step = alpha_step + self.shared_criterion = shared_criterion + self.init_alpha = init_alpha + self.loss_type = "blockwise" if do_blockwise else "model_wise" + self.calib_sample_num = n_samples if n_samples else 32 + self.op_types = op_types + self.absorb_to_layer = absorb_to_layer + self.weight_scale_dict = {} + self.q_func = q_func + self.folding = folding + self.example_inputs = example_inputs + self.max_value_info = {} # to record max values for alpha tune + self.weight_clip = weight_clip[0] if isinstance(weight_clip, tuple) else weight_clip + self.input_maxes = {} + self.input_mins = {} + self.input_maxes_abs = {} + self.device = device + + def tune(self): + """The main entry of auto_alpha + :return: Optimal alpha values and scales based on user-defined recipes.""" + calib = Calibration(self.model, self.dataloader, self.q_func, self.device) + calib_iter = 100 + self.input_mins, self.input_maxes = calib.calibrate(calib_iter, self.op_types) + for key in self.input_mins.keys(): + self.input_maxes_abs[key] = torch.max(torch.abs(self.input_mins[key]), torch.abs(self.input_maxes[key])) + + if not self.folding: + diff_modules = set(self.absorb_to_layer.keys()).difference(self.input_mins.keys()) + for d in diff_modules: + del self.absorb_to_layer[d] + + scale_memo_use = 0 + for key in self.absorb_to_layer: + layer_name = self.absorb_to_layer[key][0] + input_max = self.input_maxes_abs[layer_name] + scale_memo_use += 4 * input_max.shape[0] * len(self.absorb_to_layer[key]) + alpha_space_len = (self.alpha_max - self.alpha_min) / self.alpha_step + 1 + scale_memo_use *= alpha_space_len + self._save_scale = enough_memo_store_scale(self.device, scale_memo_use) + + if self.loss_type == "blockwise": + self.block_names = self.get_blocks() + logger.info("Blockwise auto-tuning will be performed") + module_names = self._get_sq_layer_names() + block_names, self.block_to_module = self.block_names, {} + for block in block_names: + self.block_to_module[block] = [] + for module in module_names: + checked = False + for block in block_names: + if block + "." in module: + self.block_to_module[block].append(module) + checked = True + if not checked: + self.block_to_module[module] = [module] + self.block_names = list(self.block_to_module.keys()) + logger.info(f"Blockwise auto-tuning: {len(self.block_names)} blocks found") + logger.debug(f"Blockwise auto-tuning blocks info: {self.block_to_module}") + return self._auto_tune_alpha_blockwise() + else: + return self._auto_tune_alpha() + + def get_blocks(self): + """Obtain a list of blocks in block-wise tuning mode.""" + block_names = [] + for n, m in self.model.named_modules(): + if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: + for nn, mm in m.named_children(): + block_name = n + "." + nn + block_names.append(block_name) + break + return block_names + + def _add_blockwise_observer(self, block_modules): + """ + :param block_modules: the block modules which the observer will insert to + :return: + """ + self.blockwise_hook_handles = [] + for key in block_modules.keys(): + hook_func = self._save_blockwise_hook(key) + hook_handle = block_modules[key].register_forward_hook(hook_func) + self.blockwise_hook_handles.append(hook_handle) + + def _save_blockwise_hook(self, name): + """A forward hook to save inputs/outputs of a block + :param name: the block name + :return: A hook function.""" + + def save_blockwise_hook(module, inputs, outputs): + self.block_inputs[name] = inputs[0] + self.block_outputs[name] = outputs[0] + + return save_blockwise_hook + + def _get_all_hook_module_names(self): + """Obtain all the modules that could be hooked based on given op_types.""" + module_names = [] + for n, module in self.model.named_modules(): + if isinstance(module, tuple(self.op_types)): + module_names.append(n) + return module_names + + def _update_scales_for_auto(self, absorb_scales, weight_scales): + """Apply activation and weight scales to the model.""" + for key in self.absorb_to_layer.keys(): + layer_names = self.absorb_to_layer[key] + for layer_name in layer_names: + layer = get_module(self.model, layer_name) + input_scale = absorb_scales[key] + weight_scale = weight_scales[layer_name] + input_scale = reshape_scale_as_input(layer, input_scale) + weight_scale = reshape_scale_as_weight(layer, weight_scale) + layer.update_scale(input_scale, weight_scale) ##FIXME + + def _change_qdq_for_auto(self, enable=True): + """Change the option for qdq.""" + module_names = self._get_all_hook_module_names() + for name in module_names: + name = name.split(".orig_layer")[0] + module = get_module(self.model, name) + if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration + continue + if enable: + module.enable_quant() + else: + module.disable_quant() + + def _qdq_model_wrapper_for_auto(self, save_q_input=False): + """Wrapper all the module with qdq + :return:""" + module_names = self._get_all_hook_module_names() + self.to_unwrap_module_names = module_names + for name in module_names: + if name not in self.input_mins: # skip module if it's not used in calibration + continue + module = get_module(self.model, name) + new_module = WrapperLayer(module, self.input_mins[name], self.input_maxes[name], save_q_input=save_q_input) + set_module(self.model, name, new_module) + + def _qdq_model_unwrapper_for_auto(self): + """Unwrapper all the module with qdq + :return:""" + module_names = self.to_unwrap_module_names + for name in module_names: + module = get_module(self.model, name) + if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration + continue + set_module(self.model, name, module.orig_layer) + + def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5): + """Cal the adjust scales + :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer + :param input_maxes: The channel-wise input max info for layers + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict + :return:""" + absorb_to_input_maxes = {} + for key in absorb_to_layer.keys(): + layer_name = absorb_to_layer[key][0] + absorb_to_input_maxes[key] = input_maxes[layer_name] + + weight_scales_info = {} + absorb_scales_info = {} + for index, key in enumerate(absorb_to_layer.keys()): + alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha + if alpha_tmp < 0: + scale = torch.ones((1), device=self.device) + else: + input_max = absorb_to_input_maxes[key] + layer_names = absorb_to_layer[key] + weights = [] + for layer_name in layer_names: + weight = reshape_in_channel_to_last(layer_name, self.model) + weights.append(weight) + + weight_max_per_channel = torch.max(torch.abs(torch.cat(weights, dim=0)), dim=0)[0] + if self.weight_clip: + weight_max_per_channel = weight_max_per_channel.clamp(min=1e-5) + + if self._save_scale: + if key in self.weight_scale_dict and alpha_tmp in self.weight_scale_dict[key]: + scale = self.weight_scale_dict[key][alpha_tmp] + else: + scale = cal_scale(input_max, weights, alpha_tmp) + else: + scale = cal_scale(input_max, weights, alpha_tmp) + + absorb_scales_info[key] = 1.0 / scale + absorb_scales_info[key][scale == 0] = 0 + layer_names = absorb_to_layer[key] + for layer_name in layer_names: + ##self._scale_layer_weight(layer_name, scale) + weight_scales_info[layer_name] = scale + if self._save_scale: + if layer_name not in self.weight_scale_dict: + self.weight_scale_dict[layer_name] = {} + self.weight_scale_dict[layer_name][alpha_tmp] = scale + return absorb_scales_info, weight_scales_info + + def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0): + """Get the loss for auto tuning + :param output: Fp32 output for one layer + :param output_q: Quant output for one layer + :param loss_type: The type of loss + :param loss_alpha: Loss alpha i for mean scale error + :return: A tensor of the loss.""" + if len(output.shape) <= 2: + max_value = torch.max(torch.abs(output)) + else: + output = output.reshape(output.shape[0], -1) + output_q = output_q.reshape(output_q.shape[0], -1) + max_value = torch.max(torch.abs(output), dim=-1).values.unsqueeze(-1) + max_value = torch.clip(max_value, 1e-5) + output = output / max_value ##FIXME need copy not replace + output_q = output_q / max_value + if loss_type == "abs": + return torch.sum(torch.pow(torch.abs(output - output_q), 0.5)) + else: + return torch.sum((output - output_q) ** 2) + + def _get_sq_layer_names(self): + """Get all the layers that could be smooth quanted + :return: All the sq layer names.""" + ##TODO this may not fit for folding=False + module_names = [] + for key in self.absorb_to_layer: + module_names += self.absorb_to_layer[key] + return module_names + + def _get_best_alpha(self, absorb_to_layer, loss_alphas, shared_criterion): + """Obtain the optimal alpha values based on shared criterion and loss values recorded in auto-tuning step. + + :return: A dict of layerwise alpha values. + """ + + def dict_to_list(dic): + res = [] + for key in dic.keys(): + res.append((key, dic[key])) + return res + + best_alpha = {} + for ln_name in absorb_to_layer.keys(): + layer_names = absorb_to_layer[ln_name] + cur_shared_criterion = shared_criterion + if len(layer_names) == 1: + cur_shared_criterion = "min" + if cur_shared_criterion == "mean": + loss_tmp = {} + for alpha in loss_alphas[layer_names[0]].keys(): + if alpha not in loss_tmp.keys(): + loss_tmp[alpha] = 0 + for layer_name in layer_names: + loss_tmp[alpha] += loss_alphas[layer_name][alpha] + res = dict_to_list(loss_tmp) + res.sort(key=lambda x: x[1]) + + best_alpha[ln_name] = float(res[0][0]) + + elif cur_shared_criterion == "min" or cur_shared_criterion == "max": + tmp_best_alpha = [] + for layer_name in layer_names: + res = dict_to_list(loss_alphas[layer_name]) + res.sort(key=lambda x: x[1]) + tmp_best_alpha.append(float(res[0][0])) + if cur_shared_criterion == "min": + best_alpha[ln_name] = min(tmp_best_alpha) + else: + best_alpha[ln_name] = max(tmp_best_alpha) + + else: + raise NotImplementedError + return best_alpha + + def _get_one_batch_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes): + """Calculate the losses for all alpha values given an input. + + :return: A dict of op-wise loss values with respect to alpha values. + """ + self._change_qdq_for_auto(enable=False) + module_names = self._get_sq_layer_names() + forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output + + fp32_output = {} + for name in module_names: + module = get_module(self.model, name) + fp32_output[name] = module.output + module.output = None + self._change_qdq_for_auto(enable=True) + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, orig_best_alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + forward_wrapper(self.model, input, self.device) ##save quant_input + for mod_name in module_names: # save fp32 values + mod = get_module(self.model, mod_name) + if mod_name in self.fp32_output_val: + self.fp32_output_val[mod_name].append(torch.norm(mod.output)) + else: + self.fp32_output_val[mod_name] = [torch.norm(mod.output)] + del mod + + loss_alphas = {} + for name in module_names: + module = get_module(self.model, name) + loss = self._get_auto_loss(fp32_output[name], module.output) + cur_alpha = orig_best_alpha + if isinstance(orig_best_alpha, dict): + cur_alpha = orig_best_alpha[name] + key_name = str(cur_alpha) + loss_alphas[name] = {key_name: loss} + # for name in module_names: + # loss_alphas[name]={} + for alpha in alpha_space: + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + for name in module_names: + losses = loss_alphas[name] + if str(alpha) in losses.keys(): + continue + module = get_module(self.model, name) + output = module.q_dq_forward(module.q_input, module.input_scale, module.weight_scale) + loss = self._get_auto_loss(fp32_output[name], output) + loss_alphas[name][str(alpha)] = loss + return loss_alphas + + def _get_one_batch_auto_loss_blockwise(self, input, alpha_space, orig_best_alpha, input_maxes): + """Calculate the losses for all alpha values given an input in blockwise tuning mode. + + :return: A dict of blockwise-wise loss values with respect to alpha values. + """ + self._change_qdq_for_auto(enable=False) + module_names = self._get_sq_layer_names() + + block_modules = {} + for key in self.block_names: + block_modules[key] = get_module(self.model, key) + self._add_blockwise_observer(block_modules) + + forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output + + fp32_output = {} + for block_name in self.block_names: + fp32_output[block_name] = self.block_outputs[block_name] + self._change_qdq_for_auto(enable=True) + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, orig_best_alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + forward_wrapper(self.model, input, self.device) ##save quant_input + for mod_name in module_names: # save fp32 values + mod = get_module(self.model, mod_name) + if mod_name in self.fp32_output_val: + self.fp32_output_val[mod_name].append(torch.norm(mod.output)) + else: + self.fp32_output_val[mod_name] = [torch.norm(mod.output)] + del mod + + loss_alphas = {} + + for block_name in self.block_names: + block = get_module(self.model, block_name) + loss = self._get_auto_loss(fp32_output[block_name], self.block_outputs[block_name]) + cur_alpha = orig_best_alpha + if isinstance(orig_best_alpha, dict): + cur_alpha = orig_best_alpha[self.block_to_module[block_name][0]] + key_name = str(cur_alpha) + loss_alphas[block_name] = {key_name: loss} + # for name in module_names: + # loss_alphas[name]={} + for alpha in alpha_space: + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + + for block_name in self.block_names: + losses = loss_alphas[block_name] + if str(alpha) in losses.keys(): + continue + block = get_module(self.model, block_name) + block_copy = copy.deepcopy(block) + for name in self.block_to_module[block_name]: + if name == block_name and len(self.block_to_module[block_name]) == 1: + module, module_copy = block, block_copy + else: + module = get_module(block, name) + module_copy = copy.deepcopy(module) + if module.weight_scale is not None: + module_copy.orig_layer.weight *= module.weight_scale + q_dq_weight = quant_dequant_w_v1(module_copy.orig_layer) + module_copy.orig_layer.weight.data.copy_(q_dq_weight) + module_copy.do_blockwise = True + if not (name == block_name and len(self.block_to_module[block_name]) == 1): + set_module(block_copy, name, module_copy) + try: + output = block_copy(self.block_inputs[block_name])[0] + except: # Llama model decoder_layer forward requires position_id + position_ids = torch.arange(self.block_inputs[block_name].size()[1]) + position_ids = position_ids.view(self.block_inputs[block_name].size()[0], -1) + output = block_copy(self.block_inputs[block_name], position_ids=position_ids)[0] + loss = self._get_auto_loss(fp32_output[block_name], output) + loss_alphas[block_name][str(alpha)] = loss + del block_copy # release memory + return loss_alphas + + def opwise_rank(self, loss_alphas, best_alphas): + """Rank the final losses of ops based on their ratio with respect to op output norm. + + :return: + """ + max_op, max_ratio, max_key = "", 0, "" + ratio_info = {} + for key in self.absorb_to_layer: + for op_name in self.absorb_to_layer[key]: + fp32_norm, loss_ = ( + torch.sum(torch.stack(self.fp32_output_val[op_name])), + loss_alphas[op_name][str(best_alphas[key])], + ) + ratio = loss_ / fp32_norm + max_op = op_name if ratio > max_ratio else max_op + max_key = key if ratio > max_ratio else max_key + max_ratio = max(ratio, max_ratio) + ratio_info[op_name] = ratio + logger.debug( + f"final loss: {op_name}: {loss_}; @alpha {best_alphas[key]}; \ + fp32_output norm: {fp32_norm}; ratio: {ratio}" + ) + import operator + + ratio_info = dict(sorted(ratio_info.items(), key=operator.itemgetter(1), reverse=True)) + for key in list(ratio_info.keys()): + logger.debug(f"sorted opname-ratio: {key}: {ratio_info[key]}") + if max_op != "": + logger.debug( + f"max loss: {max_op}: {loss_alphas[max_op][str(best_alphas[max_key])]} @alpha {best_alphas[max_key]}\ + fp32_output norm: {torch.sum(torch.stack(self.fp32_output_val[max_op]))}; ratio: {max_ratio}" + ) + return None + + def default_tune_setup(self): + """Setup default auto-tune settings. + + :return: A dict of op-wise loss values with respect to alpha values. + """ + round_num = max( # Initialize the alpha search space + len(str(self.alpha_min).split(".")[1]), + len(str(self.alpha_max).split(".")[1]), + len(str(self.alpha_step).split(".")[1]), + ) + self.alpha_space = numpy.round( + numpy.arange(self.alpha_min, self.alpha_max + self.alpha_step, self.alpha_step), round_num + ).tolist() + ##wrapper new module + self._qdq_model_wrapper_for_auto(save_q_input=True) + + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, self.input_maxes_abs, self.init_alpha + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + return absorb_input_scales, weight_scales + + def _auto_tune_alpha(self): + """Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.""" + logger.info("Start alpha tuning") + + absorb_input_scales, weight_scales = self.default_tune_setup() + + total_cnt, tmp_cnt = 0, 0 + alpha_update_iter, tune_cnt = 0, 4 + # multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha + multiply_factor = ( + self.calib_sample_num // tune_cnt if self.calib_sample_num >= tune_cnt else self.calib_sample_num + ) + self.fp32_output_val = {} + best_alphas = self.init_alpha + + if not self.dataloader: + logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.") + self._qdq_model_unwrapper_for_auto() + return best_alphas + bar = tqdm(self.dataloader, total=self.calib_sample_num, desc="auto tune alpha") + for input in bar: + if isinstance(input, tuple) or isinstance(input, list): + if len(input) == 2: + input, _ = input # Extract input when both input and label are yielded by dataloader. + loss_alphas = {} + best_alphas_per_module = best_alphas + if isinstance(best_alphas, dict): + for key in self.absorb_to_layer.keys(): + layer_names = self.absorb_to_layer[key] + for layer_name in layer_names: + best_alphas_per_module[layer_name] = best_alphas_per_module[key] + loss_tmp = self._get_one_batch_auto_loss( + input, self.alpha_space, best_alphas_per_module, self.input_maxes_abs + ) + if loss_alphas == {}: + loss_alphas = loss_tmp + else: + for key in loss_alphas.keys(): + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[key][alpha_key] + total_cnt += self.dataloader.batch_size + tmp_cnt += self.dataloader.batch_size + if tmp_cnt // multiply_factor >= 1: + alpha_update_iter += 1 + tmp_cnt = 0 + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, self.input_maxes_abs, best_alphas + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + # does not need to reset the weight_scale_dict, because use the weight of ori_layer, no change + # self.weight_scale_dict = {} + if total_cnt >= self.calib_sample_num: + break + + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Final alpha {key}:{best_alphas[key]}") + + self.opwise_rank(loss_alphas, best_alphas) + self._qdq_model_unwrapper_for_auto() + logger.info("auto tuning done") + + return best_alphas + + def _auto_tune_alpha_blockwise(self): + """Perform blockwise-alpha-tuning to obtain optimal alpha values and adjust parameters accordingly.""" + logger.info("Start block-wise alpha tuning") + self.block_inputs, self.block_outputs = {}, {} + + absorb_input_scales, weight_scales = self.default_tune_setup() + + total_cnt, tmp_cnt = 0, 0 + alpha_update_iter, tune_cnt = 0, 4 + # multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha + multiply_factor = ( + self.calib_sample_num // tune_cnt if self.calib_sample_num >= tune_cnt else self.calib_sample_num + ) + self.fp32_output_val = {} + best_alphas = self.init_alpha + + if not self.dataloader: + logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.") + self._qdq_model_unwrapper_for_auto() + return best_alphas + bar = tqdm(self.dataloader, total=self.calib_sample_num, desc="auto tune alpha") + for input in bar: + if isinstance(input, tuple): # Extract input when both input and label are yielded by dataloader. + input = input[0] + loss_alphas = {} + best_alphas_per_module = best_alphas + if isinstance(best_alphas, dict): + for key in self.absorb_to_layer.keys(): + layer_names = self.absorb_to_layer[key] + for layer_name in layer_names: + best_alphas_per_module[layer_name] = best_alphas_per_module[key] + loss_tmp = self._get_one_batch_auto_loss_blockwise( + input, self.alpha_space, best_alphas_per_module, self.input_maxes_abs + ) + if loss_alphas == {}: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + loss_alphas[key] = loss_tmp[block_name] + else: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[block_name][alpha_key] + + total_cnt += self.dataloader.batch_size + tmp_cnt += self.dataloader.batch_size + if tmp_cnt // multiply_factor >= 1: + alpha_update_iter += 1 + tmp_cnt = 0 + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, self.input_maxes_abs, best_alphas + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + # does not need to reset the weight_scale_dict, because use the weight of ori_layer, no change + # self.weight_scale_dict = {} + if total_cnt >= self.calib_sample_num: + break + + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Final alpha {key}:{best_alphas[key]}") + + self.opwise_rank(loss_alphas, best_alphas) + self._qdq_model_unwrapper_for_auto() + logger.info("block-wise auto tuning done") + + return best_alphas diff --git a/auto_round/smooth_quant/calibration.py b/auto_round/smooth_quant/calibration.py new file mode 100644 index 00000000..56dc7ac7 --- /dev/null +++ b/auto_round/smooth_quant/calibration.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json + +import torch +from .utils import * + + +class Calibration: + def __init__(self, model, dataloder=None, q_func=None, device="cpu"): + self.model = model + self.dataloader = dataloder + self.q_func = q_func + self.device = device + + @torch.no_grad() + def _save_input_pc_hook(self, name): + """A forward hook to save input max of a module + :param name: the module name + :return: A hook function.""" + + def save_input_hook(module, inputs, outputs): + input = inputs[0] + ##TODO check input channel is correct + if len(module.weight.shape) == 4: ##conv3d or conv1d not supported now, need better way + input = input.permute(0, 2, 3, 1) + input = input.reshape(-1, input.shape[-1]) + max_tensor = torch.max(input, dim=0)[0] + min_tensor = torch.min(input, dim=0)[0] + if name not in self.input_maxes.keys(): + self.input_mins[name], self.input_maxes[name] = min_tensor, max_tensor + else: + self.input_mins[name] = torch.min(self.input_mins[name], min_tensor) + self.input_maxes[name] = torch.max(self.input_maxes[name], max_tensor) + + return save_input_hook + + @torch.no_grad() + def _add_min_max_observer(self, modules): + """ + :param modules: the modules which the observer will insert to + :return: + """ + self.hook_handles = [] + for key in modules.keys(): + hook_func = self._save_input_pc_hook(key) + hook_handle = modules[key].register_forward_hook(hook_func) + self.hook_handles.append(hook_handle) + + @torch.no_grad() + def _remove_observer(self): + """Remove the observer from the model + :return:""" + for hook_handle in self.hook_handles: + hook_handle.remove() + + @torch.no_grad() + def _dump_min_max(self, calib_iter=100): + """Dump min max per channel information, the min max value will be saved in input_maxes attribute + :param calibration_method: only support min_max currently + :param calib_iter: Sample size for calibration + :return:""" + logger.info("Calibrating...") + if self.q_func: + self.q_func(self.model) + else: + assert self.dataloader, "Please set dataloader for calibration." + model_forward(self.model, self.dataloader, calib_iter, self.device) + + @torch.no_grad() + def calibrate(self, calib_iter, op_types=[torch.nn.Conv2d, torch.nn.Linear]): ##TODO transformers.conv1d + """ + :param absorb_to_layer: A dict,key is the absorb layer, val is a list of the to be smoothed layer + :param calib_iter: Data size for calibration + :return: A dict that saved the layer name and the channel-wise max value info + """ + ##hook all the module + self.input_mins = {} + self.input_maxes = {} + + hook_modules = {} + for n, module in self.model.named_modules(): + if isinstance(module, tuple(op_types)): + hook_modules[n] = module + + self._add_min_max_observer(hook_modules) + + self._dump_min_max(calib_iter=calib_iter) + self._remove_observer() + return self.input_mins, self.input_maxes diff --git a/auto_round/smooth_quant/graph_trace.py b/auto_round/smooth_quant/graph_trace.py new file mode 100644 index 00000000..0c1ca009 --- /dev/null +++ b/auto_round/smooth_quant/graph_trace.py @@ -0,0 +1,222 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +from collections import UserDict + +from .utils import get_module, move_input_to_device, logger + + +def get_parent(node, all_parents=False): + if node.inputs() is None: + return None + elif len(list(node.inputs())) == 0: + return None + if not all_parents: + return list(node.inputs())[0].node() + else: + return list(node.inputs()) + + +class GraphTrace: + """""" + + def __init__(self): + self.supported_torch_module_to_aten = { + "Linear": "aten::linear", + "Conv2d": "aten::_convolution", + "ConvTranspose2d": "aten::_convolution", + "LayerNorm": "aten::layer_norm", + "BatchNorm2d": "aten::batch_norm", + "GroupNorm": "aten::group_norm", + "InstanceNorm2d": "aten::instance_norm", + "LlamaRMSNorm": "aten::mul", + "T5LayerNorm": "aten::mul", + "LPLayerNorm": "aten::layer_norm", ##mpt_chat + } + + ##TODO potential bug, need to check only have one bug + ##TODO, must satisfy af(x)=f(ax),current skip layer may be incomplete + self.skip_ops_to_find_absorb = ["aten::to", "aten::relu", "aten::leaky_relu", "aten::hardtanh"] + + self.could_absorb_layers = [ + "aten::layer_norm", + "aten::batch_norm", + "aten::linear", + "aten::_convolution", + "aten::group_norm", + "aten::instance_norm", + "aten::mul", + ] ##TODO,support more norm + + def trace(self, model, dummy_input): + traced_model = None + optimize_numerics = False + orig_device = str(next(model.parameters()).device) + if orig_device != "cpu" and orig_device != "meta": # pragma: no cover + model = model.to("cpu") + dummy_input = move_input_to_device(dummy_input, "cpu") + if isinstance(dummy_input, dict) or isinstance(dummy_input, UserDict): + try: + traced_model = torch.jit.trace( + model, example_kwarg_inputs=dict(dummy_input), strict=False, check_trace=False + ) + traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) + except Exception as e: + logger.warning(e) + logger.warning("Jit trace in GraphTrace failed, absorb layer detection is skipped") + else: + try: + traced_model = torch.jit.trace(model, dummy_input, strict=False) + traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) + except: + try: + traced_model = torch.jit.trace(model, dummy_input[0], strict=False) + traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) + except Exception as e: + logger.warning(e) + logger.warning("Jit trace in GraphTrace failed, absorb layer detection is skipped") + model = model.to(orig_device) + return traced_model + + def get_nodes(self, traced_model, op_types=["Linear"]): + if isinstance(op_types, str): + op_types = [op_types] + nodes = [] + for node in traced_model.graph.nodes(): + node_type = node.kind() + for op_type in op_types: + if node_type == op_type: + nodes.append((node, op_type)) + break + return nodes + + def get_prev_absorb_layer(self, nodes): + prev_absorb_layer = [] + for node in nodes: + parent = get_parent(node) + while 1: + if parent.kind() in self.skip_ops_to_find_absorb: + parent = get_parent(parent) + continue + if parent.kind() in self.could_absorb_layers: + parent_out_kinds = [] + for val_user in list(parent.outputs())[0].uses(): + next_node = val_user.user + parent_out_kinds.append(next_node.kind()) + parent_out_kinds = set(parent_out_kinds) + parent_out_kinds.discard("aten::size") + + if parent_out_kinds == parent_out_kinds.intersection(self.could_absorb_layers): + prev_absorb_layer.append(parent) + elif parent_out_kinds.intersection(self.skip_ops_to_find_absorb): + res = self.skip_op_absorb_helper(parent) + prev_absorb_layer.append(parent) if res else prev_absorb_layer.append(None) + else: # When parent to multiple ops, sq transformation could be wrong. + prev_absorb_layer.append(None) + else: + prev_absorb_layer.append(None) + break + return prev_absorb_layer + + def skip_op_absorb_helper(self, parent_node): + for val_user in list(parent_node.outputs())[0].uses(): + next_node = val_user.user + if next_node.kind() == "aten::size": + continue + elif next_node.kind() in self.could_absorb_layers: + continue + elif next_node.kind() in self.skip_ops_to_find_absorb: + node_res = self.skip_op_absorb_helper(next_node) + if not node_res: + return False + else: + return False + return True + + def mapping_torch_module_to_aten(self, op_types): + res = [] + for op in op_types: + if op not in self.supported_torch_module_to_aten.keys(): + logger.warning(f"{op} is not supported in smooth quant, ignoring...") + continue + res.append(self.supported_torch_module_to_aten[op]) + res = list(set(res)) + return res + + def _check_valid_conv(self, module): + """Remove group conv except depthwise conv + :param module: + + :return: + """ + if not isinstance(module, torch.nn.Conv2d): + return True + if module.groups > 1: + if module.in_channels == module.out_channels and module.groups == module.in_channels: + return True + else: + return False + return True + + def get_absorb_to_layer(self, model, example_input, op_types, skip_unsupported_layers=True): + traced_model = self.trace(model, example_input) + if traced_model is None: + return None, None + + aten_op_types = self.mapping_torch_module_to_aten(op_types) + nodes_types = self.get_nodes(traced_model, aten_op_types) + nodes = [node_type[0] for node_type in nodes_types] + nodes_prev_absorb = self.get_prev_absorb_layer(nodes) + absorb_to_layer = {} + no_absorb_layers = [] + for index, absorb in enumerate(nodes_prev_absorb): + if absorb is None: + no_absorb_layers.append(".".join(nodes[index].scopeName().split("/")[-1].split(".")[1:])) + continue + node = nodes[index] + layer_name = ".".join(node.scopeName().split("/")[-1].split(".")[1:]) + absorb_name = ".".join(absorb.scopeName().split("/")[-1].split(".")[1:]) + if layer_name == "" or absorb_name == "": + continue + if absorb_name in absorb_to_layer.keys(): + absorb_to_layer[absorb_name].append(layer_name) + else: + absorb_to_layer[absorb_name] = [layer_name] + if skip_unsupported_layers: + absorb_to_layer = self.remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers) + return absorb_to_layer, no_absorb_layers + + def remove_unsupported_layers(self, model, absorb_to_layer, no_absorb_layers): + res = {} + for key in absorb_to_layer.keys(): + absorb_layer = get_module(model, key) + layer_type = absorb_layer.__class__.__name__ + if layer_type not in self.supported_torch_module_to_aten.keys(): + no_absorb_layers.extend(absorb_to_layer[key]) + continue + supported = True + for layer_name in absorb_to_layer[key]: + layer = get_module(model, layer_name) + layer_type = layer.__class__.__name__ + if (layer_type not in self.supported_torch_module_to_aten.keys()) or not self._check_valid_conv(layer): + supported = False + no_absorb_layers.extend(absorb_to_layer[key]) + break + if supported: + res[key] = absorb_to_layer[key] + return res diff --git a/auto_round/smooth_quant/smooth_quant.py b/auto_round/smooth_quant/smooth_quant.py new file mode 100644 index 00000000..43a003d1 --- /dev/null +++ b/auto_round/smooth_quant/smooth_quant.py @@ -0,0 +1,581 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy + +import torch +import numpy + +from .calibration import Calibration +from .graph_trace import GraphTrace +from .utils import * + + +class TorchSmoothQuant: + """Fake input channel quantization, for more details please refer to + [1] SmoothQuant: Accurate and Efficient + Post-Training Quantization for Large Language Models + [2] SPIQ: Data-Free Per-Channel Static Input Quantization + Currently, we only handle the layers whose smooth scale could be absorbed, we will support other layers later. + + We only support inplace mode which means the model weights will be changed, you can call recover function + to recover the weights if needed + """ + + def __init__(self, model, dataloader=None, example_inputs=None, q_func=None, traced_model=None): + """ + :param model: Torch model :param dataloader: Calibration dataloader :param traced_model: A specific model + shares the same architecture as the model and could be traced by torch.jit. If not supplied, we use model + instead. + """ + self.model = model + if not isinstance(self.model, torch.nn.Module): + return + device, dtype = self._get_device() + self.model = self.model.to(device) + self.model.eval() + self.device = device + self.dtype = dtype + self.dataloader = dataloader + self.example_inputs = example_inputs + self.q_func = q_func + self.input_maxes = {} + self.input_mins = {} + self.input_maxes_abs = {} + self.traced_model = traced_model + if self.traced_model is None: + self.traced_model = self.model + self.weight_scale_info = {} + self.absorb_scales_info = {} + self.insert_mul = False + self.allow_absorb = True + self.record_max_info = False + self.max_value_info = {} # to record max values for alpha tune + self.absorb_to_layer = {} + self.weight_max_lb = 1e-5 ##weight max low bound + self.weight_scale_dict = {} + self.sq_scale_info = {} + self.max_value_info = {} + self.need_calibration = False + + def _get_device(self): + """Get the model device + :return:Model device.""" + for _, p in self.model.named_parameters(): + return p.data.device, p.data.dtype + + def _scale_layer_weight(self, layer_name, scale, alpha=0.5, input_minmax=None): ##input channel + """Scale the layer weights at input channel, depthwise conv output channel + :param layer_name: The layer name + :param scale: The scale to be multiplied + :param alpha: alpha for SQLinearWrapper + :param input_minmax: input_minmax for SQLinearWrapper + :return:""" + layer = get_module(self.model, layer_name) + if self.insert_mul: + from .utils import SQLinearWrapper + + layer = get_module(self.model, layer_name) + if isinstance(layer, SQLinearWrapper): + layer._recover_sq_linear() + set_module(self.model, layer_name, layer.sq_linear) ##recover + else: + new_module = SQLinearWrapper(layer, 1.0 / scale, input_minmax, alpha) + set_module(self.model, layer_name, new_module) + elif self.allow_absorb: + scale = reshape_scale_as_weight(layer, scale) + layer.weight = torch.nn.Parameter(layer.weight * scale) + return scale + + def _absorb_scales(self, layer_name, scale): ##output channel + """Absorb the scale to the layer at output channel + :param layer_name: The module name + :param scale: The scale to be absorbed + :param alpha_key: The alpha passed to SQLinearWrapper + :return:""" + if self.insert_mul or not self.allow_absorb: + return # absorb is updated in SQLinearWrapper in def _scale_layer_weight + + ##if self.allow absorb + layer = get_module(self.model, layer_name) + if layer.__class__.__name__ == "WrapperLayer": + layer = layer.orig_layer + if ( + isinstance(layer, torch.nn.BatchNorm2d) + or isinstance(layer, torch.nn.GroupNorm) + or isinstance(layer, torch.nn.InstanceNorm2d) + ): + if layer.affine: + layer.weight *= scale + layer.bias *= scale + else: + layer.affine = True + weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + elif isinstance(layer, torch.nn.LayerNorm): + if layer.elementwise_affine: + layer.weight *= scale + layer.bias *= scale + else: + layer.elementwise_affine = True + weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale + layer.weight = torch.nn.Parameter(torch.ones(weight, requires_grad=False)) + bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + + elif isinstance(layer, torch.nn.Conv2d): + ##the order could not be changed + if hasattr(layer, "bias") and (layer.bias is not None): + layer.bias *= scale + scale = scale.view(scale.shape[0], 1, 1, 1) + layer.weight *= scale + + elif isinstance(layer, torch.nn.Linear): + if hasattr(layer, "bias") and (layer.bias is not None): + layer.bias *= scale + scale = scale.view(scale.shape[0], 1) + layer.weight *= scale + + elif layer.__class__.__name__ == "LlamaRMSNorm" or layer.__class__.__name__ == "T5LayerNorm": ##quite tricky + layer.weight *= scale + + else: + logger.warning( + f"found unsupported layer {type(layer)}, try to multiply scale to " + f"weight and bias directly, this may introduce accuracy issue, please have a check " + ) + if hasattr(layer, "weight") and layer.weight is not None: + layer.weight *= scale + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias *= scale + + def _export_sq_info(self, absorb_to_layer, input_maxes, alpha=0.5): + from .utils import SQLinearWrapper + + absorb_to_input_maxes = {} + for key in absorb_to_layer.keys(): + layer_name = absorb_to_layer[key][0] + absorb_to_input_maxes[key] = input_maxes[layer_name] + + for index, key in enumerate(absorb_to_layer.keys()): + alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha + layer_names = absorb_to_layer[key] + weights = [] + for layer_name in layer_names: + weight = reshape_in_channel_to_last(layer_name, self.model) + weights.append(weight) + weight_max_per_channel = torch.max(torch.abs(torch.cat(weights, dim=0)), dim=0)[0] + + weight_max_per_channel = weight_max_per_channel.clamp(min=self.weight_max_lb) + + input_max = absorb_to_input_maxes[key] + layer_names = absorb_to_layer[key] + # weight_scale = cal_scale(input_max, weights, alpha_tmp) + input_minmax = [self.input_mins[layer_names[0]].to("cpu"), self.input_maxes[layer_names[0]].to("cpu")] + abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1])) + input_power = torch.pow(abs_input_max, alpha_tmp) + weight_power = torch.pow(weight_max_per_channel, 1 - alpha_tmp) + weight_scale = torch.clip(input_power / weight_power, min=1e-5) + + input_scale = 1.0 / weight_scale + + self.max_value_info[key] = { + "alpha": alpha_tmp, + "input_minmax": input_minmax, + "weight_max": weight_max_per_channel, + "absorbed_layer": layer_names, + } # max_value_info is used for pytorch backend and sq_scale_info is used for ipex backend. + # the input of layers with same absorb layer is the same. + for op_name in layer_names: + module = copy.deepcopy(get_module(self.model, op_name)) + new_module = SQLinearWrapper(module, 1.0 / weight_scale, input_minmax, alpha_tmp) + self.sq_scale_info[op_name] = {} + self.sq_scale_info[op_name] = { + "alpha": alpha_tmp, + "input_scale_for_mul": input_scale.to("cpu"), + "input_scale_after_mul": new_module.scale, + "input_zero_point_after_mul": new_module.zero_point, + "input_dtype": new_module.dtype, + "weight_scale_after_mul": new_module._get_weight_scale(), + } + + def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5): + """Cal the adjust scales + :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer + :param input_maxes: The channel-wise input max info for layers + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict + :return:""" + absorb_to_input_maxes = {} + for key in absorb_to_layer.keys(): + layer_name = absorb_to_layer[key][0] + absorb_to_input_maxes[key] = input_maxes[layer_name] + + weight_scales_info = {} + absorb_scales_info = {} + for index, key in enumerate(absorb_to_layer.keys()): + alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha + + input_max = absorb_to_input_maxes[key] + layer_names = absorb_to_layer[key] + weights = [] + for layer_name in layer_names: + weight = reshape_in_channel_to_last(layer_name, self.model) + weights.append(weight) + scale = cal_scale(input_max, weights, alpha_tmp) + absorb_scales_info[key] = 1.0 / scale + absorb_scales_info[key][scale == 0] = 0 + layer_names = absorb_to_layer[key] + for layer_name in layer_names: + ##self._scale_layer_weight(layer_name, scale) + weight_scales_info[layer_name] = scale + return absorb_scales_info, weight_scales_info + + def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5): + """Adjust the weights and biases + :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer + :param input_maxes: The channel-wise input max info for layers + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict + :return:""" + absorb_scales_info, weight_scales_info = self._cal_scales(absorb_to_layer, input_maxes, alpha) + if not absorb_scales_info or not weight_scales_info: + return weight_scales_info, absorb_scales_info + for index, key in enumerate(absorb_to_layer.keys()): + if isinstance(alpha, float): + alpha_tmp = alpha + elif isinstance(alpha, dict): + alpha_tmp = alpha[key] + absorb_scale = absorb_scales_info[key] + self._absorb_scales(key, absorb_scale) + layer_names = absorb_to_layer[key] + for layer_name in layer_names: + input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]] + self._scale_layer_weight(layer_name, weight_scales_info[layer_name], alpha_tmp, input_minmax) + return weight_scales_info, absorb_scales_info + + def _check_need_calibration(self, alpha, percentile, op_types, scales_per_op, calib_iter): + """ + check need calibration or not + :param alpha: current alpha + :param percentile: current percentile + :param op_types: current op_types + :param scales_per_op: current scales_per_op + :param calib_iter:: current scales_per_op + :return: + """ + need_calib = True + from peft import PeftModel + + is_peft, is_auto = isinstance(self.model, PeftModel), alpha == "auto" + if len(self.input_maxes) == 0: ## the first time + need_calib = True + self.alpha = alpha + self.percentile = percentile + self.op_types = op_types + self.scales_per_op = scales_per_op + self.calib_iter = calib_iter + return False if (is_auto and not is_peft) else need_calib + + if ( + self.percentile == percentile + and self.op_types == op_types + and self.scales_per_op == scales_per_op + and self.calib_iter == calib_iter + ): + if isinstance(alpha, float) or self.alpha == "auto": + need_calib = False + + self.alpha, self.percentile, self.calib_iter = alpha, percentile, calib_iter + self.op_types, self.scales_per_op = op_types, scales_per_op + return need_calib + + @torch.no_grad() + def _parse_absorb_to_layers(self, op_types, folding): + str_op_types = [i.__name__ for i in op_types] + self_absorb_layers = {} + if self.insert_mul: + self_absorb_layers = self._get_all_layer_names(op_types) # TODO: only support linear now. + # fetch modules with the same input + group_modules = self._trace(str_op_types, skip_unsupported_layers=False) + if group_modules is not None: + # use one input for qkv + for k, v in group_modules.items(): + for i in v: + if i in self_absorb_layers: + self_absorb_layers.pop(i) + self_absorb_layers[v[0]] = v + logger.debug(f"self_absorb_layers:{self_absorb_layers}") + if self.allow_absorb: + self.absorb_to_layer, no_absorb_layers = self._trace(str_op_types) + if self.absorb_to_layer is None and no_absorb_layers is None: + return None + + # remove self.self_absorb_layers if it exists in self.absorb_to_layer + for k, v in self.absorb_to_layer.items(): + for i in v: + if i in self_absorb_layers: + self_absorb_layers.pop(i) + self.absorb_to_layer.update(self_absorb_layers) + + if self.absorb_to_layer is None and no_absorb_layers is None: + logger.warning( + "sorry, could not trace the model, smooth quant is ignored." + "If you are using huggingface model," + "you could set torchscript to True " + ) + return None + + # Check if input_maxes match self.absorb_to_layer + # (due to self._get_all_layer_names use layer tree instead of forward_path) + if not folding and self.need_calibration: + if len(self.input_mins) == 0: ##there are some modules not used in forward + calib = Calibration(self.model, self.dataloader, self.q_func, self.device) ## + input_mins, input_maxes = calib.calibrate( + 1, op_types + ) ##TODO if using qfunc for calibration, it will calibrate twice + # use qfunc to calibrate, the input min could be used for fixed alpha transformation + self.input_mins = input_mins + self.input_maxes = input_maxes + diff_modules = set(self.absorb_to_layer.keys()).difference(input_mins.keys()) + for d in diff_modules: + del self.absorb_to_layer[d] + return self.absorb_to_layer + + @torch.no_grad() + def transform( + self, + alpha=0.5, + folding=False, + percentile=100, + op_types=[torch.nn.Linear, torch.nn.Conv2d], + scales_per_op=False, + calib_iter=100, + weight_clip=True, + auto_alpha_args={ + "init_alpha": 0.5, + "alpha_min": 0.0, + "alpha_max": 1.0, + "alpha_step": 0.1, + "shared_criterion": "mean", + "n_samples": 32, ##512 for cuda, 128 for cpu? + }, + ): + """The main entry of smooth quant + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, please refer + to the paper for more details + :param folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant + :param percentile: Not supported now + :param op_types: The op typed to be smooth quantized + :param scales_per_op: Not supported now + :param calib_iter: Data size for calibration + :param weight_clip: Whether to clip weight_max when calculating scales. + + :param auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning. + By default, the search space is 0.0-1.0 with step_size 0.1. + do_blockwise: Whether to do blockwise auto-tuning. + :param init_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5. + :return: A FP32 model with the same architecture as the orig model but with different weight which will be + benefit to quantization. + """ + + if not isinstance(self.model, torch.nn.Module): + logger.warning("smoothquant is ignored since the model is not a torch module") + return self.model + + if isinstance(alpha, float) and (alpha < 0): + logger.warning("reset alpha to >=0") + alpha = numpy.clip(alpha, 0.0) + + if folding: + self.insert_mul, self.allow_absorb = False, True + else: + self.insert_mul, self.allow_absorb = True, False + self.weight_clip = weight_clip + + self.revert() + self.need_calibration = self._check_need_calibration(alpha, percentile, op_types, scales_per_op, calib_iter) + if self.need_calibration: + self.input_mins, self.input_maxes = {}, {} + self.absorb_to_layer = self._parse_absorb_to_layers( + op_types, folding + ) ##need to forward to check modules not used in forward + if len(self.input_mins) != 0: ##this is from _parse_absorb_to_layers, ugly code to support q_func + input_maxes_abs = {} + for key in self.input_mins.keys(): + input_maxes_abs[key] = torch.max(torch.abs(self.input_mins[key]), torch.abs(self.input_maxes[key])) + if self.q_func: + self.need_calibration = False # Avoid double-calibration in fixed-value alpha SQ. + + if self.absorb_to_layer is None: + logger.warning("empty absorb_to_layer, smoothquant is ignored ") + return self.model + example_inputs = self._get_example_input() + if alpha == "auto": ##TODO need to polish later + from . import auto_alpha + from .utils import TUNERS + + auto_alpha_version = "version1" + auto_alpha_tuner = TUNERS[auto_alpha_version]( + self.model, + self.dataloader, + self.absorb_to_layer, + op_types=op_types, + device=self.device, + q_func=self.q_func, + folding=folding, + example_inputs=self.example_inputs, + **auto_alpha_args, + ) + self.alpha = auto_alpha_tuner.tune() + input_maxes_abs = auto_alpha_tuner.input_maxes_abs + self.input_mins, self.input_maxes = auto_alpha_tuner.input_mins, auto_alpha_tuner.input_maxes + if auto_alpha_tuner.loss_type == "blockwise": + self.block_names = auto_alpha_tuner.block_names + + elif self.need_calibration: + calib = Calibration(self.model, self.dataloader, self.q_func, self.device) + self.input_mins, self.input_maxes = calib.calibrate(calib_iter, op_types) + input_maxes_abs = {} + for key in self.input_mins.keys(): + input_maxes_abs[key] = torch.max(torch.abs(self.input_mins[key]), torch.abs(self.input_maxes[key])) + + if example_inputs is not None: + out_pre_sq = model_forward_per_sample(self.model, example_inputs, self.device) + + if folding: + self._save_scale = False ##TODO remove it later + + if self.record_max_info: + self._export_sq_info(self.absorb_to_layer, input_maxes_abs, self.alpha) + # # max_info is recorded in self.max_value_info + # self._adjust_parameters(self.absorb_to_layer, input_maxes_abs, alpha) + self.model._smoothquant_optimized = False + return self.model + + self.weight_scale_info, self.absorb_scales_info = self._adjust_parameters( + self.absorb_to_layer, input_maxes_abs, self.alpha + ) + self.model._smoothquant_optimized = True + + if example_inputs is not None: + # Check mathematical equivalency + out_post_sq = model_forward_per_sample(self.model, example_inputs, self.device) + if not self.output_is_equal(out_post_sq, out_pre_sq): + logger.warning( + "Mathematical equivelancy of Smoothquant is not preserved. " + "Please kindly report this issue to https://github.com/intel/neural-compressor." + ) + else: + logger.warning(" Could not get example input, equivelancy check is skipped") + + return self.model + + def output_is_equal(self, out1, out2, atol=1e-04): + try: + if isinstance(out1, tuple): + return all(torch.all(torch.isclose(out1[i], out2[i], atol=atol)) for i in range(len(out1))) + elif isinstance(out1, dict): + return all(torch.all(torch.isclose(out1[k], out2[k], atol=atol)) for k in out1.keys()) + elif isinstance(out1, torch.Tensor): + return torch.all(torch.isclose(out1, out2, atol=atol)) + return False + except: + logger.warning( + "Automatically check failed, Please check equivelancy manually " + "between out_pre_sq and out_post_sq if necessary." + ) + return True + + @torch.no_grad() + def revert(self): + """Revert the model weights + :return:""" + for key in self.weight_scale_info: + self._scale_layer_weight(key, 1.0 / self.weight_scale_info[key]) + for key in self.absorb_scales_info: + self._absorb_scales(key, 1.0 / self.absorb_scales_info[key]) + self.weight_scale_info = {} ##clear the data + self.absorb_scales_info = {} + + def _get_all_layer_names(self, op_types=[torch.nn.Linear]): + """Try the model to find the layers which can be smooth quantized. + + :param op_types: The op types to be smooth quantized + :return: + self_absorb_layer: A dict, absorb layer name (itself): layers to be smooth quantized + """ + self_absorb_layer = {} + op_types = [torch.nn.Linear] # TODOļ¼š only support SQLinearWrapper + for name, module in self.model.named_modules(): + if isinstance(module, tuple(op_types)): + self_absorb_layer[name] = [name] + return self_absorb_layer + + def _get_example_input(self): + if self.dataloader is None and self.example_inputs is None: + return None + if self.example_inputs is None: + try: + for idx, (input, label) in enumerate(self.dataloader): + self.example_inputs = input + break + except: + for idx, input in enumerate(self.dataloader): + self.example_inputs = input + break + + return self.example_inputs + + def _trace(self, op_types, skip_unsupported_layers=True): + """Try the model to find the layers which can be smooth quantized. + + :param op_types: The op types to be smooth quantized + :return: + absorb_to_layer: A dict, absorb layer name:layers to be smooth quantized + no_absorb_layers: A list saving the layers which could not find the absorb layer + """ + + tg = GraphTrace() + self._get_example_input() + absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer( + self.traced_model, + self.example_inputs, + op_types, + skip_unsupported_layers=skip_unsupported_layers, + ) + if not skip_unsupported_layers: + return absorb_to_layer + if absorb_to_layer is None and no_absorb_layers is None: + logger.warning( + "sorry, could not trace the model, smooth quant is skipped." + "If you are using huggingface model," + "you could set torchscript to True " + "when loading the model or set the return_dict to False" + ) + elif absorb_to_layer == {}: + logger.warning("could not find any layer to be absorbed") + else: + to_absorb_cnt = 0 + for key, item in absorb_to_layer.items(): + to_absorb_cnt += len(item) + logger.info( + f" {to_absorb_cnt} out of {to_absorb_cnt + len(no_absorb_layers)} " + f"layers could be absorbed in smooth quant" + ) + return absorb_to_layer, no_absorb_layers \ No newline at end of file diff --git a/auto_round/smooth_quant/utils.py b/auto_round/smooth_quant/utils.py new file mode 100644 index 00000000..f8bf96e1 --- /dev/null +++ b/auto_round/smooth_quant/utils.py @@ -0,0 +1,483 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import json + +import logging +import torch + +from collections import UserDict, defaultdict + +import numpy +from tqdm import tqdm + +logger = logging.getLogger() + + +def enough_memo_store_scale(device, need_space): + if device == "cuda": # pragma: no cover + current_gpu_index = torch.cuda.current_device() + total_memory = torch.cuda.get_device_properties(current_gpu_index).total_memory + used_memory = torch.cuda.memory_allocated(current_gpu_index) + free_space = total_memory - used_memory + else: + import psutil + + free_space = psutil.virtual_memory().free + return free_space >= need_space + + +def move_input_to_device(input, device=torch.device("cpu")): + if isinstance(input, dict) or isinstance(input, UserDict): + tmp_input = {} + for k, inp in input.items(): + tmp_input[k] = move_input_to_device(inp, device) + input = tmp_input + elif isinstance(input, list) or isinstance(input, tuple): + is_tuple = isinstance(input, tuple) + tmp_input = [] + for inp in input: + tmp_input.append(move_input_to_device(inp, device)) + input = tuple(tmp_input) if is_tuple else tmp_input + elif isinstance(input, torch.Tensor): + input = input.to(device) # pylint: disable=no-member + return input + + +##TODO potential bug, data typeR +def forward_wrapper(model, input, device=torch.device("cpu")): + try: + model = model.to(device) + input = move_input_to_device(input, device) + except Exception as e: + logger.warning(e) + logger.warning("Please check the input device if the error raised.") + if isinstance(input, dict) or isinstance(input, UserDict): + output = model(**input) + elif isinstance(input, list) or isinstance(input, tuple): + try: + output = model(*input) + except: + output = model(input) + else: + output = model(input) + return output + + +def model_forward(model, dataloader, iters, device): + try: + cnt = 0 + for idx, (input, label) in enumerate(dataloader): + if input is None: + continue + output = forward_wrapper(model, input, device) + cnt += 1 + if iters != -1 and cnt >= iters: + break + except Exception as e: + cnt = 0 + for idx, input in enumerate(dataloader): + output = forward_wrapper(model, input, device) + cnt += 1 + if iters != -1 and cnt >= iters: + break + + +# def get_block_names(model): +# """Get the block names for transformers-like networks. + +# Args: +# model: The model. + +# Returns: +# block_names: A list of block names. +# """ +# block_names = [] +# target_m = None +# for n, m in model.named_modules(): +# if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: +# target_m = (n, m) +# for n, m in target_m[1].named_children(): +# block_names.append(target_m[0] + "." + n) +# return block_names + + +def model_forward_per_sample(model, sample, device): + try: + output = forward_wrapper(model, sample, device) + return output + + except Exception as e: + output = forward_wrapper(model, sample[0], device) + return output + + +def quant_dequant_w_v1(m, num_bits=8, scheme="sym"): + eps = torch.finfo(torch.float32).eps + if isinstance(m, torch.nn.Linear): + x = m.weight + tmp = torch.zeros(torch.max(x, dim=1).values.size()) + if scheme == "sym": + q_min, q_max = -(2.0 ** (num_bits - 1)), 2.0 ** (num_bits - 1) - 1.0 + x_max = torch.max(torch.abs(x), dim=1).values + scale = x_max / (float(q_max - q_min) / 2) + else: + q_min, q_max = 0, 2.0**num_bits - 1.0 + x_max = torch.maximum(torch.max(x, dim=1).values, tmp) + x_min = torch.minimum(torch.min(x, dim=1).values, tmp) + scale = (x_max - x_min) / (2**num_bits - 1) + + scale = torch.clip(scale, min=eps) + + if scheme == "sym": + bias = 0 + else: + bias = torch.round(0 - (torch.min(x, dim=1).values) / scale) + bias = bias.unsqueeze(dim=-1) + scale = scale.unsqueeze(dim=-1) + q_x = torch.round(x / scale + bias) + q_x.clamp_(q_min, q_max) + return (q_x - bias) * scale + elif isinstance(m, torch.nn.Conv2d): + x = m.weight + x = torch.permute(x, (0, 2, 3, 1)) + x = x.reshape(-1, x.shape[-1]) + tmp = torch.zeros(torch.max(x, dim=0).values.size()) + if scheme == "sym": + q_min, q_max = -(2.0 ** (num_bits - 1)), 2.0 ** (num_bits - 1) - 1.0 + x_max = torch.max(torch.abs(x), dim=0).values + scale = x_max / (2 ** (num_bits - 1) - 1) + else: + q_min, q_max = 0, 2.0**num_bits - 1.0 + x_max = torch.maximum(torch.max(x, dim=0).values, tmp) + x_min = torch.minimum(torch.min(x, dim=0).values, tmp) + scale = (x_max - x_min) / (2**num_bits - 1) + scale = torch.clip(scale, min=eps) + if scheme == "sym": + bias = 0 + else: + bias = torch.round(0 - (torch.min(x, dim=0).values) / scale) + bias = bias.unsqueeze(dim=0) + scale = scale.unsqueeze(dim=0) + + q_x = x / scale + bias + q_x.clamp_(q_min, q_max).round_() + q_dq_x = (q_x - bias) * scale + q_dq_x = q_dq_x.view(m.weight.shape[0], m.weight.shape[2], m.weight.shape[3], m.weight.shape[1]) + q_dq_x = torch.permute(q_dq_x, (0, 3, 1, 2)) + return q_dq_x + else: + logger.warning("unsupported layer type, please have a check") + + +# def quant_dequant_w(x, scale, num_bits=8): ##default sym +# scale = scale.unsqueeze(dim=1) +# q_min, q_max = -(2.0 ** (num_bits - 1)), 2.0 ** (num_bits - 1) - 1.0 +# q_x = torch.round(x / scale) +# q_x.clamp_(q_min, q_max) +# return scale * q_x + + +def quant_dequant_x_v1(x, min_x=None, max_x=None, num_bits=8): + eps = torch.finfo(torch.float32).eps + q_min, q_max = 0, 2.0**num_bits - 1.0 + if max_x is None or min_x is None: + max_x, min_x = torch.max(x), torch.min(x) + else: + max_x = torch.max(max_x) + min_x = torch.min(min_x) + scale = (max_x - min_x) / (2**num_bits - 1) + scale = torch.clip(scale, min=eps) + bias = torch.round((0 - min_x) / scale) + q_x = torch.round(x / scale + bias) + q_x.clamp_(q_min, q_max) + return scale * (q_x - bias) + + +# def quant_dequant_x(x, scale, bias, num_bits=8): ##default asym +# q_min, q_max = 0, 2.0**num_bits - 1.0 +# # if max_x is None or min_x is None: +# # max_x, min_x = torch.max(x), torch.min(x) +# # else: +# # max_x = torch.max(max_x) +# # min_x = torch.min(min_x) +# # scale = (max_x - min_x) / (2**num_bits - 1) +# # scale = torch.clip(scale, min=eps) +# # bias = torch.round((0 - min_x) / scale) +# q_x = torch.round(x / scale + bias) +# q_x.clamp_(q_min, q_max) +# return scale * (q_x - bias) + + +def get_module(model, key): + """Get module from model by key name. + + Args: + model (torch.nn.Module): original model + key (str): module name to be replaced + """ + module = model + name_list = key.split(".") + for name in name_list: + if hasattr(module, name): + module = getattr(module, name) + elif hasattr(module, "sq_linear"): # for peft models + module = getattr(module, "sq_linear") + module = getattr(module, name) + elif hasattr(module, "orig_layer"): # for peft models and auto alpha + module = getattr(module, "orig_layer") + module = getattr(module, name) + else: + module = module + return module + + +def set_module(model, key, new_module): + """Set new module into model by key name. + + Args: + model (torch.nn.Module): original model + key (str): module name to be replaced + new_module (torch.nn.Module): new module to be inserted + """ + module = model + name_list = key.split(".") + for name in name_list[:-1]: + if hasattr(module, name): + module = getattr(module, name) + elif hasattr(module, ("sq_linear")): # for peft models that Linears are contained in Linear + module = getattr(module, "sq_linear") + module = getattr(module, name) + elif hasattr(module, ("orig_layer")): # for peft models and auto alpha + module = getattr(module, "orig_layer") + module = getattr(module, name) + else: + module = module + + if hasattr(module, "sq_linear") and name_list[-1] != "sq_linear": # for peft models + module = getattr(module, "sq_linear") + if hasattr(module, "orig_layer") and name_list[-1] != "orig_layer": # for peft models and auto alpha + module = getattr(module, "orig_layer") + setattr(module, name_list[-1], new_module) + + +def cal_scale(input_max_abs, weights, alpha, weight_max_lb=1e-5): + weights = torch.cat(weights, dim=0) + weight_max = torch.max(torch.abs(weights), dim=0)[0] + weight_max = torch.clip(weight_max, weight_max_lb) + input_power = torch.pow(input_max_abs, alpha) + logger.debug(f"{max(input_max_abs)}, {min(input_max_abs)}") + weight_power = torch.pow(weight_max, 1 - alpha) + weight_scale = torch.clip(input_power / weight_power, min=1e-5) + weight_scale[input_power == 0] = 1.0 + return weight_scale + + +def reshape_in_channel_to_last(layer_name, model): + """Move the input channel to the last dim + :param layer_name: Layer name + :return: The reshaped weight.""" + layer = get_module(model, layer_name) + if layer.__class__.__name__ == "WrapperLayer": + layer = layer.orig_layer + + weight = layer.weight ##TODO oc*ic, support transposed conv + if len(weight.shape) == 4: + weight = weight.permute(0, 2, 3, 1) + weight = weight.reshape(-1, weight.shape[-1]) + return weight + + +class WrapperLayer(torch.nn.Module): + def __init__(self, layer, input_min, input_max, save_q_input=False): + super(WrapperLayer, self).__init__() + self.add_module("orig_layer", layer) # set orig_layer in get/set_module + self.quant = False + self.q_input = None + self.fp32_output = None + self.input_max = input_max + self.input_min = input_min + self.weight_scale = None + self.input_scale = None + self.save_q_input = save_q_input + self.do_blockwise = False + + def enable_quant(self): + self.quant = True + + def disable_quant(self): + self.quant = False + + def update_scale(self, input_scale, weight_scale): + self.input_scale = input_scale + self.weight_scale = weight_scale + + ##TODO better tradeoff performance and memory, currently it's too slow + def q_dq_forward(self, x, input_scale, weight_scale): + layer_copy = copy.deepcopy(self.orig_layer) + if weight_scale is not None: + layer_copy.weight *= weight_scale + q_dq_weight = quant_dequant_w_v1(layer_copy) + layer_copy.weight.data.copy_(q_dq_weight) + if input_scale is None: + x = quant_dequant_x_v1(x, self.input_min, self.input_max) + else: + x = input_scale * x + x = quant_dequant_x_v1(x, self.input_min * input_scale, self.input_max * input_scale) ##FIXME + output = layer_copy(x) + return output + + def q_dq_forward_blockwise(self, x, input_scale): + layer_copy = copy.deepcopy(self.orig_layer) + if input_scale is None: + x = quant_dequant_x_v1(x, self.input_min, self.input_max) + else: + x = input_scale * x + x = quant_dequant_x_v1(x, self.input_min * input_scale, self.input_max * input_scale) ##FIXME + output = layer_copy(x) + return output + + def forward(self, x): + if self.quant: + # self.q_input = x * scale ##save the q_input + if self.save_q_input: + self.q_input = x + if not self.do_blockwise: + output = self.q_dq_forward(x, self.input_scale, self.weight_scale) + else: + output = self.q_dq_forward_blockwise(x, self.input_scale) + + else: + output = self.orig_layer(x) + self.output = output + return output + + +def reshape_scale_as_input(layer, scale): + """Reshape the scale for input feature in channel + :param layer: + + :param scale: + :return: + """ + if hasattr(layer, "orig_layer"): + layer = layer.orig_layer + if isinstance(layer, torch.nn.Conv2d): + scale = scale.view(1, scale.shape[0], 1, 1) + + elif isinstance(layer, torch.nn.Linear): + scale = scale.view(1, scale.shape[0]) + + return scale + + +def reshape_scale_as_weight(layer, scale): + """Reshape the scale for weight input channel, depthwise output channel + :param layer: torch module + :param scale: orig scale + :return: reshaped scale.""" + if hasattr(layer, "orig_layer"): + layer = layer.orig_layer + if isinstance(layer, torch.nn.Conv2d) and layer.groups > 1: ##only depthwise conv could hit here + scale = scale.view(scale.shape[0], 1, 1, 1) ##mount on output channel + + elif isinstance(layer, torch.nn.Conv2d): + scale = scale.view(1, scale.shape[0], 1, 1) + + elif isinstance(layer, torch.nn.Linear): + scale = scale.view(1, scale.shape[0]) + + return scale + + +TUNERS = {} + + +def register_autotune(name): + """Class decorator to register a smoothquant auto-tune subclass. + + :return: the class of register + """ + + def register(auto_tune): + TUNERS[name] = auto_tune + return auto_tune + + return register + + +class SQLinearWrapper(torch.nn.Module): + def __init__(self, module, input_scale, input_minmax, alpha=0.5, dtype=torch.quint8): + super().__init__() + self.register_buffer("input_scale", input_scale) + self.alpha = alpha + self.dtype = dtype + # calculate and only save scale, zero_point to avoid memory usage + self.scale, self.zero_point = self._calculate_qparams(input_scale, input_minmax, dtype) + self.add_module("sq_linear", module) + self._update_sq_linear() + self.ipex = False # a flag used for ipex inference + + @property + def weight(self): + return self.sq_linear.weight + + def forward(self, X): + if self.ipex: + X = self.sq_linear(X) + else: + X = torch.mul(X, self.input_scale) + X = self.sq_linear(X) + return X + + def _calculate_qparams(self, input_scale, input_minmax, dtype=torch.quint8): + # calculate scale and zero_point + if dtype == torch.quint8: + quant_min, quant_max = 0, 255 + min_val = torch.min(input_minmax[0] * input_scale) + max_val = torch.max(input_minmax[1] * input_scale) + # work when min_val bigger than zero. + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, torch.tensor([torch.finfo(torch.float32).eps], device=scale.device)) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + return scale, zero_point + + def _get_weight_scale(self): + # get weight scale and zero_point + from torch.ao.quantization.observer import default_per_channel_weight_observer + + obs = default_per_channel_weight_observer() + obs(self.sq_linear.weight) + scale, _ = obs.calculate_qparams() + return scale + + def _update_sq_linear(self): + # remove mul and reset sq_linear for ipex inference + scale = self.input_scale.view(1, self.input_scale.shape[0]) + with torch.no_grad(): + self.sq_linear.weight /= scale + + def _recover_sq_linear(self): + # remove mul and reset sq_linear for ipex inference + scale = self.input_scale.view(1, self.input_scale.shape[0]) + with torch.no_grad(): + self.sq_linear.weight *= scale \ No newline at end of file From 5fc4246c5d78e3d3e7d0d33b38c46bb7edefb2f8 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 30 Jul 2024 04:20:08 -0400 Subject: [PATCH 2/4] new trace method Signed-off-by: n1ck-guo --- .../smooth_quant/__init__.py | 0 .../smooth_quant/absorb_utils.py | 372 ++++++++++++++++++ .../smooth_quant/auto_alpha.py | 0 .../smooth_quant/calibration.py | 0 .../smooth_quant/graph_trace.py | 0 .../smooth_quant/smooth_quant.py | 15 +- .../{ => algorithm_ext}/smooth_quant/utils.py | 0 7 files changed, 381 insertions(+), 6 deletions(-) rename auto_round/{ => algorithm_ext}/smooth_quant/__init__.py (100%) create mode 100644 auto_round/algorithm_ext/smooth_quant/absorb_utils.py rename auto_round/{ => algorithm_ext}/smooth_quant/auto_alpha.py (100%) rename auto_round/{ => algorithm_ext}/smooth_quant/calibration.py (100%) rename auto_round/{ => algorithm_ext}/smooth_quant/graph_trace.py (100%) rename auto_round/{ => algorithm_ext}/smooth_quant/smooth_quant.py (98%) rename auto_round/{ => algorithm_ext}/smooth_quant/utils.py (100%) diff --git a/auto_round/smooth_quant/__init__.py b/auto_round/algorithm_ext/smooth_quant/__init__.py similarity index 100% rename from auto_round/smooth_quant/__init__.py rename to auto_round/algorithm_ext/smooth_quant/__init__.py diff --git a/auto_round/algorithm_ext/smooth_quant/absorb_utils.py b/auto_round/algorithm_ext/smooth_quant/absorb_utils.py new file mode 100644 index 00000000..f1bcf17b --- /dev/null +++ b/auto_round/algorithm_ext/smooth_quant/absorb_utils.py @@ -0,0 +1,372 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from .utils import get_module + +SUPPORTED_TORCH_MODULE = [ + "Linear", + "Conv2d", + "ConvTranspose2d", + "LayerNorm", + "BatchNorm2d", + "GroupNorm", + "InstanceNorm2d", + "LlamaRMSNorm", + "T5LayerNorm", + "LPLayerNorm", +] + +GET_ABSORB_LAYERS = {} + +def register_get_func(name): + """Class decorator to register a get_absorb_layers func + """ + def register(func): + GET_ABSORB_LAYERS[name] = func + return func + return register + +def _check_valid_conv(module): + """Remove group conv except depthwise conv + :param module: + + :return: + """ + if not isinstance(module, torch.nn.Conv2d): + return True + if module.groups > 1: + if module.in_channels == module.out_channels and module.groups == module.in_channels: + return True + else: + return False + return True + +def remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers): + res = {} + for key in absorb_to_layer.keys(): + absorb_layer = get_module(model, key) + layer_type = absorb_layer.__class__.__name__ + if layer_type not in SUPPORTED_TORCH_MODULE: + no_absorb_layers.extend(absorb_to_layer[key]) + continue + supported = True + for layer_name in absorb_to_layer[key]: + layer = get_module(model, layer_name) + layer_type = layer.__class__.__name__ + if (layer_type not in SUPPORTED_TORCH_MODULE) or not _check_valid_conv(layer): + supported = False + no_absorb_layers.extend(absorb_to_layer[key]) + break + if supported: + res[key] = absorb_to_layer[key] + return res + +@register_get_func("opt") +def get_opt_absorb_layers(model): + model_layer_name = "model.decoder.layers" + absorb_to_layer = {} + for idx in range(len(model.model.decoder.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.self_attn_layer_norm"] = [ + f"{model_layer_name}.{idx}.self_attn.q_proj", + f"{model_layer_name}.{idx}.self_attn.k_proj", + f"{model_layer_name}.{idx}.self_attn.v_proj", + ] + + # attention out + # no_absorb_layers.append(f"{model_layer_name}.{idx}.self_attn.out_proj") + absorb_to_layer[f"{model_layer_name}.{idx}.v_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.out_proj", + ] + + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.final_layer_norm"] = [ + f"{model_layer_name}.{idx}.fc1", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.fc1"] = [ + f"{model_layer_name}.{idx}.fc2", + ] + + # final layer + absorb_to_layer["model.decoder.final_layer_norm"] = ['lm_head'] + + return absorb_to_layer + +@register_get_func('llama') +def get_llama_absorb_layers(model): + model_layer_name = "model.layers" + absorb_to_layer = {} + for idx in range(len(model.model.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attn.q_proj", + f"{model_layer_name}.{idx}.self_attn.k_proj", + f"{model_layer_name}.{idx}.self_attn.v_proj", + ] + + # attention out + absorb_to_layer[f"{model_layer_name}.{idx}.v_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.o_proj", + ] + + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ + f"{model_layer_name}.{idx}.mlp.gate_proj", + f"{model_layer_name}.{idx}.mlp.up_proj", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.up_proj"] = [ + f"{model_layer_name}.{idx}.mlp.down_proj", + ] + + # final layer + absorb_to_layer["model.norm"] = ['lm_head'] + + return absorb_to_layer + +@register_get_func('mistral') +def get_mistral_absorb_layers(model): + model_layer_name = "model.layers" + absorb_to_layer = {} + for idx in range(len(model.model.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attn.q_proj", + f"{model_layer_name}.{idx}.self_attn.k_proj", + f"{model_layer_name}.{idx}.self_attn.v_proj", + ] + + # attention out + absorb_to_layer[f"{model_layer_name}.{idx}.v_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.o_proj", + ] + + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ + f"{model_layer_name}.{idx}.mlp.gate_proj", + f"{model_layer_name}.{idx}.mlp.up_proj", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.up_proj"] = [ + f"{model_layer_name}.{idx}.mlp.down_proj", + ] + + # final layer + absorb_to_layer["model.norm"] = ['lm_head'] + + return absorb_to_layer + +@register_get_func('mixtral') +def get_mixtral_absorb_layers(model): + model_layer_name = "model.layers" + absorb_to_layer = {} + for idx in range(len(model.model.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attn.q_proj", + f"{model_layer_name}.{idx}.self_attn.k_proj", + f"{model_layer_name}.{idx}.self_attn.v_proj", + ] + + # attention out + absorb_to_layer[f"{model_layer_name}.{idx}.v_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.o_proj", + ] + + # linear in + module = get_module(model, f"{model_layer_name}.{idx}.block_sparse_moe.experts") + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [] + for i in range(len(module)): + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"].extend( + [ + f"{model_layer_name}.{idx}.block_sparse_moe.experts.{i}.w1", + f"{model_layer_name}.{idx}.block_sparse_moe.experts.{i}.w3" + ] + ) + + + # linear out + for i in range(len(module)): + absorb_to_layer[f"{model_layer_name}.{idx}.block_sparse_moe.experts.{i}.w3"] = [ + f"{model_layer_name}.{idx}.block_sparse_moe.experts.{i}.w2" + ] + + # final layer + absorb_to_layer["model.norm"] = ['lm_head'] + breakpoint() + return absorb_to_layer + + +@register_get_func('bloom') +def get_bloom_absorb_layers(model): + model_layer_name = "transformer.h" + absorb_to_layer = {} + for idx in range(len(model.transformer.h)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attention.query_key_value", + ] + + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ + f"{model_layer_name}.{idx}.mlp.dense_h_to_4h", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.gelu_impl"] = [ + f"{model_layer_name}.{idx}.mlp.dense_4h_to_h", + ] + + # final layer + absorb_to_layer["transformer.ln_f"] = ['lm_head'] + + return absorb_to_layer + + +@register_get_func('gptj') +def get_gptj_absorb_layers(model): + model_layer_name = "transformer.h" + absorb_to_layer = {} + for idx in range(len(model.transformer.h)): + # attention input + linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.ln_1"] = [ + f"{model_layer_name}.{idx}.attn.q_proj", + f"{model_layer_name}.{idx}.attn.k_proj", + f"{model_layer_name}.{idx}.attn.v_proj", + f"{model_layer_name}.{idx}.mlp.fc_in", + ] + + # attention out + absorb_to_layer[f"{model_layer_name}.{idx}.attn.v_proj"] = [ + f"{model_layer_name}.{idx}.attn.out_proj", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.act"] = [ + f"{model_layer_name}.{idx}.mlp.fc_out", + ] + + # final layer + absorb_to_layer["transformer.ln_f"] = ['lm_head'] + + return absorb_to_layer + +@register_get_func('phi3') +def get_phi3_absorb_layers(model): + model_layer_name = "model.layers" + absorb_to_layer = {} + for idx in range(len(model.model.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attn.qkv_proj", + ] + + # attention out + absorb_to_layer[f"{model_layer_name}.{idx}.self_attn.qkv_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.o_proj", + ] + + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ + f"{model_layer_name}.{idx}..mlp.gate_up_proj", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.gate_up_proj"] = [ + f"{model_layer_name}.{idx}.mlp.down_proj", + ] + + # final layer + absorb_to_layer["model.norm"] = ['lm_head'] + + return absorb_to_layer + + +@register_get_func('qwen') +def get_qwen_absorb_layers(model): + model_layer_name = "transformer.h" + absorb_to_layer = {} + for idx in range(len(model.transformer.h)): + # attention + absorb_to_layer[f"{model_layer_name}.{idx}.ln_1"] = [ + f"{model_layer_name}.{idx}.attn.c_attn" + ] + + # mlp + absorb_to_layer[f"{model_layer_name}.{idx}.ln_2"] = [ + f"{model_layer_name}.{idx}.mlp.w2", + f"{model_layer_name}.{idx}.mlp.w1", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.w1"] = [ + f"{model_layer_name}.{idx}.mlp.c_proj", + ] + + # final layer + absorb_to_layer["transformer.ln_f"] = ['lm_head'] + + return absorb_to_layer + + +@register_get_func('qwen2') +def get_qwen2_absorb_layers(model): + model_layer_name = "model.layers" + absorb_to_layer = {} + for idx in range(len(model.model.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attn.q_proj", + f"{model_layer_name}.{idx}.self_attn.k_proj", + f"{model_layer_name}.{idx}.self_attn.v_proj", + ] + + # attention out + absorb_to_layer[f"{model_layer_name}.{idx}.v_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.o_proj", + ] + + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ + f"{model_layer_name}.{idx}.mlp.gate_proj", + f"{model_layer_name}.{idx}.mlp.up_proj", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.up_proj"] = [ + f"{model_layer_name}.{idx}.mlp.down_proj", + ] + + # final layer + absorb_to_layer["model.norm"] = ['lm_head'] + + return absorb_to_layer + +def get_absorb_layers(model, skip_unsupported_layers=True): + model_type = model.config.model_type + assert model_type in GET_ABSORB_LAYERS, f"Unsupported model type: {model_type}" + absorb_to_layer = GET_ABSORB_LAYERS[model_type](model) + no_absorb_layers = [] + if skip_unsupported_layers: + absorb_to_layer = remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers) + return absorb_to_layer, no_absorb_layers + \ No newline at end of file diff --git a/auto_round/smooth_quant/auto_alpha.py b/auto_round/algorithm_ext/smooth_quant/auto_alpha.py similarity index 100% rename from auto_round/smooth_quant/auto_alpha.py rename to auto_round/algorithm_ext/smooth_quant/auto_alpha.py diff --git a/auto_round/smooth_quant/calibration.py b/auto_round/algorithm_ext/smooth_quant/calibration.py similarity index 100% rename from auto_round/smooth_quant/calibration.py rename to auto_round/algorithm_ext/smooth_quant/calibration.py diff --git a/auto_round/smooth_quant/graph_trace.py b/auto_round/algorithm_ext/smooth_quant/graph_trace.py similarity index 100% rename from auto_round/smooth_quant/graph_trace.py rename to auto_round/algorithm_ext/smooth_quant/graph_trace.py diff --git a/auto_round/smooth_quant/smooth_quant.py b/auto_round/algorithm_ext/smooth_quant/smooth_quant.py similarity index 98% rename from auto_round/smooth_quant/smooth_quant.py rename to auto_round/algorithm_ext/smooth_quant/smooth_quant.py index 43a003d1..be102a77 100644 --- a/auto_round/smooth_quant/smooth_quant.py +++ b/auto_round/algorithm_ext/smooth_quant/smooth_quant.py @@ -24,6 +24,7 @@ from .calibration import Calibration from .graph_trace import GraphTrace from .utils import * +from .absorb_utils import get_absorb_layers class TorchSmoothQuant: @@ -415,6 +416,7 @@ def transform( self.absorb_to_layer = self._parse_absorb_to_layers( op_types, folding ) ##need to forward to check modules not used in forward + breakpoint() if len(self.input_mins) != 0: ##this is from _parse_absorb_to_layers, ugly code to support q_func input_maxes_abs = {} for key in self.input_mins.keys(): @@ -553,12 +555,13 @@ def _trace(self, op_types, skip_unsupported_layers=True): tg = GraphTrace() self._get_example_input() - absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer( - self.traced_model, - self.example_inputs, - op_types, - skip_unsupported_layers=skip_unsupported_layers, - ) + # absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer( + # self.traced_model, + # self.example_inputs, + # op_types,, auto_alpha_args=auto_alpha_args) + # skip_unsupported_layers=skip_unsupported_layers, + # ) + absorb_to_layer, no_absorb_layers = get_absorb_layers(self.traced_model, skip_unsupported_layers) if not skip_unsupported_layers: return absorb_to_layer if absorb_to_layer is None and no_absorb_layers is None: diff --git a/auto_round/smooth_quant/utils.py b/auto_round/algorithm_ext/smooth_quant/utils.py similarity index 100% rename from auto_round/smooth_quant/utils.py rename to auto_round/algorithm_ext/smooth_quant/utils.py From 0ce0144b164a00f35d6652c43e752289ffc833a4 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 30 Jul 2024 04:31:40 -0400 Subject: [PATCH 3/4] update Signed-off-by: n1ck-guo --- auto_round/algorithm_ext/smooth_quant/absorb_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auto_round/algorithm_ext/smooth_quant/absorb_utils.py b/auto_round/algorithm_ext/smooth_quant/absorb_utils.py index f1bcf17b..c271d4f3 100644 --- a/auto_round/algorithm_ext/smooth_quant/absorb_utils.py +++ b/auto_round/algorithm_ext/smooth_quant/absorb_utils.py @@ -29,6 +29,7 @@ "LlamaRMSNorm", "T5LayerNorm", "LPLayerNorm", + "RMSNorm", ] GET_ABSORB_LAYERS = {} @@ -303,6 +304,7 @@ def get_phi3_absorb_layers(model): @register_get_func('qwen') def get_qwen_absorb_layers(model): + breakpoint() model_layer_name = "transformer.h" absorb_to_layer = {} for idx in range(len(model.transformer.h)): From 71d25821c2df0a1c41ed3a10a3cace6a384cad6d Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Sun, 4 Aug 2024 21:12:55 -0400 Subject: [PATCH 4/4] fix bug Signed-off-by: n1ck-guo --- auto_round/algorithm_ext/smooth_quant/absorb_utils.py | 2 -- auto_round/algorithm_ext/smooth_quant/auto_alpha.py | 10 +++++----- auto_round/algorithm_ext/smooth_quant/smooth_quant.py | 7 +++---- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/auto_round/algorithm_ext/smooth_quant/absorb_utils.py b/auto_round/algorithm_ext/smooth_quant/absorb_utils.py index c271d4f3..c50a4082 100644 --- a/auto_round/algorithm_ext/smooth_quant/absorb_utils.py +++ b/auto_round/algorithm_ext/smooth_quant/absorb_utils.py @@ -213,7 +213,6 @@ def get_mixtral_absorb_layers(model): # final layer absorb_to_layer["model.norm"] = ['lm_head'] - breakpoint() return absorb_to_layer @@ -304,7 +303,6 @@ def get_phi3_absorb_layers(model): @register_get_func('qwen') def get_qwen_absorb_layers(model): - breakpoint() model_layer_name = "transformer.h" absorb_to_layer = {} for idx in range(len(model.transformer.h)): diff --git a/auto_round/algorithm_ext/smooth_quant/auto_alpha.py b/auto_round/algorithm_ext/smooth_quant/auto_alpha.py index ecba1b2e..ac234247 100644 --- a/auto_round/algorithm_ext/smooth_quant/auto_alpha.py +++ b/auto_round/algorithm_ext/smooth_quant/auto_alpha.py @@ -98,7 +98,7 @@ def tune(self): scale_memo_use += 4 * input_max.shape[0] * len(self.absorb_to_layer[key]) alpha_space_len = (self.alpha_max - self.alpha_min) / self.alpha_step + 1 scale_memo_use *= alpha_space_len - self._save_scale = enough_memo_store_scale(self.device, scale_memo_use) + self._save_scale = not enough_memo_store_scale(self.device, scale_memo_use) if self.loss_type == "blockwise": self.block_names = self.get_blocks() @@ -250,13 +250,13 @@ def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5): absorb_scales_info[key] = 1.0 / scale absorb_scales_info[key][scale == 0] = 0 layer_names = absorb_to_layer[key] + if self._save_scale: + if key not in self.weight_scale_dict: + self.weight_scale_dict[key] = {} + self.weight_scale_dict[key][alpha_tmp] = scale for layer_name in layer_names: ##self._scale_layer_weight(layer_name, scale) weight_scales_info[layer_name] = scale - if self._save_scale: - if layer_name not in self.weight_scale_dict: - self.weight_scale_dict[layer_name] = {} - self.weight_scale_dict[layer_name][alpha_tmp] = scale return absorb_scales_info, weight_scales_info def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0): diff --git a/auto_round/algorithm_ext/smooth_quant/smooth_quant.py b/auto_round/algorithm_ext/smooth_quant/smooth_quant.py index be102a77..558ad1f8 100644 --- a/auto_round/algorithm_ext/smooth_quant/smooth_quant.py +++ b/auto_round/algorithm_ext/smooth_quant/smooth_quant.py @@ -416,7 +416,6 @@ def transform( self.absorb_to_layer = self._parse_absorb_to_layers( op_types, folding ) ##need to forward to check modules not used in forward - breakpoint() if len(self.input_mins) != 0: ##this is from _parse_absorb_to_layers, ugly code to support q_func input_maxes_abs = {} for key in self.input_mins.keys(): @@ -553,12 +552,12 @@ def _trace(self, op_types, skip_unsupported_layers=True): no_absorb_layers: A list saving the layers which could not find the absorb layer """ - tg = GraphTrace() - self._get_example_input() + # tg = GraphTrace() + # self._get_example_input() # absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer( # self.traced_model, # self.example_inputs, - # op_types,, auto_alpha_args=auto_alpha_args) + # op_types, # skip_unsupported_layers=skip_unsupported_layers, # ) absorb_to_layer, no_absorb_layers = get_absorb_layers(self.traced_model, skip_unsupported_layers)