forked from haofeixu/aanet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
aanet_train.sh
83 lines (79 loc) · 2.13 KB
/
aanet_train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env bash
# Train on Scene Flow training set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \
--mode val \
--data_dir data/SceneFlow \
--checkpoint_dir checkpoints/aanet_sceneflow \
--batch_size 64 \
--val_batch_size 64 \
--img_height 288 \
--img_width 576 \
--val_img_height 576 \
--val_img_width 960 \
--feature_type aanet \
--feature_pyramid_network \
--milestones 20,30,40,50,60 \
--max_epoch 64
# Train on mixed KITTI 2012 and KITTI 2015 training set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \
--data_dir data/KITTI \
--dataset_name KITTI_mix \
--checkpoint_dir checkpoints/aanet_kittimix \
--pretrained_aanet checkpoints/aanet_sceneflow/aanet_best.pth \
--batch_size 32 \
--val_batch_size 32 \
--img_height 336 \
--img_width 960 \
--val_img_height 384 \
--val_img_width 1248 \
--feature_type aanet \
--feature_pyramid_network \
--load_pseudo_gt \
--milestones 400,600,800,900 \
--max_epoch 1000 \
--save_ckpt_freq 100 \
--no_validate
# Train on KITTI 2015 training set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \
--data_dir data/KITTI/kitti_2015/data_scene_flow \
--dataset_name KITTI2015 \
--mode train_all \
--checkpoint_dir checkpoints/aanet_kitti15 \
--pretrained_aanet checkpoints/aanet_kittimix/aanet_latest.pth \
--batch_size 32 \
--val_batch_size 32 \
--img_height 336 \
--img_width 960 \
--val_img_height 384 \
--val_img_width 1248 \
--feature_type aanet \
--feature_pyramid_network \
--load_pseudo_gt \
--highest_loss_only \
--learning_rate 1e-4 \
--milestones 400,600,800,900 \
--max_epoch 1000 \
--save_ckpt_freq 100 \
--no_validate
# Train on KITTI 2012 training set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \
--data_dir data/KITTI/kitti_2012/data_stereo_flow \
--dataset_name KITTI2012 \
--mode train_all \
--checkpoint_dir checkpoints/aanet_kitti12 \
--pretrained_aanet checkpoints/aanet_kittimix/aanet_latest.pth \
--batch_size 32 \
--val_batch_size 32 \
--img_height 336 \
--img_width 960 \
--val_img_height 384 \
--val_img_width 1248 \
--feature_type aanet \
--feature_pyramid_network \
--load_pseudo_gt \
--highest_loss_only \
--learning_rate 1e-4 \
--milestones 400,600,800,900 \
--max_epoch 1000 \
--save_ckpt_freq 100 \
--no_validate