Skip to content

Commit

Permalink
Additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Jul 26, 2024
1 parent 7ddf477 commit e705d66
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
41 changes: 41 additions & 0 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,37 @@ def test_sgkit_accessors_defaults(tmp_path):
)


def test_variantdata_sites_time_default(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")

assert (
np.all(np.isnan(samples.sites_time))
and samples.sites_time.size == samples.num_sites
)


def test_variantdata_sites_time_array(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
sites_time = np.arange(ts.num_sites)
samples = tsinfer.VariantData(
zarr_path, "variant_ancestral_allele", sites_time_name_or_array=sites_time
)

assert np.array_equal(samples.sites_time, sites_time)

wrong_length_sites_time = np.arange(ts.num_sites + 1)
with pytest.raises(
ValueError,
match="Sites time array must be the same length as the number of selected sites",
):
tsinfer.VariantData(
zarr_path,
"variant_ancestral_allele",
sites_time_name_or_array=wrong_length_sites_time,
)


def test_simulate_genotype_call_dataset(tmp_path):
# Test that byte alleles are correctly converted to string
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
Expand Down Expand Up @@ -436,6 +467,16 @@ def test_sgkit_missing_masks(self, tmp_path):
)
samples.individuals_select

def test_variantdata_default_masks(self, tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
assert np.array_equal(
samples.sites_select, np.full(samples.num_sites, True, dtype=bool)
)
assert np.array_equal(
samples.samples_select, np.full(samples.num_samples, True, dtype=bool)
)

def test_sgkit_subset_equivalence(self, tmp_path, tmpdir):
(
mat_sd,
Expand Down
2 changes: 1 addition & 1 deletion tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2387,7 +2387,7 @@ def __init__(
elif isinstance(sites_time_name_or_array, np.ndarray):
if sites_time_name_or_array.shape[0] != self.num_sites:
raise ValueError(
"Sites time array must be the same length as the number selected"
"Sites time array must be the same length as the number of selected"
" sites"
)
self._sites_time = sites_time_name_or_array
Expand Down

0 comments on commit e705d66

Please sign in to comment.