-
Notifications
You must be signed in to change notification settings - Fork 0
/
val_cls.py
69 lines (56 loc) · 2.04 KB
/
val_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os, sys
from pathlib import Path
here = Path(__file__).parent
p = f'{here.parent}'
if p not in sys.path:
sys.path.append(p)
from lymonet.apis import YOLO, LYMO
from dataclasses import dataclass, field
import hai
def run(args):
# Create a new YOLO model from scratch
# model_name = args.pop('model')
kwargs = args.__dict__
model_name_or_cfg = kwargs.pop('model')
model_weights = kwargs.pop('weights', None)
# LYMO.apply_improvements()
model = LYMO(model_name_or_cfg)
# model = YOLO(model_name_or_cfg)
if model_weights:
model = model.load(model_weights)
# model = YOLO(model_name).load(model_weights)
# results = model.train(**kwargs)
# Evaluate the model's performance on the validation set
results = model.val(data=args.data, split=args.split, batch=args.batch, imgsz=args.imgsz,
conf=args.conf, half=args.half, device=args.device) # results是validator.metrics
print(results)
# Perform object detection on an image using the model
# results = model(f'{here}/lymonet/data/scripts/image.png')
# print(results)
# Export the model to ONNX format
# success = model.export(format='onnx')
@dataclass
class Args:
model: str = '/home/tml/VSProjects/lymonet/lymonet/ultralytics/runs/classify/yolov8s_class10/weights/best.pt'
mode: str = 'val'
task: str = 'classify'
val: bool = True
# model: str = f'{here}/lymonet/configs/yolov8s_1MHSA_CA.yaml'
# model: str = "yolov8x.yaml"
# weights: str = 'yolov8n.pt'
# data: str = f'{here}/lymonet/configs/lymo_mixed2.yaml'
data: str = "/data/tml/lymonet/lymo_yolo_square1"
split: str = 'val'
# epochs: int = 300
batch: int = 32
imgsz: int = 640
workers: int = 16
conf: float = 0.001 # confidence threshold
device: str = '0' # GPU id
half: bool = False # use FP16 or NOT
project: str = 'runs/val'
name: str = 'lymonet'
# merge_type: str = None
if __name__ == '__main__':
args = hai.parse_args_into_dataclasses(Args)
run(args)