Skip to content

Commit

Permalink
RB for WINTER (#893)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertdstein authored May 15, 2024
1 parent b640145 commit da49e55
Show file tree
Hide file tree
Showing 11 changed files with 673 additions and 64 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/continuous_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ jobs:
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- name: Print disk space
run: df -h

- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
with:
# this might remove tools that are actually needed,
# if set to "true" but frees about 6 GB
tool-cache: false

# all of these default to true, but feel free to set to
# "false" if necessary for your workflow
android: true
dotnet: true
haskell: true
large-packages: false
docker-images: true
swap-storage: true

- name: Print disk space
run: df -h

- uses: actions/checkout@v3

Expand Down Expand Up @@ -89,6 +110,9 @@ jobs:
make -C q3c
sudo make -C q3c install
- name: Print disk space
run: df -h

# First make sure the doc tests are up to date
- name: Run doc tests
shell: bash -el {0}
Expand Down
7 changes: 5 additions & 2 deletions mirar/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"The raw data directory will need to be specified manually for path function."
f"The raw directory is being set to {default_dir}."
)
logger.warning(warning)
logger.info(warning)
base_raw_dir: Path = default_dir
else:
base_raw_dir = Path(_base_raw_dir)
Expand All @@ -58,7 +58,7 @@
f"Run 'export OUTPUT_DATA_DIR=/path/to/data' to set this. "
f"The output directory is being set to {default_dir}."
)
logger.warning(warning)
logger.info(warning)
base_output_dir = default_dir
else:
base_output_dir = Path(_base_output_dir)
Expand All @@ -70,6 +70,9 @@
RAW_IMG_SUB_DIR = "raw"
CAL_OUTPUT_SUB_DIR = "calibration"

ml_models_dir = base_output_dir.joinpath("ml_models")
ml_models_dir.mkdir(exist_ok=True)


def raw_img_dir(
sub_dir: str = "", raw_dir: Path = base_raw_dir, img_sub_dir: str = RAW_IMG_SUB_DIR
Expand Down
19 changes: 18 additions & 1 deletion mirar/pipelines/winter/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# pylint: disable=duplicate-code
import os

from winterrb.model import WINTERNet

from mirar.catalog.kowalski import PS1, TMASS, Gaia, GaiaBright, PS1SGSc
from mirar.downloader.get_test_data import get_test_data_dir
from mirar.paths import (
Expand Down Expand Up @@ -39,6 +41,7 @@
)
from mirar.pipelines.winter.constants import NXSPLIT, NYSPLIT
from mirar.pipelines.winter.generator import (
apply_rb_to_table,
mask_stamps_around_bright_stars,
select_winter_sky_flat_images,
winter_anet_sextractor_config_path_generator,
Expand Down Expand Up @@ -127,10 +130,12 @@
CustomSourceTableModifier,
ForcedPhotometryDetector,
SourceBatcher,
SourceDebatcher,
SourceLoader,
SourceWriter,
ZOGYSourceDetector,
)
from mirar.processors.sources.machine_learning import Pytorch
from mirar.processors.split import SUB_ID_KEY, SplitImage, SwarpImageSplitter
from mirar.processors.utils import (
CustomImageBatchModifier,
Expand Down Expand Up @@ -691,9 +696,21 @@

load_sources = [
SourceLoader(input_dir_name="candidates"),
SourceBatcher(BASE_NAME_KEY),
]

ml_classify = [
Pytorch(
model=WINTERNet(),
model_weights_url="https://github.com/winter-telescope/winterrb/raw/"
"v1.0.0/models/winterrb_v1_0_0_weights.pth",
apply_to_table=apply_rb_to_table,
),
HeaderEditor(edit_keys="rbversion", values="v1.0.0"),
]

crossmatch_candidates = [
SourceDebatcher(),
XMatch(catalog=TMASS(num_sources=3, search_radius_arcmin=0.5)),
XMatch(catalog=PS1(num_sources=3, search_radius_arcmin=0.5)),
XMatch(catalog=PS1SGSc(num_sources=3, search_radius_arcmin=0.5)),
Expand Down Expand Up @@ -786,7 +803,7 @@

avro_export = avro_write + avro_broadcast

process_candidates = crossmatch_candidates + name_candidates + avro_write
process_candidates = ml_classify + crossmatch_candidates + name_candidates + avro_write

load_avro = [SourceLoader(input_dir_name="preavro")]

Expand Down
1 change: 1 addition & 0 deletions mirar/pipelines/winter/generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
winter_ref_photometric_catalogs_purifier,
winter_reference_phot_calibrator,
)
from mirar.pipelines.winter.generator.realbogus import apply_rb_to_table
from mirar.pipelines.winter.generator.reduce import (
mask_stamps_around_bright_stars,
select_winter_dome_flats_images,
Expand Down
5 changes: 3 additions & 2 deletions mirar/pipelines/winter/generator/candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,14 @@ def winter_candidate_quality_filterer(source_table: SourceBatch) -> SourceBatch:
mask = (
(src_df["nbad"] < 2)
& (src_df["ndethist"] > 0)
# & (src_df["chipsf"] < 3.0)
& (src_df["sumrat"] > 0.7)
& ((src_df["rb"] > 0.5) | pd.isnull(src_df["rb"]))
& (src_df["sumrat"] > 0.6)
& (src_df["fwhm"] < 10.0)
& (src_df["magdiff"] < 1.6)
& (src_df["magdiff"] > -1.0)
& (src_df["mindtoedge"] > 50.0)
& (src_df["isdiffpos"])
& ((src_df["sgscore1"] < 0.5) | pd.isnull(src_df["sgscore1"]))
)
filtered_df = src_df[mask].reset_index(drop=True)

Expand Down
33 changes: 33 additions & 0 deletions mirar/pipelines/winter/generator/realbogus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Functions to apply rbscore
"""

import numpy as np
import pandas as pd
import torch
from torch import nn
from winterrb.utils import make_triplet


def apply_rb_to_table(model: nn.Module, table: pd.DataFrame) -> pd.DataFrame:
"""
Apply the realbogus score to a table of sources
:param model: Pytorch model
:param table: Table of sources
:return: Table of sources with realbogus score
"""

rb_scores = []

for _, row in table.iterrows():
triplet = make_triplet(row, normalize=True)
triplet_reshaped = np.transpose(np.expand_dims(triplet, axis=0), (0, 3, 1, 2))
with torch.no_grad():
outputs = model(torch.from_numpy(triplet_reshaped))

rb_scores.append(float(outputs[0]))

table["rb"] = rb_scores

return table
4 changes: 2 additions & 2 deletions mirar/pipelines/winter/models/_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class CandidatesTable(WinterBase): # pylint: disable=too-few-public-methods
# Real/bogus properties

rb = Column(Float, nullable=True)
rbversion = Column(Float, nullable=True)
rbversion = Column(VARCHAR(10), nullable=True)

# Solar system properties

Expand Down Expand Up @@ -284,7 +284,7 @@ class Candidate(BaseDB):
scorr: float = Field(ge=0)

rb: float | None = Field(ge=0, default=None)
rbversion: float | None = Field(ge=0, default=None)
rbversion: str | None = Field(default=None, max_length=10)

ssdistnr: float | None = Field(ge=0, default=None)
ssmagnr: float | None = Field(ge=0, default=None)
Expand Down
5 changes: 5 additions & 0 deletions mirar/processors/sources/machine_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Module for machine learning models
"""

from mirar.processors.sources.machine_learning.pytorch import Pytorch
132 changes: 132 additions & 0 deletions mirar/processors/sources/machine_learning/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Module with classes to use apply an ML score from pytorch
"""

import logging
from pathlib import Path
from typing import Callable

import pandas as pd
import requests
import torch
from torch import nn

from mirar.data import SourceBatch
from mirar.paths import ml_models_dir
from mirar.processors.base_processor import BaseSourceProcessor

logger = logging.getLogger(__name__)


class Pytorch(BaseSourceProcessor):
"""
Class to apply a pytorch model to a source table
"""

base_key = "pytorch"

def __init__(
self,
model: nn.Module,
model_weights_url: str,
apply_to_table: Callable[[nn.Module, pd.DataFrame], pd.DataFrame],
):
super().__init__()
self._model = model
self.model_weights_url = model_weights_url
self.model_name = Path(self.model_weights_url).name
self.apply_to_table = apply_to_table

self.model = None

def __str__(self) -> str:
return f"Processor to use Pytorch model {self.model_name} to score sources"

def get_ml_path(self) -> Path:
"""
Get the path to the ML model
:return: Path to the ML model
"""
return ml_models_dir.joinpath(self.model_name)

def download_model(self):
"""
Download the ML model
"""

url = self.model_weights_url
local_path = self.get_ml_path()

logger.info(
f"Downloading model {self.model_name} " f"from {url} to {local_path}"
)

with requests.get(url, stream=True, timeout=120.0) as r:
r.raise_for_status()
with open(local_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
# If you have chunk encoded response uncomment if
# and set chunk_size parameter to None.
# if chunk:
f.write(chunk)

if not local_path.exists():
err = f"Model {self.model_name} not downloaded"
logger.error(err)
raise FileNotFoundError(err)

@staticmethod
def load_model(path):
"""
Function to load a pytorch model dict from a path
:param path: Path to the model
:return: Pytorch model dict
"""
if not path.exists():
err = f"Model {path} not found"
logger.error(err)
raise FileNotFoundError(err)

if path.suffix in [".pth", ".pt"]:
return torch.load(path)

raise ValueError(f"Unknown model type {path.suffix}")

def get_model(self):
"""
Load the ML model weights. Download it if it doesn't exist.
:return: ML model
"""

if self.model is None:

model = self._model

local_path = self.get_ml_path()

if not local_path.exists():
self.download_model()

model.load_state_dict(torch.load(local_path))
model.eval()

self.model = model

return self.model

def _apply_to_sources(
self,
batch: SourceBatch,
) -> SourceBatch:

model = self.get_model()

for source_table in batch:
sources = source_table.get_data()
new = self.apply_to_table(model, sources)
source_table.set_data(new)

return batch
Loading

0 comments on commit da49e55

Please sign in to comment.