-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
115 lines (91 loc) · 3.58 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
from torchvision import models
from torch.nn.functional import relu
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import torch.optim as optim
import os
import torch.nn.functional as F
from glob import glob
import os
import random
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import matplotlib.pyplot as plt
from model import UNet
from customDataset import CustomDataset
def get_image_mask_pairs(dir):
sat_images = [f for f in os.listdir(dir) if f.endswith('_sat.jpg')]
pairs = [(os.path.join(dir, f), os.path.join(dir, f.replace('_sat.jpg', '_mask.png'))) for f in sat_images]
return pairs
all_pairs = get_image_mask_pairs('./dataset/train')
train_pairs, temp_pairs = train_test_split(all_pairs, test_size=0.25, random_state=42)
valid_pairs, test_pairs = train_test_split(temp_pairs, test_size=0.4, random_state=42)
print(f"Total samples: {len(all_pairs)}")
print(f"Train samples: {len(train_pairs)} ({len(train_pairs)/len(all_pairs):.2%})")
print(f"Validation samples: {len(valid_pairs)} ({len(valid_pairs)/len(all_pairs):.2%})")
print(f"Test samples: {len(test_pairs)} ({len(test_pairs)/len(all_pairs):.2%})")
# Define your transform
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
# Create datasets
train_dataset = CustomDataset(train_pairs, transform=transform)
valid_dataset = CustomDataset(valid_pairs, transform=transform)
test_dataset = CustomDataset(test_pairs, transform=transform)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = UNet(n_class=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)
num_epochs = 30
train_losses = []
valid_losses = []
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
train_loss = train_loss / len(train_loader.dataset)
train_losses.append(train_loss)
model.eval()
valid_loss = 0.0
with torch.no_grad():
for images, masks in valid_loader:
images, masks = images.to(device), masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
valid_loss += loss.item() * images.size(0)
valid_loss = valid_loss / len(valid_loader.dataset)
valid_losses.append(valid_loss)
print(f'Epoch {epoch+1}/{num_epochs} | Train_loss: {train_loss:.4f} | Validation loss: {valid_loss:.4f}')
torch.save({
'epoch' : epoch + 1,
'model_state_dict' : model.state_dict(),
'optimizer_state_dic' : optimizer.state_dict(),
'train_losses': train_losses,
'valid_losses' : valid_losses,
}, "trained_model.pth")
# Plot and save the loss curves
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Train Loss')
plt.plot(valid_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train and Validation Loss')
plt.legend()
plt.savefig('loss_plot.jpg')