Skip to content

Commit

Permalink
Working tests
Browse files Browse the repository at this point in the history
  • Loading branch information
robertdstein committed Jun 14, 2024
1 parent 82472a4 commit 21fb572
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 34 deletions.
27 changes: 23 additions & 4 deletions tests/test_nlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from astropy.io import fits
import numpy as np

from winternlc.config import EXAMPLE_IMG_PATH, EXAMPLE_CORRECTED_IMG_PATH, corrections_dir
from winternlc.config import EXAMPLE_IMG_PATH, EXAMPLE_CORRECTED_IMG_PATH, EXAMPLE_MASKED_IMG_PATH
from winternlc.non_linear_correction import nlc_single
from winternlc.mask import mask_single


logger = logging.getLogger(__name__)
Expand All @@ -33,9 +34,7 @@ def test_nlc_correction(self):
image = hdul[ext].data
board_id = header.get("BOARD_ID", None)
print(f"Processing extension {ext} with BOARD_ID {board_id}")
corrected_image = nlc_single(
image, board_id, corrections_dir
)
corrected_image = nlc_single(image, board_id)

comparison_image = hdul_corrected[ext].data

Expand All @@ -50,3 +49,23 @@ def test_nlc_correction(self):
self.assertAlmostEqual(float(np.nanmean(ratio)), 1., delta=0.001)
self.assertAlmostEqual(float(np.nanmedian(ratio)), 1., delta=0.001)
self.assertAlmostEqual(float(np.nanstd(ratio)), 0., delta=0.01)

def test_mask(self):
"""
Test mask application on test image
"""

logger.info("Testing mask")
with (fits.open(EXAMPLE_IMG_PATH) as hdul, fits.open(EXAMPLE_MASKED_IMG_PATH) as hdul_corrected):
for ext in range(1, len(hdul)):
header = hdul[ext].header
image = hdul[ext].data
board_id = header.get("BOARD_ID", None)
print(f"Processing extension {ext} with BOARD_ID {board_id}")
corrected_image = mask_single(
image, board_id
)

comparison_image = hdul_corrected[ext].data

self.assertTrue(np.allclose(corrected_image, comparison_image, equal_nan=True))
1 change: 1 addition & 0 deletions winternlc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

EXAMPLE_IMG_PATH = example_data_dir / "example_science_image_mef.fits"
EXAMPLE_CORRECTED_IMG_PATH = example_data_dir / "corrected_example_science_image_mef.fits"
EXAMPLE_MASKED_IMG_PATH = example_data_dir / "masked_example_science_image_mef.fits"
_corrections_dir = os.getenv("WINTERNLC_DIR")
if _corrections_dir is None:
corrections_dir = Path.home() / "Data/winternlc/"
Expand Down
10 changes: 3 additions & 7 deletions winternlc/make_generic_rational_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
from scipy.optimize import curve_fit

from winternlc.config import DEFAULT_CUTOFF, output_directory, test_directory
Expand Down Expand Up @@ -188,11 +187,11 @@ def save_rational_coefficients(
)
else:
np.save(
str(get_coeffs_path(output_dir, board_id)),
str(get_coeffs_path(board_id, output_dir)),
rat_coeffs,
)
np.save(
str(get_mask_path(output_dir, board_id)),
str(get_mask_path(board_id, output_dir)),
bad_pix,
)

Expand Down Expand Up @@ -301,10 +300,7 @@ def load_and_plot_rational(
f"rat_coeffs_board_{board_ids_by_extension[ext]}_ext_{ext}_test.npy",
)
else:
rat_coeffs_path = os.path.join(
output_dir,
f"rat_coeffs_board_{board_ids_by_extension[ext]}_ext_{ext}.npy",
)
rat_coeffs_path = get_coeffs_path(board_id=board_ids_by_extension[ext], cor_dir=output_dir)

if os.path.exists(rat_coeffs_path):
rat_coeffs = np.load(rat_coeffs_path)
Expand Down
4 changes: 2 additions & 2 deletions winternlc/make_rationial8_multithread_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ def save_rational_coefficients(
)
else:
np.save(
str(get_coeffs_path(output_dir, board_id)),
str(get_coeffs_path(board_id, output_dir)),
rat_coeffs,
)
np.save(
str(get_mask_path(output_dir, board_id)),
str(get_mask_path(board_id, output_dir)),
bad_pixel_mask,
)

Expand Down
18 changes: 11 additions & 7 deletions winternlc/mask.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
from pathlib import Path
from winternlc.config import corrections_dir

import numpy as np


def get_mask_path(cor_dir: Path, board_id: int) -> Path:
def get_mask_path(board_id: int, cor_dir: Path = corrections_dir) -> Path:
"""
Returns the path to the rational coefficients file for a given board ID.
:param cor_dir: Directory containing the correction files
:param board_id: Board ID
:param cor_dir: Directory containing the correction files
:return: Path to the rational coefficients file
"""
return cor_dir / f"bad_pixel_mask_board_{board_id}.npy"


def load_mask(cor_dir: Path | str, board_id: int) -> np.ndarray:
def load_mask(board_id: int, cor_dir: Path | str, ) -> np.ndarray:
"""
Loads the rational coefficients for a given board ID.
:param cor_dir: Directory containing the correction files
:param board_id: Board ID
:param cor_dir: Directory containing the correction files
:return: Rational coefficients
"""
mask_path = get_mask_path(cor_dir, board_id)
mask_path = get_mask_path(board_id=board_id, cor_dir=cor_dir)

if not mask_path.exists():
raise FileNotFoundError(
Expand All @@ -37,14 +40,15 @@ def apply_mask(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
Applies a bad pixel mask to an image.
:param image: Image to mask
:param mask: Bad pixel mask (boolean array)
:return: Masked image (masked pixels set to NaN)
"""
image[mask] = np.nan # Set bad pixels to NaN
return image


def mask_single(image: np.ndarray, board_id: int, cor_dir: str) -> np.ndarray:
def mask_single(image: np.ndarray, board_id: int, cor_dir: str | Path = corrections_dir) -> np.ndarray:
"""
Applies a bad pixel mask to an image.
Expand All @@ -54,5 +58,5 @@ def mask_single(image: np.ndarray, board_id: int, cor_dir: str) -> np.ndarray:
:return: Masked image (masked pixels set to NaN)
"""
mask = load_mask(cor_dir, board_id)
mask = load_mask(board_id=board_id, cor_dir=cor_dir)
return apply_mask(image, mask)
30 changes: 16 additions & 14 deletions winternlc/non_linear_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,32 @@

import numpy as np

from winternlc.config import DEFAULT_CUTOFF
from winternlc.config import DEFAULT_CUTOFF, corrections_dir
from winternlc.rational import rational_func


def get_coeffs_path(cor_dir: Path, board_id: int) -> Path:
def get_coeffs_path(board_id: int, cor_dir: Path = corrections_dir) -> Path:
"""
Returns the path to the rational coefficients file for a given board ID.
:param cor_dir: Directory containing the correction files
:param board_id: Board ID
:param cor_dir: Directory containing the correction files
:return: Path to the rational coefficients file
"""
return cor_dir / f"rat_coeffs_board_{board_id}.npy"


def load_rational_coeffs(cor_dir: Path | str, board_id: int) -> np.ndarray:
def load_rational_coeffs(board_id: int, cor_dir: Path | str = corrections_dir) -> np.ndarray:
"""
Loads the rational coefficients for a given board ID.
:param cor_dir: Directory containing the correction files
:param board_id: Board ID
:param cor_dir: Directory containing the correction files
:return: Rational coefficients
"""
rat_coeffs_path = get_coeffs_path(cor_dir, board_id)
rat_coeffs_path = get_coeffs_path(board_id=board_id, cor_dir=cor_dir)

if not rat_coeffs_path.exists():
raise FileNotFoundError(
Expand All @@ -41,13 +43,13 @@ def load_rational_coeffs(cor_dir: Path | str, board_id: int) -> np.ndarray:


def apply_nonlinearity_correction(
image: np.ndarray, coeffs: np.ndarray, cutoff: float = DEFAULT_CUTOFF
image: np.ndarray, coefficients: np.ndarray, cutoff: float = DEFAULT_CUTOFF
) -> np.ndarray:
"""
Applies nonlinearity correction to an image using precomputed rational coefficients.
Applies non-linearity correction to an image using precomputed rational coefficients.
:param image: Image to correct
:param coeffs: Rational coefficients for the correction
:param coefficients: Rational coefficients for the correction
:param cutoff: Cutoff value for the image
"""

Expand All @@ -58,19 +60,19 @@ def apply_nonlinearity_correction(
image = image / cutoff

# Vectorized application of the fitted function
coeffs = coeffs.reshape(-1, 8)
image = rational_func(image.flatten(), *coeffs.T).reshape(image.shape)
coefficients = coefficients.reshape(-1, 8)
image = rational_func(image.flatten(), *coefficients.T).reshape(image.shape)

# Scale back by cutoff
image = cutoff * image
return image


def nlc_single(
image: np.ndarray, board_id: int, cor_dir: str, cutoff: float = DEFAULT_CUTOFF
image: np.ndarray, board_id: int, cor_dir: str | Path = corrections_dir, cutoff: float = DEFAULT_CUTOFF
) -> np.ndarray:
"""
Applies nonlinearity correction to an image using precomputed rational coefficients.
Applies non-linearity correction to an image using precomputed rational coefficients.
:param image: Image to correct
:param board_id: Board ID of the image
Expand All @@ -79,5 +81,5 @@ def nlc_single(
:return: Corrected image
"""
rat_coeffs = load_rational_coeffs(cor_dir, board_id)
rat_coeffs = load_rational_coeffs(board_id=board_id, cor_dir=cor_dir)
return apply_nonlinearity_correction(image, rat_coeffs, cutoff)

0 comments on commit 21fb572

Please sign in to comment.