-
Notifications
You must be signed in to change notification settings - Fork 0
/
streamlit_app.py
128 lines (95 loc) · 4.09 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import librosa
import tensorflow as tf
from tensorflow.keras.models import load_model
import os
import math
import numpy as np
import pickle
import streamlit as st
import pandas as pd
# extract MFCCs
def extract_mfcc(signal, SAMPLE_RATE):
# variable declaration
n_mfcc = 32
n_fft = 2048
hop_length = 512
# divide the song into 3 second segments so that we can feed into the model. Since it was trained on (129,32) data
# Solving this my having fixed segment length. At most missing 3 seconds
# of the total song
len_segment = (2.98 * SAMPLE_RATE)
num_segments = (int)(len(signal) / len_segment)
print("len_segment: {}, num_segments {}, cropped signal Length: {}, Original signal length: {}".format(
len_segment, num_segments, (len_segment * num_segments), len(signal)))
singal_cropped = signal[: (int)(len_segment * num_segments)]
num_sam = int(len(singal_cropped) / num_segments)
expected_num_mfcc_vectors_per_segment = math.ceil(
num_sam / hop_length) # round up
print(expected_num_mfcc_vectors_per_segment)
st.write("Processing Audio")
data = []
for s in range(num_segments):
start_sample = num_sam * s
finish_sample = start_sample + num_sam
mfcc = librosa.feature.mfcc(y=signal[start_sample:finish_sample],
sr=sr,
n_fft=n_fft,
n_mfcc=n_mfcc,
hop_length=hop_length)
mfcc = mfcc.T
if len(mfcc) == expected_num_mfcc_vectors_per_segment:
data.append(mfcc.tolist())
print("segment:{}".format(s + 1))
else:
print("wrong length")
data = np.array(data)
return data
def create_aggregate_data(predictions):
sum = np.zeros(shape=(predictions.shape))
for i in range(0, predictions.shape[0]):
sum[i] = predictions[0:i].sum(axis=0) / i
chart_data_aggregate = pd.DataFrame(sum, columns=mapping)
return chart_data_aggregate
# Setting it so that Streamlit uses the whole page
st.set_page_config(layout="wide")
st.write("""
## Audio Genre Classification
Hello, this is a web app for audio genre classification
""")
test_file = "Data/fma_small/000/000002.mp3"
with open('data/mapping.pickle', 'rb') as f:
mapping = pickle.load(f)
# getting the file
uploaded_file = st.file_uploader("Upload Files", type='.mp3')
# checking that we have an uploaded file
if uploaded_file is not None:
# note for later, how might this conflict if there are two people running
# the same script
filepath = 'temp/' + uploaded_file.name
with open(filepath, "wb") as f:
f.write(uploaded_file.getbuffer())
# display the audio file
st.audio(filepath, format='audio/mp3')
# loading in the audio file
st.write("Starting to load audio file")
SAMPLE_RATE = 22050
signal, sr = librosa.load(filepath, sr=SAMPLE_RATE)
data = extract_mfcc(signal, SAMPLE_RATE)
# expanding dimension to feed into CNN
data_cnn = np.expand_dims(data, axis=-1)
model = tf.keras.models.load_model('models/cnn_model_32.h5')
predictions = model.predict(data_cnn)
# used to calculate how sure we are
total_prob = predictions.shape[0]
sum_predictions = predictions.sum(axis=0)
predicted_indices = np.argsort(sum_predictions)
print(sum_predictions, predicted_indices, mapping)
# Really Messy Print statement that tells us the prediction certainties
st.write("The first prediction is: {} with a {}% certainty , the second guess is: {} with a {}% certainty"
.format(mapping[predicted_indices[-1]], round((sum_predictions[predicted_indices[-1]] / total_prob) * 100, 2),
mapping[predicted_indices[-2]], round((sum_predictions[predicted_indices[-2]] / total_prob) * 100, 2)))
# Create a plot that shows the probability over time
chart_data = pd.DataFrame(data=predictions, columns=mapping)
st.line_chart(chart_data)
chart_data_aggregate = create_aggregate_data(predictions)
st.line_chart(chart_data_aggregate)
os.remove(filepath)