-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiarization_pipeline.py
166 lines (150 loc) · 7.64 KB
/
diarization_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import torch
from pyannote.audio.models import SincTDNN
from pyannote.audio.train.task import Task, TaskOutput, TaskType
from pyannote.core import Annotation
from pyannote.core import Segment, Annotation
from pyannote.audio.utils.signal import Binarize
from speaker_diarization_overlap import SpeakerDiarizationOverlap
from g_net import GNet
import sys
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def baseline(pipeline, test_file):
hypothesis = Annotation()
diarization = pipeline(test_file)
for seg_cnt, seg in enumerate(diarization._tracks._list):
if seg.duration <= 0.1:
continue
label = diarization._tracks[seg]
label = label[list(label.keys())[0]]
hypothesis[seg] = label
return hypothesis
def baseline_ovl(pipeline, ovl, test_file):
hypothesis = Annotation()
ovl_scores = ovl(test_file)
binarize = Binarize(offset=0.52, onset=0.52, log_scale=True, min_duration_off=0.1, min_duration_on=0.1)
overlap = binarize.apply(ovl_scores, dimension=1)
test_file['overlap'] = overlap
diarization, diarization1 = pipeline(test_file)
for seg_cnt, seg in enumerate(diarization._tracks._list):
if seg.duration <= 0.1:
continue
label = diarization._tracks[seg]
label1 = diarization1._tracks[seg]
label = label[list(label.keys())[0]]
label1 = label1[list(label1.keys())[0]]
hypothesis[seg] = label
for ref_seg in overlap:
if ref_seg.intersects(seg):
hypothesis[Segment(max(seg.start, ref_seg.start), min(seg.end, ref_seg.end))] = label
hypothesis[Segment(max(seg.start, ref_seg.start), min(seg.end, ref_seg.end)+1e-7)] = label1
return hypothesis
def comp(pipeline, test_file):
hypothesis = Annotation()
diarization = pipeline(test_file)
for seg_cnt, seg in enumerate(diarization._tracks._list):
if seg.duration <= 0.1:
continue
label = diarization._tracks[seg]
label = label[list(label.keys())[0]]
if '_' in str(label):
a,b = label.split('_')
hypothesis[seg] = a
hypothesis[Segment(seg.start, seg.end+1e-7)] = b
else:
hypothesis[seg] = label
return hypothesis
def comp_ovl(pipeline, ovl, test_file):
hypothesis = Annotation()
binarize = Binarize(offset=0.52, onset=0.52, log_scale=True, min_duration_off=0.1, min_duration_on=0.1)
ovl_scores = ovl(test_file)
overlap = binarize.apply(ovl_scores, dimension=1)
test_file['overlap'] = overlap
diarization, diarization1 = pipeline(test_file)
for seg_cnt, seg in enumerate(diarization._tracks._list):
if seg.duration <= 0.1:
continue
label = diarization._tracks[seg]
label1 = diarization1._tracks[seg]
label = label[list(label.keys())[0]]
label1 = label1[list(label1.keys())[0]]
hypothesis[seg] = label
for ref_seg in overlap:
if ref_seg.intersects(seg):
hypothesis[Segment(max(seg.start, ref_seg.start), min(seg.end, ref_seg.end))] = label1.split('_')[0]
hypothesis[Segment(max(seg.start, ref_seg.start), min(seg.end, ref_seg.end)+1e-7)] = label1.split('_')[1]
return hypothesis
if __name__ == "__main__":
ami_path = sys.argv[1]
# create result folders for all experiments
os.makedirs('results/baseline', exist_ok=True)
os.makedirs('results/baseline_ovl', exist_ok=True)
os.makedirs('results/comp', exist_ok=True)
os.makedirs('results/comp_ovl', exist_ok=True)
# get the name of audios
audios = []
with open(f'{ami_path}/AMI/MixHeadset.test.rttm') as f:
lines = f.readlines()
for line in lines:
audio_name = line.strip().split()[1].split('.')[0]
if audio_name not in audios:
audios.append(audio_name)
# baseline experiment
task = Task(TaskType.REPRESENTATION_LEARNING,TaskOutput.VECTOR)
specifications = {'X':{'dimension': 1} ,'task': task}
sincnet = {'instance_normalize': True, 'stride': [5, 1, 1], 'waveform_normalize': True}
tdnn = {'embedding_dim': 512}
embedding = {'batch_normalize': False, 'unit_normalize': False}
f_net = SincTDNN(specifications=specifications, sincnet=sincnet, tdnn=tdnn, embedding=embedding).to(device)
f_net.load_state_dict(torch.load("checkpoints/f_vxc.pt"))
pipeline = SpeakerDiarizationOverlap('baseline', None, device, sad_scores='sad_ami', scd_scores='scd_ami', embedding='emb_ami', method = 'affinity_propagation')
pipeline.load_params('config.yml')
pipeline._pipelines['speech_turn_clustering']._embedding.scorer_.model_ = f_net
pipeline._pipelines['speech_turn_assignment']._embedding.scorer_.model_ = f_net
for cnt, audio in enumerate(audios):
print(audio)
test_file = {'uri': f'{audio}.Mix-Headset', 'audio': f'{ami_path}/amicorpus/{audio}/audio/{audio}.Mix-Headset.wav'}
hypothesis = baseline(pipeline, test_file)
hypothesis.uri = audio+'.Mix-Headset'
with open(f'results/baseline/{audio}.rttm', 'w') as f:
hypothesis.write_rttm(f)
# baseline with overlap detector
ovl = torch.load('checkpoints/ovl.pt')
pipeline = SpeakerDiarizationOverlap('baseline_ovl', None, device, sad_scores='sad_ami', scd_scores='scd_ami', embedding='emb_ami', method = 'affinity_propagation')
pipeline.load_params('config.yml')
pipeline._pipelines['speech_turn_clustering']._embedding.scorer_.model_ = f_net
pipeline._pipelines['speech_turn_assignment']._embedding.scorer_.model_ = f_net
for cnt, audio in enumerate(audios):
print(audio)
test_file = {'uri': f'{audio}.Mix-Headset', 'audio': f'{ami_path}/amicorpus/{audio}/audio/{audio}.Mix-Headset.wav'}
hypothesis = baseline_ovl(pipeline, ovl, test_file)
hypothesis.uri = audio+'.Mix-Headset'
with open(f'results/baseline_ovl/{audio}.rttm', 'w') as f:
hypothesis.write_rttm(f)
# compositional embedding
f_net.load_state_dict(torch.load("checkpoints/best_f.pt"))
g_net = GNet().to(device)
g_net.load_state_dict(torch.load("checkpoints/best_g.pt"))
pipeline = SpeakerDiarizationOverlap('comp', g_net, device, sad_scores='sad_ami', scd_scores='scd_ami', embedding='emb_ami', method = 'affinity_propagation')
pipeline.load_params('config.yml')
pipeline._pipelines['speech_turn_clustering']._embedding.scorer_.model_ = f_net
pipeline._pipelines['speech_turn_assignment']._embedding.scorer_.model_ = f_net
for cnt, audio in enumerate(audios):
print(audio)
test_file = {'uri': f'{audio}.Mix-Headset', 'audio': f'{ami_path}/amicorpus/{audio}/audio/{audio}.Mix-Headset.wav'}
hypothesis = comp(pipeline, test_file)
hypothesis.uri = audio+'.Mix-Headset'
with open(f'results/comp/{audio}.rttm', 'w') as f:
hypothesis.write_rttm(f)
# compositional embedding with overlap detector
pipeline = SpeakerDiarizationOverlap('comp_ovl', g_net, device, sad_scores='sad_ami', scd_scores='scd_ami', embedding='emb_ami', method = 'affinity_propagation')
pipeline.load_params('config.yml')
pipeline._pipelines['speech_turn_clustering']._embedding.scorer_.model_ = f_net
pipeline._pipelines['speech_turn_assignment']._embedding.scorer_.model_ = f_net
for cnt, audio in enumerate(audios):
print(audio)
test_file = {'uri': f'{audio}.Mix-Headset', 'audio': f'{ami_path}/amicorpus/{audio}/audio/{audio}.Mix-Headset.wav'}
hypothesis = comp_ovl(pipeline, ovl, test_file)
hypothesis.uri = audio+'.Mix-Headset'
with open(f'results/comp_ovl/{audio}.rttm', 'w') as f:
hypothesis.write_rttm(f)