From 91271b9bb3e4917ece6121d153399aa51d9d160f Mon Sep 17 00:00:00 2001 From: admin Date: Mon, 2 Sep 2024 17:24:54 +0800 Subject: [PATCH] rm weights only --- generate.py | 2 +- infer.py | 2 +- rl.py | 2 +- train.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/generate.py b/generate.py index b1d3e31..55695bc 100644 --- a/generate.py +++ b/generate.py @@ -72,7 +72,7 @@ def generate_abc(args): vocab_size=128, ) model: nn.Module = TunesFormer(patch_config, char_config, SHARE_WEIGHTS) - checkpoint = torch.load(args.weights, weights_only=False) + checkpoint = torch.load(args.weights) model.load_state_dict(checkpoint["model"]) model = model.to(DEVICE) model.eval() diff --git a/infer.py b/infer.py index 3343156..9c18266 100644 --- a/infer.py +++ b/infer.py @@ -145,7 +145,7 @@ def generate_music( vocab_size=128, ) model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS) - checkpoint = torch.load(weights, weights_only=False) + checkpoint = torch.load(weights) model.load_state_dict(checkpoint["model"]) model = model.to(DEVICE) model.eval() diff --git a/rl.py b/rl.py index 2952d3d..5fbacea 100644 --- a/rl.py +++ b/rl.py @@ -93,7 +93,7 @@ def load_model( if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) - checkpoint = torch.load(weights_path, weights_only=False) + checkpoint = torch.load(weights_path) if torch.cuda.device_count() > 1: model.module.load_state_dict(checkpoint["model"]) else: diff --git a/train.py b/train.py index cc277b1..b6e5a17 100644 --- a/train.py +++ b/train.py @@ -215,7 +215,7 @@ def train(subset: str, dld_mode="reuse_dataset_if_exists", bsz=1): snapshot_download("MuGeminorum/tunesformer", cache_dir=TEMP_DIR) + "/weights.pth" ) - checkpoint = torch.load(tunesformer_weights_path, weights_only=False) + checkpoint = torch.load(tunesformer_weights_path) if torch.cuda.device_count() > 1: model.module.load_state_dict(checkpoint["model"]) else: