diff --git a/src/tools21cm/helper_functions.py b/src/tools21cm/helper_functions.py index e1c7eee..326c2fe 100644 --- a/src/tools21cm/helper_functions.py +++ b/src/tools21cm/helper_functions.py @@ -373,6 +373,23 @@ def get_data_and_type(indata, cbin_bits=32, cbin_order='c', raw_density=False): return indata, 'unknown' raise Exception('Could not determine type of data') +def save_data(savefile, data, filetype=None, **kwargs): + if '.npy' in savefile[-5:] or filetype.lower() in ['npy', 'python_pickle']: + np.save(savefile, data) + elif '.pkl' in savefile[-5:] or filetype.lower() in ['pkl','pickle']: + import pickle + pickle.dump(data, open(savefile, 'wb')) + elif '.cbin' in savefile[-5:] or filetype.lower()=='cbin': + save_cbin(savefile, data, bits=kwargs.get('bits',32), order=kwargs.get('order','C')) + elif '.fits' in savefile[-5:] or filetype in ['fits']: + save_fits(data, savefile, header=kwargs.get('header')) + elif '.bin' in savefile[-5:] or filetype in ['bin', 'binary']: + save_raw_binary(savefile, data, bits=kwargs.get('bits',64), order=kwargs.get('order','C')) + else: + print('Unknown filetype.') + return False + return True + def get_mesh_size(filename): ''' diff --git a/src/tools21cm/nbody_pkdgrav.py b/src/tools21cm/nbody_pkdgrav.py index a354b43..93265ba 100644 --- a/src/tools21cm/nbody_pkdgrav.py +++ b/src/tools21cm/nbody_pkdgrav.py @@ -6,6 +6,8 @@ from scipy.interpolate import splev, splrep import pandas as pd +from .helper_functions import save_data + class ReaderPkdgrav3: def __init__(self, box_len, nGrid, Omega_m=0.31, rho_c=2.77536627e11, verbose=True): @@ -85,6 +87,33 @@ def read_fof_data(self, filename, z=None, dtype=None): ''' Read the FOF data. + Parameters: + - filenames (str): The name of the data file or a list of files. + - z (float, optional): Redshift of the data. Defaults to None, which goes to 0. + + Returns: + - numpy.ndarray: A structured array. + ''' + if isinstance(filename, list): + for ii,ff in enumerate(filename): + hl0 = self._read_fof_data(ff, z=z, dtype=dtype) + if self.verbose: + print(f'{ff} contains {hl0.shape[0]} haloes') + data = hl0 if ii==0 else np.concatenate((data,hl0), axis=0) + else: + data = self._read_fof_data(filename, z=z, dtype=dtype) + + self.fof_data_dtype = dtype + self.fof_data = data + if self.verbose: + print(f'Total haloes: {data.shape[0]}') + return data + + + def _read_fof_data(self, filename, z=None, dtype=None): + ''' + Read the FOF data. + Parameters: - filename (str): The name of the data file. - z (float, optional): Redshift of the data. Defaults to None, which goes to 0. @@ -129,12 +158,7 @@ def read_fof_data(self, filename, z=None, dtype=None): # ('nDM', '