Skip to content

Commit

Permalink
feat: more functional
Browse files Browse the repository at this point in the history
feat: more functional
  • Loading branch information
lgrcia authored Jul 30, 2024
2 parents f58becf + d97399d commit ca3375d
Show file tree
Hide file tree
Showing 24 changed files with 1,727 additions and 2,675 deletions.
14 changes: 6 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,21 @@ Documentation at [nuance.readthedocs.io](https://nuance.readthedocs.io)
## Example

```python
from nuance import Nuance, utils
import numpy as np

(time, flux, error), X, gp = utils.simulated()

nu = Nuance(time, flux, gp=gp, X=X)
from nuance import linear_search, periodic_search, core

# linear search
epochs = time.copy()
durations = np.linspace(0.01, 0.2, 15)
nu.linear_search(epochs, durations)
ls = linear_search(time, flux, gp=gp)(epochs, durations)

# periodic search
periods = np.linspace(0.3, 5, 2000)
search = nu.periodic_search(periods)
snr_function = jax.jit(core.snr(time, flux, gp=gp))
ps_function = periodic_search(epochs, durations, ls, snr_function)
snr, params = ps_function(periods)

t0, D, P = search.best
t0, D, P = params[np.argmax(snr)]
```

## Installation
Expand Down
1 change: 0 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ notebooks/tutorials/exocomet.ipynb
:maxdepth: 1
:caption: Reference
markdown/how.ipynb
notebooks/star.ipynb
notebooks/templates.ipynb
markdown/hardware.md
Expand Down
20 changes: 16 additions & 4 deletions docs/markdown/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,26 @@

```{eval-rst}
.. currentmodule:: nuance
.. autosummary::
:toctree: generated
:nosignatures:
:template: class.rst
~nuance.Star
~nuance.nuance.Nuance
~nuance.search_data.SearchData
~nuance.combined.CombinedNuance
Star
```

```{eval-rst}
.. currentmodule:: nuance
.. autofunction:: linear_search
.. autofunction:: periodic_search
.. automodule:: nuance.core
:members:
:show-inheritance:
```
57 changes: 10 additions & 47 deletions docs/markdown/hardware.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# Hardware acceleration

When running the linear search, nuance exploits the parallelization capabilities of JAX by using a default mapping strategy depending on the available devices.
When running the *linear search*, nuance exploits the parallelization capabilities of JAX by using a default mapping strategy depending on the available devices.

## Solving for `(t0, D)`

To solve a particular model (like a transit) with a given epoch `t0` and duration `D`, we define the function
To solve a particular model (like a transit) with a given epoch `t0` and duration `D`, we define the function (output of `core.solve`)

```python
import jax

@jax.jit
def solve(t0, D):
m = model(time, t0, D)
ll, w, v = nu._solve(m)
return w[-1], v[-1, -1], ll
def solve(t0, D, period=None):
m = model(time, t0, D, period=period)
ll, w, v = solve_m(m)
return jnp.array([w[-1], v[-1, -1], ll])
```

where `model` is the [template model](../notebooks/templates.ipynb), and `nu._solve` is the `Nuance._solve` method returning:
where `model` is the [template model](../notebooks/templates.ipynb). This function returns

- `w[-1]` the template model depth
- `v[-1, -1]` the variance of the template model depth
Expand All @@ -35,7 +35,7 @@ t0s_batches = np.reshape(

## JAX mapping

In order to solve a given batch in an optimal way, the `batch_size` can be set depending on the devices available:
In order to solve a given batch in an optimal way, the `batch_size` can be set depending on the devices available (see the `linear_search` documentation):

- If multiple **CPUs** are available, the `batch_size` is chosen as the number of devices (`jax.device_count()`) and we can solve a given batch using

Expand All @@ -61,42 +61,5 @@ for t0_batch in t0s_batches:
```

```{note}
Of course, one familiar with JAX can use their own mapping strategy to evaluate `solve` over a grid of epochs `t0s` and durations `Ds`.
```

## The full method

The method `nuance.Naunce.linear_search` is then

```python
def linear_search( self, t0s, Ds):

backend = jax.default_backend()
batch_size = {"cpu": DEVICES_COUNT, "gpu": 1000}[backend]

@jax.jit
def solve(t0, D):
m = self.model(self.time, t0, D)
ll, w, v = self._solve(m)
return jnp.array([w[-1], v[-1, -1], ll])

# Batches
t0s_padded = np.pad(t0s, [0, batch_size - (len(t0s) % batch_size) % batch_size])
t0s_batches = np.reshape(
t0s_padded, (len(t0s_padded) // batch_size, batch_size)
)

# Mapping
if backend == "cpu":
solve_batch = jax.pmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))
else:
solve_batch = jax.vmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))

# Iterate
results = []

for t0_batch in t0s_batches:
results.append(solve_batch(t0_batch, Ds))

...
Of course, one familiar with JAX can use their own mapping strategy to evaluate `solve` over a grid of epochs `t0s` and durations `Ds`. For these users, the implementation of the `linear_search` method is a good place to start.
```
246 changes: 121 additions & 125 deletions docs/notebooks/combined.ipynb

Large diffs are not rendered by default.

1,023 changes: 504 additions & 519 deletions docs/notebooks/motivation.ipynb

Large diffs are not rendered by default.

227 changes: 123 additions & 104 deletions docs/notebooks/multi.ipynb

Large diffs are not rendered by default.

163 changes: 92 additions & 71 deletions docs/notebooks/periodic.ipynb

Large diffs are not rendered by default.

106 changes: 61 additions & 45 deletions docs/notebooks/single.ipynb

Large diffs are not rendered by default.

40 changes: 10 additions & 30 deletions docs/notebooks/templates.ipynb

Large diffs are not rendered by default.

190 changes: 58 additions & 132 deletions docs/notebooks/tutorials/exocomet.ipynb

Large diffs are not rendered by default.

289 changes: 127 additions & 162 deletions docs/notebooks/tutorials/ground_based.ipynb

Large diffs are not rendered by default.

156 changes: 72 additions & 84 deletions docs/notebooks/tutorials/tess_search.ipynb

Large diffs are not rendered by default.

92 changes: 0 additions & 92 deletions docs/readme_quick.ipynb

This file was deleted.

5 changes: 2 additions & 3 deletions nuance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
config = jax.config
config.update("jax_enable_x64", True)

from nuance.combined import CombinedNuance
from nuance.nuance import Nuance
from nuance.search_data import SearchData
from nuance.linear_search import linear_search
from nuance.periodic_search import periodic_search
from nuance.star import Star
Loading

0 comments on commit ca3375d

Please sign in to comment.