Skip to content

Commit

Permalink
Fixed a bug in pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Jul 21, 2024
1 parent f7b2bc4 commit 2bab065
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,10 @@ def _prune(self) -> typing.Generator:
# Handle other scopes for width pruning.

for scope_id, scope_name in enumerate(width_pruning_scope_names):

if not self.global_pruning:
assert len(ranking_scope[scope_name])<=1, "Internal Error: local pruning should only contain less than one layer per scope."

records = ranking_scope[scope_name] # records[i] -> (group, ch_groups, group_size, pruning_ratio, dim_imp)_i
# Find the threshold for pruning
if len(records)>0:
Expand All @@ -514,7 +518,7 @@ def _prune(self) -> typing.Generator:
)

if n_pruned>0:
topk_imp, _ = torch.topk(concat_imp, k=n_pruned, largest=False)
topk_imp, topk_indices = torch.topk(concat_imp, k=n_pruned, largest=False)
thres = topk_imp[-1]

##############################################
Expand All @@ -529,7 +533,10 @@ def _prune(self) -> typing.Generator:
pruning_indices = []
if len(records)>0 and n_pruned>0:
if ch_groups > 1: # re-compute importance for each channel group if grouping is enabled
n_pruned_per_group = n_pruned #len((imp <= thres).nonzero().view(-1)) # if grouping is enabled, the imp tensor is the average importance of each group.
if self.global_pruning: # for global pruning, the n_pruned may be shared by multiple layers. For each layer, we should know how many channels/dim should be pruned.
n_pruned_per_group = len((imp <= thres).nonzero().view(-1))
else: # for local pruning, we can directly use the n_pruned since each scope only contains one layer
n_pruned_per_group = n_pruned
if n_pruned_per_group>0:
if self.round_to:
n_pruned_per_group = self._round_to(n_pruned_per_group, group_size, self.round_to)
Expand All @@ -542,12 +549,10 @@ def _prune(self) -> typing.Generator:
sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group]+chg*group_size
pruning_indices.append(sub_pruning_idxs)
else: # standard pruning
_pruning_indices = (imp <= thres).nonzero().view(-1)
if len(_pruning_indices)>n_pruned:
# sort the selected scores and only take the top n_pruned indices
selected_scores = imp[_pruning_indices]
imp_argsort = torch.argsort(selected_scores)
_pruning_indices = _pruning_indices[imp_argsort[:n_pruned]]
if self.global_pruning:
_pruning_indices = (imp <= thres).nonzero().view(-1)
else:
_pruning_indices = topk_indices
imp_argsort = torch.argsort(imp)
if len(_pruning_indices)>0 and self.round_to:
n_pruned = len(_pruning_indices)
Expand Down

0 comments on commit 2bab065

Please sign in to comment.