Skip to content

Commit

Permalink
refactor pawprint module
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Nov 15, 2023
1 parent 3d84da1 commit 7137e86
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 69 deletions.
9 changes: 9 additions & 0 deletions src/cats/pawprint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from . import _core, _footprint
from ._core import *
from ._footprint import *

__all__ = []
__all__ += _core.__all__
__all__ += _footprint.__all__
73 changes: 4 additions & 69 deletions src/cats/pawprint/pawprint.py → src/cats/pawprint/_core.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,25 @@
from __future__ import annotations

__all__ = ["Pawprint"]

import pathlib

import asdf
import astropy.table as apt
import astropy.units as u
import galstreams as gst
import numpy as np
from astropy.coordinates import SkyCoord
from gala.coordinates import GreatCircleICRSFrame
from matplotlib.path import Path as mpl_path

# class densityClass: #TODO: how to represent densities?


class Footprint2D(dict):
def __init__(self, vertex_coordinates, footprint_type, stream_frame=None):
if footprint_type == "sky":
if isinstance(vertex_coordinates, SkyCoord):
vc = vertex_coordinates
else:
vc = SkyCoord(vertex_coordinates)
self.edges = vc
self.vertices = np.array(
[vc.transform_to(stream_frame).phi1, vc.transform_to(stream_frame).phi2]
).T

elif footprint_type == "cartesian":
self.edges = vertex_coordinates
self.vertices = vertex_coordinates

self.stream_frame = stream_frame
self.footprint_type = footprint_type
self.footprint = mpl_path(self.vertices)

@classmethod
def from_vertices(cls, vertex_coordinates, footprint_type):
return cls(vertex_coordinates, footprint_type)

@classmethod
def from_box(cls, min1, max1, min2, max2, footprint_type):
vertices = cls.get_vertices_from_box(min1, max1, min2, max2)
return cls(vertices, footprint_type)

@classmethod
def from_file(cls, fname):
with apt.Table.read(fname) as t:
vertices = t["vertices"]
footprint_type = t["footprint_type"]
return cls(vertices, footprint_type)

def get_vertices_from_box(self, min1, max1, min2, max2):
return [[min1, min2], [min1, max2], [max1, min2], [max1, max2]]

def inside_footprint(self, data):
if isinstance(data, SkyCoord):
if self.stream_frame is None:
print("can't!")
return None
else:
pts = np.array(
[
data.transform_to(self.stream_frame).phi1.value,
data.transform_to(self.stream_frame).phi2.value,
]
).T
return self.footprint.contains_points(pts)
else:
return self.footprint.contains_points(data)

def export(self):
data = {}
data["stream_frame"] = self.stream_frame
data["vertices"] = self.vertices
data["footprint_type"] = self.footprint_type
return data


class Pawprint(dict):
"""Dictionary class to store a "pawprint":
"""Dictionary class to store a "pawprint".
polygons in multiple observational spaces that define the initial selection
used for stream track modeling,
membership calculation / density modeling, and background modeling.
New convention: everything is in phi1 phi2 (don't cross the streams)
"""

def __init__(self, data):
Expand Down
71 changes: 71 additions & 0 deletions src/cats/pawprint/_footprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

__all__ = ["Footprint2D"]

import astropy.table as apt
import numpy as np
from astropy.coordinates import SkyCoord
from matplotlib.path import Path as mpl_path


class Footprint2D(dict):
def __init__(self, vertex_coordinates, footprint_type, stream_frame=None):
if footprint_type == "sky":
if isinstance(vertex_coordinates, SkyCoord):
vc = vertex_coordinates
else:
vc = SkyCoord(vertex_coordinates)
self.edges = vc
self.vertices = np.array(
[vc.transform_to(stream_frame).phi1, vc.transform_to(stream_frame).phi2]
).T

elif footprint_type == "cartesian":
self.edges = vertex_coordinates
self.vertices = vertex_coordinates

self.stream_frame = stream_frame
self.footprint_type = footprint_type
self.footprint = mpl_path(self.vertices)

@classmethod
def from_vertices(cls, vertex_coordinates, footprint_type):
return cls(vertex_coordinates, footprint_type)

@classmethod
def from_box(cls, min1, max1, min2, max2, footprint_type):
vertices = cls.get_vertices_from_box(min1, max1, min2, max2)
return cls(vertices, footprint_type)

@classmethod
def from_file(cls, fname):
with apt.Table.read(fname) as t:
vertices = t["vertices"]
footprint_type = t["footprint_type"]
return cls(vertices, footprint_type)

def get_vertices_from_box(self, min1, max1, min2, max2):
return [[min1, min2], [min1, max2], [max1, min2], [max1, max2]]

def inside_footprint(self, data):
if isinstance(data, SkyCoord):
if self.stream_frame is None:
print("can't!")
return None
else:
pts = np.array(
[
data.transform_to(self.stream_frame).phi1.value,
data.transform_to(self.stream_frame).phi2.value,
]
).T
return self.footprint.contains_points(pts)
else:
return self.footprint.contains_points(data)

def export(self):
data = {}
data["stream_frame"] = self.stream_frame
data["vertices"] = self.vertices
data["footprint_type"] = self.footprint_type
return data
File renamed without changes.

0 comments on commit 7137e86

Please sign in to comment.