Skip to content

Commit

Permalink
subset samples in training_dataset rule
Browse files Browse the repository at this point in the history
  • Loading branch information
bfclarke committed Nov 30, 2023
1 parent 58db4c4 commit 72c8fc7
Show file tree
Hide file tree
Showing 4 changed files with 576 additions and 34 deletions.
96 changes: 80 additions & 16 deletions deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,45 @@ def cli():
pass


def subset_samples(
input_tensor: torch.Tensor,
covariates: torch.Tensor,
y: torch.Tensor,
min_variant_count: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# First sum over annotations (dim 2) for each variant in each gene.
# Then get the number of non-zero values across all variants in all
# genes for each sample.
n_samples_orig = input_tensor.shape[0]

# n_variants_per_sample = np.sum(
# np.sum(input_tensor.numpy(), axis=2) != 0, axis=(1, 2)
# )
# n_variant_mask = n_variants_per_sample >= min_variant_count
n_variant_mask = (
np.sum(np.any(input_tensor.numpy(), axis=(1, 2)), axis=1) >= min_variant_count
)

# Also make sure we don't have NaN values for y
nan_mask = ~y.squeeze().isnan()
mask = n_variant_mask & nan_mask.numpy()

# Subset all the tensors
input_tensor = input_tensor[mask]
covariates = covariates[mask]
y = y[mask]

logger.info(f"{input_tensor.shape[0]} / {n_samples_orig} samples kept")

return input_tensor, covariates, y


def make_dataset_(
config: Dict,
debug: bool = False,
training_dataset_file: str = None,
training_dataset_file: Optional[str] = None,
pickle_only: bool = False,
):
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_phenotypes = config.get("n_phenotypes", None)
if n_phenotypes is not None:
if "seed_genes" in config:
Expand Down Expand Up @@ -164,12 +197,17 @@ def make_dataset_(
input_tensor = torch.cat(
[
F.pad(r, (0, max_n_variants - r.shape[-1]), value=pad_value)
for r in tqdm(rare_batches, file=sys.stdout)
for r in rare_batches
]
)
covariates = torch.cat([b["x_phenotypes"] for b in batches])
y = torch.cat([b["y"] for b in batches])

logger.info("Subsetting samples by min_variant_count and missing y values")
input_tensor, covariates, y = subset_samples(
input_tensor, covariates, y, config["training"]["min_variant_count"]
)

return input_tensor, covariates, y


Expand Down Expand Up @@ -240,11 +278,21 @@ def __init__(
)

self.cache_tensors = cache_tensors
if self.cache_tensors:
self.zarr_root = zarr.group()
elif temp_dir is not None:
temp_path = Path(resolve_path_with_env(temp_dir)) / "deeprvat_training"
temp_path.mkdir(parents=True, exist_ok=True)
self.input_tensor_dir = TemporaryDirectory(
prefix="training_data", dir=str(temp_path)
)
# Create root group here

self.chunksize = chunksize
if self.cache_tensors:
logger.info("Keeping all input tensors in main memory")

for _, pheno_data in self.data.items():
for pheno, pheno_data in self.data.items():
if pheno_data["y"].shape == (pheno_data["input_tensor_zarr"].shape[0], 1):
pheno_data["y"] = pheno_data["y"].squeeze()
elif pheno_data["y"].shape != (pheno_data["input_tensor_zarr"].shape[0],):
Expand All @@ -253,20 +301,34 @@ def __init__(
)

if self.cache_tensors:
pheno_data["input_tensor"] = pheno_data["input_tensor_zarr"][:]
zarr.copy(
pheno_data["input_tensor_zarr"],
self.zarr_root,
name=pheno,
chunks=(self.chunksize, None, None, None),
compressor=Blosc(clevel=1),
)
pheno_data["input_tensor_zarr"] = self.zarr_root[pheno]
# pheno_data["input_tensor"] = pheno_data["input_tensor_zarr"][:]
elif temp_dir is not None:
tensor_path = (
Path(self.input_tensor_dir.name) / pheno / "input_tensor.zarr"
)
zarr.copy(
pheno_data["input_tensor_zarr"],
zarr.DirectoryStore(tensor_path),
chunks=(self.chunksize, None, None, None),
compressor=Blosc(clevel=1),
)
pheno_data["input_tensor_zarr"] = zarr.open(tensor_path)

self.min_variant_count = min_variant_count
self.samples = {
pheno: pheno_data["samples"][split]
for pheno, pheno_data in self.data.items()
}
temp_path = (Path(resolve_path_with_env(temp_dir)) / "deeprvat_training"
if temp_dir is not None
else Path("deeprvat_training"))
temp_path.mkdir(parents=True, exist_ok=True)
self.input_tensor_dir = TemporaryDirectory(prefix="training_data", dir=str(temp_path))

self.subset_samples()
# self.subset_samples()

self.total_samples = sum([s.shape[0] for s in self.samples.values()])

Expand Down Expand Up @@ -306,11 +368,12 @@ def __getitem__(self, index):
assert np.array_equal(idx, np.arange(idx[0], idx[-1] + 1))
slice_ = slice(idx[0], idx[-1] + 1)

annotations = (
self.data[pheno]["input_tensor"][slice_]
if self.cache_tensors
else self.data[pheno]["input_tensor_zarr"][slice_, :, :, :]
)
# annotations = (
# self.data[pheno]["input_tensor"][slice_]
# if self.cache_tensors
# else self.data[pheno]["input_tensor_zarr"][slice_, :, :, :]
# )
annotations = self.data[pheno]["input_tensor_zarr"][slice_, :, :, :]

result[pheno] = {
"indices": self.samples[pheno][slice_],
Expand All @@ -321,6 +384,7 @@ def __getitem__(self, index):

return result

# NOTE: This function is broken with current cache_tensors behavior
def subset_samples(self):
for pheno, pheno_data in self.data.items():
# First sum over annotations (dim 2) for each variant in each gene.
Expand Down
98 changes: 98 additions & 0 deletions lsf/lsf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
__default__:
- "-q medium"
- "-R \"select[(hname != 'odcf-cn11u15' && hname != 'odcf-cn31u13' && hname != 'odcf-cn31u21' && hname != 'odcf-cn23u23')]\""


# For association testing pipelines

config:
- "-q short"

training_dataset:
- "-q long"

delete_burden_cache:
- "-q short"

choose_training_genes:
- "-q short"

best_cv_run:
- "-q short"
link_avg_burdens:
- "-q short"
best_bagging_run:
- "-q short"

train:
- "-q gpu"
- "-gpu num=1:gmem=10.7G"
- "-R \"select[(hname != 'e230-dgx2-1' && hname != 'e230-dgx2-2' && hname != 'e230-dgxa100-1' && hname != 'e230-dgxa100-2' && hname != 'e230-dgxa100-3' && hname != 'e230-dgxa100-4' && hname != 'e071-gpu06')]\""
# - "-R tensorcore"
# - "-L /bin/bash"

compute_burdens:
- "-q gpu-short"
- "-gpu num=1:j_exclusive=yes:mode=exclusive_process:gmem=15.7G"
- "-R \"select[(hname != 'e230-dgx2-1' && hname != 'e230-dgx2-2' && hname != 'e230-dgxa100-1' && hname != 'e230-dgxa100-2' && hname != 'e230-dgxa100-3' && hname != 'e230-dgxa100-4' && hname != 'e071-gpu06')]\""
- "-W 180"
# - "-R tensorcore"
# - "-L /bin/bash"

link_burdens:
- "-q medium"

compute_plof_burdens:
- "-q medium"

regress:
- "-q long"

combine_regression_chunks:
- "-q short"


# For CV (phenotype prediction) pipeline

deeprvat_config:
- "-q short"

deeprvat_plof_config:
- "-q short"

deeprvat_training_dataset:
- "-q long"

deeprvat_delete_burden_cache:
- "-q short"

deeprvat_best_cv_run:
- "-q short"

deeprvat_train_cv:
- "-q gpu-lowprio"
- "-gpu num=1:j_exclusive=yes:mode=exclusive_process:gmem=10.7G"
- "-R \"select[(hname != 'e230-dgx2-1' && hname != 'e230-dgx2-2' && hname != 'e230-dgxa100-1' && hname != 'e230-dgxa100-2' && hname != 'e230-dgxa100-3' && hname != 'e071-gpu06')]\""
# - "-R tensorcore"
# - "-L /bin/bash"

deeprvat_train_bagging:
- "-q gpu-lowprio"
- "-gpu num=1:j_exclusive=yes:mode=exclusive_process:gmem=10.7G"
- "-R \"select[(hname != 'e230-dgx2-1' && hname != 'e230-dgx2-2' && hname != 'e230-dgxa100-1' && hname != 'e230-dgxa100-2' && hname != 'e230-dgxa100-3' && hname != 'e230-dgxa100-4' && hname != 'e071-gpu06')]\""
# - "-R tensorcore"
# - "-L /bin/bash"

deeprvat_compute_burdens:
- "-q gpu-lowprio"
- "-gpu num=1:j_exclusive=yes:mode=exclusive_process:gmem=10.7G"
- "-R \"select[(hname != 'e230-dgx2-1' && hname != 'e230-dgx2-2' && hname != 'e230-dgxa100-1' && hname != 'e230-dgxa100-2' && hname != 'e230-dgxa100-3' && hname != 'e230-dgxa100-4' && hname != 'e071-gpu06')]\""
- "-W 180"
# - "-R tensorcore"
# - "-L /bin/bash"

deeprvat_compute_plof_burdens:
- "-q medium"

deeprvat_regress:
- "-q long"
Loading

0 comments on commit 72c8fc7

Please sign in to comment.