-
Notifications
You must be signed in to change notification settings - Fork 0
/
spike_decay.py
65 lines (53 loc) · 2.33 KB
/
spike_decay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import matplotlib.pyplot as plt
import time
import torch
from pyJoules.energy_meter import EnergyContext
from pyJoules.device.rapl_device import RaplPackageDomain, RaplDramDomain
from pyJoules.device.nvidia_device import NvidiaGPUDomain
from pyJoules.handler.pandas_handler import PandasHandler
def spike_decay(input_seq, weight_lambda, t_step, batch_size, size_F, device):
full_size = (batch_size,) + size_F
F = torch.zeros(full_size, dtype=torch.float32, device=device)
lambda_tensor = torch.ones(size_F, dtype=torch.float32, device=device) * weight_lambda
F_result = []
timestamps_result = []
pandas_handler = PandasHandler()
domains = [RaplPackageDomain(0), RaplDramDomain(0), NvidiaGPUDomain(0)]
# domains = [RaplPackageDomain(0)]
# domains = [NvidiaGPUDomain(0)]
with EnergyContext(handler=pandas_handler, domains=domains, start_tag='start_decay_loop') as ctx:
for i, inp in enumerate(input_seq):
# timestamps_result.append(time.time())
# ctx.record(tag=f'inp_el {i}: scale F')
# for j in range(10):
F *= lambda_tensor
# torch.cuda.default_stream(device=0).synchronize()
# ctx.record(tag=f'inp_el {i}: add input')
# for j in range(10):
F += inp
# torch.cuda.default_stream(device=0).synchronize()
# ctx.record(tag=f'inp_el {i}: sleep')
# F_result.append(F[0,0,0])
# time.sleep(t_step)
# torch.cuda.default_stream(device=0).synchronize()
torch.cuda.default_stream(device=0).synchronize()
return F_result, timestamps_result, pandas_handler.get_dataframe()
config = {'gpu': 0}
device = torch.device(f"cuda:{config['gpu']}") if config['gpu'] > -1 else torch.device('cpu')
input_seq = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]*682
weight_lambda = 0.9
t_step = 0
batch_size = 128
size_F = (37, 48)
size_F = (64, 2656)
spike_decay(input_seq, weight_lambda, t_step, batch_size, size_F, device)
F_result, timestamps_result, energy_df = spike_decay(input_seq, weight_lambda, t_step, batch_size, size_F, device)
F_result = [F.cpu() for F in F_result]
rel_time = (np.array(timestamps_result)- timestamps_result[0])*1000
plt.figure()
plt.plot(rel_time, F_result)
plt.xlabel('Time [ms]')
plt.ylabel('weight F')
plt.savefig('figA100.pdf')
print(energy_df)