Skip to content

Commit

Permalink
feat: fast injection-recovery
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrcia committed Jul 31, 2024
1 parent f71ffee commit b622feb
Show file tree
Hide file tree
Showing 7 changed files with 385 additions and 41 deletions.
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ citation.md
:caption: Tutorials
notebooks/tutorials/GP_optimization.ipynb
notebooks/tutorials/analytical-ir.ipynb
notebooks/tutorials/ground_based.ipynb
notebooks/tutorials/tess_search.ipynb
notebooks/tutorials/exocomet.ipynb
Expand Down
37 changes: 19 additions & 18 deletions docs/notebooks/tutorials/GP_optimization.ipynb
Original file line number Diff line number Diff line change
@@ -1,39 +1,40 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [],
"cell_type": "markdown",
"metadata": {},
"source": [
"import os\n",
"# GP optimization\n",
"\n",
"os.environ[\"XLA_FLAGS\"] = f\"--xla_force_host_platform_device_count={os.cpu_count()}\"\n",
"nuance requires a Gaussian Process (GP) of the light curve to be built and optimized before searching for transits.\n",
"\n",
"import jax\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)"
"In practice, any `tinygp.GaussianProcess` object can be provided. Here is an example of how to build and optimize a custom GP on the light curve of the active star [TOI 451](https://ui.adsabs.harvard.edu/abs/2021AJ....161...65N/abstract).\n",
"\n",
"```{note}\n",
"This tutorial requires the `lightkurve` package to access the data\n",
"```\n",
"\n",
"In order to run this tutorial on all available CPUs, we set the `XLA_FLAGS` env variable to"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# GP optimization"
"import os\n",
"import jax\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"os.environ[\"XLA_FLAGS\"] = f\"--xla_force_host_platform_device_count={os.cpu_count()}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"nuance requires a Gaussian Process (GP) of the light curve to be built and optimized before searching for transits.\n",
"\n",
"In practice, any `tinygp.GaussianProcess` object can be provided. Here is an example of how to build and optimize a custom GP on the light curve of the active star [TOI 451](https://ui.adsabs.harvard.edu/abs/2021AJ....161...65N/abstract).\n",
"\n",
"## Loading data\n",
"\n",
"As in previous tutorials, we will download light curves of TOI 451 using the `lightkurve` package."
Expand Down
314 changes: 314 additions & 0 deletions docs/notebooks/tutorials/analytical-ir.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion docs/notebooks/tutorials/exocomet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
"metadata": {},
"outputs": [],
"source": [
"# in order to run on all CPUs\n",
"import os\n",
"import jax\n",
"\n",
Expand Down
48 changes: 34 additions & 14 deletions docs/notebooks/tutorials/tess_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -251,15 +251,22 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 14370/14370 [00:19<00:00, 720.80it/s]\n"
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2ec1c2b5f51f4bc2aee273b4ee308dbb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/14370 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
Expand All @@ -282,7 +289,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -301,17 +308,30 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/nuance/lib/python3.10/site-packages/multiprocess/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
" self.pid = os.fork()\n",
"100%|██████████| 11745/11745 [00:10<00:00, 1092.31it/s]\n"
" self.pid = os.fork()\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c1301ab9b2b049a7814cfc112149aa30",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/11745 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
Expand All @@ -334,7 +354,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -372,7 +392,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 12,
"metadata": {
"tags": [
"hide-input"
Expand All @@ -383,7 +403,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/7v/d8bs1hz144s2ypqglv245hp40000gn/T/ipykernel_34476/741865150.py:9: UserWarning: No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n",
"/var/folders/7v/d8bs1hz144s2ypqglv245hp40000gn/T/ipykernel_10984/741865150.py:9: UserWarning: No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n",
" ax.legend()\n"
]
},
Expand Down
23 changes: 16 additions & 7 deletions nuance/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,37 @@ def solve_model(flux, X, gp):
assert (
len(gp) == len(flux) == len(X)
), "gp, flux, and datasets must have the same length"
Liy = solve_triangular(*[(_gp, _flux) for _gp, _flux in zip(gp, flux)])
base_Liy = solve_triangular(*[(_gp, _flux) for _gp, _flux in zip(gp, flux)])
LiX = solve_triangular(*[(_gp, _X.T) for _gp, _X in zip(gp, X)])

@jax.jit
def function(ms):
def function(ms, depth=None):
if depth is not None:
raise NotImplementedError(
"depth is not implemented for multiple datasets, open an issue"
)
_Liy = base_Liy
Lim = solve_triangular(*[(_gp, m) for _gp, m in zip(gp, ms)])
LiXm = jnp.hstack([LiX, Lim[:, None]])
LiXmT = LiXm.T
LimX2 = LiXmT @ LiXm
w = jnp.linalg.lstsq(LimX2, LiXmT @ Liy)[0]
w = jnp.linalg.lstsq(LimX2, LiXmT @ _Liy)[0]
v = jnp.linalg.inv(LimX2)
return 0.0, w, v

return function

# single gp and dataset
else:
Liy = gp.solver.solve_triangular(flux)
base_Liy = gp.solver.solve_triangular(flux)
LiX = gp.solver.solve_triangular(X.T)

@jax.jit
def function(m):
def function(m, depth=None):
if depth is not None:
Liy = base_Liy + gp.solver.solve_triangular(depth * m)
else:
Liy = base_Liy
Xm = jnp.vstack([X, m])
Lim = gp.solver.solve_triangular(m)
LiXm = jnp.hstack([LiX, Lim[:, None]])
Expand Down Expand Up @@ -125,9 +134,9 @@ def _model(time, epoch, duration, period=None):
else:
_model = model

def function(epoch, duration, period=None):
def function(epoch, duration, period=None, depth=None):
m = _model(time, epoch, duration, period=period)
ll, w, v = solve_m(m)
ll, w, v = solve_m(m, depth=depth)
return jnp.array([w[-1], v[-1, -1], ll])

return function
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nuance"
version = "0.8.0"
version = "0.8.1"
description = "Transit signals detection among correlated noises"
authors = [{ name = "Lionel Garcia" }]
license = "MIT"
Expand Down

0 comments on commit b622feb

Please sign in to comment.