-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
57 lines (44 loc) · 1.46 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from model import UNet
from customDataset import CustomDataset
import numpy as np
model_path = 'best_model.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_class=1)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
def predict_single_image(img_path, model, transform, device, threshold = 0.1):
image = Image.open(img_path).convert('RGB')
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
output = torch.sigmoid(output)
output = output.squeeze().cpu().numpy()
plt.figure(figsize=(8, 8))
plt.imshow(output, cmap='gray')
binary_mask = np.where(output > threshold, 255, 0).astype(np.uint8)
return binary_mask
test_image_path = 'random.jpg'
predicted_mask = predict_single_image(test_image_path, model, transform, device)
# Visualize the result
original_image = Image.open(test_image_path)
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(predicted_mask, cmap='gray')
plt.title('Predicted Mask')
plt.axis('off')
plt.show()