forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 0
/
patchmentation_yolov5.py
92 lines (76 loc) · 3.71 KB
/
patchmentation_yolov5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import patchmentation_utils as utils
from typing import List
def get_args_parser():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--version', type=str, default=[], nargs='+', help='dataset version')
parser.add_argument('--train', action='store_true', help='train yolov5')
parser.add_argument('--test', action='store_true', help='train yolov5')
parser.add_argument('--epochs', type=int, default=None, help='number of training epoch')
parser.add_argument('--batch-size', type=int, default=None, help='number of batch size')
parser.add_argument('--data', type=str, default=[], nargs='+', help='dataset.yaml')
parser.add_argument('--hyp', type=str, default='data/hyps/hyp.patchmentation.yaml', help='hyperparameters')
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights')
parser.add_argument('--project', type=str, default=None, help='folder to save project')
parser.add_argument('--remove', action='store_true', help='remove output')
parser.add_argument('--zip', action='store_true', help='zip output')
parser.add_argument('--unzip', action='store_true', help='unzip output')
parser.add_argument('--remove-zip', action='store_true', help='remove zip output')
parser.add_argument('--upload', action='store_true', help='upload output zip')
parser.add_argument('--download', type=str, default=[], nargs='+', help='download output zip')
parser.add_argument('--plot', action='store_true', help='extra output plots')
parser.add_argument('--overwrite', action='store_true', help='overwrite existing output / zip')
args = parser.parse_args()
return args
def project(args):
return 'runs/patchmentation/' + utils._remove_ext(args.weights)
def main(args):
versions = args.version
if args.project is None:
args.project = project(args)
for index, version in enumerate(versions):
print(version)
if args.train:
if args.overwrite:
utils.remove(args.project, version)
if len(args.data) > index:
data = args.data[index]
else:
data = utils.get_file_yaml(version)
assert args.batch_size is not None
assert args.epochs is not None
utils.train(data, args.hyp, args.weights, args.epochs, args.batch_size, args.project, version)
utils.plot(args.project, version)
if args.test:
if args.overwrite:
utils.remove_test(args.project, version)
if len(args.data) > index:
data = args.data[index]
else:
data = utils.get_file_yaml(version)
assert args.batch_size is not None
utils.test(data, utils.get_weights(args.project, version), args.batch_size, args.project, version)
if args.zip:
if args.overwrite:
utils.remove_zip(args.project, version)
utils.zip(args.project, version)
if args.upload:
utils.upload(args.project, version)
if args.remove_zip:
utils.remove_zip(args.project, version)
if len(args.download) > index:
if args.overwrite:
utils.remove_zip(args.project, version)
utils.download(args.project, version, args.download[index])
if args.unzip:
if args.overwrite:
utils.remove_folder_output(args.project, version)
utils.unzip(args.project, version)
if args.plot:
utils.plot(args.project, version)
if args.remove:
utils.remove(args.project, version)
if __name__ == '__main__':
args = get_args_parser()
print(args)
main(args)