diff --git a/python/jaxodi/doppler_inference.py b/python/jaxodi/doppler_inference.py index e6a05e4..8034cb4 100644 --- a/python/jaxodi/doppler_inference.py +++ b/python/jaxodi/doppler_inference.py @@ -15,6 +15,16 @@ @jax.jit def cho_solve(A: Array, b: Array) -> Array: + """ + Solve the linear system A x = b using the Cholesky decomposition of A. + + Args: + A (Array): Lower triangular Cholesky factor of the matrix. + b (Array): Right-hand side of the equation. + + Returns: + Array: Solution to the linear system. + """ b_ = jax.scipy.linalg.solve_triangular(A, b, lower=True) return jax.scipy.linalg.solve_triangular(jnp.transpose(A), b_, lower=False) @@ -32,19 +42,18 @@ def map_solve( spherical harmonic coefficients of a map given a flux timeseries. Args: - X (matrix): The flux design matrix. - flux (array): The flux timeseries. - cho_C (scalar/vector/matrix): The lower cholesky factorization + X (Array): The flux design matrix. + flux (Array): The flux timeseries. + cho_C (float | Array): The lower cholesky factorization of the data covariance. - mu (array): The prior mean of the spherical harmonic coefficients. - LInv (scalar/vector/matrix): The inverse prior covariance of the + mu (Array): The prior mean of the spherical harmonic coefficients. + LInv (float | Array): The inverse prior covariance of the spherical harmonic coefficients. Returns: - The vector of spherical harmonic coefficients corresponding to the - MAP solution and the Cholesky factorization of the corresponding - covariance matrix. - + Tuple[Array, Array]: The vector of spherical harmonic coefficients + corresponding to the MAP solution and the Cholesky factorization + of the corresponding covariance matrix. """ # Compute C^-1 . X if cho_C.ndim == 0: @@ -98,6 +107,32 @@ def process_inputs( logTf=None, nlogT=None, ): + """ + Process and normalize inputs for Doppler inference. + + Args: + flux (Array): The observed spectral timeseries. + nt (int): Number of time points. + nw (int): Number of wavelength points. + nc (int): Number of components. + Ny (int): Number of spatial points. + nw0 (int): Initial number of wavelength points. + nw0_ (int): Final number of wavelength points. + S0e2i (Array): Transformation matrix. + flux_err (float, optional): Flux error. Defaults to 1e-4. + normalized (bool, optional): Whether the flux is normalized. Defaults to True. + baseline (optional): Baseline value. Defaults to None. + spatial_mean (optional): Mean of the spatial component. Defaults to None. + spatial_cov (optional): Covariance of the spatial component. Defaults to None. + spectral_mean (optional): Mean of the spectral component. Defaults to None. + spectral_cov (optional): Covariance of the spectral component. Defaults to None. + logT0 (optional): Logarithmic starting temperature. Defaults to None. + logTf (optional): Logarithmic final temperature. Defaults to None. + nlogT (optional): Number of logarithmic temperature steps. Defaults to None. + + Returns: + dict: Processed input values. + """ # Process defaults if flux_err is None: @@ -257,7 +292,25 @@ def get_D_fixed_spectrum( xamp, vsini, ydeg, udeg, nk, inc, theta, spectrum, nc, nwp, nt, Ny, nw, ): """ - Return the Doppler matrix for a fixed spectrum. + Return the Doppler convolution matrix for a fixed spectrum. + + Args: + xamp: Amplitude of the Doppler kernel. + vsini: Rotational velocity. + ydeg: Degree of the Y polynomial. + udeg: Degree of the U polynomial. + nk: Number of Doppler kernel points. + inc: Inclination angle. + theta: Parameters for the Doppler kernel. + spectrum: Spectral data. + nc: Number of components. + nwp: Number of wavelength points in the spectrum. + nt: Number of time points. + Ny: Number of spatial points. + nw: Number of wavelength points in the observed data. + + Returns: + Array: Convolution matrix. """ # Get the convolution kernels kT = get_kT(xamp, vsini, ydeg, udeg, nk, inc, theta) @@ -287,8 +340,17 @@ def get_D_fixed_spectrum( def sparse_dot(A, B): """ Performs matrix multiplication, optimising computation time by utilising sparse matrices. - """ + Args: + A (Array): First matrix. + B (Array): Second matrix. + + Returns: + Array: Result of the matrix multiplication. + + Raises: + ValueError: If neither input matrix is sparse. + """ def is_sparse(dense_matrix, threshold=0.9): flattened_matrix = dense_matrix.flatten() num_zeros = jnp.sum(flattened_matrix == 0) @@ -310,7 +372,22 @@ def get_default_theta( theta: Array, _angle_factor: float, ) -> Array: + """ + Scale the parameters of the Doppler kernel by the provided angle factor. + + The angle factor is used to convert between different units of angle + measurement, such as degrees to radians. The value of ``_angle_factor`` + determines how the input ``theta`` is scaled. + Args: + theta (Array): Parameters of the Doppler kernel. + _angle_factor (float): Factor to convert angle units. For example, + use ``np.pi / 180`` to convert degrees to radians, or + ``180 / np.pi`` to convert radians to degrees. + + Returns: + Array: Scaled parameters. + """ return theta * _angle_factor @@ -324,8 +401,29 @@ def design_matrix( """ Return the Doppler imaging design matrix. - This matrix dots into the spectral map to yield the model for the - observed spectral timeseries (the ``flux``). + Args: + theta (Array): The angular phase(s) at which to compute the + design matrix. This must be a vector of size :py:attr:`nt`. + _angle_factor (float): Scaling factor for the angle. + xamp: Amplitude of the Doppler kernel. + vsini (float): Rotational velocity. + ydeg (int): Degree of the Y polynomial. + udeg (int): Degree of the U polynomial. + nk (int): Number of Doppler kernel points. + inc (float): Inclination angle. + spectrum (Array): Spectral data. + nc (int): Number of components. + nwp (int): Number of wavelength points in the spectrum. + nt (int): Number of time points. + Ny (int): Number of spatial points. + nw (int): Number of wavelength points in the observed data. + _interp (bool): Whether to interpolate the matrix. + _Si2eBlk (Array): Transformation matrix for interpolation. + fix_spectrum (bool, optional): Whether to fix the spectrum. Defaults to True. + fix_map (bool, optional): Whether to fix the map. Defaults to False. + + Returns: + Array: The design matrix for the inference problem. """ theta = get_default_theta(theta, _angle_factor) # this is just undoing what get_S() did! @@ -348,7 +446,31 @@ def _get_S( theta, _angle_factor, xamp, vsini, ydeg, udeg, nk, inc, spectrum, nc, nwp, nt, Ny, nw, _interp, _Si2eBlk, fix_spectrum, ): + """ + Compute the design matrix for the Doppler inference, adjusted by the scaling factor. + Args: + theta (Array): Doppler kernel parameters. + _angle_factor (float): Scaling factor for the angle. + xamp: Amplitude of the Doppler kernel. + vsini (float): Rotational velocity. + ydeg (int): Degree of the Y polynomial. + udeg (int): Degree of the U polynomial. + nk (int): Number of Doppler kernel points. + inc (float): Inclination angle. + spectrum (Array): Spectral data. + nc (int): Number of components. + nwp (int): Number of wavelength points in the spectrum. + nt (int): Number of time points. + Ny (int): Number of spatial points. + nw (int): Number of wavelength points in the observed data. + _interp (bool): Whether to interpolate the matrix. + _Si2eBlk (Array): Transformation matrix for interpolation. + fix_spectrum (bool): Whether to fix the spectrum. + + Returns: + Array: The design matrix adjusted by the scaling factor. + """ theta = theta / _angle_factor dm = design_matrix( @@ -386,7 +508,36 @@ def solve_for_map_linear( fix_spectrum: bool, ) -> tuple[Array, Array]: """ - Solve for `y` linearly, given a baseline or unnormalized data. + Solve the Doppler inference problem for a linear model, given a baseline + of unnormalized data. + + Args: + spatial_mean (Array): Mean vector of the spatial component. + spatial_inv_cov (Array): Inverse covariance matrix of the spatial component. + flux_err (float): Flux error. + nt (int): Number of time points. + nw (int): Number of wavelength points. + nw_ (int): Number of wavelength points for the error model. + T (Array): Temperature parameter. + flux (Array): The observed spectral timeseries. + theta (Array): Doppler kernel parameters. + _angle_factor (float): Scaling factor for the angle. + xamp: Amplitude of the Doppler kernel. + vsini (float): Rotational velocity. + ydeg (int): Degree of the Y polynomial. + udeg (int): Degree of the U polynomial. + nk (int): Number of Doppler kernel points. + inc (float): Inclination angle. + spectrum (Array): Spectral data. + nc (int): Number of components. + nwp (int): Number of wavelength points in the spectrum. + Ny (int): Number of spatial points. + _interp (bool): Whether to interpolate the matrix. + _Si2eBlk (Array): Transformation matrix for interpolation. + fix_spectrum (bool): Whether to fix the spectrum. + + Returns: + Tuple[Array, Array]: Tuple containing the solution vector and its covariance matrix. """ # Reshape the priors mu = jnp.reshape(jnp.transpose(spatial_mean), (-1)) @@ -450,7 +601,40 @@ def solve_bilinear( _Si2eBlk: Array, fix_spectrum: bool, ) -> tuple[Array, Array]: + """ + Solve the Doppler inference problem for the spatial and/or spectral map + given a spectral timeseries, using a bilinear model approach. + Args: + flux (Array): The observed spectral timeseries. + nt (int): Number of time points. + nw (int): Number of wavelength points in the observed data. + nw_ (int): Number of wavelength points in the error model. + nc (int): Number of spatial components. + Ny (int): Number of spatial points. + nw0 (int): Initial number of wavelength points. + nw0_ (int): Final number of wavelength points. + S0e2i (Array): Precomputed inverse covariance or related matrix. + flux_err (float): Error associated with the flux data. + normalized (bool): Indicates if the flux is normalized. + theta (Array): Parameters of the Doppler kernel. + _angle_factor (float): Scaling factor for the angle. + xamp: Amplitude of the Doppler kernel. + vsini (float): Rotational velocity. + ydeg (int): Degree of the Y polynomial. + udeg (int): Degree of the U polynomial. + nk (int): Number of Doppler kernel points. + inc (float): Inclination angle. + spectrum: Spectral data. + nwp (int): Number of wavelength points in the spectrum. + _interp (bool): Indicates whether to interpolate the matrix. + _Si2eBlk (Array): Transformation matrix for interpolation. + fix_spectrum (bool): Whether to fix the spectrum in the model. + + Returns: + Tuple[Array, Array]: The first element is the solution array `y`, + and the second element is the covariance matrix `cho_ycov`. + """ # reset() - if have a class with self attributes. processed_inputs = process_inputs( @@ -504,10 +688,43 @@ def solve( solver: str="bilinear", ) -> tuple[Array, Array]: """ - Iteratively solves the bilinear or nonlinear problem for the spatial + Iteratively solves the Doppler inference problem for the spatial and/or spectral map given a spectral timeseries. - """ + Args: + flux (Array): The observed spectral timeseries. + nt (int): Number of time points. + nw (int): Number of wavelength points in the observed data. + nw_ (int): Number of wavelength points in the error model. + nc (int): Number of spatial components. + Ny (int): Number of spatial points. + nw0 (int): Initial number of wavelength points. + nw0_ (int): Final number of wavelength points. + S0e2i (Array): Precomputed inverse covariance or related matrix. + flux_err (float): The data uncertainty. + normalized (bool): Whether the ``flux`` dataset is + continuum-normalized. + theta (Array): The angular phase(s) at which the spectra + were observed. + _angle_factor (float): Scaling factor for the angle. + xamp: Amplitude of the Doppler kernel. + vsini (float): Rotational velocity. + ydeg (int): Degree of the Y polynomial. + udeg (int): Degree of the U polynomial. + nk (int): Number of Doppler kernel points. + inc (float): Inclination angle. + spectrum: Spectral data. + nwp (int): Number of wavelength points in the spectrum. + _interp (bool): Indicates whether to interpolate the matrix. + _Si2eBlk (Array): Transformation matrix for interpolation. + fix_spectrum (bool): If True, fixes the spectrum at the + current value and solves only for the map. + solver (str, optional): The solver method to use, default is "bilinear". + + Returns: + Tuple[Array, Array]: The first element is the solution array `y`, + and the second element is the covariance matrix `cho_ycov`. + """ # Used to calculate S. theta = get_default_theta(theta, _angle_factor) diff --git a/tests/spot_y.npy b/tests/spot_y.npy new file mode 100644 index 0000000..9cc389d Binary files /dev/null and b/tests/spot_y.npy differ diff --git a/tests/test_doppler_inference.py b/tests/test_doppler_inference.py index ec8dd8b..a99bffd 100644 --- a/tests/test_doppler_inference.py +++ b/tests/test_doppler_inference.py @@ -93,6 +93,13 @@ def saved_input_data(): # Load the component maps. map.load(spectra=spectrum, smoothing=0.075) +with open(f"{CWD}/tests/spot_y.npy", "rb") as f: + map._y = np.load(f) + +# image = str(CWD + "/tests/spot.png") +# map.load(maps=[image], spectra=spectrum, smoothing=0.075) +# with open(f"{CWD}/tests/spot_y.npy", "wb") as f: +# np.save(f, map._y) # Get rotational phases. theta = np.linspace(-180, 180, map.nt, endpoint=False)