-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathembedding.py
60 lines (43 loc) · 2.13 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from data_io import mag
class Embedding(torch.nn.Module):
def __init__(
self,
embedding_size,
):
super(Embedding, self).__init__()
self.embedding_size = embedding_size
weight = torch.zeros((embedding_size, embedding_size))
self.embedding = torch.nn.Embedding.from_pretrained(weight)
def pretrain(self, training_set, mimic_model, device):
""" This is a hack of embeddings and autograd to allow soft
'prototypical' posterior distributions for each senone. """
# Enable grad, so we can hack it to update our senone "embedding"
self.embedding.weight.requires_grad = True
# Initialize counts to a small value, so we don't divide by zero
senone_counts = torch.zeros([self.embedding_size, 1]).to(device) + 0.01
# Go through dataset, and add up posteriors and counts
for example in training_set:
# Generate posterior
clean_mag = mag(example['clean'].to(device), truncate=True)
senones = example['senone'].to(device)
posteriors = mimic_model(clean_mag)[-1]
posteriors = posteriors[:,:,:senones.shape[1]].transpose(1, 2)
# Embed senone, so we can update the result
embedded = self.embedding(senones)
# Multiply posteriors so that we can add to the gradient
embedded *= posteriors
# Propagate gradient to the embedding
embedded.sum().backward()
# Count instances of senones
example_senone_counts = senones[0].bincount(minlength = self.embedding_size).float().unsqueeze(1)
senone_counts += example_senone_counts
# Divide and update
with torch.no_grad():
self.embedding.weight *= (senone_counts - example_senone_counts) / senone_counts
self.embedding.weight += self.embedding.weight.grad / senone_counts
self.embedding.weight.grad.zero_()
# Turn off grad again
self.embedding.weight.requires_grad = False
def forward(self, x):
return self.embedding(x)