-
Notifications
You must be signed in to change notification settings - Fork 561
/
load_model.py
53 lines (42 loc) · 1.29 KB
/
load_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""
Script for downloading model weights.
"""
import argparse
import numpy as np
def parse_args():
parser = argparse.ArgumentParser(description="Download model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--model",
type=str,
required=True,
help="model name")
args = parser.parse_args()
return args
def main():
args = parse_args()
from gluon.utils import prepare_model as prepare_model_gl
prepare_model_gl(
model_name=args.model,
use_pretrained=True,
pretrained_model_file_path="",
dtype=np.float32)
from pytorch.utils import prepare_model as prepare_model_pt
prepare_model_pt(
model_name=args.model,
use_pretrained=True,
pretrained_model_file_path="",
use_cuda=False)
from chainer_.utils import prepare_model as prepare_model_ch
prepare_model_ch(
model_name=args.model,
use_pretrained=True,
pretrained_model_file_path="")
from tensorflow2.utils import prepare_model as prepare_model_tf2
prepare_model_tf2(
model_name=args.model,
use_pretrained=True,
pretrained_model_file_path="",
use_cuda=False)
if __name__ == '__main__':
main()