-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9526b1c
commit 00ec864
Showing
4 changed files
with
170 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,80 @@ | ||
import os | ||
import glob | ||
import tqdm | ||
import torch | ||
import librosa | ||
import pyworld | ||
import argparse | ||
import numpy as np | ||
|
||
from scipy.io.wavfile import write | ||
from omegaconf import OmegaConf | ||
|
||
from model.generator import Generator | ||
|
||
|
||
def load_svc_model(checkpoint_path, model): | ||
assert os.path.isfile(checkpoint_path) | ||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") | ||
model.load_state_dict(checkpoint_dict["model_g"]) | ||
return model | ||
|
||
|
||
def compute_f0(path): | ||
x, sr = librosa.load(path, sr=16000) | ||
assert sr == 16000 | ||
f0, t = pyworld.dio( | ||
x.astype(np.double), | ||
fs=sr, | ||
f0_ceil=900, | ||
frame_period=1000 * 160 / sr, | ||
) | ||
f0 = pyworld.stonemask(x.astype(np.double), f0, t, fs=16000) | ||
for index, pitch in enumerate(f0): | ||
f0[index] = round(pitch, 1) | ||
return f0 | ||
|
||
|
||
ppg_path = "uni_svc_tmp.ppg.npy" | ||
|
||
|
||
def main(args): | ||
checkpoint = torch.load(args.checkpoint_path) | ||
if args.config is not None: | ||
hp = OmegaConf.load(args.config) | ||
else: | ||
hp = OmegaConf.create(checkpoint['hp_str']) | ||
|
||
model = Generator(hp).cuda() | ||
saved_state_dict = checkpoint['model_g'] | ||
new_state_dict = {} | ||
|
||
for k, v in saved_state_dict.items(): | ||
try: | ||
new_state_dict[k] = saved_state_dict['module.' + k] | ||
except: | ||
new_state_dict[k] = v | ||
model.load_state_dict(new_state_dict) | ||
model.eval(inference=True) | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
hp = OmegaConf.load(args.config) | ||
model = Generator(hp) | ||
load_svc_model(args.model, model) | ||
|
||
with torch.no_grad(): | ||
for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))): | ||
mel = torch.load(melpath) | ||
if len(mel.shape) == 2: | ||
mel = mel.unsqueeze(0) | ||
mel = mel.cuda() | ||
os.system(f"python svc_inference_ppg.py -w {args.wave} -p {ppg_path}") | ||
|
||
ppg = np.load(ppg_path) | ||
ppg = np.repeat(ppg, 2, 0) # 320 PPG -> 160 * 2 | ||
ppg = torch.FloatTensor(ppg) | ||
|
||
pit = compute_f0(args.wave) | ||
pit = torch.FloatTensor(pit) | ||
|
||
audio = model.inference(mel) | ||
audio = audio.cpu().detach().numpy() | ||
len_pit = pit.size()[0] | ||
len_ppg = ppg.size()[0] | ||
len_min = min(len_pit, len_ppg) | ||
pit = pit[:len_min] | ||
ppg = ppg[:len_min, :] | ||
|
||
model.eval(inference=True) | ||
model.to(device) | ||
with torch.no_grad(): | ||
ppg = ppg.unsqueeze(0).to(device) | ||
pit = pit.unsqueeze(0).to(device) | ||
audio = model.inference(ppg, pit) | ||
audio = audio.cpu().detach().numpy() | ||
|
||
if args.output_folder is None: # if output folder is not defined, audio samples are saved in input folder | ||
out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch']) | ||
else: | ||
basename = os.path.basename(melpath) | ||
basename = basename.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch']) | ||
out_path = os.path.join(args.output_folder, basename) | ||
write(out_path, hp.audio.sampling_rate, audio) | ||
write("uni_svc_out.wav", hp.audio.sampling_rate, audio) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-c', '--config', type=str, default=None, | ||
help="yaml file for config. will use hp_str from checkpoint if not given.") | ||
parser.add_argument('-p', '--checkpoint_path', type=str, required=True, | ||
help="path of checkpoint pt file for evaluation") | ||
parser.add_argument('-i', '--input_folder', type=str, required=True, | ||
help="directory of mel-spectrograms to invert into raw audio.") | ||
parser.add_argument('-o', '--output_folder', type=str, default=None, | ||
help="directory which generated raw audio is saved.") | ||
parser.add_argument('-c', '--config', type=str, required=True, | ||
help="yaml file for config.") | ||
parser.add_argument('-m', '--model', type=str, required=True, | ||
help="path of model for evaluation") | ||
parser.add_argument('-i', '--wave', type=str, required=True, | ||
help="Path of raw audio.") | ||
args = parser.parse_args() | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import os | ||
import torch | ||
import argparse | ||
from omegaconf import OmegaConf | ||
|
||
from model.generator import Generator | ||
|
||
|
||
def load_model(checkpoint_path, model): | ||
assert os.path.isfile(checkpoint_path) | ||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") | ||
saved_state_dict = checkpoint_dict["model_g"] | ||
if hasattr(model, "module"): | ||
state_dict = model.module.state_dict() | ||
else: | ||
state_dict = model.state_dict() | ||
new_state_dict = {} | ||
for k, v in state_dict.items(): | ||
try: | ||
new_state_dict[k] = saved_state_dict[k] | ||
except: | ||
new_state_dict[k] = v | ||
if hasattr(model, "module"): | ||
model.module.load_state_dict(new_state_dict) | ||
else: | ||
model.load_state_dict(new_state_dict) | ||
return model | ||
|
||
|
||
def save_model(model, checkpoint_path): | ||
if hasattr(model, 'module'): | ||
state_dict = model.module.state_dict() | ||
else: | ||
state_dict = model.state_dict() | ||
torch.save({'model_g': state_dict}, checkpoint_path) | ||
|
||
|
||
def main(args): | ||
hp = OmegaConf.load(args.config) | ||
model = Generator(hp) | ||
load_model(args.checkpoint_path, model) | ||
save_model(model, "uni_svc.pth") | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-c', '--config', type=str, required=True, | ||
help="yaml file for config. will use hp_str from checkpoint if not given.") | ||
parser.add_argument('-p', '--checkpoint_path', type=str, required=True, | ||
help="path of checkpoint pt file for evaluation") | ||
args = parser.parse_args() | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
import numpy as np | ||
import argparse | ||
import torch | ||
|
||
from whisper.model import Whisper, ModelDimensions | ||
from whisper.audio import load_audio, pad_or_trim, log_mel_spectrogram | ||
|
||
|
||
def load_model(path) -> Whisper: | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
checkpoint = torch.load(path, map_location=device) | ||
dims = ModelDimensions(**checkpoint["dims"]) | ||
model = Whisper(dims) | ||
model.load_state_dict(checkpoint["model_state_dict"]) | ||
return model.to(device) | ||
|
||
|
||
def pred_ppg(whisper: Whisper, wavPath, ppgPath): | ||
audio = load_audio(wavPath) | ||
audln = audio.shape[0] | ||
ppgln = audln // 320 | ||
audio = pad_or_trim(audio) | ||
mel = log_mel_spectrogram(audio).to(whisper.device) | ||
with torch.no_grad(): | ||
ppg = whisper.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() | ||
ppg = ppg[:ppgln,] # [length, dim=1024] | ||
print(ppg.shape) | ||
np.save(ppgPath, ppg, allow_pickle=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.description = 'please enter embed parameter ...' | ||
parser.add_argument("-w", "--wav", help="wav", dest="wav") | ||
parser.add_argument("-p", "--ppg", help="ppg", dest="ppg") | ||
args = parser.parse_args() | ||
print(args.wav) | ||
print(args.ppg) | ||
|
||
wavPath = args.wav | ||
ppgPath = args.ppg | ||
|
||
whisper = load_model("medium.pt") | ||
pred_ppg(whisper, wavPath, ppgPath) |