From e2c54498dec62fdcd0530a9cb742f5c93a9072aa Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 26 Jul 2024 12:44:03 +0100 Subject: [PATCH] tests --- tests/test_variantdata.py | 51 +++++++++++++++++++++++++++++++++--- tsinfer/formats.py | 54 +++++++++++++++------------------------ 2 files changed, 69 insertions(+), 36 deletions(-) diff --git a/tests/test_variantdata.py b/tests/test_variantdata.py index 074aa040..c7765b11 100644 --- a/tests/test_variantdata.py +++ b/tests/test_variantdata.py @@ -249,6 +249,7 @@ def test_sgkit_accessors_defaults(tmp_path): ) +@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") def test_variantdata_sites_time_default(tmp_path): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") @@ -259,15 +260,14 @@ def test_variantdata_sites_time_default(tmp_path): ) +@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") def test_variantdata_sites_time_array(tmp_path): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) sites_time = np.arange(ts.num_sites) samples = tsinfer.VariantData( zarr_path, "variant_ancestral_allele", sites_time_name_or_array=sites_time ) - assert np.array_equal(samples.sites_time, sites_time) - wrong_length_sites_time = np.arange(ts.num_sites + 1) with pytest.raises( ValueError, @@ -338,7 +338,8 @@ def test_sgkit_variant_bad_mask_length(self, tmp_path): tsinfer.VariantData(zarr_path, "variant_ancestral_allele") with pytest.raises( ValueError, - match="Mask must be the same length as the number of unmasked sites", + match="Site mask array must be the same length as the number of" + " unmasked sites", ): tsinfer.VariantData( zarr_path, @@ -477,6 +478,50 @@ def test_variantdata_default_masks(self, tmp_path): samples.samples_select, np.full(samples.num_samples, True, dtype=bool) ) + def test_variantdata_mask_as_arrays(self, tmp_path): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + sample_mask = np.full(ts.num_individuals, False, dtype=bool) + site_mask = np.full(ts.num_sites, False, dtype=bool) + sample_mask[3] = True + site_mask[5] = True + + samples = tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + sample_mask_name_or_array=sample_mask, + site_mask_name_or_array=site_mask, + ) + + assert np.array_equal(samples.individuals_select, ~sample_mask) + assert np.array_equal(samples.sites_select, ~site_mask) + + def test_variantdata_incorrect_mask_lengths(self, tmp_path): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + incorrect_sample_mask = np.array([True, False]) # Incorrect length + incorrect_site_mask = np.array([True, False, True]) # Incorrect length + + with pytest.raises( + ValueError, + match="Samples mask array must be the same length as the" + " number of individuals", + ): + tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + sample_mask_name_or_array=incorrect_sample_mask, + ) + + with pytest.raises( + ValueError, + match="Site mask array must be the same length as the number of" + " unmasked sites", + ): + tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + site_mask_name_or_array=incorrect_site_mask, + ) + def test_sgkit_subset_equivalence(self, tmp_path, tmpdir): ( mat_sd, diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 2325ad84..4a00b390 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -2311,60 +2311,48 @@ def __init__( genotypes_arr = self.data["call_genotype"] _, self._num_individuals_before_mask, self.ploidy = genotypes_arr.shape - # Store sites_select if site_mask_name_or_array is None: - self.sites_select = np.full( - self.data["variant_position"].shape, True, dtype=bool + site_mask_name_or_array = np.full( + self.data["variant_position"].shape, False, dtype=bool ) elif isinstance(site_mask_name_or_array, np.ndarray): - if ( - site_mask_name_or_array.shape[0] - != self.data["variant_position"].shape[0] - ): - raise ValueError( - "Mask array must be the same length as the number of unmasked sites" - ) - self.sites_select = ~site_mask_name_or_array.astype(bool) + pass else: try: - if ( - self.data[site_mask_name_or_array].shape[0] - != self.data["variant_position"].shape[0] - ): - raise ValueError( - "Mask must be the same length as the number of unmasked sites" - ) - # We negate the mask as it is much easier in numpy to have True=keep - self.sites_select = ~( - self.data[site_mask_name_or_array].astype(bool)[:] - ) + site_mask_name_or_array = self.data[site_mask_name_or_array][:] except KeyError: raise ValueError( f"The sites mask {site_mask_name_or_array} was not found" f" in the dataset." ) + if site_mask_name_or_array.shape[0] != self.data["variant_position"].shape[0]: + raise ValueError( + "Site mask array must be the same length as the number of unmasked sites" + ) + # We negate the mask as it is much easier in numpy to have True=keep + self.sites_select = ~site_mask_name_or_array.astype(bool) + if sample_mask_name_or_array is None: - self.individuals_select = np.full( - self._num_individuals_before_mask, True, dtype=bool + sample_mask_name_or_array = np.full( + self._num_individuals_before_mask, False, dtype=bool ) elif isinstance(sample_mask_name_or_array, np.ndarray): - if sample_mask_name_or_array.shape[0] != self._num_individuals_before_mask: - raise ValueError( - "Samples mask array must be the same length as the number of" - " individuals" - ) - self.individuals_select = ~sample_mask_name_or_array.astype(bool) + pass else: try: # We negate the mask as it is much easier in numpy to have True=keep - self.individuals_select = ~( - self.data[sample_mask_name_or_array][:].astype(bool) - ) + sample_mask_name_or_array = self.data[sample_mask_name_or_array][:] except KeyError: raise ValueError( f"The samples mask {sample_mask_name_or_array} was not" f" found in the dataset." ) + if sample_mask_name_or_array.shape[0] != self._num_individuals_before_mask: + raise ValueError( + "Samples mask array must be the same length as the number of" + " individuals" + ) + self.individuals_select = ~sample_mask_name_or_array.astype(bool) self._num_sites = np.sum(self.sites_select) assert self.ploidy == self.data["call_genotype"].chunks[2]