Skip to content

Commit

Permalink
Merge branch 'jax'
Browse files Browse the repository at this point in the history
  • Loading branch information
christianhbye committed Jun 14, 2024
2 parents 87de38b + 3d9659d commit 4258545
Show file tree
Hide file tree
Showing 36 changed files with 1,660 additions and 459 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install importlib-metadata==4.13.0
python -m pip install --upgrade "jax[cpu]"
python -m pip install .[dev]
- name: Lint with flake8
run: |
Expand All @@ -30,9 +31,9 @@ jobs:
#mypy ./croissant/
- name: Test with pytest
run: |
pytest --cov=croissant --cov-report=xml tests/
pytest --cov=croissant --cov-report=xml croissant/tests croissant/core/tests croissant/jax/tests
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Overall, this makes CROISSANT a very fast visibility simulator. CROISSANT can th
## Installation
For the latest release, do `pip install croissant-sim` (see https://pypi.org/project/croissant-sim). Git clone this repository for the newest changes (this is under activate development, do so at your own risk!).

To access the JAX features, JAX must also be installed. See the [installation guide](https://github.com/google/jax#installation).

## Demo
Jupyter Notebook: https://nbviewer.org/github/christianhbye/croissant/blob/main/notebooks/example_sim.ipynb

Expand Down
14 changes: 7 additions & 7 deletions croissant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
__author__ = "Christian Hellum Bye"
__version__ = "3.0.0"
__version__ = "4.0.0"

from . import constants, dpss, sphtransform
from .healpix import Alm, HealpixMap
from .beam import Beam
from .rotations import Rotator
from .simulator import Simulator
from .sky import Sky
from . import constants
from . import core
from . import dpss
from . import jax
from . import utils
from .core import * # noqa F403
10 changes: 5 additions & 5 deletions croissant/constants.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np

# nside's for which healpix has computed pixel weights:
PIX_WEIGHTS_NSIDE = (32, 64, 128, 256, 512, 1024, 2048, 4096)
from math import pi, sqrt

# sidereal days in seconds
# https://nssdc.gsfc.nasa.gov/planetary/factsheet/earthfact.html
sidereal_day_earth = 23.9345 * 3600
# https://nssdc.gsfc.nasa.gov/planetary/factsheet/moonfact.html
sidereal_day_moon = 655.720 * 3600

Y00 = 1 / np.sqrt(4 * np.pi) # the 0,0 spherical harmonic function
sidereal_day = {"earth": sidereal_day_earth, "moon": sidereal_day_moon}

Y00 = 1 / sqrt(4 * pi) # the 0,0 spherical harmonic function
PIX_WEIGHTS_NSIDE = (32, 64, 128, 512, 1024, 2048, 4096)
6 changes: 6 additions & 0 deletions croissant/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import simulator, sphtransform
from .healpix import Alm, HealpixMap
from .beam import Beam
from .rotations import Rotator
from .simulator import Simulator
from .sky import Sky
2 changes: 1 addition & 1 deletion croissant/beam.py → croissant/core/beam.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from healpy import npix2nside, pix2ang
import numpy as np

from .constants import Y00
from ..constants import Y00
from .healpix import Alm
from .sphtransform import map2alm

Expand Down
35 changes: 3 additions & 32 deletions croissant/healpix.py → croissant/core/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,12 @@
from scipy.interpolate import RectSphereBivariateSpline
import warnings

from . import constants
from .. import constants
from ..utils import coord_rep
from .rotations import Rotator
from .sphtransform import alm2map, map2alm


def coord_rep(coord):
"""
Shorthand notation for coordinate systems.
Parameters
----------
coord : str
The name of the coordinate system.
Returns
-------
rep : str
The one-letter shorthand notation for the coordinate system.
"""
coord = coord.upper()
if coord[0] == "E" and coord[1] == "Q":
rep = "C"
else:
rep = coord[0]
return rep


def healpix2lonlat(nside, pix=None):
"""
Compute the longtitudes and latitudes of the pixel centers of a healpix
Expand Down Expand Up @@ -567,14 +545,7 @@ def rot_alm_z(self, phi=None, times=None, world="moon"):
"""
if times is not None:
if world.lower() == "moon":
sidereal_day = constants.sidereal_day_moon
elif world.lower() == "earth":
sidereal_day = constants.sidereal_day_earth
else:
raise ValueError(
f"World must be 'moon' or 'earth', not {world}."
)
sidereal_day = constants.sidereal_day[world]
phi = 2 * np.pi * times / sidereal_day
return self.rot_alm_z(phi=phi, times=None)

Expand Down
64 changes: 4 additions & 60 deletions croissant/rotations.py → croissant/core/rotations.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,12 @@
from astropy.coordinates import AltAz, EarthLocation
import healpy as hp
from lunarsky import LunarTopo, MoonLocation, SkyCoord
from lunarsky import LunarTopo, MoonLocation
import numpy as np

from ..utils import get_rot_mat, rotmat_to_euler
from .sphtransform import map2alm, alm2map


def get_rot_mat(from_frame, to_frame):
"""
Get the rotation matrix that transforms from one frame to another.
Parameters
----------
from_frame : str or astropy frame
The coordinate frame to transform from.
to_frame : str or astropy frame
The coordinate frame to transform to.
Returns
-------
rmat : np.ndarray
The rotation matrix.
"""
# cannot instantiate a SkyCoord with a gaalctic frame from cartesian
from_name = from_frame.name if hasattr(from_frame, "name") else from_frame
if from_name.lower() == "galactic":
from_frame = to_frame
to_frame = "galactic"
return_inv = True
else:
return_inv = False
x, y, z = np.eye(3) # unit vectors
sc = SkyCoord(
x=x, y=y, z=z, frame=from_frame, representation_type="cartesian"
)
rmat = sc.transform_to(to_frame).cartesian.xyz.value
if return_inv:
rmat = rmat.T
return rmat


def rotmat_to_euler(mat):
"""
Convert a rotation matrix to Euler angles in the ZYX convention. This is
sometimes referred to as Tait-Bryan angles X1-Y2-Z3.
Parameters
----------
mat : np.ndarray
The rotation matrix.
Returns
--------
eul : tup
The Euler angles.
"""
beta = np.arcsin(mat[0, 2])
alpha = np.arctan2(mat[1, 2] / np.cos(beta), mat[2, 2] / np.cos(beta))
gamma = np.arctan2(mat[0, 1] / np.cos(beta), mat[0, 0] / np.cos(beta))
eul = (gamma, beta, alpha)
return eul


class Rotator(hp.Rotator):
def __init__(
self,
Expand All @@ -90,7 +33,8 @@ def __init__(
rot : sequence of floats
Euler angles in degrees (or radians if deg=False) describing the
rotation. The order of the angles depends on the value of
eulertype.
``eulertype''. When ``eulertype'' is "ZYX", the angles are
(yaw, -pitch, roll).
coord : sequence of strings
Coordinate systems to rotate between. Supported values are
"G" (galactic), "C" (equatorial), "E" (ecliptic), "M" (MCMF),
Expand Down
60 changes: 4 additions & 56 deletions croissant/simulator.py → croissant/core/simulator.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,10 @@
from astropy import units
from astropy.coordinates import EarthLocation
from copy import deepcopy
from lunarsky import MoonLocation, Time
import matplotlib.pyplot as plt
import numpy as np
import warnings

from . import dpss


def time_array(t_start=None, t_end=None, N_times=None, delta_t=None):
"""
Generate an array of evenly sampled times to run the simulation at.
Parameters
----------
t_start : str or astropy.time.Time
The start time of the simulation.
t_end : str or astropy.time.Time
The end time of the simulation.
N_times : int
The number of times to run the simulation at.
delta_t : float or astropy.units.Quantity
The time step between each time in the simulation.
Returns
-------
times : astropy.time.Time or astropy.units.Quantity
The evenly sampled times to run the simulation at.
"""

if t_start is not None:
t_start = Time(t_start, scale="utc")

try:
dt = np.arange(N_times) * delta_t
except TypeError:
t_end = Time(t_end, scale="utc")
total_time = (t_end - t_start).sec
if N_times is None:
try:
delta_t = delta_t.to_value("s")
except AttributeError:
warnings.warn(
"delta_t is not an astropy.units.Quantity. Assuming "
"units of seconds.",
UserWarning,
)
dt = np.arange(0, total_time + delta_t, delta_t)
else:
dt = np.linspace(0, total_time, N_times)
dt = dt * units.s

if t_start is None:
times = dt
else:
times = t_start + dt

return times
from .. import dpss


class Simulator:
Expand All @@ -73,7 +19,9 @@ def __init__(
times=None,
):
"""
Simulator class. Prepares and runs simulations.
BaseSimulator class. Prepares simulations. End users should use the
subclasses in core/simulator.py and crojax/simulator.py to
instantiate this class and run simulations.
"""
self.world = world.lower()
# set up frequencies to run the simulation at
Expand Down
2 changes: 1 addition & 1 deletion croissant/sky.py → croissant/core/sky.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from pygdsm import GlobalSkyModel2016 as GSM16
from pygdsm import GlobalSkyModel16 as GSM16
from .healpix import Alm
from .sphtransform import map2alm

Expand Down
77 changes: 77 additions & 0 deletions croissant/core/sphtransform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
import healpy as hp
from ..constants import PIX_WEIGHTS_NSIDE


def alm2map(alm, nside, lmax=None):
"""
Compute the healpix map from the spherical harmonics coefficients.
Parameters
----------
alm : array-like
The spherical harmonics coefficients in the healpy convention. Shape
([nfreq], hp.Alm.getsize(lmax)).
nside : int
The nside of the output map(s).
lmax : int
The lmax of the spherical harmonics transform. Defaults to 3*nside-1.
Returns
-------
map : np.ndarray
The healpix map. Shape ([nfreq], hp.nside2npix(nside)).
"""
alm = np.array(alm, copy=True)
if alm.ndim == 1:
return hp.alm2map(alm, nside, lmax=lmax)
else:
npix = hp.nside2npix(nside)
nfreqs = alm.shape[0]
hp_map = np.empty((nfreqs, npix))
for i in range(nfreqs):
map_i = hp.alm2map(alm[i], nside, lmax=lmax)
hp_map[i] = map_i
return hp_map


def map2alm(data, lmax=None):
"""
Compute the spherical harmonics coefficents of a healpix map.
Parameters
----------
data : array-like
The healpix map(s). Shape ([nfreq], hp.nside2npix(nside)).
lmax : int
The lmax of the spherical harmonics transform. Defaults to 3*nside-1.
Returns
-------
alm : np.ndarray
The spherical harmonics coefficients in the healpy convention. Shape
([nfreq], hp.Alm.getsize(lmax)).
"""
data = np.array(data, copy=True)
npix = data.shape[-1]
nside = hp.npix2nside(npix)
use_pix_weights = nside in PIX_WEIGHTS_NSIDE
use_ring_weights = not use_pix_weights
kwargs = {
"lmax": lmax,
"pol": False,
"use_weights": use_ring_weights,
"use_pixel_weights": use_pix_weights,
}
if data.ndim == 1:
alm = hp.map2alm(data, **kwargs)
else:
# compute the alms of the first map to determine the size of the array
alm0 = hp.map2alm(data[0], **kwargs)
alm = np.empty((len(data), alm0.size), dtype=alm0.dtype)
alm[0] = alm0
for i in range(1, len(data)):
alm[i] = hp.map2alm(data[i], **kwargs)
return alm
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 4258545

Please sign in to comment.