Skip to content

Commit

Permalink
new_feature: Added functionality 1. IOU Score 2.Choosing devices
Browse files Browse the repository at this point in the history
  • Loading branch information
ishan121028 committed Oct 3, 2023
1 parent 216acfc commit b26cd6c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
6 changes: 5 additions & 1 deletion Segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def main():
parser.add_argument('--logging_directory', type=str, default='logs', help='Directory for logging')
parser.add_argument('--checkpoint_directory', type=str, default='checkpoints', help='Directory for saving checkpoints')
parser.add_argument('--classes', type=int, default='2', help='No. of classes you want to segment your model into.')
parser.add_argument('--iou', type=bool, default=False, help='Enable or disable IoU')
parser.add_argument('--device', type=str, default='cpu', help='Device to train on')
args = parser.parse_args()

# Create the logging directory
Expand Down Expand Up @@ -94,7 +96,9 @@ def main():
num_epochs=args.epochs,
learning_rate=args.learning_rate,
checkpoint_dir=args.checkpoint_directory,
logger=logging
logger=logging,
iou=args.iou,
device=args.device
)


Expand Down
43 changes: 36 additions & 7 deletions Segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,15 @@ def parse_folder(dataset_path):
return False


def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_rate, checkpoint_dir, logger=None):
def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_rate, checkpoint_dir, logger=None, iou=False, device='cpu'):
if device == 'cpu':
device = torch.device('cpu')
elif device == 'cuda':
device = torch.device('cuda')
else:
print(f"{device} is not a valid device.")
return None

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Expand All @@ -84,36 +92,49 @@ def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
iou_score_mean = 0.0

for inputs, targets in tqdm(train_data_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False):
optimizer.zero_grad()
outputs = model(inputs)
targets = targets.squeeze(1)
outputs.to(device)
targets.to(device)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
iou_score_mean += iou_score(outputs, targets)

iou_score_mean = iou_score_mean / len(train_data_loader)
average_train_loss = train_loss / len(train_data_loader)

if logger:
logger.info(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {average_train_loss:.4f}')
if logger and iou:
logger.info(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {average_train_loss:.4f}, IOU Score: {iou_score_mean:.4f}')
else:
logger.info(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {average_train_loss:.4f}")

# Validation
model.eval()
val_loss = 0.0

iou_score_mean = 0.0
with torch.no_grad():
for inputs, targets in tqdm(test_data_loader, desc=f'Validation', leave=False):
outputs = model(inputs)
targets = targets.squeeze(1)
outputs.to(device)
targets.to(device)
loss = criterion(outputs, targets)
val_loss += loss.item()
iou_score_mean += iou_score(outputs, targets)

iou_score_mean = iou_score_mean / len(test_data_loader)
average_val_loss = val_loss / len(test_data_loader)

if logger:
logger.info(f'Epoch {epoch + 1}/{num_epochs}, Validation Loss: {average_val_loss:.4f}')
if logger and iou:
logger.info(f'Epoch {epoch + 1}/{num_epochs}, Validation Loss: {average_val_loss:.4f}. IOU Score: {iou_score_mean:.4f}')
else:
logger.info(f"Epoch {epoch + 1}/{num_epochs}, Validation Loss: {average_val_loss:.4f}")

# Save model checkpoint if validation loss improves
if average_val_loss < best_loss:
Expand All @@ -126,4 +147,12 @@ def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_
print('Finished Training')

def generate_model_summary(model, input_size):
torchsummary.summary(model, input_size=input_size)
torchsummary.summary(model, input_size=input_size)

def iou_score(output, target):
smooth = 1e-6
output = output.argmax(1)
intersection = (output & target).float().sum((1, 2))
union = (output | target).float().sum((1, 2))
iou = (intersection + smooth) / (union + smooth)
return iou.mean()

0 comments on commit b26cd6c

Please sign in to comment.