-
Notifications
You must be signed in to change notification settings - Fork 6
/
e2e.yaml
74 lines (67 loc) · 1.27 KB
/
e2e.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
name: text-sed-default
output_dir: checkpoints/${name}
seed: 2617
model:
embed_model_name: bert-large-cased
num_layers: 12
bottleneck_dim: 16
model_dim: 1536
head_dim: null
num_heads: 16
seq_len: 32
use_abs_pos: false
use_rotary: true
mask_type: span
max_num_spans: 1
ema_decay: 0.999
ema_every: 1
# Diffusion configs
noise_schedule: cosine
sampler: ddim
num_steps: 1000
max_gen_len: ${seq_len}
num_gen_steps: ${num_steps}
guide_scale: null
use_self_cond: true
time_delta: 0.0
optimizer:
type: adamw
lr: 1e-4
lr_scheduler: cosine
warmup_steps: 5000
weight_decay: 1e-4
eps: 1e-6
betas: [0.9, 0.99]
max_grad_norm: 1.0
train:
batch_size: 128
total_steps: 500000
eval_every: 5000
log_every: 100
log_stats: false
save_every: 10000
sample_every: 1000
num_samples: 8
checkpoint_path: null
use_amp: true
dtype: bfloat16
valid:
batch_size: 64
total_steps: 1000
data:
train_kwargs:
path: "e2e_nlg"
name: null
use_auth_token: False
split: "train"
valid_kwargs: {}
use_fast_tokenizer: True
use_auth_token: False
num_preprocess_workers: 1
text_attr: "human_reference"
logging:
wandb_id: null
wandb_entity: null
wandb_project: ${name}
wandb_group: "Processes"
wandb_mode: null # "online"