Skip to content

Commit

Permalink
debugging new genotype
Browse files Browse the repository at this point in the history
  • Loading branch information
HolEv committed Jan 20, 2024
1 parent 9943a91 commit db9a28e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
5 changes: 3 additions & 2 deletions deeprvat/data/dense_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __getitem__(self, idx: int) -> torch.tensor:
all_sparse_variants,
sparse_genotype,
) = self.get_common_variants(sparse_variants, sparse_genotype)

# print('getitem')
rare_variant_annotations = self.get_rare_variants(
idx, all_sparse_variants, sparse_genotype
) #idx is not used by get_rare_variants
Expand Down Expand Up @@ -542,7 +542,7 @@ def setup_variants(
variants = dd.read_parquet(self.variant_filename, engine="pyarrow").compute()
variants = variants.set_index("id", drop=False)
variants = variants.drop(columns="matrix_index", errors="ignore")

self.variants_to_keep = variants['id']
if self.variants_to_keep is not None:
logger.info("Selecting subset of variants as defined by variants_to_keep")
variants = variants.loc[self.variants_to_keep]
Expand Down Expand Up @@ -789,6 +789,7 @@ def get_common_variants(
):
padding_mask = sparse_variants >= 0
if self.variants_to_keep is not None:
# print('filtering for variants_to_keep')
padding_mask &= np.isin(sparse_variants, self.variants_to_keep)

masked_sparse_variants = sparse_variants[padding_mask]
Expand Down
34 changes: 31 additions & 3 deletions deeprvat/data/rare.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,28 @@ def embed(
return torch.tensor([])

variants_mapped = self.variant_map[variant_ids]
# print('variant_ids_unique')
# print(np.unique(variant_ids))
# print('variants_mapped')
# print(variants_mapped)
# print('variants_mapped_unique')
# print(np.unique(variants_mapped))
# print('variants_sum')
# print(sum(variants_mapped < 0))
# print(variants_mapped.shape)
mask = variants_mapped >= 0
# print('mask')
# print(mask.shape)
variant_ids = variant_ids[mask]
# print('variant_ids')
# print(variant_ids.shape)
# print('genotype')
# print(genotype.shape)
genotype = genotype[mask]
# print(genotype.shape)
# print(genotype)
rows = []
# print('embed')
for v, g in zip(variant_ids, genotype):
ids = self.exp_anno_id_indices[v] # np.ndarray
# homozygous variants are considered twice
Expand All @@ -109,11 +127,10 @@ def embed(
result = [[] for _ in range(len(self.genes))]
if len(rows) > 0:
rows = np.concatenate(rows)
# logger.info(f"rows {rows}")
for i in rows:
gene = self.gene_map[self.genes_np[i]] # NOTE: Changed
result[gene].append(self.annotation_df_np[i, :])

# print(result)
return result

def collate_fn(
Expand All @@ -131,8 +148,19 @@ def collate_fn(

n_samples = len(batch)
max_n_variants = max(len(gene) for sample in batch for gene in sample)
# print(f'max n variants {max_n_variants}')
n_annotations = len(self.annotations)
result = np.zeros(
# print('collate_fn')
#got until here
# print('hot')
# (16, 17981, 34, 24931)
# print('ho')
# np.zeros(
# (n_samples, self.n_genes, n_annotations, max_n_variants), dtype=np.float32
# )
# print('ho')
# print((n_samples, self.n_genes, n_annotations, max_n_variants))
result = np.zeros( #this is causing the bug
(n_samples, self.n_genes, n_annotations, max_n_variants), dtype=np.float32
)
for i, sample in enumerate(batch):
Expand Down

0 comments on commit db9a28e

Please sign in to comment.