-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.yaml
58 lines (52 loc) · 3.86 KB
/
config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
common:
num_iters: 4 # Number of iterations to run pipeline
final_stage: 'refine' # Final stage to execute in final iteration. Choose from 'classifier','generate_features', 'anomaly_detection','attribution', 'clustering','merge','refine'
save_path: 'results/run_default' # Root path for saving results from each stage and iteration
resume: True # Resume from previous completed stage
train_data_path: 'data/train/' # Root path containing training images
eval_data_path: 'data/eval/' # Root path containing evaluation images
num_workers: 16 # Number of workers for dataloader
eval_batch_size: 1024 # Batch size for feature generation eval
step1: # Iteration 1 hyperparams
classifier: # Classifier training stage based on seen classes and discovered clusters
network: 'multiclassnet' # Refer to networks.py for types of networks
batch_size: 256 # Batch size for training
num_epochs: 50 # Number of epochs for training
learning_rate: 1.0e-4 # Learning rate for optimizer
train_classes: 'Set1' # Set of classes to be used with their labels for training. Refer to class_list.yaml for set names
transform: 'default' # Type of transform: Choose from 'default' for default transform or 'jpeg_blur' for augmented transformed images
resize: 256 # Size of input image for network classification
generate_features: # Generating features for all images
extract_level: 'fc' # Feature extraction point from network, choose from: 'backbone' for output of backbone or 'fc' for output of fc layers
eval_classes: 'Set1' # Set of classes whose images are to be discovered. Refer to class_list.yaml for set names. Labels used for evaluation of cluster metrics
ood_detection: # Out of distribution detection of unlabeled images
num_hashes: 2048 # Number of hashes to be used for Winner Take All (WTA) hash
window_size: 2 # Window size for the hash
normalize_type: 'mean_std' # Normalization type for features. Choose from 'mean_std' for mean-standard deviation, 'min_max' for minimum-maximum, 'none' for no normalization
thresh_type: 'quantiles' # Type of thresholding for OOD based on array of scores: 'quantiles' for choosing {thresh_value}th quantile or 'mean_std' for choosing mean+{thresh_value}*std
thresh_value: 0.9 # [0-1] for quantile type thresholding
clustering: # K-means clustering
num_clusters: 500 # Number of clusters for K-means
max_iter: 300 # Number of iterations for clustering
normalize_type: '_copy: ../../ood_detection/normalize_type' # Normalization type for features. Typically same as the type used for OOD detection.
size_threshold: 0 # Minimum size of clusters (number of images) to be kept. Rest are discarded to undiscovered set
merge: # Merge clusters based on a directed connected graph
use_wta: False # Whether to use WTA hash for forming nearest neighbour graph
connection_type: 'strong' # Type of connected graph, choose from: 'strong' for strongly connected graph and 'weak' for weakly connected graph for forming connected components
refine:
size_threshold: 100 # Minimum size of clusters (number of images) to be kept. Rest are discarded to undiscovered set
purity_threshold: 0.5 # Clusters below this purity threshold (based on SVM predictions) are discarded entirely
kernel_type: 'rbf' # Type of SVM kernel, choose from 'rbf', 'linear'
# addn_merge_refine_rounds: 0 # Number of additional rounds of merging and refining the discovery set
step2:
_copy: '/step1' # Copy hyperparams from step 1
classifier:
_copy: '/step1/classifier'
pretrained: True # Whether to start from previous iteration weights
weighted: True # Whether to use weighted cross entropy loss due to imbalanced number of images in discovered clusters
add_seen: False # Whether to include discovered images belonging to seen classes for training
num_epochs: 100
step3:
_copy: '/step2'
step4:
_copy: '/step3'