-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdefaults.yaml
104 lines (92 loc) · 2.57 KB
/
defaults.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
data:
- fuse_object: ""
# io:
# input_key: None
# output_key: "dataset_pipeline"
- object: "get_splits_str_ids"
io:
input_key: null
output_key: "data_splits"
cache:
num_workers: 1
restart_cache: True
root_dir: "path/to/cache"
mlflow:
MLFLOW_TRACKING_URI: null
MLFLOW_EXPERIMENT_NAME: null
modality_encoding_strategy:
- object: "ModalityEncoding"
fusion_strategy:
- object: "EncodedUnimodalToConcept" # early or late
args:
use_autoencoders: True
add_feature_names: False
encoding_layers:
- 32
- &n_layers 16
use_pretrained: True
batch_size: 3
training:
model_dir: "model_concept"
pl_trainer_num_epochs: 1
pl_trainer_accelerator: "cpu"
io:
concept_encoder_model_key: "concept_encoder_model"
output_key: "data.input.concatenated"
- object: "ConceptToGraph"
args:
module_identifier: &graph_module "mplex"
thresh_q: 0.95
io:
concept_encoder_model_key: "concept_encoder_model"
fused_dataset_key: "fused_dataset"
input_key: "data.input.concatenated"
output_key: "data.derived_graph"
task_strategy:
- object: "MultimodalGraphModel"
args:
io:
fused_dataset_key: "fused_dataset"
input_key: "data.derived_graph"
target_key: &target "data.gt.gt_global.task_1_label"
prediction_key: &prediction "model.out"
model_config:
graph_model:
module_identifier: *graph_module
n_layers: *n_layers
node_emb_dim: 1 # really needed?
head_model:
head_hidden_size:
- 100
- 20
dropout: 0.5
add_softmax: True
num_classes: 2
training:
model_dir: "model_mplex"
batch_size: 3
best_epoch_source:
mode: "max"
monitor: "validation.metrics.auc"
train_metrics:
key: "auc"
object: "MetricAUCROC"
args:
pred: *prediction
target: *target
validation_metrics:
key: "auc"
object: "MetricAUCROC"
args:
pred: *prediction
target: *target
pl_trainer_num_epochs: 1
pl_trainer_accelerator: "cpu"
pl_trainer_devices: 1
testing:
test_results_filename: &test_results_filename "test_results.pickle"
evaluation_directory: &evaluation_directory "eval"
- object: "Eval"
args:
test_results_filename: *test_results_filename
evaluation_directory: *evaluation_directory