diff --git a/src/tools21cm/corr_function.py b/src/tools21cm/corr_function.py index 284c972..7a70e5a 100644 --- a/src/tools21cm/corr_function.py +++ b/src/tools21cm/corr_function.py @@ -1,6 +1,10 @@ import numpy as np from scipy.interpolate import splrep,splev +from scipy.spatial import cKDTree +from tqdm import tqdm from .power_spectrum import power_spectrum_1d +from . import conv + def correlation_function(input_array_nd, rbins=10, kbins=10, box_dims=None, binning='log'): ''' @@ -45,3 +49,80 @@ def correlation_function(input_array_nd, rbins=10, kbins=10, box_dims=None, binn return corr/2/np.pi**2, rbins +def landy_szalay_estimator(data, randoms=None, rbins=10, box_dims=None, binning='log', **kwargs): + ''' + Function to estimate the two-point correlation function using the Landy-Szalay estimator. + + Parameters: + data (numpy array): array of positions of data points (e.g., galaxies). + randoms (numpy array): array of positions of random points. + rbins (integer or array-like): The number of radial bins or bin edges. + If an integer is provided, bins are logarithmically or linearly spaced. + box_dims (float): Length of the box in each direction in Mpc. + binning (string): 'log' for logarithmic binning, 'linear' for linear binning. + + Returns: + xi (numpy array): The estimated correlation function. + rbin_centers (numpy array): The centers of the radial bins. + ''' + if box_dims is None: + box_dims = conv.LB + print(f'Setting box_dims to ({box_dims},{box_dims},{box_dims}) Mpc.') + + if data.ndim==3: + data = np.argwhere(data>0)*box_dims/data.shape[0] + + if randoms is None: + randoms = np.random.uniform(data.min(), data.max(), data.shape) + + # Determine the radial bins + if isinstance(rbins, int): + rmin = kwargs.get('rmin') + rmax = kwargs.get('rmax') + if rmin is None: + rmin = 0.01 + if rmax is None: + rmax = box_dims # 50.0 + if binning == 'log': + rbins = np.logspace(np.log10(rmin), np.log10(rmax), rbins) + else: + rbins = np.linspace(rmin, rmax, rbins) + + rbin_centers = (rbins[:-1] + rbins[1:]) / 2 # Midpoints of the bins + + # Step 1: Pair Counting + + # Create k-d trees for fast nearest neighbor searching + data_tree = cKDTree(data) + random_tree = cKDTree(randoms) + + # Count DD pairs (data-data) + DD, _ = np.histogram(data_tree.query_pairs(rmax, output_type='ndarray'), bins=rbins) + + # Count RR pairs (random-random) + RR, _ = np.histogram(random_tree.query_pairs(rmax, output_type='ndarray'), bins=rbins) + + # Count DR pairs (data-random) + DR = np.zeros(len(rbins)-1) + for point in tqdm(data): + distances, _ = random_tree.query(point, k=len(randoms), distance_upper_bound=rmax) + distances = distances[distances <= rmax] # Exclude infinite distances (points outside search range) + DR += np.histogram(distances, bins=rbins)[0] + + # Normalize the pair counts + num_data = len(data) + num_random = len(randoms) + + DD = DD / (num_data * (num_data - 1) / 2.0) + RR = RR / (num_random * (num_random - 1) / 2.0) + DR = DR / (num_data * num_random) + + # Step 2: Calculate the Landy-Szalay estimator + xi = (DD - 2 * DR + RR) / RR + + return xi, rbin_centers + +# Example usage: +# data = np.random.uniform(0, 100, (64, 64, 64)) # 1000 random data points in a 100x100x100 box +# randoms = np.random.uniform(0, 100, (1000, 3)) # 1000 random points for the random catalog +# xi, r_bins = landy_szalay_estimator(data, randoms)