From babb97312bdf8c92aeaa7449fd2ba127485ba15f Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Thu, 24 Nov 2022 00:17:08 +0800 Subject: [PATCH] v0.3.1 --- hifigan/hub/__init__.py | 4 ++-- save_state.py | 38 ++++++++++++++++++++++++++++---------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/hifigan/hub/__init__.py b/hifigan/hub/__init__.py index 08a91cb..fc67544 100644 --- a/hifigan/hub/__init__.py +++ b/hifigan/hub/__init__.py @@ -1,6 +1,6 @@ CKPT_URLS = { - "hifigan-16k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.3.1/hifigan-16k-B67C217083569F978E07EFD1AD7B1766.pt", - "hifigan-48k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.3.1/hifigan-48k-B67C217083569F978E07EFD1AD7B1766.pt", + "hifigan-16k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.3.1/hifigan-16k-net-g-de9ba6d4a7cedf02d8b7ab1ef1a3dc6b.pt", + "hifigan-48k": "https://github.com/vtuber-plan/hifi-gan/releases/download/v0.3.1/hifigan-48k-net-g-e5e8e381165ff7e02ff2986f00eabf42.pt", } import torch from ..model.generators.generator import Generator diff --git a/save_state.py b/save_state.py index 3ab574f..a12dbd0 100644 --- a/save_state.py +++ b/save_state.py @@ -1,26 +1,44 @@ import os import glob +import shutil +from typing import Optional import torch from hifigan.model.hifigan import HifiGAN +import hashlib -def save(ckpt_path: str): +def save(ckpt_path: str, name: str): model = HifiGAN.load_from_checkpoint(checkpoint_path=ckpt_path, strict=True) # print(model.net_g.state_dict()) - torch.save(model.net_g.state_dict(), "net_g.pt") - torch.save(model.net_period_d.state_dict(), "net_period_d.pt") - torch.save(model.net_scale_d.state_dict(), "net_scale_d.pt") + torch.save(model.net_g.state_dict(), f"hifigan-{name}-net-g.pt") + torch.save(model.net_period_d.state_dict(), f"hifigan-{name}-net-period-d.pt") + torch.save(model.net_scale_d.state_dict(), f"hifigan-{name}-net-scale-d.pt") -def main(): + h = hashlib.md5(open(f"hifigan-{name}-net-g.pt",'rb').read()).hexdigest() + shutil.move(f"hifigan-{name}-net-g.pt", f"hifigan-{name}-net-g-{h}.pt") + h = hashlib.md5(open(f"hifigan-{name}-net-period-d.pt",'rb').read()).hexdigest() + shutil.move(f"hifigan-{name}-net-period-d.pt", f"hifigan-{name}-net-period-d-{h}.pt") + h = hashlib.md5(open(f"hifigan-{name}-net-scale-d.pt",'rb').read()).hexdigest() + shutil.move(f"hifigan-{name}-net-scale-d.pt", f"hifigan-{name}-net-scale-d-{h}.pt") + +def last_checkpoint(path: str) -> Optional[str]: ckpt_path = None - if os.path.exists("logs/lightning_logs"): - versions = glob.glob("logs/lightning_logs/version_*") + if os.path.exists(os.path.join(path, "lightning_logs")): + versions = glob.glob(os.path.join(path, "lightning_logs", "version_*")) if len(list(versions)) > 0: - last_ver = sorted(list(versions))[-1] + last_ver = sorted(list(versions), key=lambda p: int(p.split("_")[-1]))[-1] last_ckpt = os.path.join(last_ver, "checkpoints/last.ckpt") if os.path.exists(last_ckpt): ckpt_path = last_ckpt - + return ckpt_path + +import argparse +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-d', '--dir', type=str, default="./logs", help='Loggin Path') + parser.add_argument('-n', '--name', type=str, default="48k", help='sr') + args = parser.parse_args() + ckpt_path = last_checkpoint(args.dir) print(ckpt_path) - save(ckpt_path) + save(ckpt_path, args.name) if __name__ == "__main__": main() \ No newline at end of file