-
Notifications
You must be signed in to change notification settings - Fork 67
/
hubconf.py
75 lines (62 loc) · 2.79 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
dependencies = ['torch', 'torchaudio', 'numpy']
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import logging
import json
from pathlib import Path
from wavlm.WavLM import WavLM, WavLMConfig
from hifigan.models import Generator as HiFiGAN
from hifigan.utils import AttrDict
from matcher import KNeighborsVC
def knn_vc(pretrained=True, progress=True, prematched=True, device='cuda') -> KNeighborsVC:
""" Load kNN-VC (WavLM encoder and HiFiGAN decoder). Optionally use vocoder trained on `prematched` data. """
hifigan, hifigan_cfg = hifigan_wavlm(pretrained, progress, prematched, device)
wavlm = wavlm_large(pretrained, progress, device)
knnvc = KNeighborsVC(wavlm, hifigan, hifigan_cfg, device)
return knnvc
def hifigan_wavlm(pretrained=True, progress=True, prematched=True, device='cuda') -> HiFiGAN:
""" Load pretrained hifigan trained to vocode wavlm features. Optionally use weights trained on `prematched` data. """
cp = Path(__file__).parent.absolute()
with open(cp/'hifigan'/'config_v1_wavlm.json') as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
device = torch.device(device)
generator = HiFiGAN(h).to(device)
if pretrained:
if prematched:
url = "https://github.com/bshall/knn-vc/releases/download/v0.1/prematch_g_02500000.pt"
else:
url = "https://github.com/bshall/knn-vc/releases/download/v0.1/g_02500000.pt"
state_dict_g = torch.hub.load_state_dict_from_url(
url,
map_location=device,
progress=progress
)
generator.load_state_dict(state_dict_g['generator'])
generator.eval()
generator.remove_weight_norm()
print(f"[HiFiGAN] Generator loaded with {sum([p.numel() for p in generator.parameters()]):,d} parameters.")
return generator, h
def wavlm_large(pretrained=True, progress=True, device='cuda') -> WavLM:
"""Load the WavLM large checkpoint from the original paper. See https://github.com/microsoft/unilm/tree/master/wavlm for details. """
if torch.cuda.is_available() == False:
if str(device) != 'cpu':
logging.warning(f"Overriding device {device} to cpu since no GPU is available.")
device = 'cpu'
checkpoint = torch.hub.load_state_dict_from_url(
"https://github.com/bshall/knn-vc/releases/download/v0.1/WavLM-Large.pt",
map_location=device,
progress=progress
)
cfg = WavLMConfig(checkpoint['cfg'])
device = torch.device(device)
model = WavLM(cfg)
if pretrained:
model.load_state_dict(checkpoint['model'])
model = model.to(device)
model.eval()
print(f"WavLM-Large loaded with {sum([p.numel() for p in model.parameters()]):,d} parameters.")
return model