forked from flexflow/FlexFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
compress_llama_weights.py
117 lines (102 loc) · 3.97 KB
/
compress_llama_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import numpy as np
from transformers import AutoModelForCausalLM
import dataclasses
@dataclasses.dataclass
class CompressionConfig:
"""Group-wise quantization."""
num_bits: int
group_size: int
group_dim: int
symmetric: bool
enabled: bool = True
def compress(tensor, config):
"""Simulate group-wise quantization."""
if not config.enabled:
return tensor
group_size, num_bits, group_dim, symmetric = (
config.group_size, config.num_bits, config.group_dim, config.symmetric)
assert num_bits <= 8
original_shape = tensor.shape
num_groups = (original_shape[group_dim] + group_size - 1) // group_size
new_shape = (original_shape[:group_dim] + (num_groups, group_size) +
original_shape[group_dim+1:])
# Pad
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
if pad_len != 0:
pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:]
tensor = torch.cat([
tensor,
torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
dim=group_dim)
data = tensor.view(new_shape)
# Quantize
if symmetric:
B = 2 ** (num_bits - 1) - 1
scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
data = data * scale
data = data.clamp_(-B, B).round_().to(torch.int8)
return data, scale, original_shape
else:
B = 2 ** num_bits - 1
# print('max value')
# print(B)
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
scale = B / (mx - mn)
data = data - mn
data.mul_(scale)
data = data.clamp_(0, B).round_().to(torch.uint8)
return data, mn, scale, original_shape
def decompress(packed_data, config):
"""Simulate group-wise dequantization."""
if not config.enabled:
return packed_data
group_size, num_bits, group_dim, symmetric = (
config.group_size, config.num_bits, config.group_dim, config.symmetric)
# Dequantize
if symmetric:
data, scale, original_shape = packed_data
data = data / scale
else:
data, mn, scale, original_shape = packed_data
data = data / scale
data.add_(mn)
# Unpad
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
if pad_len:
padded_original_shape = (
original_shape[:group_dim] +
(original_shape[group_dim] + pad_len,) +
original_shape[group_dim+1:])
data = data.reshape(padded_original_shape)
indices = [slice(0, x) for x in original_shape]
return data[indices].contiguous()
else:
return data.view(original_shape)
if __name__ == "__main__":
# torch.set_default_tensor_type(torch.HalfTensor)
# torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
config = CompressionConfig(
num_bits=8, group_size=32, group_dim=0, symmetric=False)
for name, params in model.named_parameters():
name = (
name.replace(".", "_")
.replace("self_attn", "attention")
.replace("q_proj", "wq")
.replace("k_proj", "wk")
.replace("v_proj", "wv")
.replace("o_proj", "wo")
.replace("mlp", "feed_forward")
.replace("gate_proj", "w1")
.replace("down_proj", "w2")
.replace("up_proj", "w3")
.replace("input_layernorm", "attention_norm")
.replace("post_attention_layernorm", "ffn_norm")
.replace("embed_tokens", "tok_embeddings")
.replace("lm_head", "output")
.replace("model_", "")
)
if "feed_forward" in name or "output" in name or "attention_w" in name:
data, mn, scale, original_shape = compress(params, config)