-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFxNLMS_algorithm.py
79 lines (65 loc) · 2.61 KB
/
FxNLMS_algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import scipy.signal as signal
import progressbar
#------------------------------------------------------------------------------
# Class: FxNLMS algorithm
#------------------------------------------------------------------------------
class FxNLMS():
def __init__(self, Len):
self.Wc = torch.zeros(1, Len, requires_grad=True, dtype=torch.float)
self.Xd = torch.zeros(1, Len, dtype= torch.float)
def feedforward(self,Xf):
self.Xd = torch.roll(self.Xd,1,1)
self.Xd[0,0] = Xf
yt = self.Wc @ self.Xd.t()
power = self.Xd @ self.Xd.t() # FxNLMS different from FxLMS
return yt, power
def LossFunction(self, y, d, power):
e = d-y # disturbance-control signal
return e**2/(2*power), e
def _get_coeff_(self):
return self.Wc.detach().numpy()
#------------------------------------------------------------------------------
# Function : train_fxlms_algorithm()
#------------------------------------------------------------------------------
def train_fxnlms_algorithm(Model, Ref, Disturbance, Stepsize=0.0001):
bar = progressbar.ProgressBar(maxval=2*Disturbance.shape[0], \
widgets=[progressbar.Bar('=', '[', ']'), ' ', progressbar.Percentage()])
optimizer= optim.SGD([Model.Wc], lr=Stepsize)
bar.start()
Erro_signal = []
len_data = Disturbance.shape[0]
for itera in range(len_data):
# Feedfoward
xin = Ref[itera]
dis = Disturbance[itera]
y, power = Model.feedforward(xin)
loss, e = Model.LossFunction(y, dis, power)
# Progress shown
bar.update(2*itera+1)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
Erro_signal.append(e.item())
# Progress shown
bar.update(2*itera+2)
bar.finish()
return Erro_signal
#------------------------------------------------------------
# Function : Generating the testing bordband noise
#------------------------------------------------------------
def Generating_boardband_noise_wavefrom_tensor(Wc_F, Seconds, fs):
filter_len = 1024
bandpass_filter = signal.firwin(filter_len, Wc_F, pass_zero='bandpass', window ='hamming',fs=fs)
N = filter_len + Seconds*fs
xin = np.random.randn(N)
y = signal.lfilter(bandpass_filter,1,xin)
yout= y[filter_len:]
# Standarlize
yout = yout/np.sqrt(np.var(yout))
# return a tensor of [1 x sample rate]
return torch.from_numpy(yout).type(torch.float).unsqueeze(0)