-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
30 lines (24 loc) · 922 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import warnings
warnings.filterwarnings("ignore")
import numpy as np
from src.experiment_manager import ExperimentManager
from src.dataset import UCSDataset
from src.model import UCSModel
if __name__ == "__main__":
manager = ExperimentManager()
opt = manager.get_opt()
# Initializing the Dataset
print("Initializing train dataset")
train_dataset = UCSDataset(manager)
train_dataset.set_train(True)
gene_map_shape = train_dataset.gene_map.shape
# Initializing the model
model = UCSModel(manager, gene_map_shape)
model.train_model(train_dataset)
print("Initializing pred dataset with shift 0 and shift patch size//2")
pred_dataset_0 = UCSDataset(manager, 0)
pred_dataset_1 = UCSDataset(manager, opt.patch_size//2)
pred_dataset_0.set_train(False)
pred_dataset_1.set_train(False)
model.predict_whole(pred_dataset_0, pred_dataset_1)
model.postprocess()