Skip to content

Commit

Permalink
Merge pull request #116 from StingraySoftware/streamline_sum_of_spectra
Browse files Browse the repository at this point in the history
Streamline sum of spectra
  • Loading branch information
matteobachetti authored May 27, 2021
2 parents a8495de + ee5d557 commit a2aa8e2
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 90 deletions.
2 changes: 2 additions & 0 deletions hendrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
pass

import stingray
import warnings
warnings.filterwarnings("ignore", message=".*Errorbars on cross.*")
from .compat import (
_MonkeyPatchedEventList,
filter_for_deadtime,
Expand Down
242 changes: 213 additions & 29 deletions hendrics/fspec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Functions to calculate frequency spectra."""

import copy
import warnings
import contextlib
import os
from stingray.gti import cross_gtis
from stingray.crossspectrum import AveragedCrossspectrum
from stingray.powerspectrum import AveragedPowerspectrum
from stingray.utils import show_progress
from stingray.gti import time_intervals_from_gtis
from stingray.events import EventList
import numpy as np
from astropy import log
from astropy.logger import AstropyUserWarning
Expand All @@ -19,6 +24,65 @@
from .io import HEN_FILE_EXTENSION, get_file_type


def average_periodograms(fspec_iterable, total=None):
"""Sum a list (or iterable) of power density spectra.
Examples
--------
>>> pds = AveragedPowerspectrum()
>>> pds.freq = np.asarray([1, 2, 3])
>>> pds.power = np.asarray([3, 3, 3])
>>> pds.power_err = np.asarray([0.1, 0.1, 0.1])
>>> pds.m = 1
>>> pds.fftlen = 128
>>> pds1 = copy.deepcopy(pds)
>>> pds1.m = 2
>>> tot_pds = average_periodograms([pds, pds1])
>>> np.allclose(tot_pds.power, pds.power)
True
>>> np.allclose(tot_pds.power_err, pds.power_err / np.sqrt(3))
True
>>> tot_pds.m
3
"""

for i, contents in enumerate(show_progress(fspec_iterable, total=total)):
freq = contents.freq
pds = contents.power
epds = contents.power_err
nchunks = contents.m
rebin = 1
norm = contents.norm
fftlen = contents.fftlen
if i == 0:
rebin0, norm0, freq0 = rebin, norm, freq
tot_pds = pds * nchunks
tot_epds = epds ** 2 * nchunks
tot_npds = nchunks
tot_contents = copy.copy(contents)
else:
assert np.all(
rebin == rebin0
), "Files must be rebinned in the same way"
np.testing.assert_array_almost_equal(
freq,
freq0,
decimal=int(-np.log10(1 / fftlen) + 2),
err_msg="Frequencies must coincide",
)
assert norm == norm0, "Files must have the same normalization"

tot_pds += pds * nchunks
tot_epds += epds ** 2 * nchunks
tot_npds += nchunks

tot_contents.power = tot_pds / tot_npds
tot_contents.power_err = np.sqrt(tot_epds) / tot_npds
tot_contents.m = tot_npds

return tot_contents


def _wrap_fun_cpds(arglist):
f1, f2, outname, kwargs = arglist
return calc_cpds(f1, f2, outname=outname, **kwargs)
Expand Down Expand Up @@ -96,6 +160,80 @@ def _format_lc_data(data, type, fftlen=512.0, bintime=1.0):
return lc_data


def _distribute_events(events, chunk_length):
"""Split event list in chunks.
Examples
--------
>>> ev = EventList([1, 2, 3, 4, 5, 6], gti=[[0.5, 6.5]])
>>> ev.pi = np.ones_like(ev.time)
>>> ev.mjdref = 56780.
>>> ev_lists = list(_distribute_events(ev, 2))
>>> np.allclose(ev_lists[0].time, [1, 2])
True
>>> np.allclose(ev_lists[1].time, [3, 4])
True
>>> np.allclose(ev_lists[2].time, [5, 6])
True
>>> np.allclose(ev_lists[0].gti, [[0.5, 2.5]])
True
>>> ev_lists[0].mjdref == ev.mjdref
True
>>> ev_lists[2].mjdref == ev.mjdref
True
>>> np.allclose(ev_lists[1].pi, [1, 1])
True
"""
gti = events.gti
start_times, stop_times = time_intervals_from_gtis(gti, chunk_length)
for start, end in zip(start_times, stop_times):
first, last = np.searchsorted(events.time, [start, end])
new_ev = EventList(
events.time[first:last], gti=np.asarray([[start, end]])
)
for attr in events.__dict__.keys():
if attr == "gti":
continue
val = getattr(events, attr)
if np.size(val) == np.size(events.time):
val = val[first:last]
setattr(new_ev, attr, val)
yield new_ev


def _provide_periodograms(events, fftlen, dt, norm):
for new_ev in _distribute_events(events, fftlen):
# Hack: epsilon slightly below zero, to allow for a GTI to be recognized as such
new_ev.gti[:, 1] += dt / 10
pds = AveragedPowerspectrum(
new_ev, dt=dt, segment_size=fftlen, norm=norm, silent=True
)
pds.fftlen = fftlen
yield pds


def _provide_cross_periodograms(events1, events2, fftlen, dt, norm):
length = events1.gti[-1, 1] - events1.gti[0, 0]
total = int(length / fftlen)
ev1_iter = _distribute_events(events1, fftlen)
ev2_iter = _distribute_events(events2, fftlen)
for new_ev in zip(ev1_iter, ev2_iter):
new_ev1, new_ev2 = new_ev
new_ev1.gti[:, 1] += dt / 10
new_ev2.gti[:, 1] += dt / 10
with contextlib.redirect_stdout(os.devnull):
pds = AveragedCrossspectrum(
new_ev1,
new_ev2,
dt=dt,
segment_size=fftlen,
norm=norm,
silent=True,
)
pds.fftlen = fftlen
yield pds


def calc_pds(
lcfile,
fftlen,
Expand All @@ -107,6 +245,7 @@ def calc_pds(
noclobber=False,
outname=None,
save_all=False,
test=False,
):
"""Calculate the PDS from an input light curve file.
Expand Down Expand Up @@ -148,11 +287,28 @@ def calc_pds(
mjdref = data.mjdref
instr = data.instr

lc_data = _format_lc_data(data, ftype, bintime=bintime, fftlen=fftlen)
length = data.gti[-1, 1] - data.gti[0, 0]
if hasattr(data, "dt"):
bintime = max(data.dt, bintime)

nbins = int(length / bintime)

if ftype == "events" and (test or nbins > 10 ** 7):
print("Long observation. Using split analysis")
length = data.gti[-1, 1] - data.gti[0, 0]
total = int(length / fftlen)
pds = average_periodograms(
_provide_periodograms(
data, fftlen, bintime, norm=normalization.lower()
),
total=total
)
else:
lc_data = _format_lc_data(data, ftype, bintime=bintime, fftlen=fftlen)

pds = AveragedPowerspectrum(
lc_data, segment_size=fftlen, norm=normalization.lower()
)
pds = AveragedPowerspectrum(
lc_data, segment_size=fftlen, norm=normalization.lower()
)

if pdsrebin is not None and pdsrebin != 1:
pds = pds.rebin(pdsrebin)
Expand All @@ -179,6 +335,7 @@ def calc_cpds(
back_ctrate=0.0,
noclobber=False,
save_all=False,
test=False,
):
"""Calculate the CPDS from a pair of input light curve files.
Expand Down Expand Up @@ -234,12 +391,29 @@ def calc_cpds(
lc2 = lc2.change_mjdref(lc1.mjdref)
mjdref = lc1.mjdref

lc1 = _format_lc_data(lc1, ftype1, fftlen=fftlen, bintime=bintime)
lc2 = _format_lc_data(lc2, ftype2, fftlen=fftlen, bintime=bintime)
length = lc1.gti[-1, 1] - lc1.gti[0, 0]
if hasattr(lc1, "dt"):
bintime = max(lc1.dt, bintime)

nbins = int(length / bintime)

if ftype1 == "events" and (test or nbins > 10 ** 7):
print("Long observation. Using split analysis")
length = lc1.gti[-1, 1] - lc1.gti[0, 0]
total = int(length / fftlen)
cpds = average_periodograms(
_provide_cross_periodograms(
lc1, lc2, fftlen, bintime, norm=normalization.lower()
),
total=total,
)
else:
lc1 = _format_lc_data(lc1, ftype1, fftlen=fftlen, bintime=bintime)
lc2 = _format_lc_data(lc2, ftype2, fftlen=fftlen, bintime=bintime)

cpds = AveragedCrossspectrum(
lc1, lc2, segment_size=fftlen, norm=normalization.lower()
)
cpds = AveragedCrossspectrum(
lc1, lc2, segment_size=fftlen, norm=normalization.lower()
)

if pdsrebin is not None and pdsrebin != 1:
cpds = cpds.rebin(pdsrebin)
Expand Down Expand Up @@ -274,6 +448,7 @@ def calc_fspec(
noclobber=False,
ignore_instr=False,
save_all=False,
test=False,
):
r"""Calculate the frequency spectra: the PDS, the cospectrum, ...
Expand Down Expand Up @@ -323,16 +498,17 @@ def calc_fspec(
if do_calc_pds:
wrapped_file_dicts = []
for f in files:
wfd = {
"fftlen": fftlen,
"save_dyn": save_dyn,
"bintime": bintime,
"pdsrebin": pdsrebin,
"normalization": normalization.lower(),
"back_ctrate": back_ctrate,
"noclobber": noclobber,
"save_all": save_all,
}
wfd = dict(
fftlen=fftlen,
save_dyn=save_dyn,
bintime=bintime,
pdsrebin=pdsrebin,
normalization=normalization.lower(),
back_ctrate=back_ctrate,
noclobber=noclobber,
save_all=save_all,
test=test,
)
wfd["fname"] = f
wrapped_file_dicts.append(wfd)

Expand Down Expand Up @@ -361,16 +537,17 @@ def calc_fspec(

assert len(files1) == len(files2), "An even number of files is needed"

argdict = {
"fftlen": fftlen,
"save_dyn": save_dyn,
"bintime": bintime,
"pdsrebin": pdsrebin,
"normalization": normalization.lower(),
"back_ctrate": back_ctrate,
"noclobber": noclobber,
"save_all": save_all,
}
argdict = dict(
fftlen=fftlen,
save_dyn=save_dyn,
bintime=bintime,
pdsrebin=pdsrebin,
normalization=normalization.lower(),
back_ctrate=back_ctrate,
noclobber=noclobber,
save_all=save_all,
test=test,
)

funcargs = []

Expand Down Expand Up @@ -547,6 +724,12 @@ def main(args=None):
default=False,
action="store_true",
)
parser.add_argument(
"--test",
help="Only to be used in testing",
default=False,
action="store_true",
)
_add_default_args(parser, ["loglevel", "debug"])

args = check_negative_numbers_in_args(args)
Expand Down Expand Up @@ -606,4 +789,5 @@ def main(args=None):
noclobber=args.noclobber,
ignore_instr=args.ignore_instr,
save_all=args.save_all,
test=args.test,
)
Loading

0 comments on commit a2aa8e2

Please sign in to comment.