-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathsweep.py
executable file
·75 lines (65 loc) · 3.34 KB
/
sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import wandb
from attrdict import AttrDict
from argparse import ArgumentParser
from run import train
from src.data import get_data_path
def train_wrapper():
search = wandb.init(project="DART", sync_tensorboard=True)
train(base_config, **wandb.config)
search.finish()
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--project_name', type=str,
default='DART', help='project name for sweep')
parser.add_argument('--task_name', type=str, required=True)
parser.add_argument('--data_split', type=int, default=13,
choices=[13, 21, 42, 87, 100], help='few-shot split-id for GLUE dataset')
parser.add_argument('--pretrain_model', type=str,
default='pretrain/albert-xxlarge-v2', help='name or path for pretrained model')
parser.add_argument('--pet_method', type=str, default='diffpet',
choices=['pet', 'diffpet'], help='prompt encoding method')
parser.add_argument('--random_seed', type=int,
default=3407, help='random seed for training')
parser.add_argument('--max_run', type=int, default=100,
help='maximum tries for sweep')
args = parser.parse_args()
# Configure basic parameters for run
task_name = args.task_name.lower()
output_dir = os.path.join('output', task_name)
os.makedirs(output_dir, exist_ok=True)
train_path, dev_path, test_path = get_data_path(task_name, args.data_split)
base_config = AttrDict({
'task_name': task_name, 'train_path': train_path, 'dev_path': dev_path, 'test_path': test_path, 'output_dir': output_dir,
'log_file': f'{task_name}.log', 'pred_file': '', 'use_gpu': True,
'pretrain_model': args.pretrain_model, 'pet_method': args.pet_method, 'seed': args.random_seed, 'max_seq_len': 128,
'shuffle': True, 'eval_every_steps': 20, 'test_batch_size': 32, 'max_train_epochs': 20, 'early_stop_steps': 5,
'save_metric': 'f1_score' if task_name in ['mrpc', 'qqp'] else 'accuracy'
})
# Prepare sweep config (search space of hyper parameters) and get sweep id
sweep_config = {
'program': task_name,
'method': 'grid',
'metric': {
'goal': 'maximize',
'name': 'test f1_score' if task_name in ['mrpc', 'qqp'] else 'test accuracy'
},
'parameters': {
'data_split': {'values': [args.data_split]},
'warmup_ratio': {'values': [0.05]},
'learning_rate': {'values': [1e-5, 5e-5, 1e-4]},
'weight_decay': {'values': [0.01]},
'adam_epsilon': {'values': [1e-8]},
'train_batch_size': {'values': [4, 8, 16]},
'grad_acc_steps': {'values': [1, 2]},
'max_grad_norm': {'values': [1.0]},
'full_vocab_loss': {'values': [True, False]},
'mask_rate': {'values': [0.0, 0.05, 0.10]},
'mlm_loss_weight': {'values': [0.0, 0.5, 1.0]}
}
}
sweep_id = wandb.sweep(sweep_config, project=args.project_name)
# Sweep all hyper parameters
wandb.agent(sweep_id, function=train_wrapper, count=args.max_run)
# NOTE: this script does NOT save the optimal hyper parameter set itself;
# While running the sweep, track the results and manually check optimal hyper parameters on the wandb site (at your `DART` project).