-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_utils.py
951 lines (882 loc) · 41.2 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
import string
## Setting GTK backend does not work on the gj cluster,
## as I don't have a local pyGTK for my local python2.6.
## So ignore when on gj or its nodes.
import socket
hname = socket.gethostname()
if 'gulabjamun' not in hname and 'node' not in hname:
import matplotlib
#matplotlib.use('Agg')
#matplotlib.use('GTK')
from pylab import *
from matplotlib import collections
from mpl_toolkits.axes_grid.inset_locator import inset_axes
from poisson_utils import *
## choose figure or poster defaults
poster = False
if not poster:
####### figure defaults
label_fontsize = 8 # pt
plot_linewidth = 0.5 # pt
linewidth = 1.0#0.5
axes_linewidth = 0.5
marker_size = 3.0 # markersize=<...>
cap_size = 2.0 # for errorbar caps, capsize=<...>
columnwidth = 85/25.4 # inches
twocolumnwidth = 174/25.4 # inches
linfig_height = columnwidth*2.0/3.0
fig_dpi = 300
else:
####### poster defaults
label_fontsize = 12 # pt
plot_linewidth = 1.0 # pt
linewidth = 1.0
axes_linewidth = 1.0
marker_size = 3.0
cap_size = 2.0 # for errorbar caps
columnwidth = 4 # inches
linfig_height = columnwidth*2.0/3.0
#######################
########## You need to run:
## From gj:
## ./restart_mpd_static
## The 0th (boss process) will always be node000 as it is the first node in ~/hostfile.
## Hence from node000: cd to the working directory simulations/
## (so that sys.path has accurate relative paths)
## mpiexec -machinefile ~/hostfile -n <numprocs> ~/Python-2.6.4/bin/python2.6 <script_name> [args]
## 0 rank process is for collating all jobs. (rank starts from 0)
## I assume rank 0 process always runs on the machine whose
## X window system has a Display connected and can show the graphs!!!!!!
## The rank 0 stdout is always directed to the terminal from which mpiexec was run.
## I hope X output also works the same way.
## For long simulations save results in a text file
## for replotting later and avoid above ambiguity.
from mpi4py import MPI
mpicomm = MPI.COMM_WORLD
mpisize = mpicomm.Get_size() # Total number of processes
mpirank = mpicomm.Get_rank() # Number of my process
mpiname = MPI.Get_processor_name() # Name of my node
# The 0th process is the boss who collates/receives all data from workers
boss = 0
print 'Process '+str(mpirank)+' on '+mpiname+'.'
def calc_STA( inputlist, spiketrain, dt, STAtime):
""" spiketrain is a list of output spiketimes,
inputlist is the input timeseries with dt sample time,
STAtime is time for which STA must be computed.
User must ensure that inputlist is at least
as long as max spiketime in spiketrain.
endidx = int(spiketime/dt) not round(spiketime/dt)
as index 0 represents input from time 0 to time dt.
returns number of relevant spikes and
sum of Spike Triggered Averages as a list of length STAtime/dt.
"""
lenSTA = int(STAtime/dt)
STAsum = zeros(lenSTA)
numspikes = 0
for spiketime in spiketrain:
if spiketime<STAtime: continue
endidx = int(spiketime/dt)
startidx = endidx - lenSTA
STAsum += inputlist[startidx:endidx]
numspikes += 1
return numspikes,STAsum
def get_phaseimage( phaselist, phasemax, dt,\
overlay = False, rasterwidth = 10, rasterheight = 1 ):
phaseim = None
for resplist in phaselist:
## each spike 'tick' is horizontal with dimensions rasterwidth x rasterheight
phaseim_line = zeros( (int(phasemax/dt)*rasterheight, rasterwidth) )
for phase in resplist:
## a horizontal line of rasterwidth for every spike
## of a given trial, respcycle and phase.
row = int(phase/dt)*rasterheight
phaseim_line[row:row+rasterheight,:] = 1.0
if phaseim is None: phaseim = phaseim_line
else:
if overlay: phaseim += phaseim_line
## on numpy array(), axis=1, keep rows same, add cols.
else: phaseim = append( phaseim, phaseim_line, axis=1)
return phaseim
def plot_rasters(listof_rasterlists, runtime,\
colorlist=['r','g','b'], labellist=['v1','v2','v3'], labels=True):
fig = figure(facecolor='none')
ax = fig.add_subplot(111)
numrasterlists = len(listof_rasterlists)
for rlistnum,rasterlist in enumerate(listof_rasterlists):
numrasters = float(len(rasterlist))
seglist = []
for rnum,raster in enumerate(rasterlist):
for t in raster:
## append a segment for a spike
seglist.append(((t,rlistnum+rnum/numrasters),\
(t,rlistnum+(rnum+1)/numrasters)))
## plot the raster
if labels:
segs = collections.LineCollection(seglist,\
color=colorlist[rlistnum%len(colorlist)],\
label=labellist[rlistnum%len(labellist)])
else:
segs = collections.LineCollection(seglist,\
color=colorlist[rlistnum%len(colorlist)])
ax.add_collection(segs)
ax.set_xlim(0.0,runtime)
if labels:
ax.set_ylim(0,numrasterlists*1.3) # extra 0.3 height for legend
biglegend()
else:
ax.set_ylim(0,numrasterlists)
axes_labels(ax,'time (s)','spike raster trial#')
title('Spike rasters', fontsize=24)
def crosscorr(x,y):
""" pass numpy arrays x and y, so that element by element division & multiplication works.
The older version of correlate() in numpy 1.4 gives only a scalar for tau=0. """
return correlate(x,y)/(sqrt(correlate(x,x))*correlate(y,y))
def crosscorrgram(x, y, dt, halfwindow, starttime, endtime, norm='none'):
""" pass arrays x and y of numtrials arrays of spike times.
x[trialnum][tnum], y[trialnum][tnum]
dt is the binsize, T = endtime-starttime
Valid time length of the correlogram is from
-halfwindow to +halfwindow.
Analysis as per http://mulab.physiol.upenn.edu/crosscorrelation.html
I further normalize by total number of spikes (Dhawale et al 2010 fig S5).
I restrict the reference spike train x, between starttime+halfwindow to endtime-halfwindow;
I restrict the compared spike train y, between starttime to endtime.
For each spike in x, there is a sliding window of spikes in y.
norm = 'overall': a la Ashesh et al 2010, divide by (total #spikes in all sliding windows over y).
Above is same as dividing by (#spikesx * (mean #spikes in a sliding window of y)).
norm = 'analogous': divide by (sqrt(#spikesx) * sqrt(#spikesy))
Similar to dividing by sqrt(autocorrelationx)*sqrt(autocorrelationy)
norm = 'ref': divide by (#spikesx)
i.e. use the number of spikes in the reference spiketrain as the norm factor.
Normalizes such that tau=0 value of auto-corr i.e. crosscorrgram(x,x,...) = 1
norm = 'none': no division
NOTE: mean is not subtracted from the two spike trains.
To do that, first convert the list of spiketimes to a spike raster of 0s and 1s.
Then subtract the respective means, and sum after element-wise multiplication.
Finally, this function returns the average correlogram over all the trials.
With a reference spiketrain x and normalization by #spikesy,
the crosscorrgram becomes somewhat asymmetrical wrt x and y? """
T = endtime-starttime
xstarttime = starttime+halfwindow
xendtime = endtime-halfwindow
## div 2 and +1 to make number of bins odd
bins = int(4*halfwindow/dt)/2 + 1
centralbinnum = bins/2 ## integer division
corrgramavg = array([0.0]*bins)
## x[trialnum][tnum]
numtrials = len(x)
corrnums = 0
for trialnum in range(numtrials):
xtrialnum = x[trialnum]
if len(xtrialnum) == 0: continue
spikenumx = 0
spikenumy_allwindows = 0
corrgram = array([0.0]*bins)
for tx in xtrialnum:
## be careful, MOOSE inserts 0.0-s at the end of the fire times list!!!
if tx<=xstarttime or tx>=xendtime: continue
## central bin is centered around t=0
## tx=ty falls in the center of the central bin.
ystarttime = tx-halfwindow
yendtime = tx+halfwindow
spikenumx += 1
for ty in y[trialnum]:
if ty<=ystarttime or ty>=yendtime: continue
binnum = round((ty-tx)/dt)+centralbinnum
corrgram[binnum] += 1.0
spikenumy_allwindows += 1
## if variable spikenumy_thistrial exists, add it.
#if 'spikenumy_thistrial' in locals():
# spikenumy += spikenumy_thistrial
## Normalization:
## Divide by (total #spikes in all sliding windows over y).
## This is same as dividing by (#spikesx * (mean #spikes in sliding windows of y)).
if norm=='overall':
if spikenumy_allwindows>0:
corrgram /= float(spikenumy_allwindows)
corrgramavg += corrgram
corrnums += 1
#else: corrgram = [float('nan')]
## Divide by (sqrt(#spikesx) * sqrt(#spikesy))
## Similar to dividing by sqrt(autocorrelationx)*sqrt(autocorrelationy)
elif norm=='analogous':
spikenumy = 0
for ty in y[trialnum]:
if ty<starttime or ty>endtime: continue
spikenumy += 1
if spikenumx>0 and spikenumy>0:
corrgram /= (sqrt(spikenumx)*sqrt(spikenumy))
corrgramavg += corrgram
corrnums += 1
## Divide by (#spikesx)
## Normalizes such that tau=0 value of auto-corr i.e. crosscorrgram(x,x,...) = 1
elif norm=='ref':
if spikenumx>0:
corrgram /= spikenumx
corrgramavg += corrgram
corrnums += 1
else:
corrgramavg += corrgram
corrnums += 1
if corrnums==0: return array([nan]*bins)
else: return corrgramavg/float(corrnums)
## --------------------------------------
## matplotlib stuff
def axes_off(ax,x=True,y=True):
if x:
for xlabel_i in ax.get_xticklabels():
xlabel_i.set_visible(False)
xlabel_i.set_fontsize(0.0)
if y:
for xlabel_i in ax.get_yticklabels():
xlabel_i.set_fontsize(0.0)
xlabel_i.set_visible(False)
if x:
for tick in ax.get_xticklines():
tick.set_visible(False)
if y:
for tick in ax.get_yticklines():
tick.set_visible(False)
def set_tick_widths(ax,tick_width):
for tick in ax.xaxis.get_major_ticks():
tick.tick1line.set_markeredgewidth(tick_width)
tick.tick2line.set_markeredgewidth(tick_width)
for tick in ax.xaxis.get_minor_ticks():
tick.tick1line.set_markeredgewidth(tick_width)
tick.tick2line.set_markeredgewidth(tick_width)
for tick in ax.yaxis.get_major_ticks():
tick.tick1line.set_markeredgewidth(tick_width)
tick.tick2line.set_markeredgewidth(tick_width)
for tick in ax.yaxis.get_minor_ticks():
tick.tick1line.set_markeredgewidth(tick_width)
tick.tick2line.set_markeredgewidth(tick_width)
def axes_labels(ax,xtext,ytext,adjustpos=False,fontsize=label_fontsize,xpad=None,ypad=None):
ax.set_xlabel(xtext,fontsize=fontsize,labelpad=xpad)
# increase xticks text sizes
for label in ax.get_xticklabels():
label.set_fontsize(fontsize)
ax.set_ylabel(ytext,fontsize=fontsize,labelpad=ypad)
# increase yticks text sizes
for label in ax.get_yticklabels():
label.set_fontsize(fontsize)
if adjustpos:
## [left,bottom,width,height]
ax.set_position([0.135,0.125,0.84,0.75])
set_tick_widths(ax,axes_linewidth)
def biglegend(legendlocation='upper right',ax=None,fontsize=label_fontsize, **kwargs):
if ax is not None:
leg=ax.legend(loc=legendlocation, **kwargs)
else:
leg=legend(loc=legendlocation, **kwargs)
# increase legend text sizes
for t in leg.get_texts():
t.set_fontsize(fontsize)
def beautify_plot(ax,x0min=True,y0min=True,
xticksposn='bottom',yticksposn='left',xticks=None,yticks=None,
drawxaxis=True,drawyaxis=True):
"""
x0min,y0min control whether to set min of axis at 0.
xticksposn,yticksposn governs whether ticks are at
'both', 'top', 'bottom', 'left', 'right', or 'none'.
xtickx/yticks is a list of ticks, else [min,max] is taken.
Due to rendering issues,
axes do not overlap exactly with the ticks, dunno why.
"""
ax.get_yaxis().set_ticks_position(yticksposn)
ax.get_xaxis().set_ticks_position(xticksposn)
xmin, xmax = ax.get_xaxis().get_view_interval()
ymin, ymax = ax.get_yaxis().get_view_interval()
if x0min: xmin=0
if y0min: ymin=0
if xticks is None: ax.set_xticks([xmin,xmax])
else: ax.set_xticks(xticks)
if yticks is None: ax.set_yticks([ymin,ymax])
else: ax.set_yticks(yticks)
### do not set width and color of axes by below method
### axhline and axvline are not influenced by spine below.
#ax.axhline(linewidth=axes_linewidth, color="k")
#ax.axvline(linewidth=axes_linewidth, color="k")
## spine method of hiding axes is cleaner,
## but alignment problem with ticks in TkAgg backend remains.
for loc, spine in ax.spines.items(): # items() returns [(key,value),...]
spine.set_linewidth(axes_linewidth)
if loc == 'left' and not drawyaxis:
spine.set_color('none') # don't draw spine
elif loc == 'bottom' and not drawxaxis:
spine.set_color('none') # don't draw spine
elif loc in ['right','top']:
spine.set_color('none') # don't draw spine
### alternate method of drawing axes, but for it,
### need to set frameon=False in add_subplot(), etc.
#if drawxaxis:
# ax.add_artist(Line2D((xmin, xmax), (ymin, ymin),\
# color='black', linewidth=axes_linewidth))
#if drawyaxis:
# ax.add_artist(Line2D((xmin, xmin), (ymin, ymax),\
# color='black', linewidth=axes_linewidth))
## axes_labels() sets sizes of tick labels too.
axes_labels(ax,'','',adjustpos=False)
ax.set_xlim(xmin,xmax)
ax.set_ylim(ymin,ymax)
return xmin,xmax,ymin,ymax
def fig_clip_off(fig):
## clipping off for all objects in this fig
for o in fig.findobj():
o.set_clip_on(False)
## ------
## from https://gist.github.com/dmeliza/3251476#file-scalebars-py
# Adapted from mpl_toolkits.axes_grid2
# LICENSE: Python Software Foundation (http://docs.python.org/license.html)
from matplotlib.offsetbox import AnchoredOffsetbox
class AnchoredScaleBar(AnchoredOffsetbox):
def __init__(self, transform, sizex=0, sizey=0, labelx=None, labely=None, loc=4,
pad=0.1, borderpad=0.1, sep=2, prop=None, label_fontsize=label_fontsize, color='k', **kwargs):
"""
Draw a horizontal and/or vertical bar with the size in data coordinate
of the give axes. A label will be drawn underneath (center-aligned).
- transform : the coordinate frame (typically axes.transData)
- sizex,sizey : width of x,y bar, in data units. 0 to omit
- labelx,labely : labels for x,y bars; None to omit
- loc : position in containing axes
- pad, borderpad : padding, in fraction of the legend font size (or prop)
- sep : separation between labels and bars in points.
- **kwargs : additional arguments passed to base class constructor
"""
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import AuxTransformBox, VPacker, HPacker, TextArea, DrawingArea
bars = AuxTransformBox(transform)
if sizex:
bars.add_artist(Rectangle((0,0), sizex, 0, fc="none", linewidth=axes_linewidth, color=color))
if sizey:
bars.add_artist(Rectangle((0,0), 0, sizey, fc="none", linewidth=axes_linewidth, color=color))
if sizex and labelx:
textareax = TextArea(labelx,minimumdescent=False,textprops=dict(size=label_fontsize,color=color))
bars = VPacker(children=[bars, textareax], align="center", pad=0, sep=sep)
if sizey and labely:
## VPack a padstr below the rotated labely, else label y goes below the scale bar
## Just adding spaces before labely doesn't work!
padstr = '\n '*len(labely)
textareafiller = TextArea(padstr,textprops=dict(size=label_fontsize/3.0))
textareay = TextArea(labely,textprops=dict(size=label_fontsize,rotation='vertical',color=color))
## filler / pad string VPack-ed below labely
textareayoffset = VPacker(children=[textareay, textareafiller], align="center", pad=0, sep=sep)
## now HPack this padded labely to the bars
bars = HPacker(children=[textareayoffset, bars], align="top", pad=0, sep=sep)
AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad,
child=bars, prop=prop, frameon=False, **kwargs)
def add_scalebar(ax, matchx=True, matchy=True, hidex=True, hidey=True, \
label_fontsize=label_fontsize, color='k', **kwargs):
""" Add scalebars to axes
Adds a set of scale bars to *ax*, matching the size to the ticks of the plot
and optionally hiding the x and y axes
- ax : the axis to attach ticks to
- matchx,matchy : if True, set size of scale bars to spacing between ticks
if False, size should be set using sizex and sizey params
- hidex,hidey : if True, hide x-axis and y-axis of parent
- **kwargs : additional arguments passed to AnchoredScaleBars
Returns created scalebar object
"""
def f(axis):
l = axis.get_majorticklocs()
return len(l)>1 and (l[1] - l[0])
if matchx:
kwargs['sizex'] = f(ax.xaxis)
kwargs['labelx'] = str(kwargs['sizex'])
if matchy:
kwargs['sizey'] = f(ax.yaxis)
kwargs['labely'] = str(kwargs['sizey'])
sb = AnchoredScaleBar(ax.transData, label_fontsize=label_fontsize, color=color, **kwargs)
ax.add_artist(sb)
if hidex : ax.xaxis.set_visible(False)
if hidey : ax.yaxis.set_visible(False)
return sb
## from https://gist.github.com/dmeliza/3251476#file-scalebars-py -- ends
## ------
## matplotlib stuff ends
## -----------------------------------------
def plotSpikes(firetimes, runtime, plotdt):
firetimes = array(firetimes)
# MOOSE often inserts one or two spiketime = 0.0 entries when storing spikes, so discount those:
firetimes = firetimes[ where(firetimes>0.0)[0] ]
firetimes = firetimes[ where(diff(firetimes)>2*plotdt)[0] ] # Take the falling edge of every threshold crossing.
firearray = zeros(int(round(runtime/plotdt)),dtype=int8) # 1D array of type int8
firelen = len(firearray)
for firetime in firetimes:
firearray[int(round(firelen*float(firetime))/runtime)] = 1
return firearray
def plotBins(firetimes, numbins, runtime, settletime):
binlist = [0]*numbins
firetimes = array(firetimes)
## MOOSE often inserts one or two spiketime = 0.0 entries
## when storing spikes, so discount those:
firetimes = firetimes[ where(firetimes>0.0)[0] ]
for firetime in firetimes:
if firetime>=settletime:
## The small number has been added to the Dr to ensure no index out of range errors
## Nothing to do about causality here:
## while plotting, keep bintime to right edge to make it causal
binnum = int((firetime-settletime)/(runtime-settletime+0.0001)*numbins)
binlist[binnum] += 1
return [binspikes/((runtime-settletime)/float(numbins)) for binspikes in binlist] # return firing rate in Hz
def plotOverlappingBins(firetimes, numbins, time_period, settletime, bin_width_time):
"""
Firing rate in overlapping bins (moving average).
numbins # of bins in the time (settletime) to (time_period+settletime)
Assumes periodic/wrapped boundary conditions with period=time_period.
This way the end bins are accurate,
else they will not have data to one end and show lower firing rates.
Typically, adjust settletime to bin
only the first or second respiratory cycle.
"""
CAUSAL = True
binlist = [0]*numbins
firetimes = array(firetimes)
## MOOSE often inserts one or two spiketime = 0.0 entries
## when storing spikes, so discount those:
firetimes = firetimes[ where(firetimes>0.0)[0] ]
bindt = time_period/float(numbins)
## if CAUSAL, take spikes only to the left of bin centre_times.
if CAUSAL: centre_times = arange(bindt, time_period+bindt/2.0, bindt)
else: centre_times = arange(bindt/2, time_period, bindt)
bin_half_t = bin_width_time/2.0
rightmost_t = time_period
for firetime in firetimes:
## The end bins will not show correct firing rate!
if firetime>=settletime and firetime<(settletime+time_period):
firetime -= settletime
## Each firetime is in multiple bins depending on bin_width_time
for binnum,bin_centre_t in enumerate(centre_times):
## if CAUSAL, take spikes only to the left of bin centre_times.
if CAUSAL:
bin_left = bin_centre_t - bin_width_time
bin_right = bin_centre_t
else:
bin_left = bin_centre_t - bin_half_t
bin_right = bin_centre_t + bin_half_t
if firetime >= bin_left and firetime < bin_right:
binlist[binnum] += 1
## Next lines implement circularity of firetimes
if bin_left < 0 and firetime >= (bin_left+rightmost_t):
binlist[binnum] += 1
if bin_right > rightmost_t and firetime < (bin_right-rightmost_t):
binlist[binnum] += 1
return [float(binspikes)/bin_width_time for binspikes in binlist] # return firing rate in Hz
def calcFreq(timeTable, runtime, settletime, plotdt, threshold, spiketable):
# input: if spiketable is True: timeTable has spike times: i.e. a MOOSE table which has stepMode = TAB_SPIKE
# input: if spiketable is False: timeTable has Vm-s: i.e. a MOOSE table which has stepMode = TAB_BUF
# output: (meanrate, meanrate2, events)
# output: events is a list of times of falling edges of 'spikes' separated by at least 2*eventdt.
# output: meanrate2 is just #spikes/time
# output: meanrate is mean of 1/inter-spike-interval (removing very short ISIs)
tablenumpy = array(timeTable) # convert the MOOSE table into a numpy array
if spiketable: # timeTable has spike times
# only those spike times which are after settle time.
# it is important to do this even is settletime == 0.0,
# since MOOSE inserts spurious t=0.0 spitketime entries in a spike table.
events = tablenumpy[ where(tablenumpy>settletime)[0] ]
else: # timeTable has Vm-s
cutout = tablenumpy[ int(settletime/plotdt): ] # cutout only those spike times which are after settle time.
if len(cutout) <= 0:
events = []
else:
thresholded = where(cutout>threshold)[0] # gives indices whereever cutout > THRESHOLD
# where difference between two adjacent indices in thresholded > 2
# i.e. takes falling edge of every threshold crossing
# THIS IS UNLIKE SPIKETABLE ABOVE WHICH TAKES RISING EDGE!
take = where(diff(thresholded)>2)[0] # numpy's where and diff
indices = thresholded[ take ] # indexed by a list! works for ndarray only, not for usual python lists.
# numpy multiplication of array by scalar -- very different from python list multiplication by integer!!!
events = indices*plotdt + settletime
# calculate mean firing rate as 1/inter-spike-interval (removing very short ISIs)
if len(events)>1: # at least two events needed!
firingRateList = array([])
for i in range(len(events)-1):
firingtimespan = (events[i+1]-events[i])
firingRateList = append(firingRateList,1.0/firingtimespan)
############ Have to filter out the APs which have very closely spaced double firings :
############ typically happens for close to zero currents like 0.15nA etc.
## Keep removing all firing rate entries which are greater than twice the min.
while firingRateList.max() > 2*firingRateList.min():
firingRateList = delete(firingRateList,firingRateList.argmax())
################### Finally calculate the actual mean
meanrate = array(firingRateList).mean() # 1/s = Hz
else:
meanrate = 0
# mean firing rate as #spikes/time
meanrate2 = len(events)/(runtime-settletime) # Hz
return (meanrate, meanrate2, events)
def minimum_distance(a,b,p):
""" a,b,p are vectors.
Given two end-points a and b of a line segment,
find the minimum distance from p to the the line segment.
points could be 2D or 3D."""
a = array(a)
b = array(b)
p = array(p)
## length sq of line segment
len_sq = norm(b-a)**2
if len_sq==0: return norm(p-a)
## take infinte line as c = a + t*(b-a)
## find t where the point p drops a normal to line
t = dot(b-a,p-a)/len_sq
## point before a
if t<0: return norm(p-a)
## point after b
elif t>1.0: return norm(p-b)
else: return norm( p - (a+t*(b-a)) )
def outcode(x,y,xmin,ymin,xmax,ymax):
""" utility function for the Cohen-Sutherland clipper below. """
outcode = 0x0 # inside
if y>ymax: outcode |= 0x1 # top
elif y<ymin: outcode |= 0x2 # bottom
if x>xmax: outcode |= 0x4 # right
elif x<xmin: outcode |= 0x8 # left
return outcode
def clip_line_to_rectangle(x1,y1,x2,y2,xmin,ymin,xmax,ymax):
""" clip line segment from (x1,y1) to (x2,y2) inside a rectange.
2D points only. Cohen-Sutherland algorithm (Wikipedia)"""
outcode1 = outcode(x1,y1,xmin,ymin,xmax,ymax)
outcode2 = outcode(x2,y2,xmin,ymin,xmax,ymax)
accept = False
done = False
while not done:
if not (outcode1 | outcode2): # segment fully inside
accept = True
done = True
elif outcode1 & outcode2: # segment fully outside i.e.
# both points are to the left/right/top/bottom of rectangle
done = True
else:
if outcode1>0: outcode_ex = outcode1
else: outcode_ex = outcode2
if outcode_ex & 0x1 : # top
x = x1 + (x2 - x1) * (ymax - y1) / (y2 - y1)
y = ymax
elif outcode_ex & 0x2: # bottom
x = x1 + (x2 - x1) * (ymin - y1) / (y2 - y1)
y = ymin
elif outcode_ex & 0x4: # right
y = y1 + (y2 - y1) * (xmax - x1) / (x2 - x1)
x = xmax
else: # left
y = y1 + (y2 - y1) * (xmin - x1) / (x2 - x1)
x = xmin
## get the new co-ordinates of the line
if (outcode_ex == outcode1):
x1 = x
y1 = y
outcode1 = outcode(x1, y1, xmin, ymin, xmax, ymax)
else:
x2 = x
y2 = y
outcode2 = outcode(x2, y2, xmin, ymin, xmax, ymax)
return accept,x1,y1,x2,y2
## copied savitsky_golay from scipy cookbook: http://www.scipy.org/Cookbook/SavitzkyGolay
def savitzky_golay(y, window_size, order, deriv=0):
r"""Smooth (and optionally differentiate) data with a Savitzky-Golay filter.
The Savitzky-Golay filter removes high frequency noise from data.
It has the advantage of preserving the original shape and
features of the signal better than other types of filtering
approaches, such as moving averages techhniques.
Parameters
----------
y : array_like, shape (N,)
the values of the time history of the signal.
window_size : int
the length of the window. Must be an odd integer number.
order : int
the order of the polynomial used in the filtering.
Must be less then `window_size` - 1.
deriv: int
the order of the derivative to compute (default = 0 means only smoothing)
Returns
-------
ys : ndarray, shape (N)
the smoothed signal (or it's n-th derivative).
Notes
-----
The Savitzky-Golay is a type of low-pass filter, particularly
suited for smoothing noisy data. The main idea behind this
approach is to make for each point a least-square fit with a
polynomial of high order over a odd-sized window centered at
the point.
Examples
--------
t = np.linspace(-4, 4, 500)
y = np.exp( -t**2 ) + np.random.normal(0, 0.05, t.shape)
ysg = savitzky_golay(y, window_size=31, order=4)
import matplotlib.pyplot as plt
plt.plot(t, y, label='Noisy signal')
plt.plot(t, np.exp(-t**2), 'k', lw=1.5, label='Original signal')
plt.plot(t, ysg, 'r', label='Filtered signal')
plt.legend()
plt.show()
References
----------
.. [1] A. Savitzky, M. J. E. Golay, Smoothing and Differentiation of
Data by Simplified Least Squares Procedures. Analytical
Chemistry, 1964, 36 (8), pp 1627-1639.
.. [2] Numerical Recipes 3rd Edition: The Art of Scientific Computing
W.H. Press, S.A. Teukolsky, W.T. Vetterling, B.P. Flannery
Cambridge University Press ISBN-13: 9780521880688
"""
try:
window_size = np.abs(np.int(window_size))
order = np.abs(np.int(order))
except ValueError, msg:
raise ValueError("window_size and order have to be of type int")
if window_size % 2 != 1 or window_size < 1:
raise TypeError("window_size size must be a positive odd number")
if window_size < order + 2:
raise TypeError("window_size is too small for the polynomials order")
order_range = range(order+1)
half_window = (window_size -1) // 2
# precompute coefficients
b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window+1)])
m = np.linalg.pinv(b).A[deriv]
# pad the signal at the extremes with
# values taken from the signal itself
firstvals = y[0] - np.abs( y[1:half_window+1][::-1] - y[0] )
lastvals = y[-1] + np.abs(y[-half_window-1:-1][::-1] - y[-1])
y = np.concatenate((firstvals, y, lastvals))
return np.convolve( m, y, mode='valid')
def circular_convolve(x,y,n):
"""
From: http://www.dspguru.com/dsp/tutorials/a-little-mls-tutorial
I need to compute the circular cross-correlation of y and x
N-2
Ryx[n] = 1/(N-1) * SUM{ y[i] * x[i-n] }
i=0
python indexing (negative indices) automatically gives you circularity!!!
"""
N = len(y)
return array( 1.0/(N-1)*sum( [ y[i]*x[i-n] for i in range(N-1) ] ) )
def circular_correlate(x,y):
N = len(y)
return array( [ circular_convolve(x,y,n) for n in range(len(y)) ] )
############################ Information theory functions ###############################
def binary_entropy(p0,p1):
H = 0.0
if p0>0.0:
H -= p0*log2(p0)
if p1>0.0:
H -= p1*log2(p1)
return H
def makestring(intlist):
s = ''
for val in intlist:
s += str(val)
return s
## yield defines a generator
def find_substr_endchars(mainstr,substr,delay=0):
"""
Returns a list (generator) of characters
each following a 'substr' occurence in 'mainstr' after 'delay'.
"""
## don't use regular expressions re module, which finds only non-overlapping matches
## we want to find overlapping matches too.
substrlen = len(substr)
while True:
idx = mainstr.find(substr)
## find returns -1 if substr not found
if idx != -1:
endcharidx = idx+substrlen+delay
if endcharidx<len(mainstr):
yield mainstr[endcharidx]
else: # reached end of string
break
## chop the mainstr just after the start of substr,
## not after the end, as we want overlapping strings also
mainstr = mainstr[idx+1:]
else: # substr not found
break
## yield defines a generator
def find_substrs12_endchars(sidestr,mainstr,substr1,substr2,delay1=0,delay2=0):
"""
Returns a list (generator) of characters from mainstr Y,
each following a substr1 occurence in sidestr X after delay1,
and following a substr2 occurence in mainstr Y after delay2.
"""
## don't use regular expressions re module, which finds only non-overlapping matches
## we want to find overlapping matches too.
substr2len = len(substr2)
substr1len = len(substr1)
abs_idx1 = 0 ## mainstr is getting chopped, but we maintain abs index on sidestr
while True:
idx2 = mainstr.find(substr2)
## find returns -1 if substr2 not found
if idx2 != -1:
endcharidx2 = idx2+substr2len+delay2
### NOTE: abs_startidx1 is one earlier than definition!!! I think necessary for causality.
## put +1 below to switch to definition in Quinn et al 2010
abs_startidx1 = abs_idx1 + endcharidx2 - substr1len-delay1
if endcharidx2<len(mainstr): # mainstr Y has characters left?
if abs_startidx1 >= 0: # sidestr X has sufficient chars before?
## sidestr has substr1 before the char to be returned? and mainstr is not over
## IMP: below if's first term is the only place directed info enters.
## Remove first term below and you get just the entropy of mainstr Y: VERIFIED.
#print sidestr[abs_startidx1:abs_startidx1+substr1len], substr1, abs_startidx1
if sidestr[abs_startidx1:abs_startidx1+substr1len]==substr1:
yield mainstr[endcharidx2]
else: # reached end of string
break
## chop the mainstr just after the start of substr2,
## not after the end, as we want overlapping strings also
mainstr = mainstr[idx2+1:]
## don't chop sidestr as substr1len may be greater than substr2len
## in the next iteration, idx2 will be relative, but for sidestr we maintain abs_idx1
abs_idx1 += idx2+1
else: # substr2 not found
break
def calc_entropyrate(spiketrains,markovorder,delay=0):
"""
spiketrains is a list of spiketrain = [<0|1>,...]
should be int-s, else float 1.0-s make the str-s go crazy!
J = markovorder >= 1. Cannot handle non-Markov i.e. J=0 presently!!
Returns entropy rate, assuming markov chain of order J :
H(X_{J+1}|X_J..X_1).
Assumes spike train bins are binary i.e. 0/1 value in each timebin.
delay is self delay of effect of X on X.
NOTE: one time step i.e. causal delay is permanently present, 'delay' is extra.
"""
Hrate = 0.0
N = 0
if markovorder>0:
## create all possible binary sequences priorstr=X_1...X_J
## convert integer to binary repr str of length markovorder padded with zeros (=0)
reprstr = '{:=0'+str(markovorder)+'b}'
priorstrs = [ reprstr.format(i) for i in range(int(2**markovorder)) ]
else:
## return numpy nan if markovorder <= 0
return nan
## Convert the list of timebins to a string of 0s and 1s.
## Don't do it in loops below, else the same op is repeated len(priorstrs) times.
## Below conversion is quite computationally expensive.
mcs = []
for spiketrain in spiketrains:
## A generator expression is given as argument to makestring
mcs.append(makestring(val for val in spiketrain))
## Calculate entropy for each priorstr, and sum weighted by probability of each priorstr
for priorstr in priorstrs:
num1s = 0
num0s = 0
for mc in mcs:
for postchar in find_substr_endchars(mc,priorstr,delay):
if int(postchar): # if the character just after priorstr is nonzero i.e. 1
num1s += 1
else:
num0s += 1
N_givenprior = float(num1s + num0s)
## H(X|Y) = \sum p(Y=y)*H(X|Y=y) ; the normalization by N is done at the end
## p(Y=y) = N_givenprior/N where N is total after all loops
if N_givenprior>0:
Hrate += N_givenprior * binary_entropy(num0s/N_givenprior,num1s/N_givenprior)
N += N_givenprior
if N!=0: Hrate = Hrate/N
return Hrate
def calc_dirtinforate(spiketrains1,spiketrains2,markovorder1,markovorder2,delay1=0,delay2=0):
"""
Returns directed information rate from spiketrains1 X to spiketrains2 Y.
Returns directed information rate (lim_{n->/inf} 1/n ...),
assuming train2 as markov chain of order K,
and train1 affecting it with markov order J.
I(X^n->Y^n) = H( Y_{J+1} | Y_J..Y_1 ) - H( Y_{L} | Y^{L-1}_{L-J} X^{L-1}_{L-K} ),
where L = max(J,K).
NOTE: I have changed X^{L}_{L-K-1} in definition above to X^{L-1}_{L-K} for causality!
Assumes spike train bins are binary i.e. integer 0/1 value in each timebin.
spiketrains1 and 2 are each a list of spiketrain = [<0|1>,...]
should be int-s, else float 1.0-s make the str-s go crazy!
dimensions of both must be the same.
J = markovorder1 >= 1. Cannot handle non-Markov i.e. J=0 presently!!
K = markovorder2 >= 1. Cannot handle non-Markov i.e. K=0 presently!!
Keep J,K<5, else too computationally intensive.
The prior substrings are searched delay1 and delay2 before Y_n in trains 1 and 2.
delay1 is lateral/side delay of effect of X on Y,
delay2 is self/main delay of effect of Y on Y.
NOTE: one time step i.e. causal delay is permanently present, delay1 and 2 are extra.
"""
dirtIrate_term2 = 0.0
N = 0
## for the 'cause' spike train
if markovorder1>0:
## create all possible binary sequences priorstr=X_1...X_J
## convert integer to binary repr str of length markovorder padded with zeros (=0)
reprstr = '{:=0'+str(markovorder1)+'b}'
priorstrs1 = [ reprstr.format(i) for i in range(int(2**markovorder1)) ]
else:
## return numpy nan if markovorder <= 0
return nan
## for the 'effect' spike train
if markovorder2>0:
## create all possible binary sequences priorstr=X_1...X_K
## convert integer to binary repr str of length markovorder padded with zeros (=0)
reprstr = '{:=0'+str(markovorder2)+'b}'
priorstrs2 = [ reprstr.format(i) for i in range(int(2**markovorder2)) ]
else:
## return numpy nan if markovorder <= 0
return nan
## Convert the list of timebins to a string of 0s and 1s.
## Don't do it in loops below, else the same op is repeated len(priorstrs) times.
## Below conversion is quite computationally expensive.
mcs1 = []
for spiketrain in spiketrains1:
## A generator expression is given as argument to makestring
mcs1.append(makestring(val for val in spiketrain))
mcs2 = []
for spiketrain in spiketrains2:
## A generator expression is given as argument to makestring
mcs2.append(makestring(val for val in spiketrain))
## Calculate entropy for each combo of priorstr 1 & 2,
## and sum weighted by probability of each combo
for priorstr1 in priorstrs1:
for priorstr2 in priorstrs2:
num1s = 0
num0s = 0
for chaini,mc1 in enumerate(mcs1):
mc2 = mcs2[chaini]
for postchar in find_substrs12_endchars(mc1,mc2,priorstr1,priorstr2,delay1,delay2):
## if the character just after priorstr1 & priorstr2, is nonzero i.e. 1
if int(postchar):
num1s += 1
else:
num0s += 1
N_givenpriors = float(num1s + num0s)
## H(Y|Y^X^) = \sum p(Y^=y^)*H(Y|Y^=y^,X^=x^) ;
## the normalization by N is done at the end
## p(Y^=y^,X^=x^) = N_givenpriors/N where N is total after all loops
if N_givenpriors>0:
dirtIrate_term2 += N_givenpriors * \
binary_entropy(num0s/N_givenpriors,num1s/N_givenpriors)
N += N_givenpriors
if N!=0: dirtIrate_term2 = dirtIrate_term2/N
## H( Y_{J+1} | Y_J..Y_1 )
dirtIrate_term1 = calc_entropyrate(spiketrains2,markovorder2,delay2)
## I(X^n->Y^n) = H( Y_{J+1} | Y_J..Y_1 ) - H( Y_{L} | Y^{L-1}_{L-J} X^{L-1}_{L-K} )
dirtIrate = dirtIrate_term1 - dirtIrate_term2
return dirtIrate
def get_spiketrain_from_spiketimes(\
spiketimes,starttime,timerange,numbins,warnmultiple=True,forcebinary=True):
""" bin number of spikes from starttime to endtime, into dt bins.
if warnmultiple, warn if multiple spikes are binned into a single bin.
if forcebinary, set multiple spikes in a bin to 1.
"""
## important to make these int, else spikestrs in entropy calculations go haywire!
spiketrain = zeros(numbins,dtype=int)
for spiketime in spiketimes:
spiketime_reinit = spiketime-starttime
if 0.0 < spiketime_reinit < timerange:
binnum = int(spiketime_reinit/timerange*numbins)
spiketrain[binnum] += 1
if forcebinary or warnmultiple:
multiplespike_indices = where(spiketrain>1)[0] # spiketrain must be a numpy array()
if len(multiplespike_indices)>0:
## if non-empty number of multiple spikes indices, set to 1, print warning
if forcebinary:
spiketrain[multiplespike_indices] = 1
if warnmultiple:
## do not print warnings if user has turned them off
print "There are more than 1 spikes in", \
len(multiplespike_indices), "number of bins."
if forcebinary: print "Have forced them all to be 1."
return spiketrain
##############################################################################