Skip to content

Commit

Permalink
Updating validation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
tnoe1 committed Jul 29, 2022
1 parent ed50ea8 commit 667d152
Show file tree
Hide file tree
Showing 7 changed files with 861 additions and 274 deletions.
4 changes: 2 additions & 2 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def main():
# Setup Weights and Biases and specify hyperparameters
test_method = "oe_test" if args.oe_test else "new_class_test"
if args.baseline:
project_name = "TN_Masters_Proj_Baseline_Val_{}".format(test_method)
project_name = "TN_Masters_Proj_Fixed_Baseline_Val_{}".format(test_method)
else:
project_name = "TN_Masters_Proj_Val_{}_{}_{}".format(args.detection_type, args.loss, test_method)

Expand Down Expand Up @@ -260,7 +260,7 @@ def main():
dist_norm=dist_norm
)

if args.detection_type == "LS":
if args.detection_type == "LS" or args.baseline:
net = FeatureExtractor(net_type, num_training_classes)
elif args.detection_type == "KS":
# Include a kitchen sink class
Expand Down
5 changes: 4 additions & 1 deletion experiment_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,10 @@ def main():
# Setup Weights and Biases and specify hyperparameters
test_method = "oe_test" if args.oe_test else "new_class_test"
if args.baseline:
project_name = "Single_TN_Masters_Proj_Baseline_Val_{}".format(test_method)
if args.margin_baseline:
project_name = "Single_TN_Masters_Proj_margin_Baseline_Val_{}".format(test_method)
else:
project_name = "Single_TN_Masters_Proj_Baseline_Val_{}".format(test_method)
else:
project_name = "Single_TN_Masters_Proj_Val_{}_{}_{}".format(args.detection_type, args.loss, test_method)

Expand Down
53 changes: 9 additions & 44 deletions margin_check.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,9 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n",
"id batch 0 starting\n",
"id batch 1 starting\n",
"id batch 2 starting\n",
"id batch 3 starting\n",
"ood batch 0 starting\n",
"ood batch 1 starting\n",
"ood batch 2 starting\n",
"ood batch 3 starting\n",
"ood batch 4 starting\n",
"ood batch 5 starting\n",
"ood batch 6 starting\n",
"ood batch 7 starting\n",
"ood batch 8 starting\n",
"ood batch 9 starting\n",
"ood batch 10 starting\n",
"ood batch 11 starting\n",
"ood batch 12 starting\n",
"ood batch 13 starting\n",
"ood batch 14 starting\n",
"ood batch 15 starting\n",
"ood batch 16 starting\n",
"ood batch 17 starting\n",
"ood batch 18 starting\n",
"ood batch 19 starting\n",
"ood batch 20 starting\n",
"ood batch 21 starting\n",
"ood batch 22 starting\n",
"ood batch 23 starting\n",
"Test set: Accuracy: 649/1000 (65%)\n",
"Test Set: AUROC: 0.5688960000000001\n",
"\n",
"Accuracy: 64.9\n",
"AUROC: 0.5688960000000001\n"
]
}
],
"outputs": [],
"source": [
"import torch\n",
"device = torch.device(\"cuda:0\")\n",
Expand Down Expand Up @@ -308,6 +266,10 @@
" pred_sequence.append(id_pred)\n",
" target_sequence.append(id_target)\n",
"\n",
" # FOR DEBUG MAX LOGIT SCORE\n",
" # id_scores, _ = torch.max(id_output, dim=1)\n",
" # id_scores = -1 * id_scores\n",
"\n",
" # Compute anomaly scores\n",
" # Use discriminant function to compute id_scores\n",
" # id_distance = torch.abs(id_distance)\n",
Expand Down Expand Up @@ -355,6 +317,9 @@
" pred_sequence.append(ood_pred)\n",
" target_sequence.append(ood_target)\n",
"\n",
" # FOR DEBUG MAX LOGIT SCORE\n",
" # ood_scores, _ = torch.max(ood_output, dim=1)\n",
" # ood_scores = -1 * ood_scores\n",
" ood_scores, _ = torch.max(-1 * ood_distance, dim=1)\n",
"\n",
" # Detaching is important here because it removes these scores from computational graph\n",
Expand Down
Loading

0 comments on commit 667d152

Please sign in to comment.