diff --git a/croissant/jax/tests/test_alm.py b/croissant/jax/tests/test_alm.py index 1e8dd85..3ca3077 100644 --- a/croissant/jax/tests/test_alm.py +++ b/croissant/jax/tests/test_alm.py @@ -14,7 +14,7 @@ def test_total_power(lmax): alm = jnp.zeros(shape, dtype=jnp.complex128) a00_idx = crojax.alm.getidx(lmax, 0, 0) alm = alm.at[a00_idx].set(1 / Y00) - power = crojax.alm.total_power(alm) + power = crojax.alm.total_power(alm, lmax) assert jnp.isclose(power, 4 * jnp.pi) # m(theta) = cos(theta)**2 @@ -22,7 +22,7 @@ def test_total_power(lmax): alm = alm.at[a00_idx].set(1 / (3 * Y00)) a20_idx = crojax.alm.getidx(lmax, 2, 0) alm = alm.at[a20_idx].set(4 * jnp.sqrt(jnp.pi / 5) * 1 / 3) - power = crojax.alm.total_power(alm) + power = crojax.alm.total_power(alm, lmax) expected_power = 4 * jnp.pi / 3 assert jnp.isclose(power, expected_power)