-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathAlonsoMarderModel.py
635 lines (575 loc) · 29.8 KB
/
AlonsoMarderModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
import os
os.environ['MKL_NUM_THREADS'] = "1"
import logging
from typing import List, Tuple, Dict, Union
import numpy as np
from collections import OrderedDict
from scipy.integrate import solve_ivp
import time
class AlonsoMarderModel(object):
"""
AlonsoMarderModel provides a class for the AlonsoMarder neuronal modal provided in:
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6395073/
"""
def __init__(self, injected_current=None, initial_conditions=None, reversal_potentials=None, conductances=None,
inf_alphas=None, inf_betas=None, tau_cs=None, tau_ds=None, tau_as=None, tau_bs=None,
tau_a2s=None, tau_b2s=None, tau_ca=None, spike_threshold=None):
self.voltage_trace = None
self.spike_times = None
self.time_steps = None
self.constants = {
'reftemp_celsius': 10.0, # degC, reference temperature
'temp_celsius': 10.0, # degC, temperature
'gas_constant': 8.314 * pow(10, 3), # Ideal Gas Constant (*10^3 to put into mV)
'base_temp_kelvin': 273.15, # Temperature in Kelvin
'z': 2.0, # Valence of Caclium Ions
'faraday': 96485.33, # Faraday's Constant
'ca_conc_extracellular': 3000.0, # Outer Ca Concentration (uM)
'ca_conv_factor': 0.94, # Outer Ca Concentration (uM)
'ca_conc_background': 0.05, # Outer Ca Concentration (uM)
'tau_ca_conc_intracellular': tau_ca or 6.535e2, # Outer Ca Concentration (uM)
'C': 10.0, # nF / cm^2, membrane capacitance
'I_app': injected_current or 0.0, # nA, externally-applied current
}
self.initial_conditions = initial_conditions or OrderedDict({
'V': -51., # mV, membrane voltage
'm_Na': 0., # sodium activation variable
'h_Na': 0., # sodium inactivation variable
'm_CaT': 0., # low-threshold calcium activation variable
'h_CaT': 0., # low-threshold calcium inactivation variable
'm_CaS': 0., # slow calcium activation variable
'h_CaS': 0., # slow calcium inactivation variable
'm_H': 0., # hyperpolarization-activated cation activation variable
'h_H': 1., # hyperpolarization-activated cation inactivation variable
'm_Kd': 0., # potassium activation variable
'h_Kd': 1., # potassium inactivation variable
'm_KCa': 0., # mV, calcium-dependent potassium activation variable
'h_KCa': 1., # mV, calcium-dependent potassium inactivation variable
'm_A': 0., # mV, transient potassium activation variable
'h_A': 0., # mV, transient potassium inactivation variable
'm_L': 1., # leak channel activation variable
'h_L': 1., # leak channel inactivation variable
'ca_conc_intracellular': 5., # uM, intracellular Ca concentration initial condition
})
self.spike_threshold = spike_threshold or {
'spike_threshold': self.initial_conditions["V"] + 15.0,
'threshold_spike': -15.,
'threshold_mid_upper': -35.,
'threshold_mid_lower': -45.,
'threshold_slow_wave': -49.999,
}
self.reversal_potentials = reversal_potentials or {
'E_L': -50., # mV, leak reversal potential
'E_Na': 30., # mV, sodium reversal potential
'E_CaT': self._calculate_calcium_rev_potential(
self.initial_conditions['ca_conc_intracellular'],
self.constants['temp_celsius']), # mV, low-threshold calcium reversal potential
'E_CaS': self._calculate_calcium_rev_potential(
self.initial_conditions['ca_conc_intracellular'],
self.constants['temp_celsius']), # mV, slow calcium reversal potential
'E_Kd': -80., # mV, potassium reversal potential
'E_KCa': -80., # mV, calcium-dependent potassium reversal potential
'E_A': -80., # mV, transient potassium reversal potential
'E_H': -20., # mV, hyperpolarization-activated cation reversal potential
}
self.conductances = conductances or {
'g_Na': 1.0764e3, # uS, transient sodium conductance
'g_CaT': 6.4056e0, # uS, low-threshold calcium conductance
'g_CaS': 1.0048e1, # uS, slow calcium conductance
'g_A': 8.0384e0, # uS, transient potassium conductance
'g_KCa': 1.7584e1, # uS, calcium-dependent potassium conductance
'g_Kd': 1.240928e2, # uS, potassium conductance
'g_H': 1.1304e-1, # uS, hyperpolarization-activated cation conductance
'g_L': 1.7584e-1, # uS, leak conductance
}
self.inf_alphas = inf_alphas or {
'm_Na': 25.5, # sodium activation variable
'h_Na': 48.9, # sodium inactivation variable
'm_CaT': 27.1, # low-threshold calcium activation variable
'h_CaT': 32.1, # low-threshold calcium inactivation variable
'm_CaS': 33.0, # slow calcium activation variable
'h_CaS': 60.0, # slow calcium inactivation variable
'm_H': 70.0, # hyperpolarization-activated cation activation variable
'm_Kd': 12.3, # potassium activation variable
'm_KCa': 28.3, # mV, calcium-dependent potassium activation variable
'm_A': 27.2, # mV, transient potassium activation variable
'h_A': 56.9, # mV, transient potassium inactivation variable
}
self.inf_betas = inf_betas or {
'm_Na': -5.29, # sodium activation variable
'h_Na': 5.18, # sodium inactivation variable
'm_CaT': -7.20, # low-threshold calcium activation variable
'h_CaT': 5.50, # low-threshold calcium inactivation variable
'm_CaS': -8.1, # slow calcium activation variable
'h_CaS': 6.20, # slow calcium inactivation variable
'm_H': 6.0, # hyperpolarization-activated cation activation variable
'm_Kd': -11.8, # potassium activation variable
'm_KCa': -12.6, # mV, calcium-dependent potassium activation variable
'm_A': -8.70, # mV, transient potassium activation variable
'h_A': 4.90, # mV, transient potassium inactivation variable
}
self.tau_cs = tau_cs or {
'm_Na': 1.32, # sodium activation variable
'h_Na_0': 0.0, # sodium inactivation variable
'h_Na_1': 1.50, # sodium inactivation variable
'm_CaT': 21.7, # low-threshold calcium activation variable
'h_CaT': 105.0, # low-threshold calcium inactivation variable
'm_CaS': 1.40, # slow calcium activation variable
'h_CaS': 60.0, # slow calcium inactivation variable
'm_H': 272.0, # hyperpolarization-activated cation activation variable
'm_Kd': 7.20, # potassium activation variable
'm_KCa': 90.3, # mV, calcium-dependent potassium activation variable
'm_A': 11.6, # mV, transient potassium activation variable
'h_A': 38.6, # mV, transient potassium inactivation variable
}
self.tau_ds = tau_ds or {
'm_Na': 1.26, # sodium activation variable
'h_Na_0': -0.67, # sodium inactivation variable
'h_Na_1': -1.00, # sodium inactivation variable
'm_CaT': 21.3, # low-threshold calcium activation variable
'h_CaT': 89.8, # low-threshold calcium inactivation variable
'm_CaS': 7.00, # slow calcium activation variable
'h_CaS': 150.0, # slow calcium inactivation variable
'm_H': -1499.0, # hyperpolarization-activated cation activation variable
'm_Kd': 6.40, # potassium activation variable
'm_KCa': 75.1, # mV, calcium-dependent potassium activation variable
'm_A': 10.4, # mV, transient potassium activation variable
'h_A': 29.2, # mV, transient potassium inactivation variable
}
self.tau_as = tau_as or {
'm_Na': 120.0, # sodium activation variable
'h_Na_0': 62.9, # sodium inactivation variable
'h_Na_1': 34.9, # sodium inactivation variable
'm_CaT': 68.1, # low-threshold calcium activation variable
'h_CaT': 55.0, # low-threshold calcium inactivation variable
'm_CaS': 27.0, # slow calcium activation variable
'h_CaS': 55.0, # slow calcium inactivation variable
'm_H': 42.2, # hyperpolarization-activated cation activation variable
'm_Kd': 28.3, # potassium activation variable
'm_KCa': 46.0, # mV, calcium-dependent potassium activation variable
'm_A': 32.9, # mV, transient potassium activation variable
'h_A': 38.9, # mV, transient potassium inactivation variable
}
self.tau_bs = tau_bs or {
'm_Na': -25.0, # sodium activation variable
'h_Na_0': -10.0, # sodium inactivation variable
'h_Na_1': 3.60, # sodium inactivation variable
'm_CaT': -20.5, # low-threshold calcium activation variable
'h_CaT': -16.9, # low-threshold calcium inactivation variable
'm_CaS': 10.0, # slow calcium activation variable
'h_CaS': 9.00, # slow calcium inactivation variable
'm_H': -8.73, # hyperpolarization-activated cation activation variable
'm_Kd': -19.2, # potassium activation variable
'm_KCa': -22.7, # mV, calcium-dependent potassium activation variable
'm_A': -15.2, # mV, transient potassium activation variable
'h_A': -26.5, # mV, transient potassium inactivation variable
}
self.tau_a2s = tau_a2s or {
'm_CaS': 70.0,
'h_CaS': 65.0,
}
self.tau_b2s = tau_b2s or {
'm_CaS': -13.0,
'h_CaS': -16.0,
}
self.q10s = {
'i_Na': 3.,
'i_CaT': 3.,
'i_CaS': 3.,
'i_H': 1.,
'i_Kd': 4.,
'i_KCa': 4.,
'i_A': 3.,
'i_L': 1.,
'g_Na': 1.,
'm_Na': 1.,
'h_Na': 1.,
'g_CaT': 1.,
'm_CaT': 1.,
'h_CaT': 1.,
'g_CaS': 1.,
'm_CaS': 1.,
'h_CaS': 1.,
'g_A': 1.,
'm_A': 1.,
'h_A': 1.,
'g_KCa': 1.,
'm_KCa': 1.,
'h_KCa': 1.,
'g_Kd': 1.,
'm_Kd': 1.,
'h_Kd': 1.,
'g_H': 1.,
'm_H': 1.,
'h_H': 1.,
'g_L': 1.,
'tau_Ca': 1.,
}
self.channel_types = ('Na', 'CaT', 'CaS', 'H', 'Kd', 'KCa', 'A', 'L')
self.channel_currents = {}
self.state_vars_constant = ('m_L', 'h_H', 'h_Kd', 'h_KCa', 'h_L')
self.state_vars_labels = self.get_state_vars_labels()
self.state_vars = [self.initial_conditions[key] for key in self.state_vars_labels]
def get_state_var(self, key: str) -> float:
"""
get a state variable
:param key: name of channel state var
:return: the current value of the given state variable
"""
if key in self.state_vars_constant:
return 1.0
index_state_var = self.state_vars_labels.index(key)
return self.state_vars[index_state_var]
def get_state_vars_labels(self) -> list:
"""
get state variable labels
:return: state variable labels
"""
state_vars_m = [f'm_{ch}' for ch in self.channel_types]
[state_vars_m.remove(x) for x in self.state_vars_constant if x in state_vars_m]
state_vars_h = [f'h_{ch}' for ch in self.channel_types]
[state_vars_h.remove(x) for x in self.state_vars_constant if x in state_vars_h]
return ['V', 'ca_conc_intracellular'] + state_vars_m + state_vars_h
def _calculate_calcium_rev_potential(self, ca_conc_intracellular, temp_celsius):
"""
computed dynamically using the Nernst equation assuming an extra-cellular calcium concentration of 3e3 uMolars.
:param ca_conc_intracellular: calcium intracellular concentration
:param temp_celsius: temperature in c
:return: calcium reverse potential
"""
rtzf_term = self.constants['gas_constant'] * (self.constants['base_temp_kelvin'] + temp_celsius)
rtzf_term /= (self.constants['z'] * self.constants['faraday'])
return rtzf_term * np.log10(self.constants['ca_conc_extracellular'] / ca_conc_intracellular)
def _calculate_normal_inf_response(self, key: str, voltage: float) -> float:
"""
calculates infinite response
:param key: name of the channel's state variables
:param voltage: current voltage
:return: normal infinite response
"""
return 1. / (1. + np.exp((voltage + self.inf_alphas[key]) / self.inf_betas[key]))
def _calculate_kca_inf_response(self, key: str, voltage: float) -> float:
"""
calculates infinite response with respect to KCa
:param key: name of channel's state variables
:param voltage: current voltage
:return: spec infinite response
"""
index_of_conc_kca = self.state_vars_labels.index('ca_conc_intracellular')
conc_kca = self.state_vars[index_of_conc_kca]
left_term = conc_kca / (conc_kca + 3.0)
return left_term / (1. + np.exp((voltage + self.inf_alphas[key]) / self.inf_betas[key]))
def _calculate_inf_response(self, key: str, voltage: float) -> float:
"""
compute the infinite response for a channel's state variable
:param key: the name of a channel's state variable
:param voltage: current voltage
:return: the value of the channel's state variable's current infinite response
"""
if 'KCa' in key:
return self._calculate_kca_inf_response(key, voltage)
else:
return self._calculate_normal_inf_response(key, voltage)
def get_current_voltage(self) -> float:
"""
get value from state variable and return float
:return: current voltage
"""
index_voltage = self.state_vars_labels.index('V')
return self.state_vars[index_voltage]
def _calculate_normal_tau(self, key: str, voltage: float) -> float:
"""
calcualtes tau normally by : CT - DT/(1. + exp((Volt + AT)/BT))
:param key: name of channels state variables current voltage
:param voltage: current voltage
:return: normal tau
"""
timeconst = self.tau_cs[key]
timeconst -= self.tau_ds[key] / (1. + np.exp((voltage + self.tau_as[key]) / self.tau_bs[key]))
return timeconst
def _calculate_cas_tau(self, key: str, voltage: float) -> float:
"""
calculates CaS tau different from normal tau by : CT + DT/(exp((Volt + AT)/BT) + exp((Volt + AT2)/BT2))
:param key: name of channels state variables
:param voltage: current voltage
:return: spec tau
"""
div_term = np.exp((voltage + self.tau_as[key]) / self.tau_bs[key])
div_term += np.exp((voltage + self.tau_a2s[key]) / self.tau_b2s[key])
return self.tau_cs[key] + self.tau_ds[key] / div_term
def _calculate_double_tau(self, key: str, voltage: float) -> float:
"""
calculate normal tau multiplied twice with different channel states
:param key: name of channels state variables
:param voltage: current voltage
:return: double tau
"""
total = self._calculate_normal_tau(f'{key}_0', voltage) * self._calculate_normal_tau(f'{key}_1', voltage)
return total
def _calculate_tau(self, key: str) -> float:
"""
calculate the time constant for a channel's state variable
:param key: the name of a channel's state variable
:return: the value of the channel's state variable's current time constant
"""
voltage = self.get_current_voltage()
if 'h_Na' in key:
return self._calculate_double_tau(key, voltage)
elif 'CaS' in key:
return self._calculate_cas_tau(key, voltage)
else:
return self._calculate_normal_tau(key, voltage)
def scale_time(self, key: str, value_to_scale: float) -> float:
"""
scare the input by using q10 corresponding to the key in the dict
:param key: name of channels state variables
:param value_to_scale: q10 scalar
:return: scaled value with corresponding q10
"""
temp = self.constants['temp_celsius']
reftemp = self.constants['reftemp_celsius']
return value_to_scale * pow(self.q10s[key], -(temp - reftemp) / 10.0)
def _calculate_channel_current(self, channel: str) -> float:
"""
calculates the channel current
:param channel: ionic channel names
:return: channel current
"""
g = self.conductances[f'g_{channel}']
e_rev = self.reversal_potentials[f'E_{channel}']
h = self.get_state_var(f'h_{channel}')
m = self.get_state_var(f'm_{channel}')
voltage = self.get_current_voltage()
q = self.q10s[f'i_{channel}']
pow_term = pow(m, q)
return g * pow_term * h * (voltage - e_rev)
def _calculate_dvdt(self, channel_currents: dict) -> float:
"""
calculates the derivative of the voltage wrt time
:param channel_currents: channel currents
:return: deriv of current
"""
return (-sum(channel_currents.values()) + self.constants['I_app']) / self.constants['C']
def _calculate_dstate_dt(self, key: str) -> float:
"""
calculates the derivative of the states channels
:param key: name of channel state variables
:return: deriv of state variable
"""
state_var = self.get_state_var(key)
voltage = self.get_current_voltage()
inf_state_var = self._calculate_inf_response(key, voltage)
# tau_state_var = self.scale_time(f'tau_{key}', self._calculate_tau(key))
tau_state_var = self._calculate_tau(key)
return (inf_state_var - state_var) / tau_state_var
def _calculate_dca_conc_intracellular_dt(self, channel_currents: dict) -> float:
"""
calculates the derivative of calcium's intracellular concentration
:param channel_currents: channel current
:return: deriv of calcium's intracellular concentration
"""
ca_conc_intra = self.get_state_var('ca_conc_intracellular')
outcalc = -self.constants['ca_conv_factor']
outcalc *= (channel_currents['CaT'] + channel_currents['CaS'])
outcalc += self.constants['ca_conc_background'] - ca_conc_intra
outcalc /= self.scale_time('tau_Ca', self.constants['tau_ca_conc_intracellular'])
return outcalc
def _calculate_dstate_variable(self, key, channel_currents: dict) -> float:
"""
calculates state variables
:param key: name of channel state variables
:param channel_currents: channel current
:return: state variables
"""
if key == 'V':
return self._calculate_dvdt(channel_currents)
elif key == 'ca_conc_intracellular':
return self._calculate_dca_conc_intracellular_dt(channel_currents)
elif key in self.state_vars_constant:
return np.float64(0.0)
else:
return self._calculate_dstate_dt(key)
# noinspection PyUnusedLocal
def update_state_variables(self, t: float, y: np.array) -> np.array:
"""
updates state variables (dy/dt) function for solve_ivp or other ODE solver
:param t: current time, ignored
:param y: array of current values of state variables
:return: updated state variables
"""
self.state_vars = [y[self.state_vars_labels.index(key)]
for key in self.state_vars_labels]
self._update_ca_rev_potential()
self.channel_currents = {ch: self._calculate_channel_current(ch)
for ch in self.channel_types}
return np.array([self._calculate_dstate_variable(key, self.channel_currents)
for key in self.state_vars_labels])
# return np.array(new_dstate_vars)
# new_array = np.array(new_dstate_vars)
# if np.all(np.isfinite(new_array)):
# return new_array
# else:
# return np.ones_like(new_array) + np.nan
def get_state_vars_and_labels(self) -> Tuple[List[float], List[str]]:
"""
Get the current values of the state variables and their labels
:return: a Tuple of: a list of the values of the state variables, and a list of the labels
"""
return self.state_vars, self.state_vars_labels
def get_initial_conditions(self) -> Tuple[List[float], List[str]]:
"""
Get the initial conditions
:return: a Tuple of: a list of the values of the state variables, and a list of the labels
"""
return [self.initial_conditions[key] for key in self.state_vars_labels], self.state_vars_labels
def _update_ca_rev_potential(self) -> None:
"""
updates calcium reverse potential
:return: None
"""
new_ca_rev_potential = self._calculate_calcium_rev_potential(
self.get_state_var('ca_conc_intracellular'),
self.constants['temp_celsius'])
self.reversal_potentials['E_CaT'] = new_ca_rev_potential # mV, low-threshold calcium reversal potential
self.reversal_potentials['E_CaS'] = new_ca_rev_potential # mV, slow calcium reversal potential
def run_simulation(self, time_steps: np.ndarray) -> Dict[str, Union[np.ndarray, bool, float]]:
"""
run simulation of model
:param time_steps: time steps
:return: 't': time steps, 'y': voltage trace, and 'spike_times': spike times
"""
self.time_steps = time_steps
init_cond = self.get_initial_conditions()[0]
# time_start = time.time()
sol = solve_ivp(self.update_state_variables, [self.time_steps[0],
self.time_steps[-1]],
init_cond, "BDF", self.time_steps)
# out_time = time.time() - time_start
# logging.debug(f'AMM: compute time: {out_time} s')
self.voltage_trace = sol.y[0]
logging.debug("AMM: computing trace characteristics")
dict_characteristics = self.convert_trace_to_spike_characteristics_tonic(self.voltage_trace,
self.time_steps,
self.spike_threshold)
logging.debug("AMM: returning")
return {"t": self.time_steps,
"y": self.voltage_trace,
**dict_characteristics}
@staticmethod
def convert_trace_to_spike_times_upward(voltage: np.array, times: np.array, threshold: float) -> np.ndarray:
"""
Converts voltage traces to spike times
:param voltage: array of voltages (mV)
:param times: array of time values (.1ms)
:param threshold: threshold for spike activation
:return: an array of spike times
"""
# grab all indices above the activation threshold0
voltage_indices_ge_th = np.where(voltage >= threshold)[0]
voltage_indices_ge_th = voltage_indices_ge_th[np.where(voltage_indices_ge_th < (len(times) - 1))[0]]
# only grab the index that is directly below the threshold and remove all other indexes
voltage_indeces_upward_th = voltage_indices_ge_th[np.where(voltage[voltage_indices_ge_th - 1] < threshold)]
# use index from voltage and use for time
return times[voltage_indeces_upward_th]
@staticmethod
def convert_trace_to_spike_times_downward(voltage: np.array, times: np.array, threshold: float) -> np.ndarray:
"""
Converts voltage traces to spike times
:param voltage: array of voltages (mV)
:param times: array of time values (.1ms)
:param threshold: threshold for spike activation
:return: an array of spike times
"""
# grab all indices above the activation threshold0
voltage_indices_ge_th = np.where(voltage >= threshold)[0]
voltage_indices_ge_th = voltage_indices_ge_th[np.where(voltage_indices_ge_th < (len(times) - 1))[0]]
# only grab the index that is directly below the threshold and remove all other indexes
voltage_indeces_downward_th = voltage_indices_ge_th[np.where(voltage[voltage_indices_ge_th + 1] < threshold)]
# use index from voltage and use for time
return times[voltage_indeces_downward_th]
@staticmethod
def calculate_num_spikes_per_burst(spike_times, temporal_interval: Union[int, float]):
spike_diff = np.diff(spike_times)
spike_gt_temp = np.where(np.array(spike_diff) > temporal_interval)[0]
# burst_end_loc = spike_gt_temp.astype(np.int) + 1
burst_end_loc_append = np.append(spike_gt_temp, len(spike_times) - 1)
return np.diff(np.insert(burst_end_loc_append, 0, 0))
@staticmethod
def clean_up_num_spikes_per_burst(num_spikes: np.ndarray) -> np.ndarray:
# indices_of_trailing_downward = np.where(num_spikes == 1)[0] - 1
# num_spikes[indices_of_trailing_downward] += 1
return num_spikes[num_spikes > 1]
@staticmethod
def calculate_num_interbursts(spike_times, temporal_interval):
# spike_diff = [(spike_times[x + 1] - spike_times[x]) for x in range(len(spike_times) - 1)]
spike_diff = np.diff(spike_times)
spike_gt_temp = np.where(np.array(spike_diff) > temporal_interval)[0]
burst_end_loc = spike_gt_temp.astype(np.int) + 1
burst_end_loc_append = np.append(spike_gt_temp, len(spike_times)-1)
burst_start_loc_append = np.insert(burst_end_loc, 0, 0)
return {'bs_loc': burst_start_loc_append,
'be_loc': burst_end_loc_append
}
def convert_trace_to_spike_characteristics_tonic(self, voltage: np.ndarray, times: np.ndarray,
threshold: dict) -> Dict[str, np.ndarray]:
times_before_th_upward = self.convert_trace_to_spike_times_upward(voltage, times, threshold['threshold_spike'])
times_before_th_downward = self.convert_trace_to_spike_times_downward(voltage, times,
threshold['threshold_mid_upper'])
times_before_th_downward2 = self.convert_trace_to_spike_times_downward(voltage, times,
threshold['threshold_mid_lower'])
e_lag = np.mean(np.abs(times_before_th_upward[:times_before_th_downward.size] - times_before_th_downward))
t_sp = -20
# check if a spike is within 100ms AFTER collecting all the spike times that crosses upward on a th.
temporal_interval = 100.
spike_times = times_before_th_upward
num_bursts = self.calculate_num_interbursts(spike_times, temporal_interval)
# TODO: use log-barrier function instead of if
if spike_times.size < 1 or times_before_th_downward.size < 1 or times_before_th_downward2.size < 1 \
or num_bursts['bs_loc'].size < 2 or num_bursts['be_loc'].size < 2:
return {'burst_frequency_mean': 1e5,
'duty_cycle_mean': 1e5,
'times_before_th_downward': times_before_th_downward,
'times_before_th_downward2': times_before_th_downward2,
'num_sw': 0.0,
'e_lag': 0.0,
'num_spikes': np.array([]),
'num_mid': np.array([]),
'spike_per_burst': np.array([]),
'spike_per_burst_mid': np.array([]),
}
bs = spike_times[num_bursts['bs_loc']]
be = spike_times[num_bursts['be_loc']]
if np.all(bs[-1] == be[-1]):
bs = bs[:-1]
be = be[:-1]
burst_duration = be - bs # [:be.size]
# period_half = be
# period = burst_duration + period_half
period = np.diff(bs)
burst_frequency = 1 / np.diff(be)
duty_cycle = burst_duration[:period.size] / np.diff(bs)
burst_frequency_mean = np.mean(burst_frequency) # <fb>
duty_cycle_mean = np.mean(duty_cycle) # <dc>
num_sw = np.size(self.convert_trace_to_spike_times_downward(voltage, times, threshold['threshold_slow_wave']))
# num_spikes = np.size(self.convert_trace_to_spike_times_upward(voltage, times, threshold['threshold_spike']))
num_spike_per_burst = self.clean_up_num_spikes_per_burst(
np.diff(np.insert(num_bursts['be_loc'], 0, 0)))
# num_mid = np.array([np.size(times_before_th_downward), np.size(times_before_th_downward2)])
nspike_mid1 = self.clean_up_num_spikes_per_burst(
self.calculate_num_spikes_per_burst(times_before_th_downward, temporal_interval))
nspike_mid2 = self.clean_up_num_spikes_per_burst(
self.calculate_num_spikes_per_burst(times_before_th_downward2, temporal_interval))
spike_per_burst_mid = (nspike_mid1, nspike_mid2)
return {'burst_frequency_mean': burst_frequency_mean,
'duty_cycle_mean': duty_cycle_mean,
'times_before_th_downward': times_before_th_downward,
'times_before_th_downward2': times_before_th_downward2,
'num_sw': num_sw,
'e_lag': e_lag,
# 'num_spikes': num_spikes,
# 'num_mid': num_mid,
'burst_frequency': burst_frequency,
'duty_cycle': duty_cycle,
'spike_per_burst': num_spike_per_burst,
'spike_per_burst_mid': spike_per_burst_mid,
}