-
Notifications
You must be signed in to change notification settings - Fork 11
/
vnet2d_inference.py
51 lines (42 loc) · 2.2 KB
/
vnet2d_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())
from Vnet2d.vnet_model import Vnet2dModule, AGVnet2dModule
from dataprocess.utils import calcu_iou
import cv2
import os
import numpy as np
def predict_test():
Vnet2d = Vnet2dModule(512, 512, channels=1, costname="dice coefficient", inference=True,
model_path="log\segmeation\\vnet2d\model\Vnet2d.pd")
test_image_path = r"E:\MedicalData\TNSCUI2020\TNSCUI2020_test\image"
test_mask_path = r"E:\MedicalData\TNSCUI2020\TNSCUI2020_test\vnet_mask"
allimagefiles = os.listdir(test_image_path)
for imagefile in allimagefiles:
imagefilepath = os.path.join(test_image_path, imagefile)
src_image = cv2.imread(imagefilepath, cv2.IMREAD_GRAYSCALE)
resize_image = cv2.resize(src_image, (512, 512))
pd_mask_image = Vnet2d.prediction(resize_image / 255.)
new_mask_image = cv2.resize(pd_mask_image, (src_image.shape[1], src_image.shape[0]))
maskfilepath = os.path.join(test_mask_path, imagefile)
cv2.imwrite(maskfilepath, new_mask_image)
def predict_testag():
Vnet2d = AGVnet2dModule(512, 512, channels=1, costname="dice coefficient", inference=True,
model_path="log\segmeation\\agvnet2d\model\\agVnet2d.pd")
test_image_path = r"E:\MedicalData\TNSCUI2020\TNSCUI2020_test\image"
test_mask_path = r"E:\MedicalData\TNSCUI2020\TNSCUI2020_test\agvnet_mask"
allimagefiles = os.listdir(test_image_path)
for imagefile in allimagefiles:
imagefilepath = os.path.join(test_image_path, imagefile)
src_image = cv2.imread(imagefilepath, cv2.IMREAD_GRAYSCALE)
resize_image = cv2.resize(src_image, (512, 512))
pd_mask_image = Vnet2d.prediction(resize_image / 255.)
new_mask_image = cv2.resize(pd_mask_image, (src_image.shape[1], src_image.shape[0]))
maskfilepath = os.path.join(test_mask_path, imagefile)
cv2.imwrite(maskfilepath, new_mask_image)
if __name__ == "__main__":
predict_test()
# predict_testag()
print('success')