Skip to content

Commit

Permalink
Merge pull request #46 from transientlunatic/fix-injections
Browse files Browse the repository at this point in the history
Fix writing of injection files
  • Loading branch information
transientlunatic authored Jan 7, 2025
2 parents 377ecbb + 7112445 commit d73f8b9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
15 changes: 11 additions & 4 deletions heron/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
def make_injection(
waveform=IMRPhenomPv2,
injection_parameters={},
duration=32,
sample_rate=4096,
times=None,
detectors=None,
framefile=None,
Expand All @@ -33,7 +35,7 @@ def make_injection(
waveform = waveform()

if times is None:
times = np.linspace(-0.5, 0.1, int(0.6 * 4096))
times = np.linspace(parameters['gpstime']-duration+2, parameters['gpstime']+2, int(duration * sample_rate))
waveform = waveform.time_domain(
parameters,
times=times,
Expand All @@ -44,16 +46,19 @@ def make_injection(
logger.info(f"Making injection for {detector}")
psd_model = KNOWN_PSDS[psd_model]()
detector = KNOWN_IFOS[detector]()
if times is None:
times = waveform['plus'].times.value
data = psd_model.time_series(times)
print(data)

channel = f"{detector.abbreviation}:Injection"
injection = data + waveform.project(detector)
injection.channel = channel
injections[detector.abbreviation] = injection
likelihood = TimeDomainLikelihood(injection, psd=psd_model)
snr = likelihood.snr(waveform.project(detector))
# likelihood = TimeDomainLikelihood(injection, psd=psd_model)
# snr = likelihood.snr(waveform.project(detector))

logger.info(f"Optimal Filter SNR: {snr}")
#logger.info(f"Optimal Filter SNR: {snr}")

if framefile:
filename = f"{detector.abbreviation}_{framefile}.gwf"
Expand Down Expand Up @@ -146,6 +151,8 @@ def injection(settings):
}
injections = make_injection(
waveform=IMRPhenomPv2,
duration=settings["duration"],
sample_rate=settings["sample rate"],
injection_parameters=parameters,
detectors=detector_dict,
framefile="injection",
Expand Down
1 change: 1 addition & 0 deletions heron/models/lalnoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def time_series(self, times):

dt = times[1] - times[0]
N = len(times)
print(N)
T = times[-1] - times[0]
df = 1 / T
frequencies = torch.arange(len(times) // 2 + 1) * df
Expand Down
12 changes: 6 additions & 6 deletions heron/models/lalsimulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def time_domain(self, parameters, times=None):
"""
Retrieve a time domain waveform for a given set of parameters.
"""
epoch = parameters.get("gpstime", parameters.get("epoch", 0))
self._args.update(parameters)
epoch = parameters.get("gpstime", 0)

print("epoch is ", epoch)
if not (self._args == self._cache_key):
self.logger.info(f"Generating new waveform at {self.args}")
self._cache_key = self.args.copy()
Expand Down Expand Up @@ -162,17 +162,17 @@ def time_domain(self, parameters, times=None):
spl_hx = CubicSpline(times_wf, hx.data.data)
hp_data = spl_hp(times)
hx_data = spl_hx(times)
hp_ts = Waveform(data=hp_data, times=times)
hx_ts = Waveform(data=hx_data, times=times)
hp_ts = Waveform(data=hp_data, times=times + epoch)
hx_ts = Waveform(data=hx_data, times=times + epoch)
parameters.pop("time")
else:
hp_data = hp.data.data
hx_data = hx.data.data
hp_ts = Waveform(data=hp_data, dt=hp.deltaT, t0=hp.epoch + epoch)
hx_ts = Waveform(data=hx_data, dt=hx.deltaT, t0=hx.epoch + epoch)

self._cache = WaveformDict(parameters=parameters, plus=hp_ts, cross=hx_ts)

print("written epoch is ", hp_ts.times[0])
return self._cache


Expand Down

0 comments on commit d73f8b9

Please sign in to comment.