Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
virajkaram committed Jul 3, 2023
1 parent 3024bc2 commit d048072
Showing 1 changed file with 54 additions and 230 deletions.
284 changes: 54 additions & 230 deletions mirar/processors/photcal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,18 @@
from collections.abc import Callable
from pathlib import Path

import astropy.units as u
import numpy as np
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.stats import sigma_clip, sigma_clipped_stats
from astropy.table import Table

from mirar.catalog.base_catalog import BaseCatalog
from mirar.data import Image, ImageBatch
from mirar.errors import ProcessorError
from mirar.paths import BASE_NAME_KEY, copy_temp_file, get_output_dir, get_output_path
from mirar.processors.astromatic.sextractor.sextractor import (
SEXTRACTOR_HEADER_KEY,
Sextractor,
sextractor_checkimg_map,
)
from mirar.processors.base_processor import BaseImageProcessor, PrerequisiteError
from mirar.processors.candidates.utils.regions_writer import write_regions_file
from mirar.utils.ldac_tools import get_table_from_ldac
from mirar.paths import get_output_dir
from mirar.processors.astromatic.sextractor.sextractor import sextractor_checkimg_map
from mirar.processors.astrometry.validate import get_fwhm
from mirar.processors.base_catalog_xmatch_processor import BaseProcessorWithCrossMatch

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,7 +77,34 @@ def default_photometric_img_catalog_purifier(catalog: Table, image: Image) -> Ta
return catalog[clean_mask]


class PhotCalibrator(BaseImageProcessor):
def get_maglim(
bkg_rms_image_path: str | Path,
zeropoint: float | list[float],
aperture_radius_pixels: float | list[float],
) -> float:
"""
Function to calculate limiting magnitude
Args:
bkg_rms_image_path:
zeropoint:
aperture_radius_pixels:
Returns:
"""
if isinstance(zeropoint, float):
zeropoint = [zeropoint]
if isinstance(aperture_radius_pixels, float):
aperture_radius_pixels = [aperture_radius_pixels]

zeropoint = np.array(zeropoint, dtype=float)
aperture_radius_pixels = np.array(aperture_radius_pixels, dtype=float)
bkg_rms_image = fits.getdata(bkg_rms_image_path)
bkg_rms_med = np.nanmedian(bkg_rms_image)
noise = bkg_rms_med * np.sqrt(np.pi * aperture_radius_pixels)
maglim = -2.5 * np.log10(5 * noise) + zeropoint
return maglim


class PhotCalibrator(BaseProcessorWithCrossMatch):
"""
Photometric calibrator processor
"""
Expand All @@ -95,23 +115,24 @@ def __init__(
self,
ref_catalog_generator: Callable[[Image], BaseCatalog],
temp_output_sub_dir: str = "phot",
redo: bool = True,
image_photometric_catalog_purifier: Callable[
[Table, Image], Table
] = default_photometric_img_catalog_purifier,
num_matches_threshold: int = 5,
crossmatch_radius_arcsec: float = 1.0,
write_regions: bool = False,
cache: bool = False,
):
super().__init__()
self.redo = redo # What is this for?
self.ref_catalog_generator = ref_catalog_generator
self.temp_output_sub_dir = temp_output_sub_dir
self.image_photometric_catalog_purifier = image_photometric_catalog_purifier
self.cache = cache

super().__init__(
ref_catalog_generator=ref_catalog_generator,
temp_output_sub_dir=temp_output_sub_dir,
crossmatch_radius_arcsec=crossmatch_radius_arcsec,
sextractor_catalog_purifier=image_photometric_catalog_purifier,
write_regions=write_regions,
cache=cache,
required_parameters=REQUIRED_PARAMETERS,
)
self.num_matches_threshold = num_matches_threshold
self.write_regions = write_regions

def __str__(self) -> str:
return "Processor to perform photometric calibration."
Expand All @@ -135,18 +156,12 @@ def calculate_zeropoint(
clean_img_cat: Catalog of sources from image to xmatch with ref_cat
Returns:
"""
ref_coords = SkyCoord(ra=ref_cat["ra"], dec=ref_cat["dec"], unit=(u.deg, u.deg))

clean_img_coords = SkyCoord(
ra=clean_img_cat["ALPHAWIN_J2000"],
dec=clean_img_cat["DELTAWIN_J2000"],
unit=(u.deg, u.deg),
matched_img_cat, matched_ref_cat, _ = self.xmatch_catalogs(
ref_cat=ref_cat,
image_cat=clean_img_cat,
crossmatch_radius_arcsec=self.crossmatch_radius_arcsec,
)

idx, d2d, _ = ref_coords.match_to_catalog_sky(clean_img_coords)
match_mask = d2d < 1.0 * u.arcsec
matched_ref_cat = ref_cat[match_mask]
matched_img_cat = clean_img_cat[idx[match_mask]]
logger.info(
f"Cross-matched {len(matched_img_cat)} sources from catalog to the image."
)
Expand Down Expand Up @@ -227,89 +242,25 @@ def _apply_to_images(
pass

for image in batch:
ref_catalog = self.ref_catalog_generator(image)
ref_cat_path = ref_catalog.write_catalog(image, output_dir=phot_output_dir)
temp_cat_path = copy_temp_file(
output_dir=phot_output_dir, file_path=image[SEXTRACTOR_HEADER_KEY]
)

temp_files = [temp_cat_path]
fwhm_med, _, fwhm_std, med_fwhm_pix, _, _ = self.get_fwhm(temp_cat_path)
image["FWHM_MED"] = fwhm_med
image["FWHM_STD"] = fwhm_std
ref_cat, _, cleaned_img_cat = self.setup_catalogs(image)

ref_cat = get_table_from_ldac(ref_cat_path)
img_cat = get_table_from_ldac(temp_cat_path)
_, _, _, med_fwhm_pix, _, _ = get_fwhm(cleaned_img_cat)

if len(ref_cat) == 0:
err = "No sources found in reference catalog"
logger.error(err)
raise PhotometryReferenceError(err)

clean_img_cat = self.image_photometric_catalog_purifier(img_cat, image)
logger.debug(f"Found {len(clean_img_cat)} clean sources in image.")
logger.debug(f"Found {len(cleaned_img_cat)} clean sources in image.")

if len(clean_img_cat) == 0:
if len(cleaned_img_cat) == 0:
err = "No clean sources found in image"
logger.error(err)
raise PhotometrySourceError(err)

if self.write_regions:
ref_coords = SkyCoord(
ra=ref_cat["ra"], dec=ref_cat["dec"], unit=(u.deg, u.deg)
)

img_coords = SkyCoord(
ra=img_cat["ALPHAWIN_J2000"],
dec=img_cat["DELTAWIN_J2000"],
unit=(u.deg, u.deg),
)

clean_img_coords = SkyCoord(
ra=clean_img_cat["ALPHAWIN_J2000"],
dec=clean_img_cat["DELTAWIN_J2000"],
unit=(u.deg, u.deg),
)

ref_regions_path = get_output_path(
base_name=image.header[BASE_NAME_KEY] + "ref.reg",
dir_root=self.temp_output_sub_dir,
sub_dir=self.night_sub_dir,
)
cleaned_img_regions_path = get_output_path(
base_name=image.header[BASE_NAME_KEY] + "cleaned_img.reg",
dir_root=self.temp_output_sub_dir,
sub_dir=self.night_sub_dir,
)
img_regions_path = get_output_path(
base_name=image.header[BASE_NAME_KEY] + "img.reg",
dir_root=self.temp_output_sub_dir,
sub_dir=self.night_sub_dir,
)

write_regions_file(
regions_path=ref_regions_path,
x_coords=ref_coords.ra.deg,
y_coords=ref_coords.dec.deg,
system="wcs",
region_radius=2.0 / 3600,
)
write_regions_file(
regions_path=cleaned_img_regions_path,
x_coords=clean_img_coords.ra.deg,
y_coords=clean_img_coords.dec.deg,
system="wcs",
region_radius=2.0 / 3600,
)
write_regions_file(
regions_path=img_regions_path,
x_coords=img_coords.ra.deg,
y_coords=img_coords.dec.deg,
system="wcs",
region_radius=2.0 / 3600,
)

zp_dicts = self.calculate_zeropoint(ref_cat, clean_img_cat)
zp_dicts = self.calculate_zeropoint(
ref_cat=ref_cat, clean_img_cat=cleaned_img_cat
)

aperture_diameters = []
zp_values = []
Expand All @@ -328,7 +279,7 @@ def _apply_to_images(

if sextractor_checkimg_map["BACKGROUND_RMS"] in image.header.keys():
logger.info("Calculating limiting magnitudes from background RMS file")
limmags = self.get_maglim(
limmags = get_maglim(
image[sextractor_checkimg_map["BACKGROUND_RMS"]],
zp_values,
np.array(aperture_diameters) / 2.0,
Expand All @@ -340,131 +291,4 @@ def _apply_to_images(
image[f"MAGLIM_{int(diam)}"] = limmags[ind]
image["MAGLIM"] = limmags[-1]

if not self.cache:
for temp_file in temp_files:
temp_file.unlink()
logger.debug(f"Deleted temporary file {temp_file}")

return batch

@staticmethod
def get_fwhm(img_cat_path):
"""
Calculate median FWHM from a ldac path
Args:
img_cat_path:
Returns:
"""
imcat = get_table_from_ldac(img_cat_path)
# TODO: de-hardcode
nemask = (
(imcat["X_IMAGE"] > 50)
& (imcat["X_IMAGE"] < 2000)
& (imcat["Y_IMAGE"] > 50)
& (imcat["Y_IMAGE"] < 2000)
)
imcat = imcat[nemask]
med_fwhm = np.median(imcat["FWHM_WORLD"])
mean_fwhm = np.mean(imcat["FWHM_WORLD"])
std_fwhm = np.std(imcat["FWHM_WORLD"])

med_fwhm_pix = np.median(imcat["FWHM_IMAGE"])
mean_fwhm_pix = np.mean(imcat["FWHM_IMAGE"])
std_fwhm_pix = np.std(imcat["FWHM_IMAGE"])
return med_fwhm, mean_fwhm, std_fwhm, med_fwhm_pix, mean_fwhm_pix, std_fwhm_pix

@staticmethod
def get_maglim(
bkg_rms_image_path: str | Path,
zeropoint: float | list[float],
aperture_radius_pixels: float | list[float],
) -> float:
"""
Function to calculate limiting magnitude
Args:
bkg_rms_image_path:
zeropoint:
aperture_radius_pixels:
Returns:
"""
if isinstance(zeropoint, float):
zeropoint = [zeropoint]
if isinstance(aperture_radius_pixels, float):
aperture_radius_pixels = [aperture_radius_pixels]

zeropoint = np.array(zeropoint, dtype=float)
aperture_radius_pixels = np.array(aperture_radius_pixels, dtype=float)
bkg_rms_image = fits.getdata(bkg_rms_image_path)
bkg_rms_med = np.nanmedian(bkg_rms_image)
noise = bkg_rms_med * np.sqrt(np.pi * aperture_radius_pixels)
maglim = -2.5 * np.log10(5 * noise) + zeropoint
return maglim

def get_sextractor_module(self) -> Sextractor:
"""
Get the Sextractor module from the preceding steps
"""
mask = [isinstance(x, Sextractor) for x in self.preceding_steps]
return np.array(self.preceding_steps)[mask][-1]

def check_prerequisites(
self,
):
mask = [isinstance(x, Sextractor) for x in self.preceding_steps]
if np.sum(mask) < 1:
err = (
f"{self.__module__} requires {Sextractor} as a prerequisite. "
f"However, the following steps were found: {self.preceding_steps}."
)
logger.error(err)
raise PrerequisiteError(err)

sextractor_param_path = self.get_sextractor_module().parameters_name

logger.debug(f"Checking file {sextractor_param_path}")

with open(sextractor_param_path, "rb") as param_file:
sextractor_params = [
x.strip().decode() for x in param_file.readlines() if len(x.strip()) > 0
]
sextractor_params = [
x.split("(")[0] for x in sextractor_params if x[0] not in ["#"]
]

for param in REQUIRED_PARAMETERS:
if param not in sextractor_params:
err = (
f"Missing parameter: {self.__module__} requires {param} to run, "
f"but this parameter was not found in sextractor config file "
f"'{sextractor_param_path}' . "
f"Please add the parameter to this list!"
)
logger.error(err)
raise PrerequisiteError(err)

def get_sextractor_apertures(self) -> list[float]:
"""
Function to extract sextractor aperture sizes from config file
Returns:
"""
sextractor_config_path = self.get_sextractor_module().config

with open(sextractor_config_path, "rb") as sextractor_config_file:
aperture_lines = [
x.decode()
for x in sextractor_config_file.readlines()
if np.logical_and(b"PHOT_APERTURES" in x, x.decode()[0] != "#")
]

if len(aperture_lines) > 1:
err = (
f"The config file {sextractor_config_path} has "
f"multiple entries for PHOT_APERTURES."
)
logger.error(err)
raise ProcessorError(err)

line = aperture_lines[0].replace("PHOT_APERTURES", " ").split("#")[0]

return [float(x) for x in line.split(",") if x not in [""]]

0 comments on commit d048072

Please sign in to comment.