diff --git a/lenspyx/utils_hp.py b/lenspyx/utils_hp.py index 26401f6..d751844 100644 --- a/lenspyx/utils_hp.py +++ b/lenspyx/utils_hp.py @@ -94,13 +94,13 @@ def synalm(cl:np.ndarray, lmax:int, mmax:int or None, rlm_dtype=np.float64): return alm -def synalms(cls: dict, lmax:int, mmax:int or None, seed=None): - """Creates a Gaussian field alm from input cl array +def synalms(cls: dict, lmax:int, mmax:int or None, seed=None, rlm_dtype:type = np.float64): + """Creates Gaussian field alms from input cl dictionary - Parameters + Parametersseed ---------- - cls : ndarrays - The power spectra of the maps, assumes 'normal' ordering + cls : dict + The power spectra of the maps (e.g. as coming from CAMB) lmax : int Maximum multipole simulated mmax: int @@ -146,10 +146,10 @@ def synalms(cls: dict, lmax:int, mmax:int or None, seed=None): labels += f * (f in labelsf) ncomp = len(labels) - mat = np.empty((lmax + 1, ncomp, ncomp), dtype=float) + mat = np.empty((lmax + 1, ncomp, ncomp), dtype=rlm_dtype) for i, f in enumerate(labels): for j, g in enumerate(labels[i:]): - mat[:, i + j, i] = cls.get(f + g, cls.get(g + f, np.zeros(lmax + 1, dtype=float)))[:lmax + 1] + mat[:, i + j, i] = cls.get(f + g, cls.get(g + f, np.zeros(lmax + 1, dtype=rlm_dtype)))[:lmax + 1] ts, vs = np.linalg.eigh(mat) assert np.all(ts >= 0.) # Matrix not positive semidefinite for m, t, v in zip(mat, ts, vs): @@ -157,8 +157,8 @@ def synalms(cls: dict, lmax:int, mmax:int or None, seed=None): # Build phases: alm_size = Alm.getsize(lmax, mmax) rng = default_rng(seed) - phases = 1j * rng.standard_normal((ncomp, alm_size), dtype=float) - phases += rng.standard_normal((ncomp, alm_size), dtype=float) + phases = 1j * rng.standard_normal((ncomp, alm_size), dtype=rlm_dtype) + phases += rng.standard_normal((ncomp, alm_size), dtype=rlm_dtype) phases *= np.sqrt(0.5) real_idcs = Alm.getidx(lmax, np.arange(lmax + 1, dtype=int), 0) phases[:, real_idcs] = phases[:, real_idcs].real * np.sqrt(2.) @@ -169,7 +169,13 @@ def synalms(cls: dict, lmax:int, mmax:int or None, seed=None): labels_wgrad = labels_wgrad.replace('b', 'eb') if 'o' in labels and 'p' not in labels: labels_wgrad = labels_wgrad.replace('o', 'po') - alms = np.zeros((len(labels_wgrad), phases[0].size), dtype=complex) + if rlm_dtype == np.float32: + dtype_complex = np.complex64 + elif rlm_dtype == np.float64: + dtype_complex = np.complex128 + else: + assert 0, "please either choose np.float32 (single precission), or np.float64 (double precission) as rlm_dtype" + alms = np.zeros((len(labels_wgrad), phases[0].size), dtype=dtype_complex) # for L in Ls: #L @ L.T is full matrx for i, f in enumerate(labels): idx = labels_wgrad.index(f)