Skip to content

Commit

Permalink
feat: custom models and core refactor
Browse files Browse the repository at this point in the history
feat: custom models and core refactor
  • Loading branch information
lgrcia authored Feb 2, 2024
2 parents 986d980 + 9b4eb77 commit 12ad2de
Show file tree
Hide file tree
Showing 14 changed files with 1,883 additions and 1,302 deletions.
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ notebooks/combined.ipynb
:caption: Tutorials
notebooks/tutorials/tess_search.ipynb
notebooks/tutorials/exocomet.ipynb
```

```{toctree}
Expand Down
654 changes: 331 additions & 323 deletions docs/notebooks/combined.ipynb

Large diffs are not rendered by default.

940 changes: 467 additions & 473 deletions docs/notebooks/multi.ipynb

Large diffs are not rendered by default.

542 changes: 277 additions & 265 deletions docs/notebooks/periodic.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions docs/notebooks/single.ipynb

Large diffs are not rendered by default.

550 changes: 550 additions & 0 deletions docs/notebooks/tutorials/exocomet.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/notebooks/tutorials/tess_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8363e9c40d0046d7bc9557036254652a",
"model_id": "95a52071645a4d99a86f1e84c6302c14",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -275,7 +275,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eada00d375094950b4285f7835528a66",
"model_id": "352af69d7cb64138859b675746091c83",
"version_major": 2,
"version_minor": 0
},
Expand Down
22 changes: 12 additions & 10 deletions nuance/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from scipy.linalg import block_diag
from tqdm.autonotebook import tqdm

from nuance import utils
from nuance import core, utils
from nuance.nuance import Nuance
from nuance.search_data import SearchData
from nuance.utils import periodic_transit


def solve_triangular(*gps_y):
Expand Down Expand Up @@ -37,6 +36,11 @@ def __post_init__(self):
self._fill_search_data()
self._compute_L()

@property
def model(self):
"""The model"""
return self.datasets[0].model

def _fill_search_data(self):
if all([d.search_data is not None for d in self.datasets]):
t0s = np.hstack([d.search_data.t0s for d in self.datasets])
Expand Down Expand Up @@ -87,7 +91,7 @@ def linear_search(self, t0s, Ds, progress=True):
def periodic_transits(self, t0, D, P, c=None):
if c is None:
c = self.c
return [periodic_transit(d.search_data.t0s, t0, D, P) for d in self.datasets]
return [self.model(d.search_data.t0s, t0, D, P) for d in self.datasets]

def solve(self, t0, D, P, c=None):
"""Solve the combined model for a given set of parameters.
Expand Down Expand Up @@ -189,7 +193,7 @@ def _search(p):

return new_search_data

def models(self, t0, D, P, c=None):
def models(self, t0, D, P):
"""Solve the combined model for a given set of parameters.
Parameters
Expand All @@ -210,7 +214,7 @@ def models(self, t0, D, P, c=None):
"""
if c is None:
c = self.c
m = self.periodic_transits(t0, D, P, c)
m = self.model(t0, D, P)
w, _ = self.eval_m(m)

# means
Expand All @@ -220,9 +224,7 @@ def models(self, t0, D, P, c=None):
means.append(np.array(w[w_idxs[i] : w_idxs[i + 1]]) @ self.datasets[i].X)

# signals
signals = [
utils.transit(d.time, t0, D, P=P, c=c) * w[-1] for d in self.datasets
]
signals = [self.model(d.time, t0, D, P=P) * w[-1] for d in self.datasets]

# noises
noises = []
Expand All @@ -232,9 +234,9 @@ def models(self, t0, D, P, c=None):

return np.hstack(means), np.hstack(signals), np.hstack(noises)

def mask_transit(self, t0: float, D: float, P: float):
def mask_model(self, t0: float, D: float, P: float):
new_self = self.__class__(
datasets=[d.mask_transit(t0, D, P) for d in self.datasets], c=self.c
datasets=[d.mask_model(t0, D, P) for d in self.datasets], c=self.c
)
new_self._fill_search_data()
new_self._compute_L()
Expand Down
31 changes: 31 additions & 0 deletions nuance/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import jax
import jax.numpy as jnp


def eval_model(flux, X, gp):
Liy = gp.solver.solve_triangular(flux)
LiX = gp.solver.solve_triangular(X.T)

@jax.jit
def function(m):
Xm = jnp.vstack([X, m])
Lim = gp.solver.solve_triangular(m)
LiXm = jnp.hstack([LiX, Lim[:, None]])
LiXmT = LiXm.T
LimX2 = LiXmT @ LiXm
w = jnp.linalg.lstsq(LimX2, LiXmT @ Liy)[0]
v = jnp.linalg.inv(LimX2)
return gp.log_probability(flux - w @ Xm), w, v

return function


@jax.jit
def transit_protopapas(t, t0, D, P=1e15, c=12, d=1.0):
_t = P * jnp.sin(jnp.pi * (t - t0) / P) / (jnp.pi * D)
return -d * 0.5 * jnp.tanh(c * (_t + 1 / 2)) + 0.5 * jnp.tanh(c * (_t - 1 / 2))


@jax.jit
def transit_box(time, t0, D, P=1e15):
return -((jnp.abs(time - t0) % P) < D / 2).astype(float)
Loading

0 comments on commit 12ad2de

Please sign in to comment.