Skip to content

Commit

Permalink
fix: better combined
Browse files Browse the repository at this point in the history
fix: better combined
  • Loading branch information
lgrcia authored Nov 29, 2023
2 parents 069f497 + 10a2a0f commit 75ccf42
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 15 deletions.
90 changes: 77 additions & 13 deletions nuance/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from nuance.utils import periodic_transit
from tqdm.autonotebook import tqdm
import jax.numpy as jnp
from nuance import utils


def solve_triangular(*gps_y):
Expand All @@ -31,26 +32,37 @@ class CombinedNuance:
"""The c parameter of the transit model."""

def __post_init__(self):
for d in self.datasets:
assert (
d.search_data is not None
), "Linear search missing for at least one dataset. Run `linear_search` on all datasets."

t0s = np.hstack([d.search_data.t0s for d in self.datasets])
Ds = np.hstack([d.search_data.Ds for d in self.datasets])
ll = np.vstack([d.search_data.ll for d in self.datasets])
z = np.vstack([d.search_data.z for d in self.datasets])
vz = np.vstack([d.search_data.vz for d in self.datasets])
ll0 = np.sum([d.search_data.ll0 for d in self.datasets])

self.search_data = SearchData(t0s, Ds, ll, z, vz, ll0)
self._fill_search_data()
self._compute_L()

def _fill_search_data(self):
if all([d.search_data is not None for d in self.datasets]):
t0s = np.hstack([d.search_data.t0s for d in self.datasets])
Ds = np.hstack([d.search_data.Ds for d in self.datasets])
ll = np.vstack([d.search_data.ll for d in self.datasets])
z = np.vstack([d.search_data.z for d in self.datasets])
vz = np.vstack([d.search_data.vz for d in self.datasets])
ll0 = np.sum([d.search_data.ll0 for d in self.datasets])

self.search_data = SearchData(t0s, Ds, ll, z, vz, ll0)
else:
self.search_data = None

@property
def n(self):
"""Number of datasets"""
return len(self.datasets)

@property
def time(self):
"""Time of all datasets"""
return np.hstack([d.time for d in self.datasets])

@property
def flux(self):
"""Flux of all datasets"""
return np.hstack([d.flux for d in self.datasets])

def _compute_L(self):
Liy = solve_triangular(*[(d.gp, d.flux) for d in self.datasets])
LiX = solve_triangular(*[(d.gp, d.X.T) for d in self.datasets])
Expand All @@ -66,6 +78,10 @@ def eval_m(ms):

self.eval_m = eval_m

def linear_search(self, t0s, Ds, progress=True):
for d in self.datasets:
d.linear_search(t0s, Ds, progress=progress)

def periodic_transits(self, t0, D, P, c=None):
if c is None:
c = self.c
Expand Down Expand Up @@ -167,3 +183,51 @@ def _search(p):
new_search_data.Q_params = params

return new_search_data

def models(self, t0, D, P, c=None):
"""Solve the combined model for a given set of parameters.
Parameters
----------
t0 : float
epoch, same unit as time
D : float
duration, same unit as time
P : float, optional
period, same unit as time, by default None
c : float, optional
c parameter of the transit model, by default None
Returns
-------
list
(w, v): linear coefficients and their covariance matrix
"""
if c is None:
c = self.c
m = self.periodic_transits(t0, D, P, c)
w, _ = self.eval_m(m)

# means
w_idxs = [0, *np.cumsum([d.X.shape[0] for d in self.datasets])]
means = []
for i in range(len(w_idxs) - 1):
means.append(np.array(w[w_idxs[i] : w_idxs[i + 1]]) @ self.datasets[i].X)

# signals
signals = [
utils.transit(d.time, t0, D, P=P, c=c) * w[-1] for d in self.datasets
]

# noises
noises = []
for i, d in enumerate(self.datasets):
_, cond = d.gp.condition(d.flux - means[i] - signals[i])
noises.append(cond.mean)

return np.hstack(means), np.hstack(signals), np.hstack(noises)

def mask(self):
new_self = self.__class__(datasets=[d.mask() for d in self.datasets], c=self.c)
new_self.datasets = [d.mask() for d in self.datasets]
new_self._fill_search_data()
4 changes: 3 additions & 1 deletion nuance/nuance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

from functools import partial

from . import CPU_counts, utils
from . import utils
from .search_data import SearchData

# set_start_method("spawn")

CPU_counts = jax.device_count()


@dataclass
class Nuance:
Expand Down
1 change: 0 additions & 1 deletion nuance/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def periodic_transit(t, t0, D, P=1, c=12):


def interp_split_times(time, p, dphi=0.01):
dt = np.median(np.diff(time))
tmax, tmin = np.max(time), np.min(time)
total = tmax - tmin
# since for very small periods we might fold on few points only, it's better to impose
Expand Down

0 comments on commit 75ccf42

Please sign in to comment.