Skip to content

Commit

Permalink
chore: update docs + cleaner linear search
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrcia committed May 27, 2024
1 parent ef2c6c6 commit 87b285d
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 454 deletions.
12 changes: 6 additions & 6 deletions docs/notebooks/single.ipynb

Large diffs are not rendered by default.

25 changes: 17 additions & 8 deletions docs/notebooks/tutorials/GP_optimization.ipynb

Large diffs are not rendered by default.

222 changes: 98 additions & 124 deletions docs/notebooks/tutorials/exocomet.ipynb

Large diffs are not rendered by default.

135 changes: 30 additions & 105 deletions docs/notebooks/tutorials/ground_based.ipynb

Large diffs are not rendered by default.

241 changes: 57 additions & 184 deletions docs/notebooks/tutorials/tess_search.ipynb

Large diffs are not rendered by default.

51 changes: 26 additions & 25 deletions nuance/nuance.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,38 +326,39 @@ def linear_search(
if backend is None:
backend = jax.default_backend()

if backend == "cpu":
eval_t0_Ds_function = core.pmap_cpus
if batch_size is None:
batch_size = DEVICES_COUNT
if batch_size is None:
batch_size = {"cpu": DEVICES_COUNT, "gpu": 1000}[backend]

elif backend == "gpu":
eval_t0_Ds_function = core.vmap_gpu
if batch_size is None:
batch_size = 1000
@jax.jit
def solve(t0, D):
m = self.model(self.time, t0, D)
ll, w, v = self.eval_model(m)
return jnp.array([w[-1], v[-1, -1], ll])

eval_t0s_Ds = eval_t0_Ds_function(self.eval_model, self.model, self.time)
if backend == "cpu":
solve_batch = jax.pmap(
jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None)
)
else:
solve_batch = jax.vmap(
jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None)
)

batches_n = int(np.ceil(len(t0s) / batch_size))
padded_t0s = np.pad(t0s, pad_width=[0, batches_n * batch_size - len(t0s)])
batched_t0s = np.array(np.array_split(padded_t0s, batches_n))
t0s_padded = np.pad(t0s, [0, batch_size - (len(t0s) % batch_size) % batch_size])
t0s_batches = np.reshape(
t0s_padded, (len(t0s_padded) // batch_size, batch_size)
)

ll = np.zeros((len(padded_t0s), len(Ds)))
depths = ll.copy()
vars = ll.copy()
depths = ll.copy()
_progress = lambda x: (tqdm(x, unit_scale=batch_size) if progress else x)

_progress = lambda x: (tqdm(x) if progress else x)
results = []

for i, t0 in enumerate(_progress(batched_t0s)):
_depths, _vars, _ll = eval_t0s_Ds(t0, Ds)
depths[i * batch_size : (i + 1) * batch_size, :] = _depths.T
vars[i * batch_size : (i + 1) * batch_size, :] = _vars.T
ll[i * batch_size : (i + 1) * batch_size, :] = _ll.T
for t0_batch in _progress(t0s_batches):
results.append(solve_batch(t0_batch, Ds))

depths = np.array(depths[0 : len(t0s), :])
vars = np.array(vars[0 : len(t0s), :])
ll = np.array(ll[0 : len(t0s), :])
depths, vars, ll = np.transpose(results, axes=[3, 0, 1, 2]).reshape(
(3, len(t0s_padded), len(Ds))
)[:, 0 : len(t0s), :]

if positive:
ll0 = self.eval_model(np.zeros_like(self.time))[0]
Expand Down
5 changes: 3 additions & 2 deletions tests/test_full.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import os

import jax
import numpy as np

jax.config.update("jax_enable_x64", True)
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"
Expand Down Expand Up @@ -43,7 +44,7 @@ def test_mask_t0s_not_equal_time():
nu = Nuance(time, flux, gp=gp, X=X)

# linear search
t0s = np.random.choice(time, size=100, replace=False)
t0s = np.random.choice(time, size=122, replace=False)
Ds = np.linspace(0.01, 0.2, 15)
nu.linear_search(t0s, Ds)

Expand Down

0 comments on commit 87b285d

Please sign in to comment.