Skip to content

Commit

Permalink
Fix cli
Browse files Browse the repository at this point in the history
  • Loading branch information
robertdstein committed Jun 14, 2024
1 parent 21fb572 commit 00c517e
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 38 deletions.
34 changes: 33 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,38 @@ pre-commit install
Set the directory

## Get Started

You can use winternlc directly from the command line.

```bash
winternlc-apply /path/to/data.fits
```
winter_corrections/example.py

This will apply the nonlinearity correction to the data and save the corrected data to a new file.

You can also run the correction on multiple files at once.

```bash
winternlc-apply /path/to/data1.fits /path/to/data2.fits
```

Alternatively, you can specify a directory and all the files in the directory will be corrected.

```bash
winternlc-apply /path/to/directory
```

In all cases, you can also specify the output directory.

```bash
winternlc-apply /path/to/data.fits --output-dir /path/to/output
```

If you do not specify an output directory, the corrected files will be saved in the same directory as the input files.

See the help message for more information.

```bash
winternlc --help
```

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ dev = [
[project.urls]
Homepage = "https://github.com/winter-telescope/winternlc"

[project.scripts]
winternlc-apply = "winternlc.apply:nlc_cli"

[tool.setuptools]
packages = ["winternlc"]

Expand Down
93 changes: 57 additions & 36 deletions winternlc/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,27 @@
from pathlib import Path

from astropy.io import fits

from winternlc.config import (
corrections_dir,
DEFAULT_CUTOFF,
EXAMPLE_IMG_PATH,
example_data_dir,
)
import logging
from winternlc.mask import mask_single
from winternlc.non_linear_correction import nlc_single
from winternlc.config import DEFAULT_CUTOFF, corrections_dir
import argparse

logger = logging.getLogger(__name__)


def apply_nlc_mef(
fits_file: str | Path, cor_dir: str | Path, save_dir: str | Path, cutoff: float
fits_file: Path, save_dir: Path | str | None = None, cor_dir: Path = corrections_dir, cutoff: float = DEFAULT_CUTOFF, output_suffix: str | None = "corrected_"
):
"""
Process a multi-extension FITS file, applying nonlinearity correction to each
extension, and write the corrected FITS file to disk.
:param fits_file: Path to the FITS file
:param cor_dir: Directory containing the correction files
:param save_dir: Directory to save the corrected FITS file
:param cutoff: Cutoff value for the image
:param cor_dir: Directory containing the correction files
:param cutoff: Cutoff value for the correction
:param output_suffix: Suffix to append to the output file name
:return: None
"""
Expand All @@ -33,19 +32,13 @@ def apply_nlc_mef(
header = hdul[ext].header
image = hdul[ext].data
board_id = header.get("BOARD_ID", None)
if board_id is not None:
print(f"Processing extension {ext} with BOARD_ID {board_id}")
start = time.time()
corrected_image = nlc_single(image, board_id, cor_dir, cutoff)
end = time.time()
print(f"took {end-start} s to execute")
hdul[ext].data = corrected_image
else:
print(f"Skipping extension {ext} as it does not have a BOARD_ID")
corrected_image = nlc_single(image, board_id, cor_dir, cutoff)
hdul[ext].data = corrected_image

corrected_fits_file = os.path.join(
save_dir, "corrected_" + os.path.basename(fits_file)
)
if save_dir is None:
save_dir = Path(fits_file).parent

corrected_fits_file = save_dir / f"{output_suffix}{fits_file.name}"
hdul.writeto(corrected_fits_file, overwrite=True)
print(f"Corrected FITS file saved to {corrected_fits_file}")

Expand All @@ -66,25 +59,53 @@ def apply_mask_mef(fits_file: str | Path, cor_dir: str | Path, save_dir: str | P
header = hdul[ext].header
image = hdul[ext].data
board_id = header.get("BOARD_ID", None)
if board_id is not None:
print(f"Masking extension {ext} with BOARD_ID {board_id}")
start = time.time()
corrected_image = mask_single(image, board_id, cor_dir)
end = time.time()
print(f"took {end-start} s to execute")
hdul[ext].data = corrected_image
else:
print(f"Skipping extension {ext} as it does not have a BOARD_ID")
corrected_image = mask_single(image, board_id, cor_dir)
hdul[ext].data = corrected_image

corrected_fits_file = os.path.join(
save_dir, "masked_" + os.path.basename(fits_file)
)
hdul.writeto(corrected_fits_file, overwrite=True)
print(f"Corrected FITS file saved to {corrected_fits_file}")
logger.info(f"Corrected FITS file saved to {corrected_fits_file}")


if __name__ == "__main__":
apply_nlc_mef(
EXAMPLE_IMG_PATH, corrections_dir, example_data_dir, DEFAULT_CUTOFF
def nlc_cli():
"""
Command-line interface for applying non-linearity
correction to multi-extension FITS file(s)
"""
parser = argparse.ArgumentParser(
description="Apply non-linearity correction to multi-extension FITS file"
)
apply_mask_mef(EXAMPLE_IMG_PATH, corrections_dir, example_data_dir)
parser.add_argument(
"-o", "--output_dir", default=None,
type=str,
help="Directory to save the corrected FITS file(s) "
"(default: same as input file)",
)
parser.add_argument("files", nargs="+", help="FITS file(s) to correct")

args = parser.parse_args()

logger.info("Applying non-linearity correction to multi-extension FITS file")

file_paths = []

for f_name in args.files:
path = Path(f_name)
if not path.exists():
raise FileNotFoundError(f"File not found at {path}")
elif path.is_dir():
file_paths.extend(list(path.glob("*.fits")))
else:
file_paths.append(path)

for path in file_paths:
apply_nlc_mef(
fits_file=path,
save_dir=args.output_dir,
)


if __name__ == "__main__":
nlc_cli()
2 changes: 1 addition & 1 deletion winternlc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_corrections_dir = os.getenv("WINTERNLC_DIR")
if _corrections_dir is None:
corrections_dir = Path.home() / "Data/winternlc/"
logger.warning(f"No data directory set, using {corrections_dir}")
logger.warning(f"No correction data directory set, using {corrections_dir}")
else:
corrections_dir = Path(_corrections_dir)

Expand Down

0 comments on commit 00c517e

Please sign in to comment.