diff --git a/src/Corr3.cpp b/src/Corr3.cpp index 9f1bfb7a..e7835081 100644 --- a/src/Corr3.cpp +++ b/src/Corr3.cpp @@ -840,6 +840,23 @@ void BaseCorr3::process111( process111Sorted(c1, c2, c3, metric, d1sq, d2sq, d3sq); process111Sorted(c1, c3, c2, metric, d1sq, d3sq, d2sq); } + } else if (O == 2) { + if (BinTypeHelper::sort_d123) { + // If the BinType allows sorting, but we have c3 fixed, then just check d1,d2. + if (d1sq > d3sq) { + xdbg<<"123\n"; + // 123 -> 123 + process111Sorted(c1, c2, c3, metric, d1sq, d2sq, d3sq); + } else { + xdbg<<"321\n"; + // 321 -> 123 + process111Sorted(c3, c2, c1, metric, d3sq, d2sq, d1sq); + } + } else { + // If can't swap 12, do both ways and switch ordered to 4. + process111Sorted(c1, c2, c3, metric, d1sq, d2sq, d3sq); + process111Sorted(c3, c2, c1, metric, d3sq, d2sq, d1sq); + } } else if (O == 3) { if (BinTypeHelper::sort_d123) { // If the BinType allows sorting, but we have c3 fixed, then just check d1,d2. @@ -3180,6 +3197,7 @@ template void ProcessCross12b(BaseCorr3& corr, BaseField& field1, BaseField& field2, int ordered, bool dots) { + Assert(ordered == 0 || ordered == 1); Assert((ValidMC::_M == M)); #ifndef DIRECT_MULTIPOLE if (B == LogMultipole) { @@ -3255,6 +3273,7 @@ template void ProcessCross21b(BaseCorr3& corr, BaseField& field1, BaseField& field2, int ordered, bool dots) { + Assert(ordered == 0 || ordered == 3); Assert((ValidMC::_M == M)); #ifndef DIRECT_MULTIPOLE if (B == LogMultipole) { @@ -3338,6 +3357,7 @@ void ProcessCrossb(BaseCorr3& corr, BaseField& field1, BaseField& field2, BaseField& field3, int ordered, bool dots) { + Assert(ordered >= 0 && ordered <= 4); Assert((ValidMC::_M == M)); #ifndef DIRECT_MULTIPOLE if (B == LogMultipole) { @@ -3352,6 +3372,18 @@ void ProcessCrossb(BaseCorr3& corr, corr.template multipole::_B,ValidMC::_M>( field1, field2, field3, dots, 1); break; + case 2: + corr.template multipole::_B,ValidMC::_M>( + field1, field2, field3, dots, 4); + corr.template multipole::_B,ValidMC::_M>( + field3, field2, field1, dots, 4); + break; + case 3: + corr.template multipole::_B,ValidMC::_M>( + field1, field2, field3, dots, 4); + corr.template multipole::_B,ValidMC::_M>( + field2, field1, field3, dots, 4); + break; case 4: corr.template multipole::_B,ValidMC::_M>( field1, field2, field3, dots, 4); @@ -3368,6 +3400,9 @@ void ProcessCrossb(BaseCorr3& corr, case 1: corr.template process111::_M>(field1, field2, field3, dots); break; + case 2: + corr.template process111::_M>(field1, field2, field3, dots); + break; case 3: corr.template process111::_M>(field1, field2, field3, dots); break; diff --git a/tests/test_kkg.py b/tests/test_kkg.py index 3ca6b9ac..40fdfd42 100644 --- a/tests/test_kkg.py +++ b/tests/test_kkg.py @@ -327,6 +327,20 @@ def test_direct_logruv_cross(): np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + # With these, ordered=False is equivalent to the G vertex being fixed. + kkg.process(cat1, cat2, cat3, ordered=3) + np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3) + np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5) + np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-5) + kgk.process(cat1, cat3, cat2, ordered=2) + np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2) + np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5) + np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5) + gkk.process(cat3, cat1, cat2, ordered=1) + np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1) + np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) + np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + # Error to have cat3, but not cat2 with assert_raises(ValueError): kkg.process(cat1, cat3=cat3) @@ -404,6 +418,19 @@ def test_direct_logruv_cross(): np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + kkg.process(cat1p, cat2p, cat3p, ordered=3) + np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3) + np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5) + np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-5) + kgk.process(cat1p, cat3p, cat2p, ordered=2) + np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2) + np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5) + np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5) + gkk.process(cat3p, cat1p, cat2p, ordered=1) + np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1) + np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) + np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + # patch_method=local kkg.process(cat1p, cat2p, cat3p, patch_method='local') np.testing.assert_array_equal(kkg.ntri, true_ntri_123) @@ -431,6 +458,19 @@ def test_direct_logruv_cross(): np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + kkg.process(cat1p, cat2p, cat3p, ordered=3, patch_method='local') + np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3) + np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5) + np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-5) + kgk.process(cat1p, cat3p, cat2p, ordered=2, patch_method='local') + np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2) + np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5) + np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5) + gkk.process(cat3p, cat1p, cat2p, ordered=1, patch_method='local') + np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1) + np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) + np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + with assert_raises(ValueError): kkg.process(cat1p, cat2p, cat3p, patch_method='nonlocal') with assert_raises(ValueError): @@ -1269,6 +1309,19 @@ def test_direct_logsas_cross(): np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + kkg.process(cat1, cat2, cat3, ordered=3, algo='triangle') + np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3) + np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5) + np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-4) + kgk.process(cat1, cat3, cat2, ordered=2, algo='triangle') + np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2) + np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5) + np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5) + gkk.process(cat3, cat1, cat2, ordered=1, algo='triangle') + np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1) + np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) + np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + # Error to have cat3, but not cat2 with assert_raises(ValueError): kkg.process(cat1, cat3=cat3, algo='triangle') @@ -1308,6 +1361,19 @@ def test_direct_logsas_cross(): np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + kkg.process(cat1p, cat2p, cat3p, ordered=3, algo='triangle') + np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3) + np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5) + np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-5) + kgk.process(cat1p, cat3p, cat2p, ordered=2, algo='triangle') + np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2) + np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5) + np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5) + gkk.process(cat3p, cat1p, cat2p, ordered=1, algo='triangle') + np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1) + np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) + np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + kkg.process(cat1p, cat2p, cat3p, patch_method='local', algo='triangle') np.testing.assert_array_equal(kkg.ntri, true_ntri_123) np.testing.assert_allclose(kkg.weight, true_weight_123, rtol=1.e-5) @@ -1334,6 +1400,19 @@ def test_direct_logsas_cross(): np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + kkg.process(cat1p, cat2p, cat3p, ordered=3, patch_method='local', algo='triangle') + np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3) + np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5) + np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-5) + kgk.process(cat1p, cat3p, cat2p, ordered=2, patch_method='local', algo='triangle') + np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2) + np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5) + np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5) + gkk.process(cat3p, cat1p, cat2p, ordered=1, patch_method='local', algo='triangle') + np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1) + np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5) + np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5) + with assert_raises(ValueError): kkg.process(cat1p, cat2p, cat3p, patch_method='nonlocal', algo='triangle') with assert_raises(ValueError): diff --git a/tests/test_kkk.py b/tests/test_kkk.py index d94e013b..76dae0d5 100644 --- a/tests/test_kkk.py +++ b/tests/test_kkk.py @@ -321,6 +321,13 @@ def test_direct_logruv(): with assert_raises(ValueError): kkk.process(cat, algo='multipole') + # Error to use integer order with 1 or 2 catalogs + with assert_raises(ValueError): + kkk.process(cat, ordered=1) + with assert_raises(ValueError): + kkk.process(cat, cat, ordered=1) + with assert_raises(ValueError): + kkk.process(cat, cat, ordered=2) @timer def test_direct_logruv_spherical(): @@ -644,6 +651,30 @@ def test_direct_logruv_cross(): np.testing.assert_allclose(kkk.weight, true_weight_sum, rtol=1.e-5) np.testing.assert_allclose(kkk.zeta, true_zeta_sum, rtol=1.e-5) + def combine2(z1,w1,z2,w2): + w12 = w1 + w2 + w12[w12==0] = 1 + return (z1*w1 + z2*w2) / w12 + + # Test ordered = 1,2,3 options + kkk.process(cat1, cat2, cat3, ordered=1) + np.testing.assert_array_equal(kkk.ntri, true_ntri_123 + true_ntri_132) + np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_132, rtol=1.e-5) + zeta_1 = combine2(true_zeta_123, true_weight_123, true_zeta_132, true_weight_132) + np.testing.assert_allclose(kkk.zeta, zeta_1, rtol=1.e-5) + + kkk.process(cat1, cat2, cat3, ordered=2) + np.testing.assert_array_equal(kkk.ntri, true_ntri_123 + true_ntri_321) + np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_321, rtol=1.e-5) + zeta_2 = combine2(true_zeta_123, true_weight_123, true_zeta_321, true_weight_321) + np.testing.assert_allclose(kkk.zeta, zeta_2, rtol=1.e-5, atol=1.e-6) + + kkk.process(cat1, cat2, cat3, ordered=3) + np.testing.assert_array_equal(kkk.ntri, true_ntri_123 + true_ntri_213) + np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_213, rtol=1.e-5) + zeta_3 = combine2(true_zeta_123, true_weight_123, true_zeta_213, true_weight_213) + np.testing.assert_allclose(kkk.zeta, zeta_3, rtol=1.e-5) + # Error to have cat3, but not cat2 with assert_raises(ValueError): kkk.process(cat1, cat3=cat3) @@ -2787,6 +2818,19 @@ def test_direct_logmultipole_cross(): np.testing.assert_allclose(kkk.weight, true_weight_sum, rtol=1.e-5) np.testing.assert_allclose(kkk.zeta, true_zeta_sum, rtol=1.e-4) + # Test ordered = 1,2,3 options + kkk.process(cat1, cat2, cat3, ordered=1) + np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_132, rtol=1.e-5) + np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_132, rtol=1.e-5, atol=1.e-6) + + kkk.process(cat1, cat2, cat3, ordered=2) + np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_321, rtol=1.e-5) + np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_321, rtol=1.e-5, atol=1.e-6) + + kkk.process(cat1, cat2, cat3, ordered=3) + np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_213, rtol=1.e-5) + np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_213, rtol=1.e-5, atol=1.e-6) + # Split into patches to test the list-based version of the code. # First with just one catalog with patches cat1 = treecorr.Catalog(x=x1, y=y1, w=w1, k=k1, npatch=8, rng=rng) @@ -2839,6 +2883,19 @@ def test_direct_logmultipole_cross(): np.testing.assert_allclose(kkk.weight, true_weight_sum, rtol=1.e-5) np.testing.assert_allclose(kkk.zeta, true_zeta_sum, rtol=1.e-4) + # Test ordered = 1,2,3 options + kkk.process(cat1, cat2, cat3, ordered=1) + np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_132, rtol=1.e-5) + np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_132, rtol=1.e-5, atol=1.e-6) + + kkk.process(cat1, cat2, cat3, ordered=2) + np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_321, rtol=1.e-5) + np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_321, rtol=1.e-5, atol=1.e-6) + + kkk.process(cat1, cat2, cat3, ordered=3) + np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_213, rtol=1.e-5) + np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_213, rtol=1.e-5, atol=1.e-6) + # No tests of accuracy yet, but make sure patch-based covariance works. cov = kkk.estimate_cov('sample') cov = kkk.estimate_cov('jackknife') diff --git a/treecorr/corr3base.py b/treecorr/corr3base.py index cf54dcac..5877fd43 100644 --- a/treecorr/corr3base.py +++ b/treecorr/corr3base.py @@ -1707,6 +1707,9 @@ def is_my_job(my_indices, i, j, k, n1, n2, n3): temp = self.copy() temp.results = {} + # Convert bool into corresponding int + ordered = 0 if ordered is False else 4 if ordered is True else ordered + if local: for ii,c1 in enumerate(cat1): i = c1._single_patch if c1._single_patch is not None else ii @@ -1715,13 +1718,13 @@ def is_my_job(my_indices, i, j, k, n1, n2, n3): c3e = self._make_expanded_patch(c1, cat3, metric, low_mem) self.logger.info('Process patch %d with surrounding local patches',i) self._single_process123(c1, c2e, c3e, (i,i,i), metric, - 1 if not ordered else ordered, + 1 if ordered in (0,1) else 4, num_threads, temp, True) if low_mem: c1.unload() - if not ordered: + if ordered in (0,3): # local method doesn't do unordered properly as is. - # It can only handle ordered=1 or 3. + # It can only handle ordered=1 or 4. # So in this case, we need to repeat with c2 and c3 in the first spot. for ii,c2 in enumerate(cat2): i = c2._single_patch if c2._single_patch is not None else ii @@ -1730,10 +1733,12 @@ def is_my_job(my_indices, i, j, k, n1, n2, n3): c1e = self._make_expanded_patch(c2, cat1, metric, low_mem) c3e = self._make_expanded_patch(c2, cat3, metric, low_mem) self.logger.info('Process patch %d from cat2 with surrounding local patches',i) - self._single_process123(c2, c1e, c3e, (i,i,i), metric, 1, + self._single_process123(c2, c1e, c3e, (i,i,i), metric, + 1 if ordered == 0 else 4, num_threads, temp, True) if low_mem: c2.unload() + if ordered in (0,2): for ii,c3 in enumerate(cat3): i = c3._single_patch if c3._single_patch is not None else ii if (is_my_job(my_indices, i, i, i, n1, n2, n3) @@ -1741,7 +1746,8 @@ def is_my_job(my_indices, i, j, k, n1, n2, n3): c1e = self._make_expanded_patch(c3, cat1, metric, low_mem) c2e = self._make_expanded_patch(c3, cat2, metric, low_mem) self.logger.info('Process patch %d from cat3 with surrounding local patches',i) - self._single_process123(c3, c2e, c1e, (i,i,i), metric, 1, + self._single_process123(c3, c2e, c1e, (i,i,i), metric, + 1 if ordered == 0 else 4, num_threads, temp, True) if low_mem: c3.unload() @@ -2022,6 +2028,14 @@ def process(self, cat1, cat2=None, cat3=None, *, metric=None, ordered=True, num_ All arguments may be lists, in which case all items in the list are used for that element of the correlation. + .. note:: + + In addition to ordered = True or False, you may also set ordered to 1, 2 or 3 + which means that the catalog in that position is fixed, but the other two + vertices are unordered. E.g. if ordered=3, then P3 will always come from cat3, + but P1 and P2 will each come from one of cat1 or cat2 in either order. + This option is only valid when all three catalogs (cat1, cat2, cat3) are given. + Parameters: cat1 (Catalog): A catalog or list of catalogs for the first field. cat2 (Catalog): A catalog or list of catalogs for the second field. @@ -2108,8 +2122,12 @@ def process(self, cat1, cat2=None, cat3=None, *, metric=None, ordered=True, num_ raise ValueError("{} cannot use one catalog version of process".format(self._letters)) if cat3 is not None: raise ValueError("For two catalog case, use cat1,cat2, not cat1,cat3") + if not (ordered is True or ordered is False): + raise ValueError("The integer options for ordered are only valid with 3 catalogs") self._process_all_auto(cat1, metric, num_threads, comm, low_mem, local) elif cat3 is None: + if not (ordered is True or ordered is False): + raise ValueError("The integer options for ordered are only valid with 3 catalogs") if self._letter2 == self._letter3: self._process_all_cross12(cat1, cat2, metric, ordered, num_threads, comm, low_mem, local)