-
Notifications
You must be signed in to change notification settings - Fork 103
/
embeddings.py
57 lines (47 loc) · 1.81 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import json
import os
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer, BertModel
tokenizer = None
model = None
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_bert():
global tokenizer
global model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased").to(device)
# make embeddings and save as np
def make_embeddings():
load_bert()
dataset_folder = "./dataset/processed-lyrics-spotify"
dataset_files = os.listdir(dataset_folder)
embeddings_file = "embeddings"
embedding_lengths_file = "embedding_lengths.json"
embeddings = []
embedding_lengths = {}
with torch.no_grad():
for i, file in enumerate(dataset_files):
with open(f"{dataset_folder}/{file}", 'r') as json_file:
json_parsed = json.load(json_file)
lyrics = json_parsed["lyrics"]
embedding, length = make_embedding(lyrics)
embeddings.append(embedding.cpu())
embedding_lengths[file] = length
print(i)
embeddings = pad_sequence(embeddings, batch_first=True).numpy()
np.save(embeddings_file, embeddings)
with open(embedding_lengths_file, 'w') as outfile:
json.dump(embedding_lengths, outfile)
def make_embedding(lyrics, custom_device=None):
global device
if custom_device is not None:
device = custom_device
if model is None:
load_bert()
encoded_input = tokenizer(lyrics, truncation=True, return_tensors='pt').to(device)
output = model(**encoded_input)
embedding = output.last_hidden_state[0]
length = output.last_hidden_state.shape[1]
return embedding, length