From 2bab0656348270291189562a7e7518bcfd310e36 Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Sun, 21 Jul 2024 11:54:39 +0800 Subject: [PATCH] Fixed a bug in pruning --- torch_pruning/pruner/algorithms/metapruner.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index 3d17c03..4abd778 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -500,6 +500,10 @@ def _prune(self) -> typing.Generator: # Handle other scopes for width pruning. for scope_id, scope_name in enumerate(width_pruning_scope_names): + + if not self.global_pruning: + assert len(ranking_scope[scope_name])<=1, "Internal Error: local pruning should only contain less than one layer per scope." + records = ranking_scope[scope_name] # records[i] -> (group, ch_groups, group_size, pruning_ratio, dim_imp)_i # Find the threshold for pruning if len(records)>0: @@ -514,7 +518,7 @@ def _prune(self) -> typing.Generator: ) if n_pruned>0: - topk_imp, _ = torch.topk(concat_imp, k=n_pruned, largest=False) + topk_imp, topk_indices = torch.topk(concat_imp, k=n_pruned, largest=False) thres = topk_imp[-1] ############################################## @@ -529,7 +533,10 @@ def _prune(self) -> typing.Generator: pruning_indices = [] if len(records)>0 and n_pruned>0: if ch_groups > 1: # re-compute importance for each channel group if grouping is enabled - n_pruned_per_group = n_pruned #len((imp <= thres).nonzero().view(-1)) # if grouping is enabled, the imp tensor is the average importance of each group. + if self.global_pruning: # for global pruning, the n_pruned may be shared by multiple layers. For each layer, we should know how many channels/dim should be pruned. + n_pruned_per_group = len((imp <= thres).nonzero().view(-1)) + else: # for local pruning, we can directly use the n_pruned since each scope only contains one layer + n_pruned_per_group = n_pruned if n_pruned_per_group>0: if self.round_to: n_pruned_per_group = self._round_to(n_pruned_per_group, group_size, self.round_to) @@ -542,12 +549,10 @@ def _prune(self) -> typing.Generator: sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group]+chg*group_size pruning_indices.append(sub_pruning_idxs) else: # standard pruning - _pruning_indices = (imp <= thres).nonzero().view(-1) - if len(_pruning_indices)>n_pruned: - # sort the selected scores and only take the top n_pruned indices - selected_scores = imp[_pruning_indices] - imp_argsort = torch.argsort(selected_scores) - _pruning_indices = _pruning_indices[imp_argsort[:n_pruned]] + if self.global_pruning: + _pruning_indices = (imp <= thres).nonzero().view(-1) + else: + _pruning_indices = topk_indices imp_argsort = torch.argsort(imp) if len(_pruning_indices)>0 and self.round_to: n_pruned = len(_pruning_indices)