Skip to content

Commit

Permalink
rm weights only
Browse files Browse the repository at this point in the history
  • Loading branch information
admin committed Sep 2, 2024
1 parent 1c97741 commit 91271b9
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def generate_abc(args):
vocab_size=128,
)
model: nn.Module = TunesFormer(patch_config, char_config, SHARE_WEIGHTS)
checkpoint = torch.load(args.weights, weights_only=False)
checkpoint = torch.load(args.weights)
model.load_state_dict(checkpoint["model"])
model = model.to(DEVICE)
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def generate_music(
vocab_size=128,
)
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
checkpoint = torch.load(weights, weights_only=False)
checkpoint = torch.load(weights)
model.load_state_dict(checkpoint["model"])
model = model.to(DEVICE)
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def load_model(
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)

checkpoint = torch.load(weights_path, weights_only=False)
checkpoint = torch.load(weights_path)
if torch.cuda.device_count() > 1:
model.module.load_state_dict(checkpoint["model"])
else:
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def train(subset: str, dld_mode="reuse_dataset_if_exists", bsz=1):
snapshot_download("MuGeminorum/tunesformer", cache_dir=TEMP_DIR)
+ "/weights.pth"
)
checkpoint = torch.load(tunesformer_weights_path, weights_only=False)
checkpoint = torch.load(tunesformer_weights_path)
if torch.cuda.device_count() > 1:
model.module.load_state_dict(checkpoint["model"])
else:
Expand Down

0 comments on commit 91271b9

Please sign in to comment.