From 62122de00ce668324488d6c3a812c41a2b09f245 Mon Sep 17 00:00:00 2001 From: "haoxiangsnr@gmail.com" Date: Tue, 23 Jan 2024 02:03:20 +0000 Subject: [PATCH] egs: add fullband gsn --- recipes/intel_ndns/cirm_gsn/dataloader.py | 106 +++++++++++++++ recipes/intel_ndns/cirm_gsn/default.toml | 86 ++++++++++++ recipes/intel_ndns/cirm_gsn/run.py | 151 ++++++++++++++++++++++ recipes/intel_ndns/cirm_gsn/trainer.py | 107 +++++++++++++++ 4 files changed, 450 insertions(+) create mode 100644 recipes/intel_ndns/cirm_gsn/dataloader.py create mode 100644 recipes/intel_ndns/cirm_gsn/default.toml create mode 100644 recipes/intel_ndns/cirm_gsn/run.py create mode 100644 recipes/intel_ndns/cirm_gsn/trainer.py diff --git a/recipes/intel_ndns/cirm_gsn/dataloader.py b/recipes/intel_ndns/cirm_gsn/dataloader.py new file mode 100644 index 0000000..4d0f342 --- /dev/null +++ b/recipes/intel_ndns/cirm_gsn/dataloader.py @@ -0,0 +1,106 @@ +import glob +import os +import re + +import numpy as np +import soundfile as sf +from torch.utils.data import Dataset + +from audiozen.acoustics.io import subsample + + +class DNSAudio(Dataset): + def __init__(self, root="./", limit=None, offset=0, sublen=6, train=True) -> None: + """Audio dataset loader for DNS. + + Args: + root: Path of the dataset location, by default './'. + """ + super().__init__() + self.root = root + print(f"Loading dataset from {root}...") + self.noisy_files = glob.glob(root + "noisy/**.wav") + + if offset > 0: + self.noisy_files = self.noisy_files[offset:] + + if limit: + self.noisy_files = self.noisy_files[:limit] + + print(f"Found {len(self.noisy_files)} files.") + + self.file_id_from_name = re.compile(r"fileid_(\d+)") + self.snr_from_name = re.compile(r"snr(-?\d+)") + self.target_level_from_name = re.compile(r"tl(-?\d+)") + self.source_info_from_name = re.compile("^(.*?)_snr") + + self.train = train + self.sublen = sublen + self.length = len(self.noisy_files) + + def __len__(self) -> int: + """Length of the dataset.""" + return self.length + + def _get_filenames(self, n): + noisy_file = self.noisy_files[n % self.length] + filename = noisy_file.split(os.sep)[-1] + file_id = int(self.file_id_from_name.findall(filename)[0]) + clean_file = self.root + f"clean/clean_fileid_{file_id}.wav" + noise_file = self.root + f"noise/noise_fileid_{file_id}.wav" + snr = int(self.snr_from_name.findall(filename)[0]) + target_level = int(self.target_level_from_name.findall(filename)[0]) + source_info = self.source_info_from_name.findall(filename)[0] + metadata = { + "snr": snr, + "target_level": target_level, + "source_info": source_info, + } + return noisy_file, clean_file, noise_file, metadata + + def __getitem__(self, n): + """Gets the nth sample from the dataset. + + Args: + n: Index of the sample to be retrieved. + + Returns: + Noisy audio sample, clean audio sample, noise audio sample, sample metadata. + """ + noisy_file, clean_file, noise_file, metadata = self._get_filenames(n) + noisy_audio, sampling_frequency = sf.read(noisy_file) + clean_audio, _ = sf.read(clean_file) + num_samples = 30 * sampling_frequency # 30 sec data + train_num_samples = self.sublen * sampling_frequency + metadata["fs"] = sampling_frequency + + if len(noisy_audio) > num_samples: + noisy_audio = noisy_audio[:num_samples] + else: + noisy_audio = np.concatenate([noisy_audio, np.zeros(num_samples - len(noisy_audio))]) + if len(clean_audio) > num_samples: + clean_audio = clean_audio[:num_samples] + else: + clean_audio = np.concatenate([clean_audio, np.zeros(num_samples - len(clean_audio))]) + + noisy_audio = noisy_audio.astype(np.float32) + clean_audio = clean_audio.astype(np.float32) + + if self.train: + noisy_audio, start_position = subsample( + noisy_audio, + subsample_length=train_num_samples, + return_start_idx=True, + ) + clean_audio = subsample( + clean_audio, + subsample_length=train_num_samples, + start_idx=start_position, + ) + + return noisy_audio, clean_audio, noisy_file + + +if __name__ == "__main__": + train_set = DNSAudio(root="../../data/MicrosoftDNS_4_ICASSP/training_set/") + validation_set = DNSAudio(root="../../data/MicrosoftDNS_4_ICASSP/validation_set/") diff --git a/recipes/intel_ndns/cirm_gsn/default.toml b/recipes/intel_ndns/cirm_gsn/default.toml new file mode 100644 index 0000000..7049901 --- /dev/null +++ b/recipes/intel_ndns/cirm_gsn/default.toml @@ -0,0 +1,86 @@ +[meta] +save_dir = "exp" +description = "Train a model using Generative Adversarial Networks (GANs)" +seed = 20220815 + +[trainer] +path = "trainer.Trainer" +[trainer.args] +debug = false +max_steps = 0 +max_epochs = 200 +max_grad_norm = 10 +save_max_score = true +save_ckpt_interval = 10 +max_patience = 20 +plot_norm = true +validation_interval = 10 +max_num_checkpoints = 20 +scheduler_name = "constant_schedule_with_warmup" +warmup_steps = 0 +warmup_ratio = 0.00 +gradient_accumulation_steps = 1 + +[loss_function] +path = "torch.nn.MSELoss" +[loss_function.args] + +[optimizer] +path = "torch.optim.AdamW" +[optimizer.args] +lr = 1e-3 + +[model] +path = "audiozen.models.cirm_gsn.modeling_cirm_gsn.Model" +[model.args] +n_fft = 512 +hop_length = 128 +win_length = 512 +fdrc = 0.5 +input_size = 257 +hidden_size = 268 +num_layers = 4 +proj_size = 257 +output_activate_function = false +df_order = 3 +use_pre_layer_norm_fb = true +bn = true +shared_weights = true +sequence_model = "GSN" +num_spks = 1 + +[acoustics] +n_fft = 512 +hop_length = 128 +win_length = 512 +sr = 16000 + +[train_dataset] +path = "dataloader.DNSAudio" +[train_dataset.args] +root = "/datasets/datasets_fullband/training_set/" +limit = false +offset = 0 +[train_dataset.dataloader] +batch_size = 64 +num_workers = 8 +drop_last = true +pin_memory = true + +[validate_dataset] +path = "dataloader.DNSAudio" +[validate_dataset.args] +root = "/datasets/datasets_fullband/validation_set/" +train = false +[validate_dataset.dataloader] +batch_size = 16 +num_workers = 8 + +[test_dataset] +path = "dataloader.DNSAudio" +[test_dataset.args] +root = "/nfs/xhao/data/intel_ndns/test_set/" +train = false +[test_dataset.dataloader] +batch_size = 1 +num_workers = 0 diff --git a/recipes/intel_ndns/cirm_gsn/run.py b/recipes/intel_ndns/cirm_gsn/run.py new file mode 100644 index 0000000..e390dab --- /dev/null +++ b/recipes/intel_ndns/cirm_gsn/run.py @@ -0,0 +1,151 @@ +import argparse +from math import sqrt +from pathlib import Path + +import toml +from accelerate import Accelerator, DistributedDataParallelKwargs +from accelerate.utils import set_seed +from torch.utils.data import DataLoader + +from audiozen.logger import init_logging_logger +from audiozen.utils import instantiate + + +def run(config, resume): + init_logging_logger(config) + + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) + accelerator = Accelerator( + gradient_accumulation_steps=config["trainer"]["args"]["gradient_accumulation_steps"], + kwargs_handlers=[ddp_kwargs], + ) + + set_seed(config["meta"]["seed"], device_specific=True) + + model = instantiate(config["model"]["path"], args=config["model"]["args"]) + + optimizer = instantiate( + config["optimizer"]["path"], + args={"params": model.parameters()} + | config["optimizer"]["args"] + | {"lr": config["optimizer"]["args"]["lr"] * sqrt(accelerator.num_processes)}, + ) + + loss_function = instantiate( + config["loss_function"]["path"], + args=config["loss_function"]["args"], + ) + + (model, optimizer) = accelerator.prepare(model, optimizer) + + if "train" in args.mode: + train_dataset = instantiate(config["train_dataset"]["path"], args=config["train_dataset"]["args"]) + train_dataloader = DataLoader( + dataset=train_dataset, collate_fn=None, shuffle=True, **config["train_dataset"]["dataloader"] + ) + train_dataloader = accelerator.prepare(train_dataloader) + + if "train" in args.mode or "validate" in args.mode: + if not isinstance(config["validate_dataset"], list): + config["validate_dataset"] = [config["validate_dataset"]] + + validate_dataloaders = [] + for validate_config in config["validate_dataset"]: + validate_dataset = instantiate(validate_config["path"], args=validate_config["args"]) + + validate_dataloaders.append( + accelerator.prepare( + DataLoader( + dataset=validate_dataset, + **validate_config["dataloader"], + ) + ) + ) + + if "test" in args.mode: + if not isinstance(config["test_dataset"], list): + config["test_dataset"] = [config["test_dataset"]] + + test_dataloaders = [] + for test_config in config["test_dataset"]: + test_dataset = instantiate(test_config["path"], args=test_config["args"]) + + test_dataloaders.append( + accelerator.prepare( + DataLoader( + dataset=test_dataset, + **test_config["dataloader"], + ) + ) + ) + + trainer = instantiate(config["trainer"]["path"], initialize=False)( + accelerator=accelerator, + config=config, + resume=resume, + model=model, + optimizer=optimizer, + loss_function=loss_function, + ) + + for flag in args.mode: + if flag == "train": + trainer.train(train_dataloader, validate_dataloaders) + elif flag == "validate": + trainer.validate(validate_dataloaders) + elif flag == "test": + trainer.test(test_dataloaders, config["meta"]["ckpt_path"]) + elif flag == "predict": + raise NotImplementedError("Predict is not implemented yet.") + elif flag == "finetune": + raise NotImplementedError("Finetune is not implemented yet.") + else: + raise ValueError(f"Unknown mode: {flag}.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Audio-ZEN") + parser.add_argument( + "-C", + "--configuration", + required=True, + type=str, + help="Configuration (*.toml).", + ) + parser.add_argument( + "-M", + "--mode", + nargs="+", + type=str, + default=["train"], + choices=["train", "validate", "test", "predict", "finetune"], + help="Mode of the experiment.", + ) + parser.add_argument( + "-R", + "--resume", + action="store_true", + help="Resume the experiment from latest checkpoint.", + ) + parser.add_argument( + "--ckpt_path", + type=str, + default=None, + help="Checkpoint path for test. It can be 'best', 'latest', or a path to a checkpoint.", + ) + + args = parser.parse_args() + + config_path = Path(args.configuration).expanduser().absolute() + config = toml.load(config_path.as_posix()) + + config["meta"]["exp_id"] = config_path.stem + config["meta"]["config_path"] = config_path.as_posix() + + if "test" in args.mode: + if args.ckpt_path is None: + raise ValueError("checkpoint path is required for test. Use '--ckpt_path'.") + else: + config["meta"]["ckpt_path"] = args.ckpt_path + + run(config, args.resume) diff --git a/recipes/intel_ndns/cirm_gsn/trainer.py b/recipes/intel_ndns/cirm_gsn/trainer.py new file mode 100644 index 0000000..e62abe2 --- /dev/null +++ b/recipes/intel_ndns/cirm_gsn/trainer.py @@ -0,0 +1,107 @@ +import pandas as pd +from accelerate.logging import get_logger +from tqdm import tqdm + +from audiozen.loss import SISNRLoss, freq_MAE, mag_MAE +from audiozen.metric import DNSMOS, PESQ, SISDR, STOI +from audiozen.trainer import Trainer as BaseTrainer + + +logger = get_logger(__name__) + + +class Trainer(BaseTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dns_mos = DNSMOS(input_sr=self.sr, device=self.accelerator.process_index) + self.stoi = STOI(sr=self.sr) + self.pesq_wb = PESQ(sr=self.sr, mode="wb") + self.pesq_nb = PESQ(sr=self.sr, mode="nb") + self.sisnr_loss = SISNRLoss(return_neg=False) + self.si_sdr = SISDR() + self.north_star_metric = "si_sdr" + + def training_step(self, batch, batch_idx): + self.optimizer.zero_grad() + + noisy_y, clean_y, _ = batch + + batch_size, *_ = noisy_y.shape + + enhanced_y, enhanced_mag, *_ = self.model(noisy_y) + + loss_freq_mae = freq_MAE(enhanced_y, clean_y) + loss_mag_mae = mag_MAE(enhanced_y, clean_y) + loss_sdr = self.sisnr_loss(enhanced_y, clean_y) + loss_sdr_norm = 0.001 * (100 - loss_sdr) + loss = loss_freq_mae + loss_mag_mae + loss_sdr_norm # + loss_g_fake + + self.accelerator.backward(loss) + self.optimizer.step() + + return { + "loss": loss, + "loss_freq_mae": loss_freq_mae, + "loss_mag_mae": loss_mag_mae, + "loss_sdr": loss_sdr, + "loss_sdr_norm": loss_sdr_norm, + } + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + mix_y, ref_y, id = batch + est_y, *_ = self.model(mix_y) + + if len(id) != 1: + raise ValueError(f"Expected batch size 1 during validation, got {len(id)}") + + # calculate metrics + mix_y = mix_y.squeeze(0).detach().cpu().numpy() + ref_y = ref_y.squeeze(0).detach().cpu().numpy() + est_y = est_y.squeeze(0).detach().cpu().numpy() + + si_sdr = self.si_sdr(est_y, ref_y) + dns_mos = self.dns_mos(est_y) + + out = si_sdr | dns_mos + return [out] + + def validation_epoch_end(self, outputs, log_to_tensorboard=True): + score = 0.0 + + for dataloader_idx, dataloader_outputs in enumerate(outputs): + logger.info(f"Computing metrics on epoch {self.state.epochs_trained} for dataloader {dataloader_idx}...") + + loss_dict_list = [] + for step_loss_dict_list in tqdm(dataloader_outputs): + loss_dict_list.extend(step_loss_dict_list) + + df_metrics = pd.DataFrame(loss_dict_list) + + # Compute mean of all metrics + df_metrics_mean = df_metrics.mean(numeric_only=True) + df_metrics_mean_df = df_metrics_mean.to_frame().T + + time_now = self._get_time_now() + df_metrics.to_csv( + self.metrics_dir / f"dl_{dataloader_idx}_epoch_{self.state.epochs_trained}_{time_now}.csv", + index=False, + ) + df_metrics_mean_df.to_csv( + self.metrics_dir / f"dl_{dataloader_idx}_epoch_{self.state.epochs_trained}_{time_now}_mean.csv", + index=False, + ) + + logger.info(f"\n{df_metrics_mean_df.to_markdown()}") + score += df_metrics_mean[self.north_star_metric] + + if log_to_tensorboard: + for metric, value in df_metrics_mean.items(): + self.writer.add_scalar(f"metrics_{dataloader_idx}/{metric}", value, self.state.epochs_trained) + + return score + + def test_step(self, batch, batch_idx, dataloader_idx=0): + return self.validation_step(batch, batch_idx, dataloader_idx) + + def test_epoch_end(self, outputs, log_to_tensorboard=True): + return self.validation_epoch_end(outputs, log_to_tensorboard=False)