Skip to content

Commit

Permalink
pass axis keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
weaverba137 committed Aug 2, 2024
1 parent abc1372 commit d45c7a8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
2 changes: 1 addition & 1 deletion speclite/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1769,7 +1769,7 @@ def pad_spectrum(self, spectrum, wavelength, axis=-1, method='median'):
padded_wavelength = np.asanyarray(wavelength)
for response in sorted_responses:
padded_spectrum, padded_wavelength = response.pad_spectrum(
padded_spectrum, padded_wavelength, method=method)
padded_spectrum, padded_wavelength, axis=axis, method=method)
return padded_spectrum, padded_wavelength


Expand Down
15 changes: 14 additions & 1 deletion speclite/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_get_ab_maggies_modify_wavelength():
waves = np.arange(3000, 11000, 0.11)
waves0 = waves.copy()
flux = waves * 1.
filters[0].get_ab_magnitude(flux, waves)
m = filters[0].get_ab_magnitude(flux, waves)
assert (waves == waves0).all()
# print(np.nonzero(waves != waves0), waves[2249], waves0[2249], waves[40932], waves0[40932])

Expand Down Expand Up @@ -421,6 +421,19 @@ def test_response_pad_shape():
assert list(pflux.shape) == expected_shape


def test_response_pad_regression():
"""This is a regression test for
https://github.com/desihub/speclite/issues/25
"""
lam0 = np.arange(3800., 5500., 1.) * u.AA
flam0 = np.tile(np.ones_like(lam0.value)[..., None, None], (1, 20, 20)) * u.Unit('1e-17 erg / (s cm2 AA)')
assert flam0.shape == (1700, 20, 20)
sdss = load_filters('sdss2010-*')
flam, lam = sdss.pad_spectrum(spectrum=flam0, wavelength=lam0, method='zero', axis=0)
# print(flam0.shape, lam0.shape, flam.shape, lam.shape)
assert flam.shape == (lam.shape[0], 20, 20)


def test_sequence_pad():
filters = load_filters('sdss2010-r', 'sdss2010-g')
wave = np.linspace(5000., 10000., 100)
Expand Down

0 comments on commit d45c7a8

Please sign in to comment.