-
Notifications
You must be signed in to change notification settings - Fork 196
/
example.py
84 lines (60 loc) · 2.08 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import os
import pathlib
import time
import torch
from bytesep.models.lightning_modules import get_model_class
from bytesep.separator import Separator
def user_defined_build_separator() -> Separator:
r"""Users could modify this file to load different models.
Returns:
separator: Separator
"""
input_channels = 2
output_channels = 2
target_sources_num = 1
segment_samples = int(44100 * 30.)
batch_size = 1
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model_type = "ResUNet143_Subbandtime"
if model_type == "ResUNet143_Subbandtime":
checkpoint_path = os.path.join(pathlib.Path.home(), "bytesep_data",
"resunet143_subbtandtime_vocals_8.7dB_500k_steps_v2.pth")
elif model_type == "MobileNet_Subbandtime":
checkpoint_path = os.path.join(pathlib.Path.home(), "bytesep_data",
"mobilenet_subbtandtime_accompaniment_14.6dB_500k_steps_v2.pth")
# Get model class.
Model = get_model_class(model_type)
# Create model.
model = Model(
input_channels=input_channels,
output_channels=output_channels,
target_sources_num=target_sources_num,
)
# Load checkpoint.
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint["model"])
# Move model to device.
model.to(device)
# Create separator.
separator = Separator(
model=model,
segment_samples=segment_samples,
batch_size=batch_size,
device=device,
)
return separator
def main():
r"""An example of using bytesep in your programme. After installing bytesep,
users could copy and execute this file in any directory.
"""
# Build separator.
separator = user_defined_build_separator()
# dummy audio
input_dict = {'waveform': np.zeros((2, 44100 * 60))}
# Separate.
separate_time = time.time()
sep_audio = separator.separate(input_dict)
print("Done! {:.3f} s".format(time.time() - separate_time))
if __name__ == "__main__":
main()