Skip to content

Commit

Permalink
Update model configuration and dataset parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
haoxiangsnr committed Mar 11, 2024
1 parent 06581f3 commit 4eee7d1
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
[meta]
save_dir = "exp"
description = "Train a model using Generative Adversarial Networks (GANs)"
seed = 20220815

[trainer]
path = "trainer.Trainer"
[trainer.args]
debug = false
max_steps = 0
max_epochs = 200
max_grad_norm = 10
save_max_score = true
save_ckpt_interval = 1
max_patience = 20
plot_norm = true
validation_interval = 1
max_num_checkpoints = 20
scheduler_name = "constant_schedule_with_warmup"
warmup_steps = 0
warmup_ratio = 0.00
gradient_accumulation_steps = 1

[loss_function]
path = "torch.nn.MSELoss"
[loss_function.args]

[optimizer]
path = "torch.optim.AdamW"
[optimizer.args]
lr = 1e-3

[model]
path = "audiozen.models.spiking_fullsubnet.modeling_spiking_fullsubnet.SpikingFullSubNet"
[model.args]
n_fft = 512
hop_length = 128
win_length = 512
fdrc = 0.5
fb_input_size = 64
fb_hidden_size = 320
fb_num_layers = 2
fb_proj_size = 64
fb_output_activate_function = false
sb_hidden_size = 224
sb_num_layers = 2
freq_cutoffs = [0, 32, 128, 256]
df_orders = [5, 3, 1]
center_freq_sizes = [1, 1, 1]
neighbor_freq_sizes = [15, 15, 15]
use_pre_layer_norm_fb = true
use_pre_layer_norm_sb = true
bn = true
shared_weights = true
sequence_model = "GSN"
num_spks = 1

[acoustics]
n_fft = 512
hop_length = 128
win_length = 512
sr = 16000

[train_dataset]
path = "dataloader.DNSAudio"
[train_dataset.args]
root = "/datasets/datasets_fullband/training_set/"
limit = false
offset = 0

[train_dataset.dataloader]
batch_size = 6
num_workers = 6
drop_last = true
pin_memory = true

[validate_dataset]
path = "dataloader.DNSAudio"
[validate_dataset.args]
root = "/datasets/datasets_fullband/validation_set/"
train = false
[validate_dataset.dataloader]
batch_size = 16
num_workers = 8

[test_dataset]
path = "dataloader.DNSAudio"
[test_dataset.args]
root = "/nfs/xhao/data/intel_ndns/test_set/"
train = false
[test_dataset.dataloader]
batch_size = 1
num_workers = 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
[meta]
save_dir = "exp"
description = "Train a model using Generative Adversarial Networks (GANs)"
seed = 20220815

[trainer]
path = "trainer.Trainer"
[trainer.args]
debug = false
max_steps = 0
max_epochs = 200
max_grad_norm = 10
save_max_score = true
save_ckpt_interval = 5
max_patience = 20
plot_norm = true
validation_interval = 5
max_num_checkpoints = 20
scheduler_name = "constant_schedule_with_warmup"
warmup_steps = 0
warmup_ratio = 0.00
gradient_accumulation_steps = 1

[loss_function]
path = "torch.nn.MSELoss"
[loss_function.args]

[optimizer]
path = "torch.optim.AdamW"
[optimizer.args]
lr = 1e-3

[model]
path = "audiozen.models.spiking_fullsubnet.modeling_spiking_fullsubnet.SpikingFullSubNet"
[model.args]
n_fft = 512
hop_length = 128
win_length = 512
fdrc = 0.5
fb_input_size = 64
fb_hidden_size = 320
fb_num_layers = 2
fb_proj_size = 64
fb_output_activate_function = false
sb_hidden_size = 224
sb_num_layers = 2
freq_cutoffs = [0, 32, 128, 256]
df_orders = [5, 3, 1]
center_freq_sizes = [2, 32, 64]
neighbor_freq_sizes = [15, 15, 15]
use_pre_layer_norm_fb = true
use_pre_layer_norm_sb = true
bn = true
shared_weights = true
sequence_model = "GSN"
num_spks = 1

[acoustics]
n_fft = 512
hop_length = 128
win_length = 512
sr = 16000

[train_dataset]
path = "dataloader.DNSAudio"
[train_dataset.args]
root = "/datasets/datasets_fullband/training_set/"
limit = false
offset = 0

[train_dataset.dataloader]
batch_size = 32
num_workers = 8
drop_last = true
pin_memory = true

[validate_dataset]
path = "dataloader.DNSAudio"
[validate_dataset.args]
root = "/datasets/datasets_fullband/validation_set/"
train = false
[validate_dataset.dataloader]
batch_size = 1
num_workers = 0

[test_dataset]
path = "dataloader.DNSAudio"
[test_dataset.args]
root = "/nfs/xhao/data/intel_ndns/test_set/"
train = false
[test_dataset.dataloader]
batch_size = 1
num_workers = 0
17 changes: 9 additions & 8 deletions recipes/intel_ndns/spiking_fullsubnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,21 +431,22 @@ def forward(self, input):
win_length=512,
fdrc=0.5,
fb_input_size=64,
fb_hidden_size=256,
fb_hidden_size=320,
fb_num_layers=2,
fb_proj_size=64,
fb_output_activate_function=None,
sb_hidden_size=128,
sb_hidden_size=224,
sb_num_layers=2,
freq_cutoffs=[0, 20, 80, 256],
df_orders=[2, 2, 2],
center_freq_sizes=[2, 10, 22],
neighbor_freq_sizes=[8, 16, 32],
freq_cutoffs=[0, 32, 128, 256],
df_orders=[5, 3, 1],
center_freq_sizes=[2, 32, 64],
neighbor_freq_sizes=[15, 15, 15],
use_pre_layer_norm_fb=True,
use_pre_layer_norm_sb=True,
bn=False,
shared_weights=False,
bn=True,
shared_weights=True,
sequence_model="GSN",
num_spks=1,
)

input = torch.rand(2, 16000)
Expand Down

0 comments on commit 4eee7d1

Please sign in to comment.