-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
116 lines (102 loc) · 3.37 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from unittest.mock import call
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import subprocess
import argparse
from torch.utils.data import DataLoader
from src.data import FIVES
from src.models import PHISeg, ProbUNet, UNet, UNetMCDropout
FIVES_path = "/FIVES.h5"
def get_git_revision_short_hash() -> str:
return (
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
.decode("ascii")
.strip()
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Main trainer file for all models.")
parser.add_argument(
"--random-seed",
dest="random_seed",
action="store",
default=0,
type=int,
help="Random seed for pl.seed_everything function.",
)
parser.add_argument(
"--method",
dest="method",
action="store",
default=None,
type=str,
help="The method should be used [probunet, phiseg, unet, unet-mcdropout]",
)
parser.add_argument(
"--batch-size",
dest="batch_size",
action="store",
default=16,
type=int,
help="Batch size for training.",
)
args = parser.parse_args()
git_hash = get_git_revision_short_hash()
human_readable_extra = ""
experiment_name = "-".join(
[
git_hash,
f"seed={args.random_seed}",
args.method,
human_readable_extra,
f"bs={args.batch_size}",
]
)
pl.seed_everything(seed=args.random_seed)
train_dataset = FIVES.Fives(file_path=FIVES_path, t="train", transform=None)
valid_dataset = FIVES.FIVES(file_path=FIVES_path, t="val", transform=None)
print(f"Training dataset length: {len(train_dataset)}")
print(f"Validation dataset length: {len(valid_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
valid_loader = DataLoader(
valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)
logger = TensorBoardLogger(
save_dir="./runsFIVES", name=experiment_name, default_hp_metric=False
)
checkpoint_callbacks = [
ModelCheckpoint(
monitor="val/total_loss",
filename="best-loss-{epoch}-{step}",
),
ModelCheckpoint(
monitor="val/dice",
filename="best-dice-{epoch}-{step}",
mode="max",
),
]
if args.method == "phiseg":
model = PHISeg(
total_levels=7, latent_levels=5, zdim=2, num_classes=2, beta=1.0
) # changed number of classes
elif args.method == "probunet":
model = ProbUNet(
total_levels=7, zdim=6, num_classes=2, beta=1.0
) # changed number of classes
elif args.method == "unet-mcdropout":
model = UNetMCDropout(total_levels=7, num_classes=2)
elif args.method == "unet":
model = UNet(total_levels=7, num_classes=2)
else:
raise ValueError(f"Unknown method: {args.method}.")
trainer = pl.Trainer(
logger=logger,
val_check_interval=0.25,
log_every_n_steps=50,
accelerator="gpu",
devices=1,
callbacks=checkpoint_callbacks,
)
trainer.fit(
model=model, train_dataloaders=train_loader, val_dataloaders=valid_loader
)