This repository has been archived by the owner on Aug 19, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 47
/
reshard.py
105 lines (92 loc) · 3.38 KB
/
reshard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
# modified from https://gist.github.com/benob/4850a0210b01672175942203aa36d300
# script to decompose/recompose llama model in different number of shards
# note that it loads the full model * 2 in cpu memory
import os
import sys
import glob
import json
import torch
if len(sys.argv) != 4:
print(
"usage: %s <new-shards> <input-model-path> <output-model-path>" % sys.argv[0],
file=sys.stderr,
)
sys.exit(1)
num_shards = int(sys.argv[1])
input_model_dir = sys.argv[2]
output_model_dir = sys.argv[3]
with open(os.path.join(input_model_dir, "params.json"), "r") as fp:
params = json.loads(fp.read())
assert params["dim"] % num_shards == 0, (
"number of shards need to divide parameter dimension %d" % params["dim"]
)
print("loading...")
checkpoints = [
torch.load(path, map_location=torch.device("cpu"))
for path in glob.glob(os.path.join(input_model_dir, "*.pth"))
]
layer_kind = {
"tok_embeddings": "ParallelEmbedding",
"output": "ColumnParallelLinear",
"attention.wq": "ColumnParallelLinear",
"attention.wk": "ColumnParallelLinear",
"attention.wv": "ColumnParallelLinear",
"attention.wo": "RowParallelLinear",
"feed_forward.w1": "ColumnParallelLinear",
"feed_forward.w2": "RowParallelLinear",
"feed_forward.w3": "ColumnParallelLinear",
"attention_norm": None,
"ffn_norm": None,
"norm": None,
"rope.freqs": None,
}
output = [dict() for x in range(num_shards)]
print("converting...")
for key in checkpoints[0].keys():
tensors = [m[key] for m in checkpoints]
print(key)
print(" in shapes=", [p.shape for p in tensors])
for pattern, kind in layer_kind.items():
if key.replace(".weight", "").endswith(pattern):
print(" kind=", kind)
if kind == "ColumnParallelLinear":
with torch.no_grad():
merged = torch.cat(tensors, 0)
slice_size = merged.shape[0] // num_shards
for rank in range(num_shards):
output[rank][key] = (
merged[slice_size * rank : slice_size * (rank + 1), :]
.clone()
.detach()
)
elif kind in ("ParallelEmbedding", "RowParallelLinear"):
with torch.no_grad():
merged = torch.cat(tensors, 1)
slice_size = merged.shape[1] // num_shards
for rank in range(num_shards):
output[rank][key] = (
merged[:, slice_size * rank : slice_size * (rank + 1)]
.clone()
.detach()
)
else:
for rank in range(num_shards):
output[rank][key] = tensors[0]
print(
" out shapes=", [output[rank][key].shape for rank in range(num_shards)]
)
print()
break
else:
raise Exception("parameter name not recognized")
print("saving...")
os.makedirs(output_model_dir, exist_ok=True)
with open(os.path.join(output_model_dir, "params.json"), "w") as fp:
fp.write(json.dumps(params))
for rank in range(num_shards):
print(" ", rank)
torch.save(
output[rank], os.path.join(output_model_dir, "consolidated.%02d.pth" % rank)
)
print("done.")