Skip to content

Commit

Permalink
Set correct data filepaths.
Browse files Browse the repository at this point in the history
  • Loading branch information
alserene committed Sep 4, 2024
1 parent b340bd9 commit 34c20c6
Showing 1 changed file with 12 additions and 30 deletions.
42 changes: 12 additions & 30 deletions tests/test_doppler_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Doppler inference tests.
"""

import os
import pytest

import jax
Expand Down Expand Up @@ -34,46 +35,27 @@

jax.config.update("jax_enable_x64", True)

# import importlib.resources
# import jaxodi

# TEST_DATA_S = str(importlib.resources.files(jaxodi).joinpath("tests", "map_solve_S_input.npy"))
# TEST_DATA_FLUX = str(importlib.resources.files(jaxodi).joinpath("tests", "map_solve_flux_input.npy"))
# TEST_DATA_CHO_C = str(importlib.resources.files(jaxodi).joinpath("tests", "map_solve_cho_C_input.npy"))
# TEST_DATA_MU = str(importlib.resources.files(jaxodi).joinpath("tests", "map_solve_mu_input.npy"))
# TEST_DATA_INVL = str(importlib.resources.files(jaxodi).joinpath("tests", "map_solve_invL_input.npy"))

# @pytest.fixture(autouse=True)
# def saved_input_data():
# Get current working directory
CWD = os.getcwd()

# with open(TEST_DATA_S, "rb") as f:
# S = jnp.load(f)
# with open(TEST_DATA_FLUX, "rb") as f:
# flux = jnp.load(f)
# with open(TEST_DATA_CHO_C, "rb") as f:
# cho_C = jnp.load(f)
# with open(TEST_DATA_MU, "rb") as f:
# mu = jnp.load(f)
# with open(TEST_DATA_INVL, "rb") as f:
# invL = jnp.load(f)
# If running the tests from within the test folder
# (rather than the root).
if "/tests" in CWD:
CWD = CWD.replace("/tests", "")

# return (S, flux, cho_C, mu, invL)

import os
cwd = os.getcwd()

@pytest.fixture(autouse=True)
def saved_input_data():

with open(f"{cwd}/tests/map_solve_S_input.npy", "rb") as f:
with open(f"{CWD}/tests/map_solve_S_input.npy", "rb") as f:
S = jnp.load(f)
with open(f"{cwd}/tests/map_solve_flux_input.npy", "rb") as f:
with open(f"{CWD}/tests/map_solve_flux_input.npy", "rb") as f:
flux = jnp.load(f)
with open(f"{cwd}/tests/map_solve_cho_C_input.npy", "rb") as f:
with open(f"{CWD}/tests/map_solve_cho_C_input.npy", "rb") as f:
cho_C = jnp.load(f)
with open(f"{cwd}/tests/map_solve_mu_input.npy", "rb") as f:
with open(f"{CWD}/tests/map_solve_mu_input.npy", "rb") as f:
mu = jnp.load(f)
with open(f"{cwd}/tests/map_solve_invL_input.npy", "rb") as f:
with open(f"{CWD}/tests/map_solve_invL_input.npy", "rb") as f:
invL = jnp.load(f)

return (S, flux, cho_C, mu, invL)
Expand Down

0 comments on commit 34c20c6

Please sign in to comment.