Skip to content

Commit

Permalink
multi node training
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Nov 23, 2022
1 parent 0a5c73f commit 88b523c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
9 changes: 8 additions & 1 deletion examples/test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import sys
import os

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))
)

import torch
import os
import glob
Expand Down Expand Up @@ -62,7 +69,7 @@ def load_remote():
hifigan = load_local().to(device)

# Load audio
wav, sr = torchaudio.load("zszy_48k.wav")
wav, sr = torchaudio.load("7200000318_0_generated.wav")
assert sr == 48000

# mel = mel_spectrogram_torch(wav, 2048, 128, 48000, 512, 2048, 0, None, False)
Expand Down
3 changes: 3 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def main():
parser.add_argument('-c', '--config', type=str, default="./configs/48k.json", help='JSON file for configuration')
parser.add_argument('-a', '--accelerator', type=str, default="gpu", help='training device')
parser.add_argument('-d', '--device', type=str, default="0", help='training device ids')
parser.add_argument('-n', '--num-nodes', type=int, default=1, help='training node number')
args = parser.parse_args()

hparams = get_hparams(args.config)
Expand Down Expand Up @@ -75,6 +76,8 @@ def main():
if hparams.train.fp16_run:
trainer_params["amp_backend"] = "native"
trainer_params["precision"] = 16

trainer_params["num_nodes"] = args.num_nodes

# data
train_dataset = MelDataset(hparams.data.training_files, hparams.data)
Expand Down

0 comments on commit 88b523c

Please sign in to comment.