Skip to content

Commit

Permalink
Merge pull request #139 from Starfish-develop/ml/scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
mileslucas authored Jun 28, 2021
2 parents 01852a1 + a117921 commit 77e6d13
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 177 deletions.
2 changes: 1 addition & 1 deletion Starfish/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.2"
__version__ = "0.4.0"

from .spectrum import Spectrum

Expand Down
37 changes: 34 additions & 3 deletions Starfish/emulator/emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import h5py
import numpy as np
from nptyping import NDArray
from scipy.interpolate import LinearNDInterpolator
from scipy.linalg import cho_factor, cho_solve
from scipy.optimize import minimize
from sklearn.decomposition import PCA
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
w_hat: NDArray[float],
flux_mean: NDArray[float],
flux_std: NDArray[float],
factors: NDArray[float],
lambda_xi: float = 1.0,
variances: Optional[NDArray[float]] = None,
lengthscales: Optional[NDArray[float]] = None,
Expand All @@ -88,6 +90,10 @@ def __init__(
self.eigenspectra = eigenspectra
self.flux_mean = flux_mean
self.flux_std = flux_std
self.factors = factors
self.factor_interpolator = LinearNDInterpolator(
grid_points, factors, rescale=True
)

self.dv = calculate_dv(wavelength)
self.ncomps = eigenspectra.shape[0]
Expand Down Expand Up @@ -198,6 +204,7 @@ def load(cls, filename: Union[str, os.PathLike]):
flux_mean = base["flux_mean"][:]
flux_std = base["flux_std"][:]
w_hat = base["w_hat"][:]
factors = base["factors"][:]
lambda_xi = base["hyperparameters"]["lambda_xi"][()]
variances = base["hyperparameters"]["variances"][:]
lengthscales = base["hyperparameters"]["lengthscales"][:]
Expand All @@ -220,6 +227,7 @@ def load(cls, filename: Union[str, os.PathLike]):
variances=variances,
lengthscales=lengthscales,
name=name,
factors=factors,
)
emulator._trained = trained
return emulator
Expand Down Expand Up @@ -251,6 +259,7 @@ def save(self, filename: Union[str, os.PathLike]):
base.attrs["trained"] = self._trained
if self.name is not None:
base.attrs["name"] = self.name
base.create_dataset("factors", data=self.factors, compression=9)
hp_group = base.create_group("hyperparameters")
hp_group.create_dataset("lambda_xi", data=self.lambda_xi)
hp_group.create_dataset("variances", data=self.variances, compression=9)
Expand Down Expand Up @@ -283,7 +292,8 @@ def from_grid(cls, grid, **pca_kwargs):

fluxes = np.array(list(grid.fluxes))
# Normalize to an average of 1 to remove uninteresting correlation
fluxes /= fluxes.mean(1, keepdims=True)
norm_factors = fluxes.mean(1)
fluxes /= norm_factors[:, np.newaxis]
# Center and whiten
flux_mean = fluxes.mean(0)
fluxes -= flux_mean
Expand Down Expand Up @@ -313,6 +323,7 @@ def from_grid(cls, grid, **pca_kwargs):
w_hat=w_hat,
flux_mean=flux_mean,
flux_std=flux_std,
factors=norm_factors,
)
return emulator

Expand Down Expand Up @@ -391,7 +402,7 @@ def bulk_fluxes(self) -> NDArray[float]:
return np.vstack([self.eigenspectra, self.flux_mean, self.flux_std])

def load_flux(
self, params: Union[Sequence[float], NDArray[float]]
self, params: Union[Sequence[float], NDArray[float]], norm=False
) -> NDArray[float]:
"""
Interpolate a model given any parameters within the grid's parameter range
Expand All @@ -410,7 +421,27 @@ def load_flux(
mu, cov = self(params, reinterpret_batch=False)
weights = np.random.multivariate_normal(mu, cov).reshape(-1, self.ncomps)
X = self.eigenspectra * self.flux_std
return weights @ X + self.flux_mean
flux = weights @ X + self.flux_mean
if norm:
flux *= self.norm_factor(params)[:, np.newaxis]
return np.squeeze(flux)

def norm_factor(self, params: Union[Sequence[float], NDArray[float]]) -> float:
"""
Return the scaling factor for the absolute flux units in flux-normalized spectra
Parameters
----------
params : array_like
The parameters to interpolate at
Returns
-------
factor: float
The multiplicative factor to normalize a spectrum to the model's absolute flux units
"""
_params = np.asarray(params)
return self.factor_interpolator(_params)

def determine_chunk_log(self, wavelength: Sequence[float], buffer: float = 50):
"""
Expand Down
43 changes: 30 additions & 13 deletions Starfish/models/spectrum_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class SpectrumModel:
max_deque_len : int, optional
The maximum number of residuals to retain in a deque of residuals. Default is
100
norm : bool, optional
If true, will rescale the model flux to the appropriate flux normalization
according to the original spectral library. Default is `False`.
name : str, optional
A name for the model. Default is 'SpectrumModel'
Expand Down Expand Up @@ -126,6 +129,7 @@ def __init__(
data: Union[str, Spectrum],
grid_params: Sequence[float],
max_deque_len: int = 100,
norm=False,
name: str = "SpectrumModel",
**params,
):
Expand Down Expand Up @@ -154,13 +158,15 @@ def __init__(
self.residuals = deque(maxlen=max_deque_len)

# manually handle cheb coeffs to offset index by 1
chebs = params.pop("cheb", [])
cheb_idxs = [str(i) for i in range(1, len(chebs) + 1)]
params["cheb"] = dict(zip(cheb_idxs, chebs))
if "cheb" in params:
chebs = params.pop("cheb")
cheb_idxs = [str(i) for i in range(1, len(chebs) + 1)]
params["cheb"] = dict(zip(cheb_idxs, chebs))
# load rest of params into FlatterDict
self.params = FlatterDict(params)
self.frozen = []
self.name = name
self.norm = norm

# Unpack the grid parameters
self.n_grid_params = len(grid_params)
Expand All @@ -170,6 +176,7 @@ def __init__(
self._lnprob = None
self._glob_cov = None
self._loc_cov = None
self._log_scale = params.get("log_scale", None)

self.log = logging.getLogger(self.__class__.__name__)

Expand Down Expand Up @@ -296,28 +303,35 @@ def __call__(self):
coeffs = [1, *self.cheb]
fluxes = chebyshev_correct(self.data.wave, fluxes, coeffs)

# Only rescale flux_mean and flux_std
if "log_scale" in self.params:
scale = np.exp(self.params["log_scale"])
fluxes[-2:] = rescale(fluxes[-2:], scale)

weights, weights_cov = self.emulator(self.grid_params)

L, flag = cho_factor(weights_cov, overwrite_a=True)

# Decompose the bulk_fluxes (see emulator/emulator.py for the ordering)
*eigenspectra, flux_mean, flux_std = fluxes

# Complete the reconstruction
X = eigenspectra * flux_std
flux = weights @ X + flux_mean

# optionally scale using absolute flux calibration
if self.norm:
norm = self.emulator.norm_factor(self.grid_params)
else:
norm = 1

# Renorm to data flux if no "log_scale" provided
if "log_scale" not in self.params:
factor = _get_renorm_factor(self.data.wave, flux, self.data.flux)
flux = rescale(flux, factor)
X = rescale(X, factor)
scale = _get_renorm_factor(self.data.wave, flux * norm, self.data.flux)
self._log_scale = np.log(scale)
scale *= norm
self.log.debug(f"fit scale factor using integrated flux ratio: {scale}")
else:
self._log_scale = self.params["log_scale"]
scale = np.exp(self.params["log_scale"]) * norm

flux = rescale(flux, scale)
X = rescale(X, scale)

L, flag = cho_factor(weights_cov, overwrite_a=True)
cov = X.T @ cho_solve((L, flag), X)

# Trivial covariance
Expand Down Expand Up @@ -797,6 +811,9 @@ def __repr__(self):
output += f" cheb: {list(value.values())}\n"
else:
output += f" {key}: {value}\n"
if "log_scale" not in self.params:
self()
output += f" log_scale: {self._log_scale} (fit)\n"
if len(self.frozen) > 0:
output += "\nFrozen Parameters\n"
for key in self.frozen:
Expand Down
4 changes: 2 additions & 2 deletions docs/conversion.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
###################
Conversion to 0.3.0
Conversion from v0.2
###################

There have been some significant changes to *Starfish* in the upgrades to version ``0.3.0``. Below are some of the main changes, and we also recommend viewing some of the :doc:`examples/index` to get a hang for the new workflow.
There have been some significant changes to *Starfish* in the upgrades to version ``0.3.0`` and later. Below are some of the main changes, and we also recommend viewing some of the :doc:`examples/index` to get a hang for the new workflow.

.. warning::
The current, updated code base does not have the framework for fitting multi-order Echelle spectra. We are working diligently to update the original functionality to match the updated API. For now, you will have to revert to Starfish ``0.2.0``.
Expand Down
281 changes: 125 additions & 156 deletions examples/single.ipynb

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions tests/test_emulator/test_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ def test_load_many_fluxes(self, mock_emulator):
assert len(flux) == len(params)
assert np.all(np.isfinite(flux))

def test_load_flux_norm(self, mock_hdf5_interface, mock_emulator):
params = mock_hdf5_interface.grid_points[0]
grid_flux = mock_hdf5_interface.load_flux(params)
np.random.seed(123)
emu_flux = mock_emulator.load_flux(params, norm=True)
assert np.allclose(emu_flux, grid_flux, rtol=1e-2)

factor = mock_emulator.norm_factor(params)
np.random.seed(123)
raw_flux = mock_emulator.load_flux(params)
assert np.allclose(factor * raw_flux, emu_flux)

def test_load_many_fluxes_norm(self, mock_hdf5_interface, mock_emulator):
params = mock_hdf5_interface.grid_points[:2]
np.random.seed(123)
emu_flux = mock_emulator.load_flux(params, norm=True)
factor = mock_emulator.norm_factor(params)
np.random.seed(123)
raw_flux = mock_emulator.load_flux(params)
assert np.allclose(factor[:, np.newaxis] * raw_flux, emu_flux)

def test_warns_before_trained(self, mock_emulator):
with pytest.warns(UserWarning):
mock_emulator([6000, 4.2, 0.0])
Expand Down
45 changes: 43 additions & 2 deletions tests/test_models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def test_create_from_strings(self, mock_spectrum, mock_trained_emulator, tmpdir)
mock_spectrum.save(tmp_data)

model = SpectrumModel(tmp_emu, grid_params=[6000, 4.0, 0.0], data=tmp_data)

assert mock_trained_emulator.hyperparams == model.emulator.hyperparams
for key in mock_trained_emulator.hyperparams.keys():
assert np.isclose(
mock_trained_emulator.hyperparams[key], model.emulator.hyperparams[key]
)
assert model.data_name == mock_spectrum.name

def test_cheb_coeffs_index(self, mock_model):
Expand Down Expand Up @@ -374,3 +376,42 @@ def test_thaw_bad_param(self, mock_model):
fr = mock_model.frozen
mock_model.thaw("pinguino")
assert all([old == new for old, new in zip(fr, mock_model.frozen)])

def test_normalize(self, mock_model):
F1, _ = mock_model()
mock_model.norm = True
F2, _ = mock_model()
factor = mock_model.emulator.norm_factor(mock_model.grid_params)
assert np.allclose(F1 * factor, F2)

def test_str_logscale_cheat(self, mock_model):
mock_model.freeze("logg")
del mock_model["log_scale"]
expected = textwrap.dedent(
f"""
SpectrumModel
-------------
Data: {mock_model.data_name}
Emulator: {mock_model.emulator.name}
Log Likelihood: {mock_model.log_likelihood()}
Parameters
vz: 0
Av: 0
vsini: 30
global_cov:
log_amp: 1
log_ls: 1
local_cov:
0: mu: 10000.0, log_amp: 2, log_sigma: 2
1: mu: 13000.0, log_amp: 1.5, log_sigma: 2
cheb: [0.1, -0.2]
T: 6000
Z: 0
log_scale: {mock_model._log_scale} (fit)
Frozen Parameters
logg: 4.0
"""
).strip()
assert str(mock_model) == expected

0 comments on commit 77e6d13

Please sign in to comment.