Skip to content

Commit

Permalink
Merge pull request #5 from ADACS-Australia/doppler-inference
Browse files Browse the repository at this point in the history
update test for solve
  • Loading branch information
JHu-s authored Oct 8, 2024
2 parents 91f232b + 1e97113 commit 243144c
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 16 deletions.
249 changes: 233 additions & 16 deletions python/jaxodi/doppler_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@

@jax.jit
def cho_solve(A: Array, b: Array) -> Array:
"""
Solve the linear system A x = b using the Cholesky decomposition of A.
Args:
A (Array): Lower triangular Cholesky factor of the matrix.
b (Array): Right-hand side of the equation.
Returns:
Array: Solution to the linear system.
"""
b_ = jax.scipy.linalg.solve_triangular(A, b, lower=True)
return jax.scipy.linalg.solve_triangular(jnp.transpose(A), b_, lower=False)

Expand All @@ -32,19 +42,18 @@ def map_solve(
spherical harmonic coefficients of a map given a flux timeseries.
Args:
X (matrix): The flux design matrix.
flux (array): The flux timeseries.
cho_C (scalar/vector/matrix): The lower cholesky factorization
X (Array): The flux design matrix.
flux (Array): The flux timeseries.
cho_C (float | Array): The lower cholesky factorization
of the data covariance.
mu (array): The prior mean of the spherical harmonic coefficients.
LInv (scalar/vector/matrix): The inverse prior covariance of the
mu (Array): The prior mean of the spherical harmonic coefficients.
LInv (float | Array): The inverse prior covariance of the
spherical harmonic coefficients.
Returns:
The vector of spherical harmonic coefficients corresponding to the
MAP solution and the Cholesky factorization of the corresponding
covariance matrix.
Tuple[Array, Array]: The vector of spherical harmonic coefficients
corresponding to the MAP solution and the Cholesky factorization
of the corresponding covariance matrix.
"""
# Compute C^-1 . X
if cho_C.ndim == 0:
Expand Down Expand Up @@ -98,6 +107,32 @@ def process_inputs(
logTf=None,
nlogT=None,
):
"""
Process and normalize inputs for Doppler inference.
Args:
flux (Array): The observed spectral timeseries.
nt (int): Number of time points.
nw (int): Number of wavelength points.
nc (int): Number of components.
Ny (int): Number of spatial points.
nw0 (int): Initial number of wavelength points.
nw0_ (int): Final number of wavelength points.
S0e2i (Array): Transformation matrix.
flux_err (float, optional): Flux error. Defaults to 1e-4.
normalized (bool, optional): Whether the flux is normalized. Defaults to True.
baseline (optional): Baseline value. Defaults to None.
spatial_mean (optional): Mean of the spatial component. Defaults to None.
spatial_cov (optional): Covariance of the spatial component. Defaults to None.
spectral_mean (optional): Mean of the spectral component. Defaults to None.
spectral_cov (optional): Covariance of the spectral component. Defaults to None.
logT0 (optional): Logarithmic starting temperature. Defaults to None.
logTf (optional): Logarithmic final temperature. Defaults to None.
nlogT (optional): Number of logarithmic temperature steps. Defaults to None.
Returns:
dict: Processed input values.
"""

# Process defaults
if flux_err is None:
Expand Down Expand Up @@ -257,7 +292,25 @@ def get_D_fixed_spectrum(
xamp, vsini, ydeg, udeg, nk, inc, theta, spectrum, nc, nwp, nt, Ny, nw,
):
"""
Return the Doppler matrix for a fixed spectrum.
Return the Doppler convolution matrix for a fixed spectrum.
Args:
xamp: Amplitude of the Doppler kernel.
vsini: Rotational velocity.
ydeg: Degree of the Y polynomial.
udeg: Degree of the U polynomial.
nk: Number of Doppler kernel points.
inc: Inclination angle.
theta: Parameters for the Doppler kernel.
spectrum: Spectral data.
nc: Number of components.
nwp: Number of wavelength points in the spectrum.
nt: Number of time points.
Ny: Number of spatial points.
nw: Number of wavelength points in the observed data.
Returns:
Array: Convolution matrix.
"""
# Get the convolution kernels
kT = get_kT(xamp, vsini, ydeg, udeg, nk, inc, theta)
Expand Down Expand Up @@ -287,8 +340,17 @@ def get_D_fixed_spectrum(
def sparse_dot(A, B):
"""
Performs matrix multiplication, optimising computation time by utilising sparse matrices.
"""
Args:
A (Array): First matrix.
B (Array): Second matrix.
Returns:
Array: Result of the matrix multiplication.
Raises:
ValueError: If neither input matrix is sparse.
"""
def is_sparse(dense_matrix, threshold=0.9):
flattened_matrix = dense_matrix.flatten()
num_zeros = jnp.sum(flattened_matrix == 0)
Expand All @@ -310,7 +372,22 @@ def get_default_theta(
theta: Array,
_angle_factor: float,
) -> Array:
"""
Scale the parameters of the Doppler kernel by the provided angle factor.
The angle factor is used to convert between different units of angle
measurement, such as degrees to radians. The value of ``_angle_factor``
determines how the input ``theta`` is scaled.
Args:
theta (Array): Parameters of the Doppler kernel.
_angle_factor (float): Factor to convert angle units. For example,
use ``np.pi / 180`` to convert degrees to radians, or
``180 / np.pi`` to convert radians to degrees.
Returns:
Array: Scaled parameters.
"""
return theta * _angle_factor


Expand All @@ -324,8 +401,29 @@ def design_matrix(
"""
Return the Doppler imaging design matrix.
This matrix dots into the spectral map to yield the model for the
observed spectral timeseries (the ``flux``).
Args:
theta (Array): The angular phase(s) at which to compute the
design matrix. This must be a vector of size :py:attr:`nt`.
_angle_factor (float): Scaling factor for the angle.
xamp: Amplitude of the Doppler kernel.
vsini (float): Rotational velocity.
ydeg (int): Degree of the Y polynomial.
udeg (int): Degree of the U polynomial.
nk (int): Number of Doppler kernel points.
inc (float): Inclination angle.
spectrum (Array): Spectral data.
nc (int): Number of components.
nwp (int): Number of wavelength points in the spectrum.
nt (int): Number of time points.
Ny (int): Number of spatial points.
nw (int): Number of wavelength points in the observed data.
_interp (bool): Whether to interpolate the matrix.
_Si2eBlk (Array): Transformation matrix for interpolation.
fix_spectrum (bool, optional): Whether to fix the spectrum. Defaults to True.
fix_map (bool, optional): Whether to fix the map. Defaults to False.
Returns:
Array: The design matrix for the inference problem.
"""
theta = get_default_theta(theta, _angle_factor) # this is just undoing what get_S() did!

Expand All @@ -348,7 +446,31 @@ def _get_S(
theta, _angle_factor, xamp, vsini, ydeg, udeg, nk, inc, spectrum,
nc, nwp, nt, Ny, nw, _interp, _Si2eBlk, fix_spectrum,
):
"""
Compute the design matrix for the Doppler inference, adjusted by the scaling factor.
Args:
theta (Array): Doppler kernel parameters.
_angle_factor (float): Scaling factor for the angle.
xamp: Amplitude of the Doppler kernel.
vsini (float): Rotational velocity.
ydeg (int): Degree of the Y polynomial.
udeg (int): Degree of the U polynomial.
nk (int): Number of Doppler kernel points.
inc (float): Inclination angle.
spectrum (Array): Spectral data.
nc (int): Number of components.
nwp (int): Number of wavelength points in the spectrum.
nt (int): Number of time points.
Ny (int): Number of spatial points.
nw (int): Number of wavelength points in the observed data.
_interp (bool): Whether to interpolate the matrix.
_Si2eBlk (Array): Transformation matrix for interpolation.
fix_spectrum (bool): Whether to fix the spectrum.
Returns:
Array: The design matrix adjusted by the scaling factor.
"""
theta = theta / _angle_factor

dm = design_matrix(
Expand Down Expand Up @@ -386,7 +508,36 @@ def solve_for_map_linear(
fix_spectrum: bool,
) -> tuple[Array, Array]:
"""
Solve for `y` linearly, given a baseline or unnormalized data.
Solve the Doppler inference problem for a linear model, given a baseline
of unnormalized data.
Args:
spatial_mean (Array): Mean vector of the spatial component.
spatial_inv_cov (Array): Inverse covariance matrix of the spatial component.
flux_err (float): Flux error.
nt (int): Number of time points.
nw (int): Number of wavelength points.
nw_ (int): Number of wavelength points for the error model.
T (Array): Temperature parameter.
flux (Array): The observed spectral timeseries.
theta (Array): Doppler kernel parameters.
_angle_factor (float): Scaling factor for the angle.
xamp: Amplitude of the Doppler kernel.
vsini (float): Rotational velocity.
ydeg (int): Degree of the Y polynomial.
udeg (int): Degree of the U polynomial.
nk (int): Number of Doppler kernel points.
inc (float): Inclination angle.
spectrum (Array): Spectral data.
nc (int): Number of components.
nwp (int): Number of wavelength points in the spectrum.
Ny (int): Number of spatial points.
_interp (bool): Whether to interpolate the matrix.
_Si2eBlk (Array): Transformation matrix for interpolation.
fix_spectrum (bool): Whether to fix the spectrum.
Returns:
Tuple[Array, Array]: Tuple containing the solution vector and its covariance matrix.
"""
# Reshape the priors
mu = jnp.reshape(jnp.transpose(spatial_mean), (-1))
Expand Down Expand Up @@ -450,7 +601,40 @@ def solve_bilinear(
_Si2eBlk: Array,
fix_spectrum: bool,
) -> tuple[Array, Array]:
"""
Solve the Doppler inference problem for the spatial and/or spectral map
given a spectral timeseries, using a bilinear model approach.
Args:
flux (Array): The observed spectral timeseries.
nt (int): Number of time points.
nw (int): Number of wavelength points in the observed data.
nw_ (int): Number of wavelength points in the error model.
nc (int): Number of spatial components.
Ny (int): Number of spatial points.
nw0 (int): Initial number of wavelength points.
nw0_ (int): Final number of wavelength points.
S0e2i (Array): Precomputed inverse covariance or related matrix.
flux_err (float): Error associated with the flux data.
normalized (bool): Indicates if the flux is normalized.
theta (Array): Parameters of the Doppler kernel.
_angle_factor (float): Scaling factor for the angle.
xamp: Amplitude of the Doppler kernel.
vsini (float): Rotational velocity.
ydeg (int): Degree of the Y polynomial.
udeg (int): Degree of the U polynomial.
nk (int): Number of Doppler kernel points.
inc (float): Inclination angle.
spectrum: Spectral data.
nwp (int): Number of wavelength points in the spectrum.
_interp (bool): Indicates whether to interpolate the matrix.
_Si2eBlk (Array): Transformation matrix for interpolation.
fix_spectrum (bool): Whether to fix the spectrum in the model.
Returns:
Tuple[Array, Array]: The first element is the solution array `y`,
and the second element is the covariance matrix `cho_ycov`.
"""
# reset() - if have a class with self attributes.

processed_inputs = process_inputs(
Expand Down Expand Up @@ -504,10 +688,43 @@ def solve(
solver: str="bilinear",
) -> tuple[Array, Array]:
"""
Iteratively solves the bilinear or nonlinear problem for the spatial
Iteratively solves the Doppler inference problem for the spatial
and/or spectral map given a spectral timeseries.
"""
Args:
flux (Array): The observed spectral timeseries.
nt (int): Number of time points.
nw (int): Number of wavelength points in the observed data.
nw_ (int): Number of wavelength points in the error model.
nc (int): Number of spatial components.
Ny (int): Number of spatial points.
nw0 (int): Initial number of wavelength points.
nw0_ (int): Final number of wavelength points.
S0e2i (Array): Precomputed inverse covariance or related matrix.
flux_err (float): The data uncertainty.
normalized (bool): Whether the ``flux`` dataset is
continuum-normalized.
theta (Array): The angular phase(s) at which the spectra
were observed.
_angle_factor (float): Scaling factor for the angle.
xamp: Amplitude of the Doppler kernel.
vsini (float): Rotational velocity.
ydeg (int): Degree of the Y polynomial.
udeg (int): Degree of the U polynomial.
nk (int): Number of Doppler kernel points.
inc (float): Inclination angle.
spectrum: Spectral data.
nwp (int): Number of wavelength points in the spectrum.
_interp (bool): Indicates whether to interpolate the matrix.
_Si2eBlk (Array): Transformation matrix for interpolation.
fix_spectrum (bool): If True, fixes the spectrum at the
current value and solves only for the map.
solver (str, optional): The solver method to use, default is "bilinear".
Returns:
Tuple[Array, Array]: The first element is the solution array `y`,
and the second element is the covariance matrix `cho_ycov`.
"""
# Used to calculate S.
theta = get_default_theta(theta, _angle_factor)

Expand Down
Binary file added tests/spot_y.npy
Binary file not shown.
7 changes: 7 additions & 0 deletions tests/test_doppler_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def saved_input_data():

# Load the component maps.
map.load(spectra=spectrum, smoothing=0.075)
with open(f"{CWD}/tests/spot_y.npy", "rb") as f:
map._y = np.load(f)

# image = str(CWD + "/tests/spot.png")
# map.load(maps=[image], spectra=spectrum, smoothing=0.075)
# with open(f"{CWD}/tests/spot_y.npy", "wb") as f:
# np.save(f, map._y)

# Get rotational phases.
theta = np.linspace(-180, 180, map.nt, endpoint=False)
Expand Down

0 comments on commit 243144c

Please sign in to comment.