-
Notifications
You must be signed in to change notification settings - Fork 0
/
New_03_FastWhisper_EcapaDiariz._Combine.py
213 lines (181 loc) · 8.8 KB
/
New_03_FastWhisper_EcapaDiariz._Combine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# To ignore warnings
import warnings
warnings.filterwarnings("ignore")
from ctypes import *
ERROR_HANDLER_FUNC = CFUNCTYPE(None, c_char_p, c_int, c_char_p, c_int, c_char_p)
def py_error_handler(filename, line, function, err, fmt):
return
c_error_handler = ERROR_HANDLER_FUNC(py_error_handler)
asound = cdll.LoadLibrary('libasound.so')
asound.snd_lib_error_set_handler(c_error_handler)
# Import required libraries
from faster_whisper import WhisperModel
import pyaudio
import wave
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.io import wavfile
from itertools import combinations
from collections import OrderedDict
from pydub import AudioSegment
from diarization.voice_activity_detection import voice_activity_detection
from SNR.snr import wada_snr
from New_02_ECAPA_Diarization import pairwiseDists
from New_02_ECAPA_Diarization import SpeakerDiarizationChunkEcapa as SpeakerDiarization
from statistics import mode
from anonymization.anonymization import anonymize_text, anonymize_text_with_deny_list
import json
from decimal import Decimal
# Make Torch use the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:",device)
def record_chunk(p, stream, file_path, chunk_length=3):
"""
Function that retrives that recorded audio of length "chunk_length", saves it in file_path and returns the audio in numpy format.
"""
frames=[]
for _ in range(0, int(16000/1024*chunk_length)):
data= stream.read(1024)
frames.append(data)
# Save the audio in the mentioned file_path
wf= wave.open(file_path, 'wb')
wf.setnchannels(1)
wf.setsampwidth(p.get_sample_size(pyaudio.paInt16))
wf.setframerate(16000)
wf.writeframes(b''.join(frames))
wf.close()
# Return the audio in numpy format
final = b''.join(frames)
audio_np = np.frombuffer(final, dtype=np.int16).astype(np.float32) / 32768.0
return audio_np
def extract_audio_between_timestamps(audio_np, sample_rate, start_ts, end_ts):
"""
Function to extract audio from audio_np between start_ts and end_ts timestamps.
Parameters:
audio_np (numpy.ndarray): NumPy array containing the audio data.
sample_rate (int): Sampling rate of the audio data.
start_ts (float): Start timestamp (in seconds).
end_ts (float): End timestamp (in seconds).
Returns:
numpy.ndarray: Extracted audio data between start_ts and end_ts timestamps.
"""
# Convert timestamps to indices
start_idx = int(start_ts * sample_rate)
end_idx = int(end_ts * sample_rate)
# Slice the audio_np array to extract the desired portion
extracted_audio = audio_np[start_idx:end_idx]
return extracted_audio
def AudioSegment_to_np_array(asg):
"""
Funtion to convert audio files to np array
"""
dtype = getattr(np, "int{:d}".format(asg.sample_width * 8)) # Or could create a mapping: {1: np.int8, 2: np.int16, 4: np.int32, 8: np.int64}
arr = np.ndarray((int(asg.frame_count()), asg.channels), buffer=asg.raw_data, dtype=dtype)
return arr
def load_video_to_ndarray(filepath):
"""load video and convert to numpy 1-d ndarray of PCM data representing audio signal
Args:
filepath (str): path to media file (must be an audio or video file supported by ffmpeg)
Returns:
sig: _description_
sig_nsamp
"""
#load chunk and covert to 16k mono pcm on the fly
chunk_audio = AudioSegment.from_file(filepath)
SAMPLE_RATE = 16000 # Hz, see load_video for requirements
chunk_audio = chunk_audio.set_channels(1).set_sample_width(2).set_frame_rate(SAMPLE_RATE)
sig = AudioSegment_to_np_array(chunk_audio).flatten()
sig_nsamp = len(sig)
return sig, sig_nsamp
def main():
# Initialize the fast whisper models
model_size= "large-v3" #"medium" "distil-large-v2"
model= WhisperModel(model_size, device= str(device), compute_type= "float16")
# Start recording the audio
SAMPLE_RATE = 16000 # Hz, see load_video for requirements
CHUNK_LENGTH = 3 # Define the frequency of recording in seconds
MIN_CHUNK_LEN = 0.1 # seconds, chunks shorter than this will be skipped
p= pyaudio.PyAudio()
stream= p.open(format= pyaudio.paInt16, channels=1, rate= SAMPLE_RATE, input=True, frames_per_buffer= 1024)
print("Recording Started...")
# Variable the transcript to collect all the transcripts in this session for logging
accumulated_transcripts= ""
# Initialze the diarizer with recorded enrollments
enrollment_utterances = {'Niranjan':load_video_to_ndarray('Enrollment_Niranjan.wav'),
'Kamala Harris':load_video_to_ndarray('Enrollment_Kamala_Harris.wav'),
'Trump':load_video_to_ndarray('Enrollment_Trump.wav')}
diarizer = SpeakerDiarization(refSpeakers=enrollment_utterances) #stepSize decides the frequency of diarization. Say if it is 0.5 then the result will be given as a speaker ID for every 0.5secs of the input audio in a dictionary.
sessionID= 1 # TODO: get this sort of metadaata from calling function
chunk_number= 1 # TODO: Should this be 0?
chunk_start= 0.0
try:
while True:
# Retrive the recorded audio and transcribe them
chunk_file= 'temp_chunk.wav'
recorded_audio= record_chunk(p, stream, chunk_file, chunk_length=CHUNK_LENGTH)
segments, info= model.transcribe(chunk_file, beam_size= 5, vad_filter=True, language= 'en') # Note: we can get faster with greedy search instead of beam search.
os.remove(chunk_file)
# Prepare the output dictionary
chunk_results = {"sessionID": sessionID,
"chunk_number":chunk_number,
"start_sec":chunk_start, # this needs to be time relative to start of audio file
"end_sec":chunk_start+CHUNK_LENGTH # this needs to be time relative to start of audio file
}
# Unpack the segments and diarize them individually.
utterance_num= 0
unique_speakers= list()
utterances = list()
for segment in segments:
segment_length= segment.end- segment.start # Length of the segment. Used for diarizer.
transcript= anonymize_text_with_deny_list(segment.text) # Anonymize the text
accumulated_transcripts+= transcript+" "
seg_start_relative= chunk_start+ segment.start
# Find the closest matching speaker of this segment
audio_to_diarize= extract_audio_between_timestamps(recorded_audio, SAMPLE_RATE, segment.start, segment.end) # Extract the audio using timeframe
try: # Sometimes the segement length would to too short to diarize and fails. So, this exception function is used to skip such cases.
speaker = diarizer.getResults(audio_to_diarize, stepSize= segment_length)
except:
print("Segment too short")
continue
"""
# Prepare the output dictionary
chunk_results = {"sessionID": sessionID,
"chunk_number":chunk_number,
"start_sec":seg_start_relative,
"end_sec":seg_start_relative+segment_length
}
"""
if speaker not in unique_speakers:
unique_speakers.append(speaker)
utterances.append({'utterance_id': utterance_num,
'text': transcript,
'speaker': speaker,
'start_time': segment.start+chunk_start,
'end_time': segment.end+chunk_start,
'confidence': segment.avg_logprob
})
utterance_num+=1
print(speaker,":", transcript)
snr_result = wada_snr(recorded_audio)
chunk_results['unique_speakers']=unique_speakers
chunk_results['utterances']= utterances
chunk_results['signal_to_noise'] = snr_result
#yield(chunk_results)
print(chunk_results)
print("----------------------------------------------")
chunk_number+=1
chunk_start+= CHUNK_LENGTH
except KeyboardInterrupt:
print("Stopping...")
with open("log.text", 'w') as log_file:
log_file.write(accumulated_transcripts)
finally:
print("LOG:", accumulated_transcripts)
stream.stop_stream()
stream.close()
p.terminate()
if __name__ == "__main__":
main()