This repository has been archived by the owner on Mar 14, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimize.py
48 lines (39 loc) · 1.58 KB
/
optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Freezes and optimize the model. Use after training."""
import argparse
import torch
from model.model import SpeechRecognition
from collections import OrderedDict
import yaml
def trace(model):
model.eval()
x = torch.rand(1, 81, 300)
hidden = model._init_hidden(1)
traced = torch.jit.trace(model, (x, hidden))
return traced
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_checkpoint', type=str, default=None, required=True,
help='Checkpoint of model to optimize')
parser.add_argument('--save_path', type=str, default=None, required=True,
help='path to save optmized model')
parser.add_argument('--params', default=None, required=True,
type=str, help='YAML file to load parameters')
args = parser.parse_args()
print("loading model from", args.model_checkpoint)
checkpoint = torch.load(args.model_checkpoint,
map_location=torch.device('cpu'))
with open(args.params, 'r') as stream:
config = yaml.safe_load(stream)
_, h_params, _ = config.values()
model = SpeechRecognition(**h_params)
model_state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in model_state_dict.items():
name = k.replace("model.", "") # remove `model.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
print("tracing model...")
traced_model = trace(model)
print("saving to", args.save_path)
traced_model.save(args.save_path)
print("Done!")