diff --git a/tests/test_inference.py b/tests/test_inference.py index 09a6f183..199e8b55 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1411,7 +1411,7 @@ def test_cache_used_by_timing(self, tmpdir): sample_data, ancestor_ts, match_data_dir=tmpdir ) time2 = time.time() - t - assert time2 < time1 / 1.25 + assert time2 < time1 final_ts1.tables.assert_equals(final_ts2.tables, ignore_provenance=True) @@ -1419,12 +1419,16 @@ def test_cache_used_by_timing(self, tmpdir): class TestBatchAncestorMatching: def test_equivalance(self, tmp_path, tmpdir): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") ancestors = tsinfer.generate_ancestors( samples, path=str(tmpdir / "ancestors.zarr") ) metadata = tsinfer.match_ancestors_batch_init( - tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000 + tmpdir / "work", + zarr_path, + "variant_ancestral_allele", + tmpdir / "ancestors.zarr", + 1000, ) for group_index, _ in enumerate(metadata["ancestor_grouping"]): tsinfer.match_ancestors_batch_groups( @@ -1436,12 +1440,16 @@ def test_equivalance(self, tmp_path, tmpdir): def test_equivalance_many_at_once(self, tmp_path, tmpdir): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") ancestors = tsinfer.generate_ancestors( samples, path=str(tmpdir / "ancestors.zarr") ) metadata = tsinfer.match_ancestors_batch_init( - tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000 + tmpdir / "work", + zarr_path, + "variant_ancestral_allele", + tmpdir / "ancestors.zarr", + 1000, ) tsinfer.match_ancestors_batch_groups( tmpdir / "work", 0, len(metadata["ancestor_grouping"]) // 2, 2 @@ -1459,12 +1467,16 @@ def test_equivalance_many_at_once(self, tmp_path, tmpdir): def test_equivalance_with_partitions(self, tmp_path, tmpdir): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") ancestors = tsinfer.generate_ancestors( samples, path=str(tmpdir / "ancestors.zarr") ) metadata = tsinfer.match_ancestors_batch_init( - tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000 + tmpdir / "work", + zarr_path, + "variant_ancestral_allele", + tmpdir / "ancestors.zarr", + 1000, ) for group_index, group in enumerate(metadata["ancestor_grouping"]): if group["partitions"] is None: @@ -1485,13 +1497,14 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir): def test_max_partitions(self, tmp_path, tmpdir): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") ancestors = tsinfer.generate_ancestors( samples, path=str(tmpdir / "ancestors.zarr") ) metadata = tsinfer.match_ancestors_batch_init( tmpdir / "work", zarr_path, + "variant_ancestral_allele", tmpdir / "ancestors.zarr", 10000, max_num_partitions=2, @@ -1516,10 +1529,14 @@ def test_max_partitions(self, tmp_path, tmpdir): def test_errors(self, tmp_path, tmpdir): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr")) metadata = tsinfer.match_ancestors_batch_init( - tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000 + tmpdir / "work", + zarr_path, + "variant_ancestral_allele", + tmpdir / "ancestors.zarr", + 1000, ) with pytest.raises(ValueError, match="out of range"): tsinfer.match_ancestors_batch_groups(tmpdir / "work", -1, 1) diff --git a/tests/test_sgkit.py b/tests/test_variantdata.py similarity index 82% rename from tests/test_sgkit.py rename to tests/test_variantdata.py index fb828ca7..c7765b11 100644 --- a/tests/test_sgkit.py +++ b/tests/test_variantdata.py @@ -78,7 +78,7 @@ def ts_to_dataset(ts, chunks=None, samples=None): @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") def test_sgkit_dataset_roundtrip(tmp_path): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") inf_ts = tsinfer.infer(samples) ds = sgkit.load_dataset(zarr_path) @@ -120,7 +120,7 @@ def test_sgkit_individual_metadata_not_clobbered(tmp_path): tskit.MetadataSchema.permissive_json() ) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") inf_ts = tsinfer.infer(samples) ds = sgkit.load_dataset(zarr_path) @@ -139,7 +139,9 @@ def test_sgkit_dataset_accessors(tmp_path): ts, zarr_path = tsutil.make_ts_and_zarr( tmp_path, add_optional=True, shuffle_alleles=False ) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData( + zarr_path, "variant_ancestral_allele", sites_time_name_or_array="sites_time" + ) ds = sgkit.load_dataset(zarr_path) assert samples.format_name == "tsinfer-sgkit-sample-data" @@ -203,7 +205,7 @@ def test_sgkit_dataset_accessors(tmp_path): # Need to shuffle for the ancestral allele test ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path, add_optional=True) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") for i in range(ts.num_sites): assert ( samples.sites_alleles[i][samples.sites_ancestral_allele[i]] @@ -214,7 +216,7 @@ def test_sgkit_dataset_accessors(tmp_path): @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") def test_sgkit_accessors_defaults(tmp_path): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") ds = sgkit.load_dataset(zarr_path) default_schema = tskit.MetadataSchema.permissive_json().schema @@ -247,6 +249,37 @@ def test_sgkit_accessors_defaults(tmp_path): ) +@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") +def test_variantdata_sites_time_default(tmp_path): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") + + assert ( + np.all(np.isnan(samples.sites_time)) + and samples.sites_time.size == samples.num_sites + ) + + +@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") +def test_variantdata_sites_time_array(tmp_path): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + sites_time = np.arange(ts.num_sites) + samples = tsinfer.VariantData( + zarr_path, "variant_ancestral_allele", sites_time_name_or_array=sites_time + ) + assert np.array_equal(samples.sites_time, sites_time) + wrong_length_sites_time = np.arange(ts.num_sites + 1) + with pytest.raises( + ValueError, + match="Sites time array must be the same length as the number of selected sites", + ): + tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + sites_time_name_or_array=wrong_length_sites_time, + ) + + def test_simulate_genotype_call_dataset(tmp_path): # Test that byte alleles are correctly converted to string ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123) @@ -254,7 +287,7 @@ def test_simulate_genotype_call_dataset(tmp_path): ds = ts_to_dataset(ts) ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]}) ds.to_zarr(tmp_path, mode="w") - sd = tsinfer.SgkitSampleData(tmp_path) + sd = tsinfer.VariantData(tmp_path, "variant_ancestral_allele") ts = tsinfer.infer(sd) for v, ds_v, sd_v in zip(ts.variants(), ds.call_genotype, sd.sites_genotypes): assert np.all(v.genotypes == ds_v.values.flatten()) @@ -271,7 +304,11 @@ def test_sgkit_variant_mask(self, tmp_path, sites): for i in sites: sites_mask[i] = False tsutil.add_array_to_dataset("variant_mask_42", sites_mask, zarr_path) - samples = tsinfer.SgkitSampleData(zarr_path, sites_mask_name="variant_mask_42") + samples = tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + site_mask_name_or_array="variant_mask_42", + ) assert samples.num_sites == len(sites) assert np.array_equal(samples.sites_select, ~sites_mask) assert np.array_equal( @@ -298,12 +335,17 @@ def test_sgkit_variant_bad_mask_length(self, tmp_path): ds = sgkit.load_dataset(zarr_path) sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int) tsutil.add_array_to_dataset("variant_mask_foobar", sites_mask, zarr_path) - tsinfer.SgkitSampleData(zarr_path) + tsinfer.VariantData(zarr_path, "variant_ancestral_allele") with pytest.raises( ValueError, - match="Mask must be the same length as the number of unmasked sites", + match="Site mask array must be the same length as the number of" + " unmasked sites", ): - tsinfer.SgkitSampleData(zarr_path, sites_mask_name="variant_mask_foobar") + tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + site_mask_name_or_array="variant_mask_foobar", + ) def test_bad_select_length_at_iterator(self, tmp_path): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) @@ -325,8 +367,10 @@ def test_sgkit_sample_mask(self, tmp_path, sample_list): for i in sample_list: samples_mask[i] = False tsutil.add_array_to_dataset("samples_mask_69", samples_mask, zarr_path) - samples = tsinfer.SgkitSampleData( - zarr_path, sgkit_samples_mask_name="samples_mask_69" + samples = tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + sample_mask_name_or_array="samples_mask_69", ) assert samples.ploidy == 3 assert samples.num_individuals == len(sample_list) @@ -380,10 +424,11 @@ def test_sgkit_sample_and_site_mask(self, tmp_path): tsutil.add_array_to_dataset( "samples_mask_foobar", samples_mask, zarr_path, dims=["samples"] ) - samples = tsinfer.SgkitSampleData( + samples = tsinfer.VariantData( zarr_path, - sites_mask_name="variant_mask_foobar", - sgkit_samples_mask_name="samples_mask_foobar", + "variant_ancestral_allele", + site_mask_name_or_array="variant_mask_foobar", + sample_mask_name_or_array="samples_mask_foobar", ) genotypes = samples.sites_genotypes ds_genotypes = ds.call_genotype.values[~variant_mask][ @@ -403,22 +448,80 @@ def test_sgkit_sample_and_site_mask(self, tmp_path): def test_sgkit_missing_masks(self, tmp_path): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") samples.individuals_select samples.sites_select with pytest.raises( ValueError, match="The sites mask foobar was not found in the dataset." ): - tsinfer.SgkitSampleData(zarr_path, sites_mask_name="foobar") + tsinfer.VariantData( + zarr_path, "variant_ancestral_allele", site_mask_name_or_array="foobar" + ) with pytest.raises( ValueError, - match="The sgkit samples mask foobar2 was not found in the dataset.", + match="The samples mask foobar2 was not found in the dataset.", ): - samples = tsinfer.SgkitSampleData( - zarr_path, sgkit_samples_mask_name="foobar2" + samples = tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + sample_mask_name_or_array="foobar2", ) samples.individuals_select + def test_variantdata_default_masks(self, tmp_path): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") + assert np.array_equal( + samples.sites_select, np.full(samples.num_sites, True, dtype=bool) + ) + assert np.array_equal( + samples.samples_select, np.full(samples.num_samples, True, dtype=bool) + ) + + def test_variantdata_mask_as_arrays(self, tmp_path): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + sample_mask = np.full(ts.num_individuals, False, dtype=bool) + site_mask = np.full(ts.num_sites, False, dtype=bool) + sample_mask[3] = True + site_mask[5] = True + + samples = tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + sample_mask_name_or_array=sample_mask, + site_mask_name_or_array=site_mask, + ) + + assert np.array_equal(samples.individuals_select, ~sample_mask) + assert np.array_equal(samples.sites_select, ~site_mask) + + def test_variantdata_incorrect_mask_lengths(self, tmp_path): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + incorrect_sample_mask = np.array([True, False]) # Incorrect length + incorrect_site_mask = np.array([True, False, True]) # Incorrect length + + with pytest.raises( + ValueError, + match="Samples mask array must be the same length as the" + " number of individuals", + ): + tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + sample_mask_name_or_array=incorrect_sample_mask, + ) + + with pytest.raises( + ValueError, + match="Site mask array must be the same length as the number of" + " unmasked sites", + ): + tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + site_mask_name_or_array=incorrect_site_mask, + ) + def test_sgkit_subset_equivalence(self, tmp_path, tmpdir): ( mat_sd, @@ -500,7 +603,7 @@ def test_sgkit_subset_equivalence(self, tmp_path, tmpdir): def test_sgkit_ancestral_allele_same_ancestors(tmp_path): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) ts_sampledata = tsinfer.SampleData.from_tree_sequence(ts) - sg_sampledata = tsinfer.SgkitSampleData(zarr_path) + sg_sampledata = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") ts_ancestors = tsinfer.generate_ancestors(ts_sampledata) sg_ancestors = tsinfer.generate_ancestors(sg_sampledata) assert np.array_equal(ts_ancestors.sites_position, sg_ancestors.sites_position) @@ -526,9 +629,8 @@ def test_missing_ancestral_allele(tmp_path): ds = sgkit.load_dataset(zarr_path) ds = ds.drop_vars(["variant_ancestral_allele"]) sgkit.save_dataset(ds, str(zarr_path) + ".tmp") - samples = tsinfer.SgkitSampleData(str(zarr_path) + ".tmp") with pytest.raises(ValueError, match="variant_ancestral_allele was not found"): - tsinfer.infer(samples) + tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele") @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows") @@ -549,12 +651,12 @@ def test_ancestral_missingness(tmp_path): ["variants"], ) ds = sgkit.load_dataset(str(zarr_path) + ".tmp") - sd = tsinfer.SgkitSampleData(str(zarr_path) + ".tmp") with pytest.warns( UserWarning, match=r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2", ): - inf_ts = tsinfer.infer(sd) + sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele") + inf_ts = tsinfer.infer(sd) for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())): if i in [0, 11, 12, 15]: assert inf_var.site.metadata == {"inference_type": "parsimony"} @@ -574,7 +676,7 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path): sgkit.display_genotypes(ds) -class TestSgkitSampleDataErrors: +class TestVariantDataErrors: def test_missing_phase(self, tmp_path): path = tmp_path / "data.zarr" ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3) @@ -582,7 +684,7 @@ def test_missing_phase(self, tmp_path): with pytest.raises( ValueError, match="The call_genotype_phased array is missing" ): - tsinfer.SgkitSampleData(path) + tsinfer.VariantData(path, "variant_ancestral_allele") def test_phased(self, tmp_path): path = tmp_path / "data.zarr" @@ -591,15 +693,23 @@ def test_phased(self, tmp_path): ds["call_genotype"].dims, np.ones(ds["call_genotype"].shape, dtype=bool), ) + ds["variant_ancestral_allele"] = ( + ds["variant_position"].dims, + np.array(["A", "C", "G"], dtype="S1"), + ) sgkit.save_dataset(ds, path) - tsinfer.SgkitSampleData(path) + tsinfer.VariantData(path, "variant_ancestral_allele") def test_ploidy1_missing_phase(self, tmp_path): path = tmp_path / "data.zarr" # Ploidy==1 is always ok ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) + ds["variant_ancestral_allele"] = ( + ds["variant_position"].dims, + np.array(["A", "C", "G"], dtype="S1"), + ) sgkit.save_dataset(ds, path) - tsinfer.SgkitSampleData(path) + tsinfer.VariantData(path, "variant_ancestral_allele") def test_ploidy1_unphased(self, tmp_path): path = tmp_path / "data.zarr" @@ -608,8 +718,12 @@ def test_ploidy1_unphased(self, tmp_path): ds["call_genotype"].dims, np.zeros(ds["call_genotype"].shape, dtype=bool), ) + ds["variant_ancestral_allele"] = ( + ds["variant_position"].dims, + np.array(["A", "C", "G"], dtype="S1"), + ) sgkit.save_dataset(ds, path) - tsinfer.SgkitSampleData(path) + tsinfer.VariantData(path, "variant_ancestral_allele") def test_duplicate_positions(self, tmp_path): path = tmp_path / "data.zarr" @@ -617,7 +731,7 @@ def test_duplicate_positions(self, tmp_path): ds["variant_position"][2] = ds["variant_position"][1] sgkit.save_dataset(ds, path) with pytest.raises(ValueError, match="duplicate or out-of-order values"): - tsinfer.SgkitSampleData(path) + tsinfer.VariantData(path, "variant_ancestral_allele") def test_bad_order_positions(self, tmp_path): path = tmp_path / "data.zarr" @@ -625,7 +739,7 @@ def test_bad_order_positions(self, tmp_path): ds["variant_position"][0] = ds["variant_position"][2] - 0.5 sgkit.save_dataset(ds, path) with pytest.raises(ValueError, match="duplicate or out-of-order values"): - tsinfer.SgkitSampleData(path) + tsinfer.VariantData(path, "variant_ancestral_allele") def test_empty_alleles_not_at_end(self, tmp_path): path = tmp_path / "data.zarr" @@ -639,7 +753,7 @@ def test_empty_alleles_not_at_end(self, tmp_path): np.array(["C", "A", "A"], dtype="S1"), ) sgkit.save_dataset(ds, path) - samples = tsinfer.SgkitSampleData(path) + samples = tsinfer.VariantData(path, "variant_ancestral_allele") with pytest.raises(ValueError, match="Empty alleles must be at the end"): tsinfer.infer(samples) @@ -649,7 +763,7 @@ class TestSgkitMatchSamplesToDisk: @pytest.mark.parametrize("slice", [(0, 6), (0, 0), (0, 3), (12, 15)]) def test_match_samples_to_disk_write(self, slice, tmp_path, tmpdir): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") ancestors = tsinfer.generate_ancestors(samples) anc_ts = tsinfer.match_ancestors(samples, ancestors) tsinfer.match_samples_slice_to_disk( @@ -664,7 +778,7 @@ def test_match_samples_to_disk_write(self, slice, tmp_path, tmpdir): def test_match_samples_to_disk_slice_error(self, tmp_path, tmpdir): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") ancestors = tsinfer.generate_ancestors(samples) anc_ts = tsinfer.match_ancestors(samples, ancestors) with pytest.raises( @@ -678,7 +792,7 @@ def test_match_samples_to_disk_full(self, tmp_path, tmpdir): match_data_dir = tmpdir / "match_data" os.mkdir(match_data_dir) ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.SgkitSampleData(zarr_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") ancestors = tsinfer.generate_ancestors(samples) anc_ts = tsinfer.match_ancestors(samples, ancestors) ts = tsinfer.match_samples(samples, anc_ts) diff --git a/tests/tsutil.py b/tests/tsutil.py index e17f5626..b6fd6795 100644 --- a/tests/tsutil.py +++ b/tests/tsutil.py @@ -447,10 +447,11 @@ def make_materialized_and_masked_sampledata(tmp_path, tmpdir): mat_ds = mat_ds.unify_chunks() sgkit.save_dataset(mat_ds, tmpdir / "subset.zarr", auto_rechunk=True) - mat_sd = tsinfer.SgkitSampleData(tmpdir / "subset.zarr") - mask_sd = tsinfer.SgkitSampleData( + mat_sd = tsinfer.VariantData(tmpdir / "subset.zarr", "variant_ancestral_allele") + mask_sd = tsinfer.VariantData( zarr_path, - sites_mask_name="variant_mask_foobar", - sgkit_samples_mask_name="samples_mask_foobar", + "variant_ancestral_allele", + site_mask_name_or_array="variant_mask_foobar", + sample_mask_name_or_array="samples_mask_foobar", ) return mat_sd, mask_sd, samples_mask, variant_mask diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 689b5117..4a00b390 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -2292,56 +2292,67 @@ def populations(self): yield Population(j, metadata=metadata) -class SgkitSampleData(SampleData): +class VariantData(SampleData): FORMAT_NAME = "tsinfer-sgkit-sample-data" FORMAT_VERSION = (0, 1) - def __init__(self, path, sgkit_samples_mask_name=None, sites_mask_name=None): + def __init__( + self, + path, + ancestral_allele_name_or_array, + *, + sample_mask_name_or_array=None, + site_mask_name_or_array=None, + sites_time_name_or_array=None, + ): self.path = path self.data = zarr.open(path, mode="r") - self._sgkit_samples_mask_name = sgkit_samples_mask_name - self._sites_mask_name = sites_mask_name genotypes_arr = self.data["call_genotype"] _, self._num_individuals_before_mask, self.ploidy = genotypes_arr.shape - # Store sites_select - if self._sites_mask_name is None: - self.sites_select = np.full( - self.data["variant_position"].shape, True, dtype=bool + if site_mask_name_or_array is None: + site_mask_name_or_array = np.full( + self.data["variant_position"].shape, False, dtype=bool ) + elif isinstance(site_mask_name_or_array, np.ndarray): + pass else: try: - if ( - self.data[self._sites_mask_name].shape[0] - != self.data["variant_position"].shape[0] - ): - raise ValueError( - "Mask must be the same length as the number of unmasked sites" - ) - # We negate the mask as it is much easier in numpy to have True=keep - self.sites_select = ~(self.data[self._sites_mask_name].astype(bool)[:]) + site_mask_name_or_array = self.data[site_mask_name_or_array][:] except KeyError: raise ValueError( - f"The sites mask {self._sites_mask_name} was not found" + f"The sites mask {site_mask_name_or_array} was not found" f" in the dataset." ) - # Store individuals_select - if self._sgkit_samples_mask_name is None: - self.individuals_select = np.full( - self._num_individuals_before_mask, True, dtype=bool + if site_mask_name_or_array.shape[0] != self.data["variant_position"].shape[0]: + raise ValueError( + "Site mask array must be the same length as the number of unmasked sites" + ) + # We negate the mask as it is much easier in numpy to have True=keep + self.sites_select = ~site_mask_name_or_array.astype(bool) + + if sample_mask_name_or_array is None: + sample_mask_name_or_array = np.full( + self._num_individuals_before_mask, False, dtype=bool ) + elif isinstance(sample_mask_name_or_array, np.ndarray): + pass else: try: # We negate the mask as it is much easier in numpy to have True=keep - self.individuals_select = ~( - self.data[self._sgkit_samples_mask_name][:].astype(bool) - ) + sample_mask_name_or_array = self.data[sample_mask_name_or_array][:] except KeyError: raise ValueError( - f"The sgkit samples mask {self._sgkit_samples_mask_name} was not" + f"The samples mask {sample_mask_name_or_array} was not" f" found in the dataset." ) + if sample_mask_name_or_array.shape[0] != self._num_individuals_before_mask: + raise ValueError( + "Samples mask array must be the same length as the number of" + " individuals" + ) + self.individuals_select = ~sample_mask_name_or_array.astype(bool) self._num_sites = np.sum(self.sites_select) assert self.ploidy == self.data["call_genotype"].chunks[2] @@ -2359,6 +2370,69 @@ def __init__(self, path, sgkit_samples_mask_name=None, sites_mask_name=None): "These must be masked out to run tsinfer." ) + if sites_time_name_or_array is None: + self._sites_time = np.full(self.num_sites, tskit.UNKNOWN_TIME) + elif isinstance(sites_time_name_or_array, np.ndarray): + if sites_time_name_or_array.shape[0] != self.num_sites: + raise ValueError( + "Sites time array must be the same length as the number of selected" + " sites" + ) + self._sites_time = sites_time_name_or_array + else: + try: + self._sites_time = self.data[sites_time_name_or_array][:][ + self.sites_select + ] + except KeyError: + raise ValueError( + f"The sites time {sites_time_name_or_array} was not found" + f" in the dataset." + ) + + if isinstance(ancestral_allele_name_or_array, np.ndarray): + if ancestral_allele_name_or_array.shape[0] != self.num_sites: + raise ValueError( + "Ancestral allele array must be the same length as the number of" + " selected sites" + ) + self._sites_ancestral_allele = ancestral_allele_name_or_array + else: + try: + self._sites_ancestral_allele = self.data[ + ancestral_allele_name_or_array + ][:][self.sites_select] + except KeyError: + raise ValueError( + f"The ancestral allele {ancestral_allele_name_or_array} was not" + f" found in the dataset." + ) + self._sites_ancestral_allele = self._sites_ancestral_allele.astype(str) + unknown_alleles = collections.Counter() + converted = np.zeros(self.num_sites, dtype=np.int8) + for i, allele in enumerate(self._sites_ancestral_allele): + allele_index = -1 + try: + allele_index = np.where(allele == self.sites_alleles[i])[0][0] + except IndexError: + unknown_alleles[allele] += 1 + converted[i] = allele_index + tot = sum(unknown_alleles.values()) + if tot > 0: + frac_bad = tot / self.num_sites + frac_bad_per_type = [v / self.num_sites for v in unknown_alleles.values()] + summarise_unknown = [ + f"'{k}': {v} ({frac * 100:.2f}% of sites)" # Summarise per allele type + for (k, v), frac in zip(unknown_alleles.items(), frac_bad_per_type) + ] + warnings.warn( + "An ancestral allele was not found in the variant_allele array for " + + f"the {tot} sites ({frac_bad * 100 :.2f}%) listed below. " + + "They will be treated as of unknown ancestral state:\n " + + "\n ".join(summarise_unknown) + ) + self._sites_ancestral_allele = converted + # Create zarr arrays for convenience when iterating over chunks self.z_sites_select = zarr.array( self.sites_select, chunks=self.data["call_genotype"].chunks[0], dtype=bool @@ -2427,12 +2501,9 @@ def sites_metadata(self): except KeyError: return [{} for _ in range(self.num_sites)] - @functools.cached_property + @property def sites_time(self): - try: - return self.data["sites_time"][:][self.sites_select] - except KeyError: - return np.full(self.num_sites, tskit.UNKNOWN_TIME) + return self._sites_time @functools.cached_property def sites_position(self): @@ -2442,46 +2513,9 @@ def sites_position(self): def sites_alleles(self): return self.data["variant_allele"][:][self.sites_select].astype(str) - @functools.cached_property + @property def sites_ancestral_allele(self): - unknown_alleles = collections.Counter() - try: - string_allele = ( - self.data["variant_ancestral_allele"][:][self.sites_select] - ).astype(str) - except KeyError: - raise ValueError( - "variant_ancestral_allele was not found in the dataset." - "This is required for tsinfer. If you know that the " - "zero-th allele at each site is the ancestral state, you " - "add this to the dataset by running:\n" - "ds.update({'variant_ancestral_allele': " - "ds['variant_allele'][:,0]})\n" - ) - ret = np.zeros(self.num_sites, dtype=np.int8) - for i, allele in enumerate(string_allele): - allele_index = -1 - try: - allele_index = np.where(allele == self.sites_alleles[i])[0][0] - except IndexError: - unknown_alleles[allele] += 1 - ret[i] = allele_index - tot = sum(unknown_alleles.values()) - if tot > 0: - num_sites = len(string_allele) - frac_bad = tot / num_sites - frac_bad_per_type = [v / num_sites for v in unknown_alleles.values()] - summarise_unknown = [ - f"'{k}': {v} ({frac * 100:.2f}% of sites)" # Summarise per allele type - for (k, v), frac in zip(unknown_alleles.items(), frac_bad_per_type) - ] - warnings.warn( - "An ancestral allele was not found in the variant_allele array for " - + f"the {tot} sites ({frac_bad * 100 :.2f}%) listed below. " - + "They will be treated as of unknown ancestral state:\n " - + "\n ".join(summarise_unknown) - ) - return ret + return self._sites_ancestral_allele @functools.cached_property def sites_genotypes(self): diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 29af5371..a1c13ff6 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -585,12 +585,13 @@ def match_ancestors( def match_ancestors_batch_init( working_dir, sample_data_path, + ancestral_allele_name_or_array, ancestor_data_path, min_work_per_job, *, max_num_partitions=None, - sgkit_samples_mask_name=None, - sites_mask_name=None, + sample_mask_name_or_array=None, + site_mask_name_or_array=None, recombination_rate=None, mismatch_ratio=None, path_compression=True, @@ -610,10 +611,11 @@ def match_ancestors_batch_init( working_dir.mkdir(parents=True, exist_ok=True) ancestors = formats.AncestorData.load(ancestor_data_path) - sample_data = formats.SgkitSampleData( + sample_data = formats.VariantData( sample_data_path, - sgkit_samples_mask_name=sgkit_samples_mask_name, - sites_mask_name=sites_mask_name, + ancestral_allele_name_or_array=ancestral_allele_name_or_array, + sample_mask_name_or_array=sample_mask_name_or_array, + site_mask_name_or_array=site_mask_name_or_array, ) ancestors._check_finalised() sample_data._check_finalised() @@ -663,9 +665,10 @@ def match_ancestors_batch_init( metadata = { "sample_data_path": str(sample_data_path), + "ancestral_allele_name_or_array": ancestral_allele_name_or_array, "ancestor_data_path": str(ancestor_data_path), - "sgkit_samples_mask_name": sgkit_samples_mask_name, - "sites_mask_name": sites_mask_name, + "sample_mask_name_or_array": sample_mask_name_or_array, + "site_mask_name_or_array": site_mask_name_or_array, "recombination_rate": recombination_rate, "mismatch_ratio": mismatch_ratio, "path_compression": path_compression, @@ -684,10 +687,11 @@ def match_ancestors_batch_init( def initialize_matcher(metadata, ancestors_ts=None, **kwargs): - sample_data = formats.SgkitSampleData( + sample_data = formats.VariantData( metadata["sample_data_path"], - sgkit_samples_mask_name=metadata["sgkit_samples_mask_name"], - sites_mask_name=metadata["sites_mask_name"], + ancestral_allele_name_or_array=metadata["ancestral_allele_name_or_array"], + sample_mask_name_or_array=metadata["sample_mask_name_or_array"], + site_mask_name_or_array=metadata["site_mask_name_or_array"], ) ancestors = formats.AncestorData.load(metadata["ancestor_data_path"]) sample_data._check_finalised()