-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
134 lines (114 loc) · 5.35 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import streamlit as st
st.set_page_config(page_title="Song Recommendation", layout="wide")
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import plotly.express as px
import streamlit.components.v1 as components
@st.cache(allow_output_mutation=True)
def load_data():
df = pd.read_csv("cleaned_dataset.csv")
df['genres'] = df.genres.apply(lambda x: [i[1:-1] for i in str(x)[1:-1].split(", ")])
exploded_track_df = df.explode("genres")
return exploded_track_df
genre_names = ['Dance Pop', 'Electronic', 'Electropop', 'Hip Hop', 'Jazz', 'K-pop', 'Latin', 'Pop', 'Pop Rap', 'R&B', 'Rock']
audio_feats = ["acousticness", "danceability", "energy", "instrumentalness", "valence", "tempo"]
exploded_track_df = load_data()
def n_neighbors_uri_audio(genre, start_year, end_year, test_feat):
genre = genre.lower()
genre_data = exploded_track_df[(exploded_track_df["genres"]==genre) & (exploded_track_df["release_year"]>=start_year) & (exploded_track_df["release_year"]<=end_year)]
genre_data = genre_data.sort_values(by='popularity', ascending=False)[:500]
neigh = NearestNeighbors()
neigh.fit(genre_data[audio_feats].to_numpy())
n_neighbors = neigh.kneighbors([test_feat], n_neighbors=len(genre_data), return_distance=False)[0]
uris = genre_data.iloc[n_neighbors]["uri"].tolist()
audios = genre_data.iloc[n_neighbors][audio_feats].to_numpy()
return uris, audios
title = "Welcome To Banger Station."
st.title(title)
st.header("Listen to a few bangers... Jk:Every song here is a banger.")
st.write("Please play around with the audio features.Also dont forget to enjoy these bangers.")
st.markdown("##")
with st.container():
col1, col2,col3,col4 = st.columns((2,0.5,0.5,0.5))
with col3:
st.markdown("***Choose your genre:***")
genre = st.radio(
"",
genre_names, index=genre_names.index("Pop"))
with col1:
st.markdown("***Choose features to customize:***")
start_year, end_year = st.slider(
'Select the year range',
1990, 2019, (2015, 2019)
)
acousticness = st.slider(
'Acousticness',
0.0, 1.0, 0.5)
danceability = st.slider(
'Danceability',
0.0, 1.0, 0.5)
energy = st.slider(
'Energy',
0.0, 1.0, 0.5)
instrumentalness = st.slider(
'Instrumentalness',
0.0, 1.0, 0.0)
valence = st.slider(
'Valence',
0.0, 1.0, 0.45)
tempo = st.slider(
'Tempo',
0.0, 244.0, 118.0)
tracks_per_page = 6
test_feat = [acousticness, danceability, energy, instrumentalness, valence, tempo]
uris, audios = n_neighbors_uri_audio(genre, start_year, end_year, test_feat)
tracks = []
for uri in uris:
track = """<iframe src="https://open.spotify.com/embed/track/{}" width="260" height="380" frameborder="0" allowtransparency="true" allow="encrypted-media"></iframe>""".format(uri)
tracks.append(track)
if 'previous_inputs' not in st.session_state:
st.session_state['previous_inputs'] = [genre, start_year, end_year] + test_feat
current_inputs = [genre, start_year, end_year] + test_feat
if current_inputs != st.session_state['previous_inputs']:
if 'start_track_i' in st.session_state:
st.session_state['start_track_i'] = 0
st.session_state['previous_inputs'] = current_inputs
if 'start_track_i' not in st.session_state:
st.session_state['start_track_i'] = 0
with st.container():
col1, col2, col3 = st.columns([2,1,2])
if st.button("Recommend More Songs"):
if st.session_state['start_track_i'] < len(tracks):
st.session_state['start_track_i'] += tracks_per_page
current_tracks = tracks[st.session_state['start_track_i']: st.session_state['start_track_i'] + tracks_per_page]
current_audios = audios[st.session_state['start_track_i']: st.session_state['start_track_i'] + tracks_per_page]
if st.session_state['start_track_i'] < len(tracks):
for i, (track, audio) in enumerate(zip(current_tracks, current_audios)):
if i%2==0:
with col1:
components.html(
track,
height=400,
)
with st.expander("See more details"):
df = pd.DataFrame(dict(
r=audio[:5],
theta=audio_feats[:5]))
fig = px.line_polar(df, r='r', theta='theta', line_close=True)
fig.update_layout(height=400, width=340)
st.plotly_chart(fig)
else:
with col3:
components.html(
track,
height=400,
)
with st.expander("See more details"):
df = pd.DataFrame(dict(
r=audio[:5],
theta=audio_feats[:5]))
fig = px.line_polar(df, r='r', theta='theta', line_close=True)
fig.update_layout(height=400, width=340)
st.plotly_chart(fig)
else:
st.write("No more songs left to recommend, You heard em' all.")