Skip to content

Commit

Permalink
Ready for rb
Browse files Browse the repository at this point in the history
  • Loading branch information
robertdstein committed May 14, 2024
1 parent dd90ea3 commit 1dfa8eb
Show file tree
Hide file tree
Showing 9 changed files with 553 additions and 109 deletions.
19 changes: 18 additions & 1 deletion mirar/pipelines/winter/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# pylint: disable=duplicate-code
import os

from winterrb.model import WINTERNet

from mirar.catalog.kowalski import PS1, TMASS, Gaia, GaiaBright, PS1SGSc
from mirar.downloader.get_test_data import get_test_data_dir
from mirar.paths import (
Expand Down Expand Up @@ -39,6 +41,7 @@
)
from mirar.pipelines.winter.constants import NXSPLIT, NYSPLIT
from mirar.pipelines.winter.generator import (
apply_rb_to_table,
mask_stamps_around_bright_stars,
select_winter_sky_flat_images,
winter_anet_sextractor_config_path_generator,
Expand Down Expand Up @@ -127,10 +130,12 @@
CustomSourceTableModifier,
ForcedPhotometryDetector,
SourceBatcher,
SourceDebatcher,
SourceLoader,
SourceWriter,
ZOGYSourceDetector,
)
from mirar.processors.sources.machine_learning import Pytorch
from mirar.processors.split import SUB_ID_KEY, SplitImage, SwarpImageSplitter
from mirar.processors.utils import (
CustomImageBatchModifier,
Expand Down Expand Up @@ -682,9 +687,21 @@

load_sources = [
SourceLoader(input_dir_name="candidates"),
SourceBatcher(BASE_NAME_KEY),
]

ml_classify = [
Pytorch(
model=WINTERNet(),
model_weights_url="https://github.com/winter-telescope/winterrb/raw/"
"v1.0.0/models/winterrb_v1_0_0_weights.pth",
apply_to_table=apply_rb_to_table,
),
HeaderEditor(edit_keys="rbversion", values="v1.0.0"),
]

crossmatch_candidates = [
SourceDebatcher(),
XMatch(catalog=TMASS(num_sources=3, search_radius_arcmin=0.5)),
XMatch(catalog=PS1(num_sources=3, search_radius_arcmin=0.5)),
XMatch(catalog=PS1SGSc(num_sources=3, search_radius_arcmin=0.5)),
Expand Down Expand Up @@ -777,7 +794,7 @@

avro_export = avro_write + avro_broadcast

process_candidates = crossmatch_candidates + name_candidates + avro_write
process_candidates = ml_classify + crossmatch_candidates + name_candidates + avro_write

load_avro = [SourceLoader(input_dir_name="preavro")]

Expand Down
1 change: 1 addition & 0 deletions mirar/pipelines/winter/generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
winter_ref_photometric_catalogs_purifier,
winter_reference_phot_calibrator,
)
from mirar.pipelines.winter.generator.realbogus import apply_rb_to_table
from mirar.pipelines.winter.generator.reduce import (
mask_stamps_around_bright_stars,
select_winter_dome_flats_images,
Expand Down
5 changes: 3 additions & 2 deletions mirar/pipelines/winter/generator/candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,14 @@ def winter_candidate_quality_filterer(source_table: SourceBatch) -> SourceBatch:
mask = (
(src_df["nbad"] < 2)
& (src_df["ndethist"] > 0)
# & (src_df["chipsf"] < 3.0)
& (src_df["sumrat"] > 0.7)
& ((src_df["rb"] > 0.5) | pd.isnull(src_df["rb"]))
& (src_df["sumrat"] > 0.6)
& (src_df["fwhm"] < 10.0)
& (src_df["magdiff"] < 1.6)
& (src_df["magdiff"] > -1.0)
& (src_df["mindtoedge"] > 50.0)
& (src_df["isdiffpos"])
& ((src_df["sgscore1"] < 0.5) | pd.isnull(src_df["sgscore1"]))
)
filtered_df = src_df[mask].reset_index(drop=True)

Expand Down
35 changes: 35 additions & 0 deletions mirar/pipelines/winter/generator/realbogus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Functions to apply rbscore
"""

import numpy as np
import pandas as pd
import torch
from torch import nn
from winterrb.utils import make_triplet

# ML_KEYS =


def apply_rb_to_table(model: nn.Module, table: pd.DataFrame) -> pd.DataFrame:
"""
Apply the realbogus score to a table of sources
:param model: Pytorch model
:param table: Table of sources
:return: Table of sources with realbogus score
"""

rb_scores = []

for _, row in table.iterrows():
triplet = make_triplet(row, normalize=True)
triplet_reshaped = np.transpose(np.expand_dims(triplet, axis=0), (0, 3, 1, 2))
with torch.no_grad():
outputs = model(torch.from_numpy(triplet_reshaped))

rb_scores.append(float(outputs[0]))

table["rb"] = rb_scores

return table
4 changes: 2 additions & 2 deletions mirar/pipelines/winter/models/_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class CandidatesTable(WinterBase): # pylint: disable=too-few-public-methods
# Real/bogus properties

rb = Column(Float, nullable=True)
rbversion = Column(Float, nullable=True)
rbversion = Column(VARCHAR(10), nullable=True)

# Solar system properties

Expand Down Expand Up @@ -284,7 +284,7 @@ class Candidate(BaseDB):
scorr: float = Field(ge=0)

rb: float | None = Field(ge=0, default=None)
rbversion: float | None = Field(ge=0, default=None)
rbversion: str | None = Field(default=None, max_length=10)

ssdistnr: float | None = Field(ge=0, default=None)
ssmagnr: float | None = Field(ge=0, default=None)
Expand Down
5 changes: 5 additions & 0 deletions mirar/processors/sources/machine_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Module for machine learning models
"""

from mirar.processors.sources.machine_learning.pytorch import Pytorch
Original file line number Diff line number Diff line change
@@ -1,45 +1,46 @@
"""
Module with classes to use apply an ML score
Module with classes to use apply an ML score from pytorch
"""

from typing import Callable
import logging
import pickle
from pathlib import Path
from typing import Callable

import pandas as pd
import requests
import torch
from torch import nn

from mirar.data import SourceBatch
from mirar.paths import ml_models_dir
from mirar.processors.base_processor import BaseSourceProcessor
import tensorflow as tf

logger = logging.getLogger(__name__)


class MLScore(BaseSourceProcessor):
class Pytorch(BaseSourceProcessor):
"""
Class to apply an ML model to a source table
Class to apply a pytorch model to a source table
"""

base_key = "MLScore"
base_key = "pytorch"

def __init__(
self,
model_url: str,
apply_to_row: Callable[[object, pd.Series], pd.Series] = None,
model: nn.Module,
model_weights_url: str,
apply_to_table: Callable[[nn.Module, pd.DataFrame], pd.DataFrame],
):
super().__init__()
self.model_url = model_url
self.model_name = Path(self.model_url).name
self.apply_to_row = apply_to_row
self._model = model
self.model_weights_url = model_weights_url
self.model_name = Path(self.model_weights_url).name
self.apply_to_table = apply_to_table

self._model = None
self.model = None

def __str__(self) -> str:
return (
f"Processor to use ML model {self.model_name} to score sources"
)
return f"Processor to use Pytorch model {self.model_name} to score sources"

def get_ml_path(self) -> Path:
"""
Expand All @@ -54,17 +55,16 @@ def download_model(self):
Download the ML model
"""

url = self.model_url
url = self.model_weights_url
local_path = self.get_ml_path()

logger.info(
f"Downloading model {self.model_name} "
f"from {url} to {local_path}"
f"Downloading model {self.model_name} " f"from {url} to {local_path}"
)

with requests.get(url, stream=True, timeout=120.) as r:
with requests.get(url, stream=True, timeout=120.0) as r:
r.raise_for_status()
with open(local_path, 'wb') as f:
with open(local_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
# If you have chunk encoded response uncomment if
# and set chunk_size parameter to None.
Expand All @@ -78,63 +78,55 @@ def download_model(self):

@staticmethod
def load_model(path):
"""
Function to load a pytorch model dict from a path
:param path: Path to the model
:return: Pytorch model dict
"""
if not path.exists():
err = f"Model {path} not found"
logger.error(err)
raise FileNotFoundError(err)

if path.suffix in [".h5", ".hdf5", ".keras"]:
return tf.keras.models.load_model(path)
if path.suffix == ".pkl":
with open(path, "rb") as model_file:
return pickle.load(model_file)
if path.suffix in [".pth", ".pt"]:
return torch.load(path)

raise ValueError(f"Unknown model type {path.suffix}")

def get_model(self):
"""
Load the ML model. Download it if it doesn't exist.
Load the ML model weights. Download it if it doesn't exist.
:return: ML model
"""

if self._model is None:
if self.model is None:

model = self._model

local_path = self.get_ml_path()

if not local_path.exists():
self.download_model()

self._model = self.load_model(local_path)
model.load_state_dict(torch.load(local_path))
model.eval()

self.model = model

return self._model
return self.model

def _apply_to_sources(
self,
batch: SourceBatch,
) -> SourceBatch:

print("Applying ML model to sources")

model = self.get_model()

print(model)

for source_table in batch:

sources = source_table.get_data()

new = []

for _, source in sources.iterrow():
row = self.apply_to_row(model, source)
new.append(row)

new = pd.DataFrame(new)
new = self.apply_to_table(model, sources)
source_table.set_data(new)

return batch


ml = MLScore(model_url='https://github.com//winter-telescope/winter_rb_models/raw/v1.0.0/models/winterdrb_VGG6_20240410_051006.keras') #FIXME: Update URL
ml.apply(SourceBatch())
Loading

0 comments on commit 1dfa8eb

Please sign in to comment.