Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Jul 26, 2024
1 parent e705d66 commit e2c5449
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 36 deletions.
51 changes: 48 additions & 3 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def test_sgkit_accessors_defaults(tmp_path):
)


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
def test_variantdata_sites_time_default(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
Expand All @@ -259,15 +260,14 @@ def test_variantdata_sites_time_default(tmp_path):
)


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
def test_variantdata_sites_time_array(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
sites_time = np.arange(ts.num_sites)
samples = tsinfer.VariantData(
zarr_path, "variant_ancestral_allele", sites_time_name_or_array=sites_time
)

assert np.array_equal(samples.sites_time, sites_time)

wrong_length_sites_time = np.arange(ts.num_sites + 1)
with pytest.raises(
ValueError,
Expand Down Expand Up @@ -338,7 +338,8 @@ def test_sgkit_variant_bad_mask_length(self, tmp_path):
tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
with pytest.raises(
ValueError,
match="Mask must be the same length as the number of unmasked sites",
match="Site mask array must be the same length as the number of"
" unmasked sites",
):
tsinfer.VariantData(
zarr_path,
Expand Down Expand Up @@ -477,6 +478,50 @@ def test_variantdata_default_masks(self, tmp_path):
samples.samples_select, np.full(samples.num_samples, True, dtype=bool)
)

def test_variantdata_mask_as_arrays(self, tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
sample_mask = np.full(ts.num_individuals, False, dtype=bool)
site_mask = np.full(ts.num_sites, False, dtype=bool)
sample_mask[3] = True
site_mask[5] = True

samples = tsinfer.VariantData(
zarr_path,
"variant_ancestral_allele",
sample_mask_name_or_array=sample_mask,
site_mask_name_or_array=site_mask,
)

assert np.array_equal(samples.individuals_select, ~sample_mask)
assert np.array_equal(samples.sites_select, ~site_mask)

def test_variantdata_incorrect_mask_lengths(self, tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
incorrect_sample_mask = np.array([True, False]) # Incorrect length
incorrect_site_mask = np.array([True, False, True]) # Incorrect length

with pytest.raises(
ValueError,
match="Samples mask array must be the same length as the"
" number of individuals",
):
tsinfer.VariantData(
zarr_path,
"variant_ancestral_allele",
sample_mask_name_or_array=incorrect_sample_mask,
)

with pytest.raises(
ValueError,
match="Site mask array must be the same length as the number of"
" unmasked sites",
):
tsinfer.VariantData(
zarr_path,
"variant_ancestral_allele",
site_mask_name_or_array=incorrect_site_mask,
)

def test_sgkit_subset_equivalence(self, tmp_path, tmpdir):
(
mat_sd,
Expand Down
54 changes: 21 additions & 33 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2311,60 +2311,48 @@ def __init__(
genotypes_arr = self.data["call_genotype"]
_, self._num_individuals_before_mask, self.ploidy = genotypes_arr.shape

# Store sites_select
if site_mask_name_or_array is None:
self.sites_select = np.full(
self.data["variant_position"].shape, True, dtype=bool
site_mask_name_or_array = np.full(
self.data["variant_position"].shape, False, dtype=bool
)
elif isinstance(site_mask_name_or_array, np.ndarray):
if (
site_mask_name_or_array.shape[0]
!= self.data["variant_position"].shape[0]
):
raise ValueError(
"Mask array must be the same length as the number of unmasked sites"
)
self.sites_select = ~site_mask_name_or_array.astype(bool)
pass
else:
try:
if (
self.data[site_mask_name_or_array].shape[0]
!= self.data["variant_position"].shape[0]
):
raise ValueError(
"Mask must be the same length as the number of unmasked sites"
)
# We negate the mask as it is much easier in numpy to have True=keep
self.sites_select = ~(
self.data[site_mask_name_or_array].astype(bool)[:]
)
site_mask_name_or_array = self.data[site_mask_name_or_array][:]
except KeyError:
raise ValueError(
f"The sites mask {site_mask_name_or_array} was not found"
f" in the dataset."
)
if site_mask_name_or_array.shape[0] != self.data["variant_position"].shape[0]:
raise ValueError(
"Site mask array must be the same length as the number of unmasked sites"
)
# We negate the mask as it is much easier in numpy to have True=keep
self.sites_select = ~site_mask_name_or_array.astype(bool)

if sample_mask_name_or_array is None:
self.individuals_select = np.full(
self._num_individuals_before_mask, True, dtype=bool
sample_mask_name_or_array = np.full(
self._num_individuals_before_mask, False, dtype=bool
)
elif isinstance(sample_mask_name_or_array, np.ndarray):
if sample_mask_name_or_array.shape[0] != self._num_individuals_before_mask:
raise ValueError(
"Samples mask array must be the same length as the number of"
" individuals"
)
self.individuals_select = ~sample_mask_name_or_array.astype(bool)
pass
else:
try:
# We negate the mask as it is much easier in numpy to have True=keep
self.individuals_select = ~(
self.data[sample_mask_name_or_array][:].astype(bool)
)
sample_mask_name_or_array = self.data[sample_mask_name_or_array][:]
except KeyError:
raise ValueError(
f"The samples mask {sample_mask_name_or_array} was not"
f" found in the dataset."
)
if sample_mask_name_or_array.shape[0] != self._num_individuals_before_mask:
raise ValueError(
"Samples mask array must be the same length as the number of"
" individuals"
)
self.individuals_select = ~sample_mask_name_or_array.astype(bool)

self._num_sites = np.sum(self.sites_select)
assert self.ploidy == self.data["call_genotype"].chunks[2]
Expand Down

0 comments on commit e2c5449

Please sign in to comment.