diff --git a/noxfile.py b/noxfile.py index 06c6c9c..57568db 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,3 +1,5 @@ +"""Nox Test Sessions.""" + from __future__ import annotations import argparse @@ -13,9 +15,7 @@ @nox.session def lint(session: nox.Session) -> None: - """ - Run the linter. - """ + """Run the linter.""" session.install("pre-commit") session.run( "pre-commit", "run", "--all-files", "--show-diff-on-failure", *session.posargs @@ -25,9 +25,7 @@ def lint(session: nox.Session) -> None: # TODO: turn this on eventually @nox.session def pylint(session: nox.Session) -> None: - """ - Run PyLint. - """ + """Run PyLint.""" # This needs to be installed into the package environment, and is slower # than a pre-commit check session.install(".", "pylint") @@ -36,19 +34,14 @@ def pylint(session: nox.Session) -> None: @nox.session def tests(session: nox.Session) -> None: - """ - Run the unit and regular tests. - """ + """Run the unit and regular tests.""" session.install(".[test]") session.run("pytest", *session.posargs) @nox.session(reuse_venv=True) def docs(session: nox.Session) -> None: - """ - Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links. - """ - + """Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links.""" parser = argparse.ArgumentParser() parser.add_argument("--serve", action="store_true", help="Serve after building") parser.add_argument( @@ -87,10 +80,7 @@ def docs(session: nox.Session) -> None: @nox.session def build_api_docs(session: nox.Session) -> None: - """ - Build (regenerate) API docs. - """ - + """Build (regenerate) API docs.""" session.install("sphinx") session.chdir("docs") session.run( @@ -106,10 +96,7 @@ def build_api_docs(session: nox.Session) -> None: @nox.session def build(session: nox.Session) -> None: - """ - Build an SDist and wheel. - """ - + """Build an SDist and wheel.""" build_path = DIR.joinpath("build") if build_path.exists(): shutil.rmtree(build_path) diff --git a/pyproject.toml b/pyproject.toml index 3d9cecb..3b0993d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,33 +127,22 @@ disallow_incomplete_defs = true src = ["src"] [tool.ruff.lint] -extend-select = [ - "B", # flake8-bugbear - "I", # isort - "ARG", # flake8-unused-arguments - "C4", # flake8-comprehensions - "EM", # flake8-errmsg - "ICN", # flake8-import-conventions - "G", # flake8-logging-format - "PGH", # pygrep-hooks - "PIE", # flake8-pie - "PL", # pylint - "PT", # flake8-pytest-style - "PTH", # flake8-use-pathlib - "RET", # flake8-return - "RUF", # Ruff-specific - "SIM", # flake8-simplify - "T20", # flake8-print - "UP", # pyupgrade - "YTT", # flake8-2020 - "EXE", # flake8-executable - "NPY", # NumPy specific rules - "PD", # pandas-vet -] +extend-select = ["ALL"] ignore = [ + "ANN101", # Missing type annotation for self in method + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed in `**kwargs` + "BLE001", # Using bare `except` "COM812", # Missing trailing comma + "D107", # Missing docstring in `__init__` + "D203", # 1 blank line required before class docstring + "D213", # Multi-line docstring summary should start at the second line + "ERA001", # Commented-out code + "FIX002", # Line contains TODO + "N8", # Naming conventions "PD", # Pandas "PLR", # Design related pylint codes + "TD002", # Missing author in TODO + "TD003", # Missing issue link on the line following this TODO # TODO: fix these and remove "ARG001", # Unused function argument `X` "ARG003", # Unused method argument `X` @@ -167,7 +156,8 @@ isort.required-imports = ["from __future__ import annotations"] # typing-modules = ["cats._compat.typing"] [tool.ruff.lint.per-file-ignores] -"tests/**" = ["T20"] +"docs/conf.py" = ["A001", "D100", "INP001"] +"tests/**" = ["ANN", "D1", "INP", "S101", "T20"] "noxfile.py" = ["T20"] "__init__.py" = ["F403"] # TODO: fix these and remove diff --git a/src/cats/__init__.py b/src/cats/__init__.py index 654a467..cc04876 100644 --- a/src/cats/__init__.py +++ b/src/cats/__init__.py @@ -1,5 +1,4 @@ -""" -Copyright (c) 2023 CATS. All rights reserved. +"""Copyright (c) 2023 CATS. All rights reserved. cats: Community Atlas of Tidal Streams """ diff --git a/src/cats/_version.pyi b/src/cats/_version.pyi index 91744f9..5bb2b22 100644 --- a/src/cats/_version.pyi +++ b/src/cats/_version.pyi @@ -1,4 +1,2 @@ -from __future__ import annotations - version: str version_tuple: tuple[int, int, int] | tuple[int, int, int, str, str] diff --git a/src/cats/cmd/_core.py b/src/cats/cmd/_core.py index 3f28e54..e46e2b9 100644 --- a/src/cats/cmd/_core.py +++ b/src/cats/cmd/_core.py @@ -2,11 +2,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any, Callable + import astropy.units as u import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import scipy +from astropy.convolution import Gaussian1DKernel, convolve from astropy.coordinates import Distance from isochrones.mist import MIST_Isochrone from matplotlib.patches import PathPatch @@ -15,7 +18,13 @@ from ugali.analysis.isochrone import factory as isochrone_factory from cats.inputs import stream_inputs as inputs -from cats.pawprint.pawprint import Footprint2D +from cats.pawprint._footprint import Footprint2D + +if TYPE_CHECKING: + from matplotlib.figure import Figure + from numpy.typing import NDArray + + from cats.pawprint import Pawprint __authors__ = "Ani, Kiyan, Richard" @@ -37,7 +46,7 @@ class Isochrone: Stream multidimensional footprint. """ - def __init__(self, name: str, /, cat, pawprint) -> None: + def __init__(self, name: str, /, cat: Any, pawprint: Pawprint) -> None: self.stream = name self.cat = cat self.pawprint = pawprint @@ -84,10 +93,8 @@ def __init__(self, name: str, /, cat, pawprint) -> None: # otherwise it will just shift to the background self.correct_isochrone() - def sel_sky(self): - """ - Initialising the on-sky polygon mask to return only contained sources. - """ + def sel_sky(self) -> None: + """Initialize the on-sky polygon mask.""" on_poly_patch = mpl.patches.Polygon( self.pawprint.skyprint["stream"].vertices[::50], facecolor="none", @@ -95,25 +102,17 @@ def sel_sky(self): linewidth=2, ) on_points = np.vstack((self.cat["phi1"], self.cat["phi2"])).T - on_mask = on_poly_patch.get_path().contains_points(on_points) - - self.on_skymask = on_mask + self.on_skymask = on_poly_patch.get_path().contains_points(on_points) - def sel_pm(self): - """ - Initialising the proper motions polygon mask to return only contained sources. - """ + def sel_pm(self) -> None: + """Initialize the proper motions polygon mask.""" on_points = np.vstack( (self.cat["pm_phi1_cosphi2_unrefl"], self.cat["pm_phi2_unrefl"]) ).T - on_mask = self.pawprint.pmprint.inside_footprint(on_points) - self.on_pmmask = on_mask - - def sel_pm12(self): - """ - Initialising the proper motions polygon mask to return only contained sources. - """ + self.on_pmmask = self.pawprint.pmprint.inside_footprint(on_points) + def sel_pm12(self) -> None: + """Initialize the proper motions polygon mask.""" on_pm1_points = np.vstack( (self.cat["phi1"], self.cat["pm_phi1_cosphi2_unrefl"]) ).T @@ -126,10 +125,8 @@ def sel_pm12(self): self.on_pm2mask = on_pm2_mask self.on_pm12mask = on_pm1_mask & on_pm2_mask - def generate_isochrone(self): - """ - load an isochrone, LF model for a given metallicity, age, distance - """ + def generate_isochrone(self) -> None: + """Load an isochrone, LF model for a given metallicity, age, distance.""" # Convert feh to z Y_p = 0.245 # Primordial He abundance (WMAP, 2003) c = 1.54 # He enrichment ratio @@ -200,15 +197,20 @@ def generate_isochrone(self): self.mmag_2 = mmag_2 self.mmass_pdf = mmass_pdf - def data_cmd(self, xrange=(-0.5, 1.0), yrange=(15, 22)): - """ - Empirical CMD generated from the input catalogue, with distance gradient accounted for. + def data_cmd( + self, + xrange: tuple[float, float] = (-0.5, 1.0), + yrange: tuple[float, float] = (15, 22), + ) -> None: + """Make Empirical CMD. - ------------------------------------------------------------------ + Empirical CMD generated from the input catalogue, with distance gradient + accounted for. - Parameters: - xrange: Set the range of color values. Default is [-0.5, 1.0]. - yrange: Set the range of magnitude values. Default is [15, 22]. + Parameters + ---------- + xrange, yrange: tuple[float, float] + Set the range of color values. """ tab = self.cat x_bins = np.arange( @@ -218,7 +220,8 @@ def data_cmd(self, xrange=(-0.5, 1.0), yrange=(15, 22)): yrange[0], yrange[1], inputs[self.stream]["bin_sizes"][1] ) # Used 0.2 for Jhelum - # if this is the second runthrough and a proper motion mask already exists, use that instead of the rough one + # if this is the second runthrough and a proper motion mask already + # exists, use that instead of the rough one if self.pawprint.pm1print is not None: data, xedges, yedges = np.histogram2d( (tab[self.data_color1] - tab[self.data_color2])[ @@ -246,21 +249,21 @@ def data_cmd(self, xrange=(-0.5, 1.0), yrange=(15, 22)): self.y_edges = yedges self.CMD_data = data.T - def correct_isochrone(self): - """ + def correct_isochrone(self) -> None: + """Correct the isochrone. + Correlate the 2D histograms from the data and the theoretical isochrone to find the shift in color and magnitude necessary for the best match """ - - signal, xedges, yedges = np.histogram2d( + signal, *_ = np.histogram2d( self.color, self.mag, bins=[self.x_edges, self.y_edges], weights=np.ones(len(self.mag)), ) - signal_counts, xedges, yedges = np.histogram2d( + signal_counts, *_ = np.histogram2d( self.color, self.mag, bins=[self.x_edges, self.y_edges] ) signal = signal / signal_counts @@ -272,24 +275,36 @@ def correct_isochrone(self): self.x_shift = (x - len(ccor2d[0]) / 2.0) * (self.x_edges[1] - self.x_edges[0]) self.y_shift = (y - len(ccor2d) / 2.0) * (self.y_edges[1] - self.y_edges[0]) - def make_poly(self, iso_low, iso_high, maxmag=26, minmag=14): + def make_poly( + self, + iso_low: InterpolatedUnivariateSpline, + iso_high: InterpolatedUnivariateSpline, + maxmag: float = 26, + minmag: float = 14, + ) -> tuple[Any, Any]: + """Generate the CMD polygon mask. + + Parameters + ---------- + iso_low: InterpolatedUnivariateSpline + spline function describing the "left" bound of the theorietical + isochrone + iso_high: InterpolatedUnivariateSpline + spline function describing the "right" bound of the theoretical + isochrone + maxmag: float + faint limit of theoretical isochrone, should be deeper than all data + minmag: float + bright limit of theoretical isochrone, either include just MS and + subgiant branch or whole isochrone + + Returns + ------- + cmd_poly : NDArray + Polygon vertices in CMD space. + cmd_mask : NDArray[bool] + Boolean mask in CMD sapce. """ - Generate the CMD polygon mask. - - ------------------------------------------------------------------ - - Parameters: - iso_low: spline function describing the "left" bound of the theorietical isochrone - iso_high: spline function describing the "right" bound of the theoretical isochrone - maxmag: faint limit of theoretical isochrone, should be deeper than all data - minmag: bright limit of theoretical isochrone, either include just MS and subgiant branch or whole isochrone - - Returns: - cmd_poly: Polygon vertices in CMD space. - cmd_mask: Boolean mask in CMD sapce. - - """ - mag_vals = np.arange(minmag, maxmag, 0.01) col_low_vals = iso_low(mag_vals) col_high_vals = iso_high(mag_vals) @@ -312,10 +327,11 @@ def make_poly(self, iso_low, iso_high, maxmag=26, minmag=14): return cmd_footprint, cmd_mask - def get_tolerance(self, scale_err=1, base_tol=0.075): - """ - Convolving errors to create wider selections near mag limit - Code written by Nora Shipp and adapted by Kiyan Tavangar + def get_tolerance(self, scale_err: float = 1, base_tol: float = 0.075) -> float: + """Convolving errors to create wider selections near mag limit. + + .. codeauthor:: + Nora Shipp, Kiyan Tavangar """ if self.phot_survey == "PS1": offset = 0.00363355415 @@ -334,30 +350,35 @@ def get_tolerance(self, scale_err=1, base_tol=0.075): mu = 23.9127145 scale = 1.09685211 - def err(x): + def err(x: float) -> float: return offset + np.exp((x - mu) / scale) return scale_err * err(self.mag) + base_tol - def simpleSln(self, maxmag=22, scale_err=2, mass_thresh=0.80): - """ - Select the stars that are within the CMD polygon cut - -------------------------------- - Parameters: - - maxmag: faint limit of created CMD polygon, should be deeper than all data - - mass_thresh: upper limit for the theoretical mass that dictates the bright limit of the - theoretical isochrone used for polygon - - coloff: shift in color from theoretical isochrone to data - - magoff: shift in magnitude from theoretical isochrone to data - - Returns: + def simpleSln( + self, maxmag: float = 22, scale_err: float = 2, mass_thresh: float = 0.80 + ) -> tuple[Any, Any, Any, Any]: + """Select the stars that are within the CMD polygon cut. + + Parameters + ---------- + maxmag: float + faint limit of created CMD polygon, should be deeper than all data + scale_err : float + TODO. + mass_thresh : float + TODO. + + Returns + ------- - cmd_poly: vertices of the CMD polygon cut - cmd_mask: bitmask of stars that pass the polygon cut - iso_model: the theoretical isochrone after shifts - - iso_low: the "left" bound of the CMD polygon cut made from theoretical isochrone - - iso_high: the "right" bound of the CMD polygon cut made from theoretical isochrone + - iso_low: the "left" bound of the CMD polygon cut made from theoretical + isochrone + - iso_high: the "right" bound of the CMD polygon cut made from + theoretical isochrone """ - coloff = self.x_shift magoff = self.y_shift ind = self.masses < mass_thresh @@ -384,8 +405,6 @@ def simpleSln(self, maxmag=22, scale_err=2, mass_thresh=0.80): iso_low, iso_high, maxmag, minmag=self.turnoff ) - # self.pawprint.cmd_filters = ... need to specify this since g vs g-r is a specific choice - # self.pawprint.add_cmd_footprint(cmd_footprint, 'g_r', 'g', 'cmdprint') self.pawprint.cmdprint = cmd_footprint self.pawprint.hbprint = hb_print @@ -393,8 +412,9 @@ def simpleSln(self, maxmag=22, scale_err=2, mass_thresh=0.80): return cmd_footprint, self.cmd_mask, hb_print, self.hb_mask, self.pawprint - def make_hb_print(self): - # probably want to incorporate this into cmdprint and have two discontinuous regions + def make_hb_print(self) -> None: + # probably want to incorporate this into cmdprint and have two + # discontinuous regions if self.phot_survey == "PS1": if self.band2 == "i": g_i_0 = np.array([-0.9, -0.6, -0.2, 0.45, 0.6, -0.6, -0.9]) @@ -463,10 +483,10 @@ def make_hb_print(self): return hb_footprint, hb_mask - def plot_CMD(self, scale_err=2): - """ - Plot the shifted isochrone over a 2D histogram of the polygon-selected - data. + def plot_CMD(self, scale_err: float = 2) -> Figure: + """Plot the shifted isochrone. + + Over a 2D histogram of the polygon-selected data. Returns matplotlib Figure. @@ -548,16 +568,15 @@ def plot_CMD(self, scale_err=2): return fig - def convolve_1d(self, probabilities, mag_err): - """ - 1D Gaussian convolution. - - ------------------------------------------------------------------ - - Parameters: - probabilities: - mag_err: Uncertainty in the magnitudes. + def convolve_1d(self, probabilities: NDArray, mag_err: NDArray) -> NDArray: + """1D Gaussian convolution. + Parameters + ---------- + probabilities : NDArray + Probability of the magnitudes. + mag_err : NDArray + Uncertainty in the magnitudes. """ self.probabilities = probabilities self.mag_err = mag_err @@ -568,20 +587,21 @@ def convolve_1d(self, probabilities, mag_err): self.convolved = convolved - def convolve_errors(self, g_errors, r_errors, intr_err=0.1): - """ - - 1D Gaussian convolution of the data with uncertainties. - - ------------------------------------------------------------------ - - Parameters: - g_errors: g magnitude uncertainties. - r_errors: r magnitude uncertainties. - intr_err: Free to set. Default is 0.1. - + def convolve_errors( + self, + g_errors: Callable[[NDArray], NDArray], + r_errors: Callable[[NDArray], NDArray], + intr_err: float = 0.1, + ) -> None: + """1D Gaussian convolution of the data with uncertainties. + + Parameters + ---------- + g_errors, r_errors : Callable[[ndarray], ndarray] + g, r magnitude uncertainties. + intr_err: + Free to set. Default is 0.1. """ - for i in range(len(probabilities)): probabilities[i] = convolve_1d( probabilities[i], @@ -595,11 +615,8 @@ def convolve_errors(self, g_errors, r_errors, intr_err=0.1): self.probabilities = probabilities - def errFn(self): - """ - Generate the errors for the magnitudes? - """ - + def errFn(self) -> None: + """Generate the errors for the magnitudes.""" gerrs = np.zeros(len(self.y_bins)) rerrs = np.zeros(len(self.x_bins)) diff --git a/src/cats/combine_pm_cmd.py b/src/cats/combine_pm_cmd.py index b2710dc..92c8757 100644 --- a/src/cats/combine_pm_cmd.py +++ b/src/cats/combine_pm_cmd.py @@ -1,39 +1,26 @@ +"""Combine PM and CMD cuts.""" + from __future__ import annotations +from typing import Any + import astropy.table as at -import matplotlib as mpl +import matplotlib.pyplot as plt import pandas as pd from cats.cmd.CMD import Isochrone from cats.pawprint.pawprint import Footprint2D, Pawprint -plt = mpl.pyplot - -plt.rc( - "xtick", - top=True, - direction="in", - labelsize=15, -) -plt.rc( - "ytick", - right=True, - direction="in", - labelsize=15, -) -plt.rc( - "font", - family="Arial", -) +plt.rc("xtick", top=True, direction="in", labelsize=15) +plt.rc("ytick", right=True, direction="in", labelsize=15) +plt.rc("font", family="Arial") def generate_isochrone_vertices( - cat, - sky_poly, - pm_poly, - config, -): - """ + cat: Any, sky_poly: Any, pm_poly: Any, config: Any +) -> Any: + """Generate Isochrone Vertices. + Use the generated class to make a new polygon for the given catalog in CMD space given a sky and PM polygon. """ @@ -64,24 +51,18 @@ def generate_isochrone_vertices( def generate_pm_vertices( - cat, - sky_poly, - cmd_poly, - config, -): - """ + cat: Any, sky_poly: Any, cmd_poly: Any, config: Any +) -> list[list[float]]: + """Generate Proper Motion Vertices. + Use the generated class to make a new polygon for the given catalog in PM space given a sky and CMD polygon. """ - return [ - [-7.0, 0.0], - [-5.0, 0.0], - [-5.0, 1.6], - [-7.0, -1.6], - ] + return [[-7.0, 0.0], [-5.0, 0.0], [-5.0, 1.6], [-7.0, -1.6]] -def load_sky_region(fn): +def load_sky_region(fn: Any) -> tuple[list[float], list[float]]: + """Load Sky Region.""" sky_print = [ [-5, -2], [+5, -2], @@ -93,6 +74,7 @@ def load_sky_region(fn): def main() -> int: + """Run Script.""" # load in config file, catalog from filename config = pd.read_json("config.json") cat = at.Table.read(config.streaminfo.cat_fn) diff --git a/src/cats/data.py b/src/cats/data.py index 9ea28b8..019b1f7 100644 --- a/src/cats/data.py +++ b/src/cats/data.py @@ -28,8 +28,11 @@ def make_astro_photo_joined_data( Parameters ---------- gaia_data : `pyia.GaiaData` + The Gaia data. phot_data : `cats.photometry.PhotometricSurvey` + The photometry data. track6d : `galstreams.Track6D` + The stream track. Returns ------- diff --git a/src/cats/inputs.py b/src/cats/inputs.py index 5405ec9..d0d40c2 100644 --- a/src/cats/inputs.py +++ b/src/cats/inputs.py @@ -1,3 +1,5 @@ +"""Stream Configuration Inputs.""" + from __future__ import annotations from typing import Any diff --git a/src/cats/pawprint/__init__.py b/src/cats/pawprint/__init__.py index 0ffc0dd..9b06878 100644 --- a/src/cats/pawprint/__init__.py +++ b/src/cats/pawprint/__init__.py @@ -1,3 +1,5 @@ +"""Pawprint module.""" + from __future__ import annotations from . import _core, _footprint diff --git a/src/cats/pawprint/_core.py b/src/cats/pawprint/_core.py index 7941303..dda8a5a 100644 --- a/src/cats/pawprint/_core.py +++ b/src/cats/pawprint/_core.py @@ -3,6 +3,7 @@ __all__ = ["Pawprint"] import pathlib +from typing import TYPE_CHECKING, Any import asdf import astropy.table as apt @@ -11,6 +12,9 @@ from astropy.coordinates import SkyCoord from gala.coordinates import GreatCircleICRSFrame +if TYPE_CHECKING: + from typing_extensions import Self + class Pawprint(dict): """Dictionary class to store a "pawprint". @@ -22,7 +26,7 @@ class Pawprint(dict): New convention: everything is in phi1 phi2 (don't cross the streams) """ - def __init__(self, data): + def __init__(self, data: dict[str, Any]) -> None: self.stream_name = data["stream_name"] self.pawprint_ID = data["pawprint_ID"] self.stream_frame = data["stream_frame"] @@ -74,9 +78,7 @@ def __init__(self, data): self.track = data["track"] @classmethod - def from_file(cls, fname): - import asdf - + def from_file(cls: type[Self], fname: str) -> Self: data = {} with asdf.open("fname") as a: # first transfer the stuff that goes directly @@ -110,34 +112,32 @@ def from_file(cls, fname): return cls(data) @classmethod - def pawprint_from_galstreams(cls, stream_name, pawprint_ID, width): - def _get_stream_frame_from_file(summary_file): + def pawprint_from_galstreams( + cls: type[Self], stream_name: str, pawprint_ID: Any, width: float + ) -> Self: + def _get_stream_frame_from_file(summary_file: str) -> GreatCircleICRSFrame: t = apt.QTable.read(summary_file) x = {} atts = [x.replace("mid.", "") for x in t.keys() if "mid" in x] - for ( - att - ) in ( - atts - ): # we're effectively looping over skycoords defined for mid here (ra, dec, ...) - x[att] = t[f"mid.{att}"][ - 0 - ] # <- make sure to set it up as a scalar. if not, frame conversions get into trouble + # we're effectively looping over skycoords defined for mid here (ra, + # dec, ...) + for att in atts: + # Make sure to set it up as a scalar. if not, frame conversions + # get into trouble + x[att] = t[f"mid.{att}"][0] mid_point = SkyCoord(**x) x = {} atts = [x.replace("pole.", "") for x in t.keys() if "pole" in x] - for ( - att - ) in ( - atts - ): # we're effectively looping over skycoords defined for pole here (ra, dec, ...) + # we're effectively looping over skycoords defined for pole here + # (ra, dec, ...) + for att in atts: x[att] = t[f"pole.{att}"][0] - # Make sure to set the pole's distance attribute to 1 (zero causes problems, when transforming to stream frame coords) - x["distance"] = ( - 1.0 * u.kpc - ) # it shouldn't matter, but if it's zero it does crazy things + # Make sure to set the pole's distance attribute to 1 (zero causes + # problems, when transforming to stream frame coords) it shouldn't + # matter, but if it's zero it does crazy things + x["distance"] = 1.0 * u.kpc mid_pole = SkyCoord(**x) return GreatCircleICRSFrame(pole=mid_pole, ra0=mid_point.icrs.ra) @@ -160,9 +160,8 @@ def _get_stream_frame_from_file(summary_file): summary_file=summary_file, ) try: - data["width"] = ( - 2 * data["track"].track_width["width_phi2"] - ) # one standard deviation on each side (is this wide enough?) + # one standard deviation on each side (is this wide enough?) + data["width"] = 2 * data["track"].track_width["width_phi2"] except Exception: data["width"] = width data["stream_vertices"] = data["track"].create_sky_polygon_footprint_from_track( @@ -181,7 +180,9 @@ def _get_stream_frame_from_file(summary_file): return cls(data) - def add_cmd_footprint(self, new_footprint, color, mag, name): + def add_cmd_footprint( + self, new_footprint: Any, color: Any, mag: Any, name: str + ) -> None: if self.cmd_filters is None: self.cmd_filters = dict((name, [color, mag])) self.cmdprint = dict((name, new_footprint)) @@ -189,13 +190,13 @@ def add_cmd_footprint(self, new_footprint, color, mag, name): self.cmd_filters[name] = [color, mag] self.cmdprint[name] = new_footprint - def add_pm_footprint(self, new_footprint, name): + def add_pm_footprint(self, new_footprint: Any, name: str) -> None: if self.pmprint is None: self.pmprint = dict((name, new_footprint)) else: self.pmprint[name] = new_footprint - def save_pawprint(self): + def save_pawprint(self) -> None: # WARNING this doesn't save the track yet - need schema # WARNING the stream frame doesn't save right either fname = self.stream_name + self.pawprint_ID + ".asdf" @@ -207,7 +208,6 @@ def save_pawprint(self): "width": self.width, "on_stream": {"sky": self.skyprint["stream"].export()}, "off_stream": self.skyprint["background"].export(), - # 'track':self.track #TODO } if self.cmdprint is not None: tree["on_stream"]["cmd"] = { diff --git a/src/cats/pawprint/_footprint.py b/src/cats/pawprint/_footprint.py index 4d788ac..d18bcd8 100644 --- a/src/cats/pawprint/_footprint.py +++ b/src/cats/pawprint/_footprint.py @@ -2,14 +2,29 @@ __all__ = ["Footprint2D"] +from typing import TYPE_CHECKING, Any + import astropy.table as apt import numpy as np from astropy.coordinates import SkyCoord from matplotlib.path import Path as mpl_path +if TYPE_CHECKING: + from astropy.coordinates import BaseCoordinateFrame + from numpy import bool_ + from numpy.typing import NDArray + from typing_extensions import Self + class Footprint2D(dict): - def __init__(self, vertex_coordinates, footprint_type, stream_frame=None): + """A 2D footprint.""" + + def __init__( + self, + vertex_coordinates: Any, + footprint_type: Any, + stream_frame: BaseCoordinateFrame | None = None, + ) -> None: if footprint_type == "sky": if isinstance(vertex_coordinates, SkyCoord): vc = vertex_coordinates @@ -29,25 +44,36 @@ def __init__(self, vertex_coordinates, footprint_type, stream_frame=None): self.footprint = mpl_path(self.vertices) @classmethod - def from_vertices(cls, vertex_coordinates, footprint_type): + def from_vertices( + cls: type[Self], vertex_coordinates: Any, footprint_type: Any + ) -> Self: return cls(vertex_coordinates, footprint_type) @classmethod - def from_box(cls, min1, max1, min2, max2, footprint_type): + def from_box( + cls: type[Self], + min1: float, + max1: float, + min2: float, + max2: float, + footprint_type: str, + ) -> Self: vertices = cls.get_vertices_from_box(min1, max1, min2, max2) return cls(vertices, footprint_type) @classmethod - def from_file(cls, fname): + def from_file(cls: type[Self], fname: str) -> Self: with apt.Table.read(fname) as t: vertices = t["vertices"] footprint_type = t["footprint_type"] return cls(vertices, footprint_type) - def get_vertices_from_box(self, min1, max1, min2, max2): + def get_vertices_from_box( + self, min1: float, max1: float, min2: float, max2: float + ) -> list[list[float]]: return [[min1, min2], [min1, max2], [max1, min2], [max1, max2]] - def inside_footprint(self, data): + def inside_footprint(self, data: SkyCoord | Any) -> NDArray[bool_] | None: if isinstance(data, SkyCoord): if self.stream_frame is None: print("can't!") @@ -63,7 +89,7 @@ def inside_footprint(self, data): else: return self.footprint.contains_points(data) - def export(self): + def export(self) -> dict[str, Any]: data = {} data["stream_frame"] = self.stream_frame data["vertices"] = self.vertices diff --git a/src/cats/photometry/_base.py b/src/cats/photometry/_base.py index 8c2404f..3aed24d 100644 --- a/src/cats/photometry/_base.py +++ b/src/cats/photometry/_base.py @@ -4,15 +4,17 @@ import abc from dataclasses import dataclass -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar -import numpy as np -import numpy.typing as npt -from astropy.coordinates import SkyCoord from astropy.table import QTable, Table -from dustmaps.map_base import DustMap from dustmaps.sfd import SFDQuery -from typing_extensions import Self + +if TYPE_CHECKING: + from astropy.coordinates import SkyCoord + from dustmaps.map_base import DustMap + from numpy import bool_ + from numpy.typing import NDArray + from typing_extensions import Self @dataclass(frozen=True) @@ -61,7 +63,7 @@ def get_skycoord(self) -> SkyCoord: """Return a SkyCoord object from the data table.""" @abc.abstractmethod - def get_star_mask(self) -> npt.NDArray[np.bool_]: + def get_star_mask(self) -> NDArray[bool_]: """Star-galaxy separation.""" def get_ext_corrected_phot( diff --git a/src/cats/photometry/_builtin/desi.py b/src/cats/photometry/_builtin/desi.py index 8a6006a..5cedcb3 100644 --- a/src/cats/photometry/_builtin/desi.py +++ b/src/cats/photometry/_builtin/desi.py @@ -2,14 +2,17 @@ __all__ = ["DESY6Phot"] -from typing import ClassVar, TypedDict +from typing import TYPE_CHECKING, ClassVar, TypedDict -import astropy.coordinates as coord import astropy.units as u +from astropy.coordinates import SkyCoord from astropy.table import QTable from cats.photometry._base import AbstractPhotometricSurvey +if TYPE_CHECKING: + from numpy import bool_ + class DESY6BandNames(TypedDict): WAVG_MAG_PSF_G: str @@ -30,13 +33,15 @@ class DESY6Phot(AbstractPhotometricSurvey): extinction_coeffs: ClassVar[DESY6ExtinctionCoeffs] = {"g": 3.237, "r": 2.176} custom_extinction: ClassVar[bool] = True - def get_skycoord(self): - return coord.SkyCoord(self.data["RA"] * u.deg, self.data["DEC"] * u.deg) + def get_skycoord(self) -> SkyCoord: + return SkyCoord(self.data["RA"] * u.deg, self.data["DEC"] * u.deg) - def get_star_mask(self): + def get_star_mask(self) -> NDArray[bool_]: return (self.data["EXT_FITVD"] >= 0) & (self.data["EXT_FITVD"] < 2) - def get_ext_corrected_phot(self, dustmaps_cls=None): + def get_ext_corrected_phot( + self, dustmaps_cls: tuple[Dustmap] | None = None + ) -> QTable: if dustmaps_cls is None: dustmaps_cls = self.dustmaps_cls diff --git a/src/cats/photometry/_builtin/gaia.py b/src/cats/photometry/_builtin/gaia.py index e44e2e0..b43483b 100644 --- a/src/cats/photometry/_builtin/gaia.py +++ b/src/cats/photometry/_builtin/gaia.py @@ -2,7 +2,7 @@ __all__ = ["GaiaDR3Phot"] -from typing import ClassVar, TypedDict +from typing import TYPE_CHECKING, ClassVar, TypedDict import numpy as np from astropy.table import QTable @@ -10,6 +10,12 @@ from cats.photometry._base import AbstractPhotometricSurvey +if TYPE_CHECKING: + from astropy.coordinates import SkyCoord + from dustmaps.map_base import DustMap + from numpy import bool_ + from numpy.typing import NDArray + class GaiaDR3BandNames(TypedDict): phot_g_mean_mag: str @@ -25,13 +31,15 @@ class GaiaDR3Phot(AbstractPhotometricSurvey): } custom_extinction: ClassVar[bool] = True - def get_skycoord(self): + def get_skycoord(self) -> SkyCoord: return GaiaData(self.data).get_skycoord(distance=False) - def get_star_mask(self): + def get_star_mask(self) -> NDArray[bool_]: return np.ones(len(self.data), dtype=bool) - def get_ext_corrected_phot(self, dustmaps_cls=None): + def get_ext_corrected_phot( + self, dustmaps_cls: type[DustMap] | None = None + ) -> QTable: if dustmaps_cls is None: dustmaps_cls = self.dustmaps_cls diff --git a/src/cats/photometry/_builtin/ps1.py b/src/cats/photometry/_builtin/ps1.py index 65b5c80..6602822 100644 --- a/src/cats/photometry/_builtin/ps1.py +++ b/src/cats/photometry/_builtin/ps1.py @@ -2,15 +2,17 @@ __all__ = ["PS1Phot"] -from typing import ClassVar, TypedDict +from typing import TYPE_CHECKING, ClassVar, TypedDict import astropy.units as u -import numpy as np -import numpy.typing as npt from astropy.coordinates import SkyCoord from cats.photometry._base import AbstractPhotometricSurvey +if TYPE_CHECKING: + import numpy.typing as npt + from numpy import bool_ + class PS1BandNames(TypedDict): gMeanPSFMag: str @@ -54,11 +56,10 @@ def get_skycoord(self) -> SkyCoord: frame="icrs", ) - def get_star_mask(self) -> npt.NDArray[np.bool_]: + def get_star_mask(self) -> npt.NDArray[bool_]: """Star/galaxy separation for PS1. - See: - https://outerspace.stsci.edu/display/PANSTARRS/How+to+separate+stars+and+galaxies + See: https://outerspace.stsci.edu/display/PANSTARRS/How+to+separate+stars+and+galaxies Returns ------- diff --git a/src/cats/proper_motions.py b/src/cats/proper_motions.py index 1915217..c57695a 100644 --- a/src/cats/proper_motions.py +++ b/src/cats/proper_motions.py @@ -2,9 +2,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any, Callable + import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np +from astropy.modeling import fitting, models +from matplotlib.colors import LogNorm from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.interpolate import InterpolatedUnivariateSpline as IUS from scipy.spatial import ConvexHull @@ -12,20 +16,29 @@ from cats.inputs import stream_inputs as inputs from cats.pawprint.pawprint import Footprint2D +if TYPE_CHECKING: + from matplotlib.cm import ScalarMappable + from matplotlib.colorbar import Colorbar + from matplotlib.figure import Figure + from numpy.typing import NDArray + + from cats.pawprint import Pawprint + __author__ = ("Sophia", "Nora", "Nondh", "Lina", "Bruno", "Kiyan") -def rough_pm_poly(pawprint, data, buffer=2): - """ - Will return a polygon with a rough cut in proper motion space. +def rough_pm_poly( + pawprint: Pawprint, data: Any, buffer: float = 2 +) -> tuple[Footprint2D, NDArray[np.bool_]]: + """Will return a polygon with a rough cut in proper motion space. + This aims to be ~100% complete with no thoughts about purity. The goal is to use this cut in conjunction with the cmd cut in order to see the stream as a clear overdensity in (phi_1, phi_2), which will allow membership probability modeling """ - stream_fr = pawprint.track.stream_frame - track = pawprint.track.track.transform_to(stream_fr) - # track_refl = gc.reflex_correct(track) + stream_frame = pawprint.track.stream_frame + track = pawprint.track.track.transform_to(stream_frame) # use the galstream proper motion track track_pm1_min = np.min(track.pm_phi1_cosphi2.value) @@ -52,43 +65,54 @@ def rough_pm_poly(pawprint, data, buffer=2): class ProperMotionSelection: + """Proper Motion Selection. + + Parameters + ---------- + stream: object + galstream object that contains stream's proper motion tracks. + data: object + data dictionary. + pawprint: object + pawprint object that contains stream's sky and proper motion + polygons. + + best_pm_phi1_mean: + best initial guess for mean of pm_phi1. + best_pm_phi2_mean: + best initial guess for mean of pm_phi2. + best_pm_phi1_std: + best initial guess for pm_phi1 standard deviation. + best_pm_phi2_std: + best initial guess for pm_phi2 standard deviation. + n_dispersion_phi1: + float, default set to 1 standard deviation around phi_1. + n_dispersion_phi2: + float, default set to 1 standard deviation around phi_2. + refine_factor: + int, default set to 100, how smooth are the edges of the polygons. + cutoff: + float, in [0,1], cutoff on the height of the pdf to keep the stars + that have a probability to belong to the 2D gaussian above the + cutoff value. + """ + def __init__( self, - stream, - data, - pawprint, - # CMD_mask=True, - # spatial_mask_on=True, - # spatial_mask_off=True, - # pm_phi1_grad = None, # think we should take this from pawprint, or at least make that the default - # pm_phi2_grad = None, - best_pm_phi1_mean=None, - best_pm_phi2_mean=None, - best_pm_phi1_std=None, - best_pm_phi2_std=None, - cutoff=0.95, - n_dispersion_phi1=1, - n_dispersion_phi2=1, - refine_factor=100, - ): - """ - stream_obj: galstream object that contains stream's proper motion tracks - data: - :param: stream_obj: from galstreams so far #TODO: generalize - :param: CMD_mask: Used before the PM - :param: spatial_mask_on: - :param: spatial_mask_off: - :param: best_pm_phi1_mean: best initial guess for mean of pm_phi1 - :param: best_pm_phi2_mean: best initial guess for mean of pm_phi2 - :param: best_pm_phi1_std: best initial guess for pm_phi1 standard deviation - :param: best_pm_phi2_std: best initial guess for pm_phi2 standard deviation - :param: n_dispersion_phi1: float, default set to 1 standard deviation around phi_1 - :param: n_dispersion_phi2: float, default set to 1 standard deviation around phi_2 - :param: refine_factor: int, default set to 100, how smooth are the edges of the polygons - :param: cutoff: float, in [0,1], cutoff on the height of the pdf to keep the stars that have a probability to belong to the 2D gaussian above the cutoff value - """ - - # stream_obj starting as galstream but then should be replaced by best values that we find + stream: Any, + data: Any, + pawprint: Any, + best_pm_phi1_mean: float | None = None, + best_pm_phi2_mean: float | None = None, + best_pm_phi1_std: float | None = None, + best_pm_phi2_std: float | None = None, + cutoff: float = 0.95, + n_dispersion_phi1: int = 1, + n_dispersion_phi2: int = 1, + refine_factor: int = 100, + ) -> None: + # stream_obj starting as galstream but then should be replaced by best + # values that we find self.stream = stream self.stream_obj = pawprint.track self.data = data @@ -99,7 +123,10 @@ def __init__( self.cutoff = cutoff if not (self.cutoff <= 1 and self.cutoff >= 0): - msg = "the value of self.cutoff put in does not make sense! It has to be between 0 and 1" + msg = ( + "the value of self.cutoff put in does not make sense! " + "It has to be between 0 and 1" + ) raise AssertionError(msg) # Get tracks from galstreams with splines @@ -125,7 +152,8 @@ def __init__( # distmod_spl = np.poly1d([2.41e-4, 2.421e-2, 15.001]) # self.dist_mod_correct = distmod_spl(self.cat["phi1"]) - self.dist_mod - # SHOULD THE CMD CUT ALSO MAKE AN OFFSTREAM MASK? MAY BE USEFUL TO MAKE CUTS FOR SOME STREAMS + # SHOULD THE CMD CUT ALSO MAKE AN OFFSTREAM MASK? MAY BE USEFUL TO MAKE + # CUTS FOR SOME STREAMS self.initial_masks() self.pm_phi1_cosphi2 = self.data["pm_phi1_cosphi2_unrefl"][self.mask] self.pm_phi2 = self.data["pm_phi2_unrefl"][self.mask] @@ -138,12 +166,6 @@ def __init__( print(mid_phi1) if best_pm_phi1_mean is None: - # TODO: generalize this later to percentile_values = [16, 50, 84] - - # if self.stream == 'Fjorm-M68': - # self.best_pm_phi1_mean = 1 - # self.best_pm_phi2_mean = 4 - # else: self.best_pm_phi1_mean = spline_pm1(mid_phi1) self.best_pm_phi2_mean = spline_pm2(mid_phi1) @@ -192,12 +214,6 @@ def __init__( self.pm_phi1_cosphi2 = data["pm_phi1_cosphi2_unrefl"][self.mask] self.pm_phi2 = data["pm_phi2_unrefl"][self.mask] - # Plot the ellipse-like cut - # self.plot_pms_scatter(self.data, mask=True, - # n_dispersion_phi1=n_dispersion_phi1, - # n_dispersion_phi2=n_dispersion_phi2) - # self.plot_pm_hist(self.data, pms=[self.best_pm_phi1_mean, self.best_pm_phi2_mean]) - ###################################################### ## PM cut in PM space using PM gradient information ## ###################################################### @@ -233,19 +249,12 @@ def __init__( ) = self.build_pm12_polys_and_masks() self.mask = self.pm1_mask & self.pm2_mask & self.spatial_mask_on & self.CMD_mask - # Plot the cut in (phi1, pm1) and (phi1, pm2) space - # self.plot_pms_scatter(self.data, mask=True, - # n_dispersion_phi1=n_dispersion_phi1, - # n_dispersion_phi2=n_dispersion_phi2) - # self.plot_pm_hist(self.data, pms=[self.best_pm_phi1_mean, self.best_pm_phi2_mean]) - return - def from_galstreams(self): - stream_fr = self.stream_obj.stream_frame - self.track = self.stream_obj.track.transform_to(stream_fr) - # track_refl = gc.reflex_correct(track) - # self.track_refl = track_refl + def from_galstreams(self) -> tuple[IUS, IUS, IUS, IUS]: + """Get tracks from galstreams with splines.""" + stream_frame = self.stream_obj.stream_frame + self.track = self.stream_obj.track.transform_to(stream_frame) self.galstream_phi1 = self.track.phi1.value self.galstream_phi2 = self.track.phi2.value @@ -269,10 +278,6 @@ def from_galstreams(self): spline_pm1 = IUS(self.galstream_phi1, self.galstream_pm_phi1_cosphi2) spline_pm2 = IUS(self.galstream_phi1, self.galstream_pm_phi2) - # spline_phi2 = US(self.galstream_phi1, self.galstream_phi2, k=3, s=len(self.galstream_phi1)/1000) - # spline_pm1 = US(self.galstream_phi1, self.galstream_pm_phi1_cosphi2, k=3, s=len(self.galstream_phi1)/1000) - # spline_pm2 = US(self.galstream_phi1, self.galstream_pm_phi2, k=3, s=len(self.galstream_phi1)/1000) - if self.stream == "GD-1": spline_dist = np.poly1d( [2.41e-4, 2.421e-2, 15.001] @@ -282,10 +287,8 @@ def from_galstreams(self): return spline_phi2, spline_pm1, spline_pm2, spline_dist - def sel_sky(self): - """ - Initialising the on-sky polygon mask to return only contained sources. - """ + def sel_sky(self) -> tuple[NDArray[np.bool_], NDArray[np.bool_]]: + """Initialize the on-sky polygon mask to return only contained sources.""" on_poly_patch = mpl.patches.Polygon( self.pawprint.skyprint["stream"].vertices[::100], facecolor="none", @@ -306,11 +309,11 @@ def sel_sky(self): return on_mask, off_mask - def sel_cmd(self): - """ - Initialising the proper motions polygon mask to return only contained sources. - """ + def sel_cmd(self) -> NDArray[np.bool_]: + """Initialize the proper motions polygon mask. + Set to return only contained sources. + """ mag = inputs[self.stream]["mag"] color1 = inputs[self.stream]["color1"] color2 = inputs[self.stream]["color2"] @@ -323,18 +326,16 @@ def sel_cmd(self): ).T return self.pawprint.cmdprint.inside_footprint(cmd_points) - def initial_masks(self): - """ - Generate the initial spatial, and CMD masks based on the input - """ + def initial_masks(self) -> None: + """Generate the initial spatial, and CMD masks based on the input.""" self.spatial_mask_on, self.spatial_mask_off = self.sel_sky() self.CMD_mask = self.sel_cmd() self.mask = self.spatial_mask_on & self.CMD_mask self.off_mask = self.spatial_mask_off & self.CMD_mask - def rough_pm(self, buffer=2): - """ - Will return a polygon with a rough cut in proper motion space. + def rough_pm(self, buffer: float = 2) -> tuple[NDArray, NDArray[np.bool_]]: + """Will return a polygon with a rough cut in proper motion space. + This aims to be ~100% complete with no thoughts about purity. The goal is to use this cut in conjunction with the cmd cut in order to see the stream as a clear overdensity in (phi_1, phi_2), which will @@ -367,28 +368,45 @@ def rough_pm(self, buffer=2): return self.rough_pm_poly, self.rough_pm_mask @staticmethod - def two_dimensional_gaussian(x, y, x0, y0, sigma_x, sigma_y): - """ - Evaluates a two dimensional gaussian distribution in x, y, with means x0, y0, and dispersions sigma_x and sigma_y - """ + def two_dimensional_gaussian( + x: float, y: float, x0: float, y0: float, sigma_x: float, sigma_y: float + ) -> float: + """Evaluate a two dimensional gaussian distribution. + In x, y, with means x0, y0, and dispersions sigma_x and sigma_y. + """ return np.exp( -((x - x0) ** 2 / (2 * sigma_x**2) + (y - y0) ** 2 / (2 * sigma_y**2)) ) def build_poly_and_mask( - self, n_dispersion_phi1=3, n_dispersion_phi2=3, refine_factor=100 - ): - """ - Builds the mask of the proper motion with n_dispersion around the mean - :param: n_dispersion_phi1: float, default set to 1 standard deviation around phi_1 - :param: n_dispersion_phi2: float, default set to 1 standard deviation around phi_2 - :param: refine_factor: int, default set to 100, how smooth are the edges of the polygons - :param: cutoff: float, in [0,1], cutoff on the height of the pdf to keep the stars that have a probability to belong to the 2D gaussian above the cutoff value - - :output: is a list of points that are the vertices of a polygon + self, + n_dispersion_phi1: float = 3, + n_dispersion_phi2: float = 3, + refine_factor: int = 100, + ) -> tuple[NDArray, NDArray[np.bool_]]: + """Build mask of the proper motion with around the mean. + + Parameters + ---------- + n_dispersion_phi1: float + default set to 1 standard deviation around phi_1 + n_dispersion_phi2: float + default set to 1 standard deviation around phi_2 + refine_factor: int + default set to 100, how smooth are the edges of the polygons + cutoff: float + in [0,1], cutoff on the height of the pdf to keep the stars that + have a probability to belong to the 2D gaussian above the cutoff + value + + Returns + ------- + NDArray + vertices of the polygon. + NDArray[np.bool_] + mask of the polygon. """ - # First generate the 2D histograms pm_phi1_min, pm_phi1_max = ( self.best_pm_phi1_mean - n_dispersion_phi1 * self.best_pm_phi1_std, @@ -436,9 +454,11 @@ def build_poly_and_mask( return self.pm_poly, self.pm_mask - def build_pm12_polys_and_masks(self): - """ - This assumes that galstreams is correct, which is maybe not a great assumption but will work for now. + def build_pm12_polys_and_masks(self) -> tuple[NDArray, NDArray, NDArray, NDArray]: + """Build the pm1 and pm2 polygons and masks. + + This assumes that galstreams is correct, which is maybe not a great + assumption but will work for now. """ self.pm1_poly = np.concatenate( [ @@ -492,14 +512,20 @@ def build_pm12_polys_and_masks(self): return self.pm1_poly, self.pm2_poly, self.pm1_mask, self.pm2_mask - def build_mask(self, data, spline_pm1, spline_pm2, pm_poly): - """ - This builds a mask (i.e. finds the data points satisfying pm constraints) - that does not use the peak fitting used elsewhere. - It relies on splines for pm_phi1_cosphi2 and pm_phi2 vs phi1 which must be given as inputs - Most of the time, these will naturally come from galstreams + def build_mask( + self, + data: dict, + spline_pm1: Callable[[NDArray], NDArray], + spline_pm2: Callable[[NDArray], NDArray], + pm_poly: NDArray, + ) -> NDArray[np.bool_]: + """Build a mask. + + Finds the data points satisfying pm constraints that does not use the + peak fitting used elsewhere. It relies on splines for pm_phi1_cosphi2 + and pm_phi2 vs phi1 which must be given as inputs Most of the time, + these will naturally come from galstreams. """ - pm1_data_corrected = data["pm_phi1_cosphi2_unrefl"] - spline_pm1(data["phi1"]) pm2_data_corrected = data["pm_phi2_unrefl"] - spline_pm2(data["phi1"]) @@ -514,18 +540,35 @@ def build_mask(self, data, spline_pm1, spline_pm2, pm_poly): def plot_pms_scatter( self, - data, - save=True, - mask=False, - n_dispersion_phi1=1, - n_dispersion_phi2=1, - refine_factor=100, - **kwargs, - ): - """ - Plot proper motions on stream and off stream scatter or hist2d plots - :param: save: boolean, whether or not to save the figure - :param: mask: boolean, if true, calls in the mask + data: dict, + n_dispersion_phi1: int = 1, + n_dispersion_phi2: int = 1, + refine_factor: int = 100, + *, + save: bool = True, + mask: bool = False, + **kwargs: Any, + ) -> Figure: + """Plot proper motions on stream and off stream scatter or hist2d plots. + + Parameters + ---------- + data: dict + data dictionary + n_dispersion_phi1, n_dispersion_phi2: int + Passed to :meth:`~ProperMotionSelection.build_poly_and_mask`. + refine_factor: int + Passed to :meth:`~ProperMotionSelection.build_poly_and_mask`. + save: bool, keyword-only + whether or not to save the figure + mask: bool, keyword-only + If true, calls in the mask. + **kwargs: Any + Passed to `~matplotlib.axes.Axes.scatter`. + + Returns + ------- + :class:`~matplotlib.figure.Figure` """ data_on = data[self.mask] data_off = data[self.off_mask] @@ -560,8 +603,8 @@ def plot_pms_scatter( ax[0].set_xlim(-20, 20) ax[0].set_ylim(-20, 20) - ax[0].set_xlabel("$\mu_{\phi_1}$ [mas yr$^{-1}$]") - ax[0].set_ylabel("$\mu_{\phi_2}$ [mas yr$^{-1}$]") + ax[0].set_xlabel(r"$\mu_{\phi_1}$ [mas yr$^{-1}$]") + ax[0].set_ylabel(r"$\mu_{\phi_2}$ [mas yr$^{-1}$]") ax[0].set_title("Stream", fontsize="medium") # resize and fix column name @@ -580,8 +623,8 @@ def plot_pms_scatter( ax[1].set_xlim(-20, 20) ax[1].set_ylim(-20, 20) - ax[1].set_xlabel("$\mu_{\phi_1}$ [mas yr$^{-1}$]") - ax[1].set_ylabel("$\mu_{\phi_2}$ [mas yr$^{-1}$]") + ax[1].set_xlabel(r"$\mu_{\phi_1}$ [mas yr$^{-1}$]") + ax[1].set_ylabel(r"$\mu_{\phi_2}$ [mas yr$^{-1}$]") ax[1].set_title("Off stream", fontsize="medium") fig.tight_layout() @@ -600,20 +643,24 @@ def plot_pms_scatter( def plot_pm_hist( self, - data, - dx=0.5, - norm=1, - save=0, - pms=(None, None), - match_norm=False, - stream_coords=True, - reflex_corr=True, - zero_line=True, - pm_lims=(-20, 20), - **kwargs, - ): - # Code from Nora - + data: dict, + dx: float = 0.5, + norm: float = 1, + save: float = 0, + pms: tuple[None, None] = (None, None), + pm_lims: tuple[float, float] = (-20, 20), + *, + match_norm: bool = False, + stream_coords: bool = True, + reflex_corr: bool = True, + zero_line: bool = True, + **kwargs: Any, + ) -> Figure: + """Plot proper motions on stream and off stream histograms. + + .. codeauthor:: + Nora Shipp + """ data_on = data[self.mask] data_off = data[self.off_mask] @@ -635,7 +682,8 @@ def plot_pm_hist( h1 = np.histogram2d(data_on["PMRA0"], data_on["PMDEC0"], bins)[0] h2 = np.histogram2d(data_off["PMRA0"], data_off["PMDEC0"], bins)[0] - # might need to normalise histogram for different areas of off stream mask for subtraction histogram + # might need to normalise histogram for different areas of off stream + # mask for subtraction histogram h2 *= norm # print h1.sum(), h2.sum() @@ -698,9 +746,9 @@ def plot_pm_hist( **kwargs, ) - colorbar(im1) - colorbar(im2) - colorbar(im3) + _colorbar(im1) + _colorbar(im2) + _colorbar(im3) if (pms[0] is None) or (pms[1] is None): ax1.axvline(self.best_pm[0], ls="--", c="k", lw=1) @@ -764,23 +812,33 @@ def plot_pm_hist( # ========================= added Nov 3 (need checking) ======================== - def find_peak_location(self, data, x_width=3.0, y_width=3.0, draw_histograms=True): - """ - find peak location in the proper motion space - :param: data: list of the stellar parameters to get the peak pm. - :param: x_width: float, half x-size of zoomed region box, default set to 3. - :param: y_width: float, half y-size of zoomed region box, default set to 3. - :param: draw_histograms: print histograms, default set to True - - output: [pm_x_cen, pm_y_cen, x_std, y_std]: array + def find_peak_location( + self, + data: dict, + x_width: float = 3.0, + y_width: float = 3.0, + *, + draw_histograms: bool = True, + ) -> tuple[float, float, float, float]: + """Find peak location in the proper motion space. + + Parameters + ---------- + data : dict[str, np.ndarray] + Stellar parameters to get the peak pm. + x_width, y_width : float, optional + half x,y-size of zoomed region box, default set to 3. + draw_histograms: bool + print histograms, default set to True + + Returns + ------- + pm_x_cen, pm_y_cen, x_std, y_std: array pm_x_cen: peak proper motion in phi1 pm_y_cen: peak proper motion in phi2 x_std: standard deviation proper motion in phi1 y_std: standard deviation proper motion in phi2 """ - from astropy.modeling import fitting, models - from matplotlib.colors import LogNorm - x_center, y_center = self.best_pm_phi1_mean, self.best_pm_phi2_mean print(f"Pre-fitting mean PM values: {x_center}, {y_center}") xmin, xmax, ymin, ymax = ( @@ -814,7 +872,8 @@ def find_peak_location(self, data, x_width=3.0, y_width=3.0, draw_histograms=Tru hist = ( H1 - self.mask.sum() / self.off_mask.sum() * H2 ) # check this scale factor --> self.mask.sum()/self.off_mask.sum() - # Do we want to do based on counts or based on area, since we do expect more counts on stream (but maybe negligible) + # Do we want to do based on counts or based on area, since we do expect + # more counts on stream (but maybe negligible) # fitting 2D gaussian (Code from Ani) # Find overdensity @@ -833,7 +892,6 @@ def find_peak_location(self, data, x_width=3.0, y_width=3.0, draw_histograms=Tru y_stddev=0.5, ) fit_g = fitting.LevMarLSQFitter() - # x,y = np.meshgrid(x_edges[(ind[0]-6):(ind[0]+7)], y_edges[(ind[1]-6):(ind[1]+7)]) x, y = np.meshgrid(x_edges[:-1], y_edges[:-1]) # g = fit_g(g_init, x, y, hist_zoom) @@ -874,8 +932,6 @@ def find_peak_location(self, data, x_width=3.0, y_width=3.0, draw_histograms=Tru axes[2].plot( pm_x_cen + x_std * np.cos(t), pm_y_cen + y_std * np.sin(t), c="green" ) - # axes[2].set_xlim(xmin,xmax) - # axes[2].set_ylim(ymin,ymax) for ax in axes[:2]: ax.plot( @@ -890,10 +946,14 @@ def find_peak_location(self, data, x_width=3.0, y_width=3.0, draw_histograms=Tru self.best_pm_phi1_std = x_std self.best_pm_phi2_std = y_std - return [pm_x_cen, pm_y_cen, x_std, y_std] + return (pm_x_cen, pm_y_cen, x_std, y_std) + + +############################################################################### -def colorbar(mappable): +def _colorbar(mappable: ScalarMappable) -> Colorbar: + """Add a colorbar to the figure.""" ax = mappable.axes fig = ax.figure divider = make_axes_locatable(ax) diff --git a/src/cats/cmd/tests/gd1_testcmd.png b/tests/cmd/gd1_testcmd.png similarity index 100% rename from src/cats/cmd/tests/gd1_testcmd.png rename to tests/cmd/gd1_testcmd.png diff --git a/src/cats/cmd/tests/test_GD1.py b/tests/cmd/test_GD1.py similarity index 100% rename from src/cats/cmd/tests/test_GD1.py rename to tests/cmd/test_GD1.py diff --git a/src/cats/cmd/tests/test_run_GD-1.py b/tests/cmd/test_run_GD-1.py similarity index 100% rename from src/cats/cmd/tests/test_run_GD-1.py rename to tests/cmd/test_run_GD-1.py diff --git a/src/cats/cmd/tests/test_run_Jhelum.py b/tests/cmd/test_run_Jhelum.py similarity index 100% rename from src/cats/cmd/tests/test_run_Jhelum.py rename to tests/cmd/test_run_Jhelum.py diff --git a/src/cats/cmd/tests/test_run_Pal5.py b/tests/cmd/test_run_Pal5.py similarity index 100% rename from src/cats/cmd/tests/test_run_Pal5.py rename to tests/cmd/test_run_Pal5.py diff --git a/tests/pawprint/test_mwe.py b/tests/pawprint/test_mwe.py index 4aa7d83..f814709 100644 --- a/tests/pawprint/test_mwe.py +++ b/tests/pawprint/test_mwe.py @@ -38,9 +38,8 @@ on = stars.makeMask(pawprint, what="sky.stream") # is a function of starlist ax.plot(stars.ra[on], stars.dec[on], ".", ms=2.5, color="C0") -# Create a new polygon footprint off-stream, with a given offset and width, and select field points inside it -# -# off_poly = mwsts[st].create_sky_polygon_footprint_from_track(width=1.*u.deg, phi2_offset=3.5*u.deg) +# Create a new polygon footprint off-stream, with a given offset and width, and +# select field points inside it. off = stars.makeMask(pawprint, what="sky.background") # Plot the off-stream polygon footprint and points selected inside it ax.plot(