diff --git a/examples/test.py b/examples/test.py index 84e7fd3..ed0cc75 100644 --- a/examples/test.py +++ b/examples/test.py @@ -1,3 +1,10 @@ +import sys +import os + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), '../')) +) + import torch import os import glob @@ -62,7 +69,7 @@ def load_remote(): hifigan = load_local().to(device) # Load audio -wav, sr = torchaudio.load("zszy_48k.wav") +wav, sr = torchaudio.load("7200000318_0_generated.wav") assert sr == 48000 # mel = mel_spectrogram_torch(wav, 2048, 128, 48000, 512, 2048, 0, None, False) diff --git a/train.py b/train.py index 019be5d..f73fa65 100644 --- a/train.py +++ b/train.py @@ -46,6 +46,7 @@ def main(): parser.add_argument('-c', '--config', type=str, default="./configs/48k.json", help='JSON file for configuration') parser.add_argument('-a', '--accelerator', type=str, default="gpu", help='training device') parser.add_argument('-d', '--device', type=str, default="0", help='training device ids') + parser.add_argument('-n', '--num-nodes', type=int, default=1, help='training node number') args = parser.parse_args() hparams = get_hparams(args.config) @@ -75,6 +76,8 @@ def main(): if hparams.train.fp16_run: trainer_params["amp_backend"] = "native" trainer_params["precision"] = 16 + + trainer_params["num_nodes"] = args.num_nodes # data train_dataset = MelDataset(hparams.data.training_files, hparams.data)