Skip to content

Commit

Permalink
Merge pull request #135 from iancze/ml/emulator-train
Browse files Browse the repository at this point in the history
Modify emulator train loop
  • Loading branch information
iancze authored Feb 27, 2021
2 parents 8787cc6 + dba7fd2 commit b6f379d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
9 changes: 5 additions & 4 deletions Starfish/emulator/emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ def __init__(
variances if variances is not None else 1e4 * np.ones(self.ncomps)
)

unique = [sorted(np.unique(param_set)) for param_set in self.grid_points.T]
self._grid_sep = np.array([np.diff(param).max() for param in unique])

if lengthscales is None:
unique = [sorted(np.unique(param_set)) for param_set in self.grid_points.T]
self._grid_sep = np.array([np.diff(param).max() for param in unique])
lengthscales = np.tile(3 * self._grid_sep, (self.ncomps, 1))

self.lengthscales = lengthscales
Expand Down Expand Up @@ -451,7 +452,7 @@ def determine_chunk_log(self, wavelength: Sequence[float], buffer: float = 50):

def train(self, **opt_kwargs):
"""
Trains the emulator's hyperparameters using gradient descent
Trains the emulator's hyperparameters using gradient descent. This is a light wrapper around `scipy.optimize.minimize`. If you are experiencing problems optimizing the emulator, consider implementing your own training loop, using this function as a template.
Parameters
----------
Expand All @@ -466,7 +467,7 @@ def train(self, **opt_kwargs):
"""
# Define our loss function
def nll(P):
if np.any(P < 0):
if np.any(~np.isfinite(P)):
return np.inf
self.set_param_vector(P)
if np.any(self.lengthscales < 2 * self._grid_sep):
Expand Down
1 change: 1 addition & 0 deletions tests/test_emulator/test_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def test_creation(self, mock_emulator):
def test_creation_from_string(self, mock_hdf5):
emu = Emulator.from_grid(mock_hdf5)
assert emu._trained is False
assert np.allclose(emu._grid_sep, [100, 0.5, 0.5]) # issue 134

def test_call(self, mock_emulator):
mu, cov = mock_emulator([6020, 4.21, -0.01])
Expand Down

0 comments on commit b6f379d

Please sign in to comment.