Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Jul 25, 2024
1 parent 0c5d895 commit ee3474f
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 111 deletions.
12 changes: 6 additions & 6 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,15 +1411,15 @@ def test_cache_used_by_timing(self, tmpdir):
sample_data, ancestor_ts, match_data_dir=tmpdir
)
time2 = time.time() - t
assert time2 < time1 / 1.25
assert time2 < time1
final_ts1.tables.assert_equals(final_ts2.tables, ignore_provenance=True)


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
class TestBatchAncestorMatching:
def test_equivalance(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
ancestors = tsinfer.generate_ancestors(
samples, path=str(tmpdir / "ancestors.zarr")
)
Expand All @@ -1436,7 +1436,7 @@ def test_equivalance(self, tmp_path, tmpdir):

def test_equivalance_many_at_once(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
ancestors = tsinfer.generate_ancestors(
samples, path=str(tmpdir / "ancestors.zarr")
)
Expand All @@ -1459,7 +1459,7 @@ def test_equivalance_many_at_once(self, tmp_path, tmpdir):

def test_equivalance_with_partitions(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
ancestors = tsinfer.generate_ancestors(
samples, path=str(tmpdir / "ancestors.zarr")
)
Expand All @@ -1485,7 +1485,7 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir):

def test_max_partitions(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
ancestors = tsinfer.generate_ancestors(
samples, path=str(tmpdir / "ancestors.zarr")
)
Expand Down Expand Up @@ -1516,7 +1516,7 @@ def test_max_partitions(self, tmp_path, tmpdir):

def test_errors(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000
Expand Down
85 changes: 50 additions & 35 deletions tests/test_sgkit.py → tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
def test_sgkit_dataset_roundtrip(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
inf_ts = tsinfer.infer(samples)
ds = sgkit.load_dataset(zarr_path)

Expand Down Expand Up @@ -120,7 +120,7 @@ def test_sgkit_individual_metadata_not_clobbered(tmp_path):
tskit.MetadataSchema.permissive_json()
)

samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
inf_ts = tsinfer.infer(samples)
ds = sgkit.load_dataset(zarr_path)

Expand All @@ -139,7 +139,7 @@ def test_sgkit_dataset_accessors(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(
tmp_path, add_optional=True, shuffle_alleles=False
)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path, sites_time_name_or_array="sites_time")
ds = sgkit.load_dataset(zarr_path)

assert samples.format_name == "tsinfer-sgkit-sample-data"
Expand Down Expand Up @@ -203,7 +203,7 @@ def test_sgkit_dataset_accessors(tmp_path):

# Need to shuffle for the ancestral allele test
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path, add_optional=True)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
for i in range(ts.num_sites):
assert (
samples.sites_alleles[i][samples.sites_ancestral_allele[i]]
Expand All @@ -214,7 +214,7 @@ def test_sgkit_dataset_accessors(tmp_path):
@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
def test_sgkit_accessors_defaults(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
ds = sgkit.load_dataset(zarr_path)

default_schema = tskit.MetadataSchema.permissive_json().schema
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_simulate_genotype_call_dataset(tmp_path):
ds = ts_to_dataset(ts)
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
ds.to_zarr(tmp_path, mode="w")
sd = tsinfer.SgkitSampleData(tmp_path)
sd = tsinfer.VariantData(tmp_path)
ts = tsinfer.infer(sd)
for v, ds_v, sd_v in zip(ts.variants(), ds.call_genotype, sd.sites_genotypes):
assert np.all(v.genotypes == ds_v.values.flatten())
Expand All @@ -271,7 +271,9 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
for i in sites:
sites_mask[i] = False
tsutil.add_array_to_dataset("variant_mask_42", sites_mask, zarr_path)
samples = tsinfer.SgkitSampleData(zarr_path, sites_mask_name="variant_mask_42")
samples = tsinfer.VariantData(
zarr_path, site_mask_name_or_array="variant_mask_42"
)
assert samples.num_sites == len(sites)
assert np.array_equal(samples.sites_select, ~sites_mask)
assert np.array_equal(
Expand All @@ -298,12 +300,14 @@ def test_sgkit_variant_bad_mask_length(self, tmp_path):
ds = sgkit.load_dataset(zarr_path)
sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int)
tsutil.add_array_to_dataset("variant_mask_foobar", sites_mask, zarr_path)
tsinfer.SgkitSampleData(zarr_path)
tsinfer.VariantData(zarr_path)
with pytest.raises(
ValueError,
match="Mask must be the same length as the number of unmasked sites",
):
tsinfer.SgkitSampleData(zarr_path, sites_mask_name="variant_mask_foobar")
tsinfer.VariantData(
zarr_path, site_mask_name_or_array="variant_mask_foobar"
)

def test_bad_select_length_at_iterator(self, tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
Expand All @@ -325,8 +329,8 @@ def test_sgkit_sample_mask(self, tmp_path, sample_list):
for i in sample_list:
samples_mask[i] = False
tsutil.add_array_to_dataset("samples_mask_69", samples_mask, zarr_path)
samples = tsinfer.SgkitSampleData(
zarr_path, sgkit_samples_mask_name="samples_mask_69"
samples = tsinfer.VariantData(
zarr_path, sample_mask_name_or_array="samples_mask_69"
)
assert samples.ploidy == 3
assert samples.num_individuals == len(sample_list)
Expand Down Expand Up @@ -380,10 +384,10 @@ def test_sgkit_sample_and_site_mask(self, tmp_path):
tsutil.add_array_to_dataset(
"samples_mask_foobar", samples_mask, zarr_path, dims=["samples"]
)
samples = tsinfer.SgkitSampleData(
samples = tsinfer.VariantData(
zarr_path,
sites_mask_name="variant_mask_foobar",
sgkit_samples_mask_name="samples_mask_foobar",
site_mask_name_or_array="variant_mask_foobar",
sample_mask_name_or_array="samples_mask_foobar",
)
genotypes = samples.sites_genotypes
ds_genotypes = ds.call_genotype.values[~variant_mask][
Expand All @@ -403,19 +407,19 @@ def test_sgkit_sample_and_site_mask(self, tmp_path):

def test_sgkit_missing_masks(self, tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
samples.individuals_select
samples.sites_select
with pytest.raises(
ValueError, match="The sites mask foobar was not found in the dataset."
):
tsinfer.SgkitSampleData(zarr_path, sites_mask_name="foobar")
tsinfer.VariantData(zarr_path, site_mask_name_or_array="foobar")
with pytest.raises(
ValueError,
match="The sgkit samples mask foobar2 was not found in the dataset.",
match="The samples mask foobar2 was not found in the dataset.",
):
samples = tsinfer.SgkitSampleData(
zarr_path, sgkit_samples_mask_name="foobar2"
samples = tsinfer.VariantData(
zarr_path, sample_mask_name_or_array="foobar2"
)
samples.individuals_select

Expand Down Expand Up @@ -500,7 +504,7 @@ def test_sgkit_subset_equivalence(self, tmp_path, tmpdir):
def test_sgkit_ancestral_allele_same_ancestors(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
ts_sampledata = tsinfer.SampleData.from_tree_sequence(ts)
sg_sampledata = tsinfer.SgkitSampleData(zarr_path)
sg_sampledata = tsinfer.VariantData(zarr_path)
ts_ancestors = tsinfer.generate_ancestors(ts_sampledata)
sg_ancestors = tsinfer.generate_ancestors(sg_sampledata)
assert np.array_equal(ts_ancestors.sites_position, sg_ancestors.sites_position)
Expand All @@ -526,9 +530,8 @@ def test_missing_ancestral_allele(tmp_path):
ds = sgkit.load_dataset(zarr_path)
ds = ds.drop_vars(["variant_ancestral_allele"])
sgkit.save_dataset(ds, str(zarr_path) + ".tmp")
samples = tsinfer.SgkitSampleData(str(zarr_path) + ".tmp")
with pytest.raises(ValueError, match="variant_ancestral_allele was not found"):
tsinfer.infer(samples)
tsinfer.VariantData(str(zarr_path) + ".tmp")


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
Expand All @@ -549,12 +552,12 @@ def test_ancestral_missingness(tmp_path):
["variants"],
)
ds = sgkit.load_dataset(str(zarr_path) + ".tmp")
sd = tsinfer.SgkitSampleData(str(zarr_path) + ".tmp")
with pytest.warns(
UserWarning,
match=r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2",
):
inf_ts = tsinfer.infer(sd)
sd = tsinfer.VariantData(str(zarr_path) + ".tmp")
inf_ts = tsinfer.infer(sd)
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
if i in [0, 11, 12, 15]:
assert inf_var.site.metadata == {"inference_type": "parsimony"}
Expand All @@ -574,15 +577,15 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path):
sgkit.display_genotypes(ds)


class TestSgkitSampleDataErrors:
class TestVariantDataErrors:
def test_missing_phase(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
sgkit.save_dataset(ds, path)
with pytest.raises(
ValueError, match="The call_genotype_phased array is missing"
):
tsinfer.SgkitSampleData(path)
tsinfer.VariantData(path)

def test_phased(self, tmp_path):
path = tmp_path / "data.zarr"
Expand All @@ -591,15 +594,23 @@ def test_phased(self, tmp_path):
ds["call_genotype"].dims,
np.ones(ds["call_genotype"].shape, dtype=bool),
)
ds["variant_ancestral_allele"] = (
ds["variant_position"].dims,
np.array(["A", "C", "G"], dtype="S1"),
)
sgkit.save_dataset(ds, path)
tsinfer.SgkitSampleData(path)
tsinfer.VariantData(path)

def test_ploidy1_missing_phase(self, tmp_path):
path = tmp_path / "data.zarr"
# Ploidy==1 is always ok
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["variant_ancestral_allele"] = (
ds["variant_position"].dims,
np.array(["A", "C", "G"], dtype="S1"),
)
sgkit.save_dataset(ds, path)
tsinfer.SgkitSampleData(path)
tsinfer.VariantData(path)

def test_ploidy1_unphased(self, tmp_path):
path = tmp_path / "data.zarr"
Expand All @@ -608,24 +619,28 @@ def test_ploidy1_unphased(self, tmp_path):
ds["call_genotype"].dims,
np.zeros(ds["call_genotype"].shape, dtype=bool),
)
ds["variant_ancestral_allele"] = (
ds["variant_position"].dims,
np.array(["A", "C", "G"], dtype="S1"),
)
sgkit.save_dataset(ds, path)
tsinfer.SgkitSampleData(path)
tsinfer.VariantData(path)

def test_duplicate_positions(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["variant_position"][2] = ds["variant_position"][1]
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
tsinfer.SgkitSampleData(path)
tsinfer.VariantData(path)

def test_bad_order_positions(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["variant_position"][0] = ds["variant_position"][2] - 0.5
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
tsinfer.SgkitSampleData(path)
tsinfer.VariantData(path)

def test_empty_alleles_not_at_end(self, tmp_path):
path = tmp_path / "data.zarr"
Expand All @@ -639,7 +654,7 @@ def test_empty_alleles_not_at_end(self, tmp_path):
np.array(["C", "A", "A"], dtype="S1"),
)
sgkit.save_dataset(ds, path)
samples = tsinfer.SgkitSampleData(path)
samples = tsinfer.VariantData(path)
with pytest.raises(ValueError, match="Empty alleles must be at the end"):
tsinfer.infer(samples)

Expand All @@ -649,7 +664,7 @@ class TestSgkitMatchSamplesToDisk:
@pytest.mark.parametrize("slice", [(0, 6), (0, 0), (0, 3), (12, 15)])
def test_match_samples_to_disk_write(self, slice, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
ancestors = tsinfer.generate_ancestors(samples)
anc_ts = tsinfer.match_ancestors(samples, ancestors)
tsinfer.match_samples_slice_to_disk(
Expand All @@ -664,7 +679,7 @@ def test_match_samples_to_disk_write(self, slice, tmp_path, tmpdir):

def test_match_samples_to_disk_slice_error(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
ancestors = tsinfer.generate_ancestors(samples)
anc_ts = tsinfer.match_ancestors(samples, ancestors)
with pytest.raises(
Expand All @@ -678,7 +693,7 @@ def test_match_samples_to_disk_full(self, tmp_path, tmpdir):
match_data_dir = tmpdir / "match_data"
os.mkdir(match_data_dir)
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
samples = tsinfer.VariantData(zarr_path)
ancestors = tsinfer.generate_ancestors(samples)
anc_ts = tsinfer.match_ancestors(samples, ancestors)
ts = tsinfer.match_samples(samples, anc_ts)
Expand Down
8 changes: 4 additions & 4 deletions tests/tsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,10 @@ def make_materialized_and_masked_sampledata(tmp_path, tmpdir):
mat_ds = mat_ds.unify_chunks()
sgkit.save_dataset(mat_ds, tmpdir / "subset.zarr", auto_rechunk=True)

mat_sd = tsinfer.SgkitSampleData(tmpdir / "subset.zarr")
mask_sd = tsinfer.SgkitSampleData(
mat_sd = tsinfer.VariantData(tmpdir / "subset.zarr")
mask_sd = tsinfer.VariantData(
zarr_path,
sites_mask_name="variant_mask_foobar",
sgkit_samples_mask_name="samples_mask_foobar",
site_mask_name_or_array="variant_mask_foobar",
sample_mask_name_or_array="samples_mask_foobar",
)
return mat_sd, mask_sd, samples_mask, variant_mask
Loading

0 comments on commit ee3474f

Please sign in to comment.