-
Notifications
You must be signed in to change notification settings - Fork 117
/
example.py
113 lines (91 loc) · 3.49 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
import pyarrow as pa
from pathlib import Path
from llama import ModelArgs, Transformer, Tokenizer, LLaMA
def load(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
arrow_dir = Path(ckpt_dir).expanduser() / 'arrow'
if not arrow_dir.exists():
print('Converting checkpoints to arrow format')
checkpoints = sorted(Path(ckpt_dir).expanduser().glob("*.pth"))
for ckpt_file in checkpoints:
print(ckpt_file)
index = ckpt_file.parts[-1].split('.')[-2]
ckpt = torch.load(ckpt_file, map_location='cpu')
(arrow_dir / index).mkdir(parents=True, exist_ok=True)
for k, v in ckpt.items():
tens = pa.Tensor.from_numpy(v.numpy())
with pa.output_stream(arrow_dir / index / k) as f:
pa.ipc.write_tensor(tens, f)
ckpt = None
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
print("Loading checkpoint")
segments = sorted((arrow_dir / '00').glob("*"))
# print(segments)
checkpoint = {}
files = []
for seg in segments:
f = pa.memory_map(str(seg))
files.append(f)
t = pa.ipc.read_tensor(f).to_numpy()
t = torch.from_numpy(t)
checkpoint[seg.parts[-1]] = t
# torch.set_default_tensor_type(torch.cuda.HalfTensor)
torch.set_default_tensor_type(torch.BFloat16Tensor)
# torch.set_default_tensor_type(torch.FloatTensor)
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
print("Loading tokenizer")
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
print("Loading model")
model = Transformer(model_args)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
model.load_state_dict(torch.load(checkpoints[-1]), strict=False)
for f in files:
f.close()
files = None
generator = LLaMA(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95, # use 0.95 or so for top_p sampler, and 0.0 for top_k sampler
top_k: int = 40,
repetition_penalty: float = (1.0 / 0.85), # 1.0 to disable repetition_penalty
sampler: str = 'top_p', # top_p or top_k
max_seq_len: int = 2048,
max_batch_size: int = 1,
):
generator = load(ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
prompts = [
# "I believe the meaning of life is",
"""Write the Python code with detailed comments to generate 256 random integers in the range from -128 to 512, inclusive.
\\begin{code}\n""",
]
results = generator.generate(
prompts, max_gen_len=max_seq_len, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, sampler=sampler
)
for result in results:
print("\n==================================\n")
print(result)
print("\n==================================\n")
if __name__ == "__main__":
fire.Fire(main)