Skip to content

Commit

Permalink
Add integer options for ordered to fix only 1 catalog vertex
Browse files Browse the repository at this point in the history
  • Loading branch information
rmjarvis committed Dec 26, 2024
1 parent 594e57d commit c4b7d28
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 5 deletions.
35 changes: 35 additions & 0 deletions src/Corr3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,23 @@ void BaseCorr3::process111(
process111Sorted<B,4>(c1, c2, c3, metric, d1sq, d2sq, d3sq);
process111Sorted<B,4>(c1, c3, c2, metric, d1sq, d3sq, d2sq);
}
} else if (O == 2) {
if (BinTypeHelper<B>::sort_d123) {
// If the BinType allows sorting, but we have c3 fixed, then just check d1,d2.
if (d1sq > d3sq) {
xdbg<<"123\n";
// 123 -> 123
process111Sorted<B,O>(c1, c2, c3, metric, d1sq, d2sq, d3sq);
} else {
xdbg<<"321\n";
// 321 -> 123
process111Sorted<B,O>(c3, c2, c1, metric, d3sq, d2sq, d1sq);
}
} else {
// If can't swap 12, do both ways and switch ordered to 4.
process111Sorted<B,4>(c1, c2, c3, metric, d1sq, d2sq, d3sq);
process111Sorted<B,4>(c3, c2, c1, metric, d3sq, d2sq, d1sq);
}
} else if (O == 3) {
if (BinTypeHelper<B>::sort_d123) {
// If the BinType allows sorting, but we have c3 fixed, then just check d1,d2.
Expand Down Expand Up @@ -3180,6 +3197,7 @@ template <int B, int M, int C>
void ProcessCross12b(BaseCorr3& corr, BaseField<C>& field1, BaseField<C>& field2,
int ordered, bool dots)
{
Assert(ordered == 0 || ordered == 1);
Assert((ValidMC<M,C>::_M == M));
#ifndef DIRECT_MULTIPOLE
if (B == LogMultipole) {
Expand Down Expand Up @@ -3255,6 +3273,7 @@ template <int B, int M, int C>
void ProcessCross21b(BaseCorr3& corr, BaseField<C>& field1, BaseField<C>& field2,
int ordered, bool dots)
{
Assert(ordered == 0 || ordered == 3);
Assert((ValidMC<M,C>::_M == M));
#ifndef DIRECT_MULTIPOLE
if (B == LogMultipole) {
Expand Down Expand Up @@ -3338,6 +3357,7 @@ void ProcessCrossb(BaseCorr3& corr,
BaseField<C>& field1, BaseField<C>& field2, BaseField<C>& field3,
int ordered, bool dots)
{
Assert(ordered >= 0 && ordered <= 4);
Assert((ValidMC<M,C>::_M == M));
#ifndef DIRECT_MULTIPOLE
if (B == LogMultipole) {
Expand All @@ -3352,6 +3372,18 @@ void ProcessCrossb(BaseCorr3& corr,
corr.template multipole<ValidMPB<B>::_B,ValidMC<M,C>::_M>(
field1, field2, field3, dots, 1);
break;
case 2:
corr.template multipole<ValidMPB<B>::_B,ValidMC<M,C>::_M>(
field1, field2, field3, dots, 4);
corr.template multipole<ValidMPB<B>::_B,ValidMC<M,C>::_M>(
field3, field2, field1, dots, 4);
break;
case 3:
corr.template multipole<ValidMPB<B>::_B,ValidMC<M,C>::_M>(
field1, field2, field3, dots, 4);
corr.template multipole<ValidMPB<B>::_B,ValidMC<M,C>::_M>(
field2, field1, field3, dots, 4);
break;
case 4:
corr.template multipole<ValidMPB<B>::_B,ValidMC<M,C>::_M>(
field1, field2, field3, dots, 4);
Expand All @@ -3368,6 +3400,9 @@ void ProcessCrossb(BaseCorr3& corr,
case 1:
corr.template process111<B,1,ValidMC<M,C>::_M>(field1, field2, field3, dots);
break;
case 2:
corr.template process111<B,2,ValidMC<M,C>::_M>(field1, field2, field3, dots);
break;
case 3:
corr.template process111<B,3,ValidMC<M,C>::_M>(field1, field2, field3, dots);
break;
Expand Down
79 changes: 79 additions & 0 deletions tests/test_kkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,20 @@ def test_direct_logruv_cross():
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

# With these, ordered=False is equivalent to the G vertex being fixed.
kkg.process(cat1, cat2, cat3, ordered=3)
np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3)
np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5)
np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-5)
kgk.process(cat1, cat3, cat2, ordered=2)
np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2)
np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5)
np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5)
gkk.process(cat3, cat1, cat2, ordered=1)
np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1)
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

# Error to have cat3, but not cat2
with assert_raises(ValueError):
kkg.process(cat1, cat3=cat3)
Expand Down Expand Up @@ -404,6 +418,19 @@ def test_direct_logruv_cross():
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

kkg.process(cat1p, cat2p, cat3p, ordered=3)
np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3)
np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5)
np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-5)
kgk.process(cat1p, cat3p, cat2p, ordered=2)
np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2)
np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5)
np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5)
gkk.process(cat3p, cat1p, cat2p, ordered=1)
np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1)
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

# patch_method=local
kkg.process(cat1p, cat2p, cat3p, patch_method='local')
np.testing.assert_array_equal(kkg.ntri, true_ntri_123)
Expand Down Expand Up @@ -431,6 +458,19 @@ def test_direct_logruv_cross():
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

kkg.process(cat1p, cat2p, cat3p, ordered=3, patch_method='local')
np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3)
np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5)
np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-5)
kgk.process(cat1p, cat3p, cat2p, ordered=2, patch_method='local')
np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2)
np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5)
np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5)
gkk.process(cat3p, cat1p, cat2p, ordered=1, patch_method='local')
np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1)
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

with assert_raises(ValueError):
kkg.process(cat1p, cat2p, cat3p, patch_method='nonlocal')
with assert_raises(ValueError):
Expand Down Expand Up @@ -1269,6 +1309,19 @@ def test_direct_logsas_cross():
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

kkg.process(cat1, cat2, cat3, ordered=3, algo='triangle')
np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3)
np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5)
np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-4)
kgk.process(cat1, cat3, cat2, ordered=2, algo='triangle')
np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2)
np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5)
np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5)
gkk.process(cat3, cat1, cat2, ordered=1, algo='triangle')
np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1)
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

# Error to have cat3, but not cat2
with assert_raises(ValueError):
kkg.process(cat1, cat3=cat3, algo='triangle')
Expand Down Expand Up @@ -1308,6 +1361,19 @@ def test_direct_logsas_cross():
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

kkg.process(cat1p, cat2p, cat3p, ordered=3, algo='triangle')
np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3)
np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5)
np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-5)
kgk.process(cat1p, cat3p, cat2p, ordered=2, algo='triangle')
np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2)
np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5)
np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5)
gkk.process(cat3p, cat1p, cat2p, ordered=1, algo='triangle')
np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1)
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

kkg.process(cat1p, cat2p, cat3p, patch_method='local', algo='triangle')
np.testing.assert_array_equal(kkg.ntri, true_ntri_123)
np.testing.assert_allclose(kkg.weight, true_weight_123, rtol=1.e-5)
Expand All @@ -1334,6 +1400,19 @@ def test_direct_logsas_cross():
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

kkg.process(cat1p, cat2p, cat3p, ordered=3, patch_method='local', algo='triangle')
np.testing.assert_array_equal(kkg.ntri, true_ntri_sum3)
np.testing.assert_allclose(kkg.weight, true_weight_sum3, rtol=1.e-5)
np.testing.assert_allclose(kkg.zeta, true_zeta_sum3, rtol=1.e-5)
kgk.process(cat1p, cat3p, cat2p, ordered=2, patch_method='local', algo='triangle')
np.testing.assert_array_equal(kgk.ntri, true_ntri_sum2)
np.testing.assert_allclose(kgk.weight, true_weight_sum2, rtol=1.e-5)
np.testing.assert_allclose(kgk.zeta, true_zeta_sum2, rtol=1.e-5)
gkk.process(cat3p, cat1p, cat2p, ordered=1, patch_method='local', algo='triangle')
np.testing.assert_array_equal(gkk.ntri, true_ntri_sum1)
np.testing.assert_allclose(gkk.weight, true_weight_sum1, rtol=1.e-5)
np.testing.assert_allclose(gkk.zeta, true_zeta_sum1, rtol=1.e-5)

with assert_raises(ValueError):
kkg.process(cat1p, cat2p, cat3p, patch_method='nonlocal', algo='triangle')
with assert_raises(ValueError):
Expand Down
57 changes: 57 additions & 0 deletions tests/test_kkk.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ def test_direct_logruv():
with assert_raises(ValueError):
kkk.process(cat, algo='multipole')

# Error to use integer order with 1 or 2 catalogs
with assert_raises(ValueError):
kkk.process(cat, ordered=1)
with assert_raises(ValueError):
kkk.process(cat, cat, ordered=1)
with assert_raises(ValueError):
kkk.process(cat, cat, ordered=2)

@timer
def test_direct_logruv_spherical():
Expand Down Expand Up @@ -644,6 +651,30 @@ def test_direct_logruv_cross():
np.testing.assert_allclose(kkk.weight, true_weight_sum, rtol=1.e-5)
np.testing.assert_allclose(kkk.zeta, true_zeta_sum, rtol=1.e-5)

def combine2(z1,w1,z2,w2):
w12 = w1 + w2
w12[w12==0] = 1
return (z1*w1 + z2*w2) / w12

# Test ordered = 1,2,3 options
kkk.process(cat1, cat2, cat3, ordered=1)
np.testing.assert_array_equal(kkk.ntri, true_ntri_123 + true_ntri_132)
np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_132, rtol=1.e-5)
zeta_1 = combine2(true_zeta_123, true_weight_123, true_zeta_132, true_weight_132)
np.testing.assert_allclose(kkk.zeta, zeta_1, rtol=1.e-5)

kkk.process(cat1, cat2, cat3, ordered=2)
np.testing.assert_array_equal(kkk.ntri, true_ntri_123 + true_ntri_321)
np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_321, rtol=1.e-5)
zeta_2 = combine2(true_zeta_123, true_weight_123, true_zeta_321, true_weight_321)
np.testing.assert_allclose(kkk.zeta, zeta_2, rtol=1.e-5, atol=1.e-6)

kkk.process(cat1, cat2, cat3, ordered=3)
np.testing.assert_array_equal(kkk.ntri, true_ntri_123 + true_ntri_213)
np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_213, rtol=1.e-5)
zeta_3 = combine2(true_zeta_123, true_weight_123, true_zeta_213, true_weight_213)
np.testing.assert_allclose(kkk.zeta, zeta_3, rtol=1.e-5)

# Error to have cat3, but not cat2
with assert_raises(ValueError):
kkk.process(cat1, cat3=cat3)
Expand Down Expand Up @@ -2787,6 +2818,19 @@ def test_direct_logmultipole_cross():
np.testing.assert_allclose(kkk.weight, true_weight_sum, rtol=1.e-5)
np.testing.assert_allclose(kkk.zeta, true_zeta_sum, rtol=1.e-4)

# Test ordered = 1,2,3 options
kkk.process(cat1, cat2, cat3, ordered=1)
np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_132, rtol=1.e-5)
np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_132, rtol=1.e-5, atol=1.e-6)

kkk.process(cat1, cat2, cat3, ordered=2)
np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_321, rtol=1.e-5)
np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_321, rtol=1.e-5, atol=1.e-6)

kkk.process(cat1, cat2, cat3, ordered=3)
np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_213, rtol=1.e-5)
np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_213, rtol=1.e-5, atol=1.e-6)

# Split into patches to test the list-based version of the code.
# First with just one catalog with patches
cat1 = treecorr.Catalog(x=x1, y=y1, w=w1, k=k1, npatch=8, rng=rng)
Expand Down Expand Up @@ -2839,6 +2883,19 @@ def test_direct_logmultipole_cross():
np.testing.assert_allclose(kkk.weight, true_weight_sum, rtol=1.e-5)
np.testing.assert_allclose(kkk.zeta, true_zeta_sum, rtol=1.e-4)

# Test ordered = 1,2,3 options
kkk.process(cat1, cat2, cat3, ordered=1)
np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_132, rtol=1.e-5)
np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_132, rtol=1.e-5, atol=1.e-6)

kkk.process(cat1, cat2, cat3, ordered=2)
np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_321, rtol=1.e-5)
np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_321, rtol=1.e-5, atol=1.e-6)

kkk.process(cat1, cat2, cat3, ordered=3)
np.testing.assert_allclose(kkk.weight, true_weight_123 + true_weight_213, rtol=1.e-5)
np.testing.assert_allclose(kkk.zeta, true_zeta_123 + true_zeta_213, rtol=1.e-5, atol=1.e-6)

# No tests of accuracy yet, but make sure patch-based covariance works.
cov = kkk.estimate_cov('sample')
cov = kkk.estimate_cov('jackknife')
Expand Down
28 changes: 23 additions & 5 deletions treecorr/corr3base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,6 +1707,9 @@ def is_my_job(my_indices, i, j, k, n1, n2, n3):
temp = self.copy()
temp.results = {}

# Convert bool into corresponding int
ordered = 0 if ordered is False else 4 if ordered is True else ordered

if local:
for ii,c1 in enumerate(cat1):
i = c1._single_patch if c1._single_patch is not None else ii
Expand All @@ -1715,13 +1718,13 @@ def is_my_job(my_indices, i, j, k, n1, n2, n3):
c3e = self._make_expanded_patch(c1, cat3, metric, low_mem)
self.logger.info('Process patch %d with surrounding local patches',i)
self._single_process123(c1, c2e, c3e, (i,i,i), metric,
1 if not ordered else ordered,
1 if ordered in (0,1) else 4,
num_threads, temp, True)
if low_mem:
c1.unload()
if not ordered:
if ordered in (0,3):
# local method doesn't do unordered properly as is.
# It can only handle ordered=1 or 3.
# It can only handle ordered=1 or 4.
# So in this case, we need to repeat with c2 and c3 in the first spot.
for ii,c2 in enumerate(cat2):
i = c2._single_patch if c2._single_patch is not None else ii
Expand All @@ -1730,18 +1733,21 @@ def is_my_job(my_indices, i, j, k, n1, n2, n3):
c1e = self._make_expanded_patch(c2, cat1, metric, low_mem)
c3e = self._make_expanded_patch(c2, cat3, metric, low_mem)
self.logger.info('Process patch %d from cat2 with surrounding local patches',i)
self._single_process123(c2, c1e, c3e, (i,i,i), metric, 1,
self._single_process123(c2, c1e, c3e, (i,i,i), metric,
1 if ordered == 0 else 4,
num_threads, temp, True)
if low_mem:
c2.unload()
if ordered in (0,2):
for ii,c3 in enumerate(cat3):
i = c3._single_patch if c3._single_patch is not None else ii
if (is_my_job(my_indices, i, i, i, n1, n2, n3)
and self._letter1 == self._letter3):
c1e = self._make_expanded_patch(c3, cat1, metric, low_mem)
c2e = self._make_expanded_patch(c3, cat2, metric, low_mem)
self.logger.info('Process patch %d from cat3 with surrounding local patches',i)
self._single_process123(c3, c2e, c1e, (i,i,i), metric, 1,
self._single_process123(c3, c2e, c1e, (i,i,i), metric,
1 if ordered == 0 else 4,
num_threads, temp, True)
if low_mem:
c3.unload()
Expand Down Expand Up @@ -2022,6 +2028,14 @@ def process(self, cat1, cat2=None, cat3=None, *, metric=None, ordered=True, num_
All arguments may be lists, in which case all items in the list are used
for that element of the correlation.
.. note::
In addition to ordered = True or False, you may also set ordered to 1, 2 or 3
which means that the catalog in that position is fixed, but the other two
vertices are unordered. E.g. if ordered=3, then P3 will always come from cat3,
but P1 and P2 will each come from one of cat1 or cat2 in either order.
This option is only valid when all three catalogs (cat1, cat2, cat3) are given.
Parameters:
cat1 (Catalog): A catalog or list of catalogs for the first field.
cat2 (Catalog): A catalog or list of catalogs for the second field.
Expand Down Expand Up @@ -2108,8 +2122,12 @@ def process(self, cat1, cat2=None, cat3=None, *, metric=None, ordered=True, num_
raise ValueError("{} cannot use one catalog version of process".format(self._letters))
if cat3 is not None:
raise ValueError("For two catalog case, use cat1,cat2, not cat1,cat3")
if not (ordered is True or ordered is False):
raise ValueError("The integer options for ordered are only valid with 3 catalogs")
self._process_all_auto(cat1, metric, num_threads, comm, low_mem, local)
elif cat3 is None:
if not (ordered is True or ordered is False):
raise ValueError("The integer options for ordered are only valid with 3 catalogs")
if self._letter2 == self._letter3:
self._process_all_cross12(cat1, cat2, metric, ordered, num_threads, comm, low_mem,
local)
Expand Down

0 comments on commit c4b7d28

Please sign in to comment.