Skip to content

Commit

Permalink
v0.3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Nov 23, 2022
1 parent 88b523c commit babb973
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
4 changes: 2 additions & 2 deletions hifigan/hub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
CKPT_URLS = {
"hifigan-16k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.3.1/hifigan-16k-B67C217083569F978E07EFD1AD7B1766.pt",
"hifigan-48k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.3.1/hifigan-48k-B67C217083569F978E07EFD1AD7B1766.pt",
"hifigan-16k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.3.1/hifigan-16k-net-g-de9ba6d4a7cedf02d8b7ab1ef1a3dc6b.pt",
"hifigan-48k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.3.1/hifigan-48k-net-g-e5e8e381165ff7e02ff2986f00eabf42.pt",
}
import torch
from ..model.generators.generator import Generator
Expand Down
38 changes: 28 additions & 10 deletions save_state.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,44 @@
import os
import glob
import shutil
from typing import Optional
import torch
from hifigan.model.hifigan import HifiGAN
import hashlib

def save(ckpt_path: str):
def save(ckpt_path: str, name: str):
model = HifiGAN.load_from_checkpoint(checkpoint_path=ckpt_path, strict=True)
# print(model.net_g.state_dict())
torch.save(model.net_g.state_dict(), "net_g.pt")
torch.save(model.net_period_d.state_dict(), "net_period_d.pt")
torch.save(model.net_scale_d.state_dict(), "net_scale_d.pt")
torch.save(model.net_g.state_dict(), f"hifigan-{name}-net-g.pt")
torch.save(model.net_period_d.state_dict(), f"hifigan-{name}-net-period-d.pt")
torch.save(model.net_scale_d.state_dict(), f"hifigan-{name}-net-scale-d.pt")

def main():
h = hashlib.md5(open(f"hifigan-{name}-net-g.pt",'rb').read()).hexdigest()
shutil.move(f"hifigan-{name}-net-g.pt", f"hifigan-{name}-net-g-{h}.pt")
h = hashlib.md5(open(f"hifigan-{name}-net-period-d.pt",'rb').read()).hexdigest()
shutil.move(f"hifigan-{name}-net-period-d.pt", f"hifigan-{name}-net-period-d-{h}.pt")
h = hashlib.md5(open(f"hifigan-{name}-net-scale-d.pt",'rb').read()).hexdigest()
shutil.move(f"hifigan-{name}-net-scale-d.pt", f"hifigan-{name}-net-scale-d-{h}.pt")

def last_checkpoint(path: str) -> Optional[str]:
ckpt_path = None
if os.path.exists("logs/lightning_logs"):
versions = glob.glob("logs/lightning_logs/version_*")
if os.path.exists(os.path.join(path, "lightning_logs")):
versions = glob.glob(os.path.join(path, "lightning_logs", "version_*"))
if len(list(versions)) > 0:
last_ver = sorted(list(versions))[-1]
last_ver = sorted(list(versions), key=lambda p: int(p.split("_")[-1]))[-1]
last_ckpt = os.path.join(last_ver, "checkpoints/last.ckpt")
if os.path.exists(last_ckpt):
ckpt_path = last_ckpt

return ckpt_path

import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--dir', type=str, default="./logs", help='Loggin Path')
parser.add_argument('-n', '--name', type=str, default="48k", help='sr')
args = parser.parse_args()
ckpt_path = last_checkpoint(args.dir)
print(ckpt_path)
save(ckpt_path)
save(ckpt_path, args.name)
if __name__ == "__main__":
main()

0 comments on commit babb973

Please sign in to comment.