Skip to content

Commit

Permalink
Merge branch 'main' into fix_common_ref
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti authored Dec 4, 2024
2 parents 4d62d41 + 32d09c5 commit f07d308
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/changes/866.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement filter by time interval in stingraytimeseries objects
61 changes: 61 additions & 0 deletions stingray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2720,6 +2720,67 @@ def analyze_by_gti(self, func, fraction_step=1, **kwargs):
"""
return self.analyze_segments(func, segment_size=None, fraction_step=fraction_step, **kwargs)

def apply_gti_lists(self, new_gti_lists):
"""Split the event list into different files, each with a different GTI.
Parameters
----------
new_gti_lists : list of lists
A list of lists of GTIs. Each sublist should contain a list of GTIs
for a new file.
Returns
-------
output_files : list of str
A list of the output file names.
"""

if len(new_gti_lists[0]) == len(self.gti) and np.all(
np.abs(np.asanyarray(new_gti_lists[0]).flatten() - self.gti.flatten()) < 1e-3
):
ev = self[:]
yield ev

else:
for gti in new_gti_lists:
if len(gti) == 0:
continue
gti = np.asarray(gti)
lower_edge = np.searchsorted(self.time, gti[0, 0])
upper_edge = np.searchsorted(self.time, gti[-1, 1])
if upper_edge == self.time.size:
upper_edge -= 1
if self.time[upper_edge] > gti[-1, 1]:
upper_edge -= 1
ev = self[lower_edge : upper_edge + 1]

if hasattr(ev, "gti"):
ev.gti = gti

yield ev

def filter_at_time_intervals(self, time_intervals, check_gtis=True):
"""Filter the event list at the given time intervals.
Parameters
----------
time_intervals : 2-d float array
List of time intervals of the form ``[[time0_0, time0_1], [time1_0, time1_1], ...]``
Returns
-------
output_files : list of str
A list of the output file names.
"""
if len(np.shape(time_intervals)) == 1:
time_intervals = [time_intervals]
if check_gtis:
new_gti = [cross_two_gtis(self.gti, [t_int]) for t_int in time_intervals]
else:
new_gti = [np.asarray([t_int]) for t_int in time_intervals]
return self.apply_gti_lists(new_gti)


def interpret_times(time: TTime, mjdref: float = 0) -> tuple[npt.ArrayLike, float]:
"""Understand the format of input times, and return seconds from MJDREF
Expand Down
55 changes: 55 additions & 0 deletions stingray/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
from astropy.table import Table
from stingray.base import StingrayObject, StingrayTimeseries
from stingray.io import FITSTimeseriesReader

_HAS_XARRAY = importlib.util.find_spec("xarray") is not None
_HAS_PANDAS = importlib.util.find_spec("pandas") is not None
Expand Down Expand Up @@ -1052,6 +1053,9 @@ def setup_class(cls):
sting_obj.panesapa = np.asanyarray([[41, 25], [98, 3]])
sting_obj.gti = np.asanyarray([[-0.5, 2.5]])
cls.sting_obj = sting_obj
curdir = os.path.abspath(os.path.dirname(__file__))
datadir = os.path.join(curdir, "data")
cls.fname = os.path.join(datadir, "monol_testA.evt")

def test_print(self, capsys):
print(self.sting_obj)
Expand Down Expand Up @@ -1082,6 +1086,57 @@ def test_change_mjdref(self):
assert np.allclose(new_so.time - 43200, self.sting_obj.time)
assert np.allclose(new_so.gti - 43200, self.sting_obj.gti)

@pytest.mark.parametrize("check_gtis", [True, False])
def test_read_timeseries_by_time_intv(self, check_gtis):
reader = FITSTimeseriesReader(self.fname, output_class=DummyStingrayTs)[:]

# Full slice
evs = list(reader.filter_at_time_intervals([80000100, 80001000], check_gtis=check_gtis))
assert len(evs) == 1
ev0 = evs[0]
assert np.all((ev0.time > 80000100) & (ev0.time < 80001000))
assert np.all((ev0.gti >= 80000100) & (ev0.gti <= 80001000))
assert np.isclose(ev0.gti[0, 0], 80000100)
assert np.isclose(ev0.gti[-1, 1], 80001000)

def test_read_timeseries_by_time_intv_check_bad_gtis(self):
reader = FITSTimeseriesReader(self.fname, output_class=DummyStingrayTs)[:]

# Full slice
evs = list(reader.filter_at_time_intervals([80000100, 80001100], check_gtis=False))
assert len(evs) == 1
ev0 = evs[0]
assert np.all((ev0.time > 80000100) & (ev0.time < 80001025))
assert np.isclose(ev0.gti[0, 0], 80000100)
# This second gti will be ugly, larger than the original gti boundary
assert np.isclose(ev0.gti[-1, 1], 80001100)

@pytest.mark.parametrize("gti_kind", ["same", "one", "multiple"])
def test_read_apply_gti_lists(self, gti_kind):
reader = FITSTimeseriesReader(self.fname, output_class=DummyStingrayTs)[:]
if gti_kind == "same":
gti_list = [reader.gti]
elif gti_kind == "one":
gti_list = [[[80000000, 80001024]]]
elif gti_kind == "multiple":
gti_list = [[[80000000, 80000512]], [[80000513, 80001024]]]

evs = list(reader.apply_gti_lists(gti_list))

# Check that the number of event lists is the same as the number of GTI lists we input
assert len(evs) == len(gti_list)

for i, ev in enumerate(evs):
# Check that the gtis of the output event lists are the same we input
assert np.allclose(ev.gti, gti_list[i])

def test_read_apply_gti_lists_ignore_empty(self):
reader = FITSTimeseriesReader(self.fname, output_class=DummyStingrayTs)[:]
gti_list = [[], [[80000000, 80000512]], [[80000513, 80001024]]]
evs = list(reader.apply_gti_lists(gti_list))
assert np.allclose(evs[0].gti, gti_list[1])
assert np.allclose(evs[1].gti, gti_list[2])


class TestJoinEvents:
def test_join_without_times_simulated(self):
Expand Down
2 changes: 0 additions & 2 deletions stingray/tests/test_crossspectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,6 @@ def test_timelag(self):
dt=dt,
)

# with pytest.warns(UserWarning) as w:
cs = AveragedCrossspectrum(test_lc1, test_lc2, segment_size=5, norm="none")
time_lag, time_lag_err = cs.time_lag()

Expand All @@ -1245,7 +1244,6 @@ def test_classical_significances(self):
np.random.seed(62)
test_lc1 = Lightcurve(time, np.random.poisson(200, 10000))
test_lc2 = Lightcurve(time, np.random.poisson(200, 10000))
# with pytest.warns(UserWarning) as w:
cs = AveragedCrossspectrum(test_lc1, test_lc2, segment_size=10, norm="leahy")
maxpower = np.max(cs.power)
assert np.all(np.isfinite(cs.classical_significances(threshold=maxpower / 2.0)))
Expand Down

0 comments on commit f07d308

Please sign in to comment.