Skip to content

Commit

Permalink
Added single experiment and other scripts; mostly debugging large mar…
Browse files Browse the repository at this point in the history
…gin loss + logit suppression outlier exposure
  • Loading branch information
tnoe1 committed Jul 28, 2022
1 parent 4f3953e commit baf7c1e
Show file tree
Hide file tree
Showing 12 changed files with 5,729 additions and 36 deletions.
60 changes: 30 additions & 30 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ def main():
# Setup Weights and Biases and specify hyperparameters
test_method = "oe_test" if args.oe_test else "new_class_test"
if args.baseline:
project_name = "TN_Masters_Proj_Baseline_{}".format(test_method)
project_name = "TN_Masters_Proj_Baseline_Val_{}".format(test_method)
else:
project_name = "TN_Masters_Proj_{}_{}_{}".format(args.detection_type, args.loss, test_method)
project_name = "TN_Masters_Proj_Val_{}_{}_{}".format(args.detection_type, args.loss, test_method)

wandb.init(project=project_name)

Expand Down Expand Up @@ -323,34 +323,34 @@ def main():
wandb.log({"loss": loss, "ID_Accuracy": acc, "AUROC": auc, "metric_combined": metric_combined, "epoch": i})

# Save Model
if i % 5 == 0 or metric_combined > metric_combined_running_max:
if metric_combined > metric_combined_running_max:
metric_combined_running_max = metric_combined

if args.baseline:
directory = "baseline_{}".format(test_method)
torch.save({
'epoch': i,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optim.state_dict(),
'auc': auc,
'id_accuracy': acc,
'split': args.split,
'test_method': test_method,
}, 'models/{}/day_{}_{}_time_{}_{}_split_{}_epoch_{}.pth'.format(directory, now.month, now.day, now.hour, now.minute, args.split, i))
else:
directory = "{}_{}_{}".format(args.loss, args.detection_type, test_method)
torch.save({
'epoch': i,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optim.state_dict(),
'auc': auc,
'id_accuracy': acc,
'split': args.split,
'test_method': test_method,
'loss': args.loss,
'detection_type': args.detection_type,
}, 'models/{}/day_{}_{}_time_{}_{}_split_{}_epoch_{}.pth'.format(directory, now.month, now.day, now.hour, now.minute, args.split, i))
# if i % 5 == 0 or metric_combined > metric_combined_running_max:
# if metric_combined > metric_combined_running_max:
# metric_combined_running_max = metric_combined

if args.baseline:
directory = "val_baseline_{}".format(test_method)
torch.save({
'epoch': i,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optim.state_dict(),
'auc': auc,
'id_accuracy': acc,
'split': args.split,
'test_method': test_method,
}, 'models/{}/day_{}_{}_time_{}_{}_split_{}_epoch_{}.pth'.format(directory, now.month, now.day, now.hour, now.minute, args.split, i))
else:
directory = "val_{}_{}_{}".format(args.loss, args.detection_type, test_method)
torch.save({
'epoch': i,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optim.state_dict(),
'auc': auc,
'id_accuracy': acc,
'split': args.split,
'test_method': test_method,
'loss': args.loss,
'detection_type': args.detection_type,
}, 'models/{}/day_{}_{}_time_{}_{}_split_{}_epoch_{}.pth'.format(directory, now.month, now.day, now.hour, now.minute, args.split, i))

if __name__ == '__main__':
main()
Loading

0 comments on commit baf7c1e

Please sign in to comment.