Skip to content

Commit

Permalink
infer
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxMax2016 committed Mar 4, 2023
1 parent 9526b1c commit 00ec864
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 46 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@ change sample rate of waves, and put waves to ./data_opencpop/waves
> python svc_trainer.py -c config/default_c32.yaml -n uni_svc
3k wavs of opencpop training~~~~~~
### Infer
export clean model

https://user-images.githubusercontent.com/16432329/222747832-ee6aaa27-6257-49c8-b373-5d13d0c09496.mp4
> python svc_export.py --config config/default_c32.yaml --checkpoint_path chkpt/uni_svc/uni_svc_0740.pt
download preview form release page

# data-sets
> python svc_inference.py --config config/default_c32.yaml --model uni_svc.pth --wave uni_svc_test.wav
### Data-sets
KiSing http://shijt.site/index.php/2021/05/16/kising-the-first-open-source-mandarin-singing-voice-synthesis-corpus/

PopCS https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/apply_form.md
Expand Down Expand Up @@ -67,3 +71,7 @@ Aishell-3 http://www.aishelltech.com/aishell_3

VCTK https://datashare.ed.ac.uk/handle/10283/2651

# Notice
如果您参考了本项目,请您在您项目中列出本项目。【武德】

If you refer to this project, please list it in your project.
104 changes: 61 additions & 43 deletions svc_inference.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,80 @@
import os
import glob
import tqdm
import torch
import librosa
import pyworld
import argparse
import numpy as np

from scipy.io.wavfile import write
from omegaconf import OmegaConf

from model.generator import Generator


def load_svc_model(checkpoint_path, model):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint_dict["model_g"])
return model


def compute_f0(path):
x, sr = librosa.load(path, sr=16000)
assert sr == 16000
f0, t = pyworld.dio(
x.astype(np.double),
fs=sr,
f0_ceil=900,
frame_period=1000 * 160 / sr,
)
f0 = pyworld.stonemask(x.astype(np.double), f0, t, fs=16000)
for index, pitch in enumerate(f0):
f0[index] = round(pitch, 1)
return f0


ppg_path = "uni_svc_tmp.ppg.npy"


def main(args):
checkpoint = torch.load(args.checkpoint_path)
if args.config is not None:
hp = OmegaConf.load(args.config)
else:
hp = OmegaConf.create(checkpoint['hp_str'])

model = Generator(hp).cuda()
saved_state_dict = checkpoint['model_g']
new_state_dict = {}

for k, v in saved_state_dict.items():
try:
new_state_dict[k] = saved_state_dict['module.' + k]
except:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.eval(inference=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hp = OmegaConf.load(args.config)
model = Generator(hp)
load_svc_model(args.model, model)

with torch.no_grad():
for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))):
mel = torch.load(melpath)
if len(mel.shape) == 2:
mel = mel.unsqueeze(0)
mel = mel.cuda()
os.system(f"python svc_inference_ppg.py -w {args.wave} -p {ppg_path}")

ppg = np.load(ppg_path)
ppg = np.repeat(ppg, 2, 0) # 320 PPG -> 160 * 2
ppg = torch.FloatTensor(ppg)

pit = compute_f0(args.wave)
pit = torch.FloatTensor(pit)

audio = model.inference(mel)
audio = audio.cpu().detach().numpy()
len_pit = pit.size()[0]
len_ppg = ppg.size()[0]
len_min = min(len_pit, len_ppg)
pit = pit[:len_min]
ppg = ppg[:len_min, :]

model.eval(inference=True)
model.to(device)
with torch.no_grad():
ppg = ppg.unsqueeze(0).to(device)
pit = pit.unsqueeze(0).to(device)
audio = model.inference(ppg, pit)
audio = audio.cpu().detach().numpy()

if args.output_folder is None: # if output folder is not defined, audio samples are saved in input folder
out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
else:
basename = os.path.basename(melpath)
basename = basename.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
out_path = os.path.join(args.output_folder, basename)
write(out_path, hp.audio.sampling_rate, audio)
write("uni_svc_out.wav", hp.audio.sampling_rate, audio)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default=None,
help="yaml file for config. will use hp_str from checkpoint if not given.")
parser.add_argument('-p', '--checkpoint_path', type=str, required=True,
help="path of checkpoint pt file for evaluation")
parser.add_argument('-i', '--input_folder', type=str, required=True,
help="directory of mel-spectrograms to invert into raw audio.")
parser.add_argument('-o', '--output_folder', type=str, default=None,
help="directory which generated raw audio is saved.")
parser.add_argument('-c', '--config', type=str, required=True,
help="yaml file for config.")
parser.add_argument('-m', '--model', type=str, required=True,
help="path of model for evaluation")
parser.add_argument('-i', '--wave', type=str, required=True,
help="Path of raw audio.")
args = parser.parse_args()

main(args)
53 changes: 53 additions & 0 deletions svc_inference_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
import torch
import argparse
from omegaconf import OmegaConf

from model.generator import Generator


def load_model(checkpoint_path, model):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint_dict["model_g"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
new_state_dict[k] = saved_state_dict[k]
except:
new_state_dict[k] = v
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
return model


def save_model(model, checkpoint_path):
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save({'model_g': state_dict}, checkpoint_path)


def main(args):
hp = OmegaConf.load(args.config)
model = Generator(hp)
load_model(args.checkpoint_path, model)
save_model(model, "uni_svc.pth")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, required=True,
help="yaml file for config. will use hp_str from checkpoint if not given.")
parser.add_argument('-p', '--checkpoint_path', type=str, required=True,
help="path of checkpoint pt file for evaluation")
args = parser.parse_args()

main(args)
45 changes: 45 additions & 0 deletions svc_inference_ppg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import numpy as np
import argparse
import torch

from whisper.model import Whisper, ModelDimensions
from whisper.audio import load_audio, pad_or_trim, log_mel_spectrogram


def load_model(path) -> Whisper:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=device)
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
return model.to(device)


def pred_ppg(whisper: Whisper, wavPath, ppgPath):
audio = load_audio(wavPath)
audln = audio.shape[0]
ppgln = audln // 320
audio = pad_or_trim(audio)
mel = log_mel_spectrogram(audio).to(whisper.device)
with torch.no_grad():
ppg = whisper.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy()
ppg = ppg[:ppgln,] # [length, dim=1024]
print(ppg.shape)
np.save(ppgPath, ppg, allow_pickle=False)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.description = 'please enter embed parameter ...'
parser.add_argument("-w", "--wav", help="wav", dest="wav")
parser.add_argument("-p", "--ppg", help="ppg", dest="ppg")
args = parser.parse_args()
print(args.wav)
print(args.ppg)

wavPath = args.wav
ppgPath = args.ppg

whisper = load_model("medium.pt")
pred_ppg(whisper, wavPath, ppgPath)

0 comments on commit 00ec864

Please sign in to comment.