Skip to content

Commit

Permalink
performance optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
bfclarke committed Nov 20, 2023
1 parent 21758ce commit 6761a90
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,33 @@
import gc
import itertools
import logging
import sys
import pickle
import random
import shutil
import sys
from pathlib import Path
from pprint import pformat, pprint
from typing import Dict, Optional, Tuple

import torch.nn.functional as F
import numpy as np
import click
import math
import numpy as np
import optuna
import pandas as pd
import pytorch_lightning as pl
import torch.nn.functional as F
import deeprvat.deeprvat.models as deeprvat_models
import torch
import yaml
import zarr
from deeprvat.data import DenseGTDataset
from deeprvat.metrics import (
AveragePrecisionWithLogits,
PearsonCorr,
PearsonCorrTorch,
RSquared,
)
from deeprvat.utils import suggest_hparams
from numcodecs import Blosc
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
Expand All @@ -27,15 +37,6 @@
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm

import deeprvat.deeprvat.models as deeprvat_models
from deeprvat.data import DenseGTDataset
from deeprvat.metrics import (
PearsonCorr,
PearsonCorrTorch,
RSquared,
AveragePrecisionWithLogits,
)
from deeprvat.utils import suggest_hparams

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
Expand Down Expand Up @@ -236,6 +237,8 @@ def __init__(
)

self.cache_tensors = cache_tensors
if self.cache_tensors:
logger.info("Keeping all input tensors in main memory")

for _, pheno_data in self.data.items():
if pheno_data["y"].shape == (pheno_data["input_tensor_zarr"].shape[0], 1):
Expand Down Expand Up @@ -294,7 +297,7 @@ def __getitem__(self, index):
annotations = (
self.data[pheno]["input_tensor"][idx]
if self.cache_tensors
else self.data[pheno]["input_tensor_zarr"].oindex[idx, :, :, :]
else self.data[pheno]["input_tensor_zarr"][idx[0]:idx[-1] + 1, :, :, :]
)

result[pheno] = {
Expand Down Expand Up @@ -339,6 +342,7 @@ def __init__(
upsampling_factor: int = 1,
batch_size: Optional[int] = None,
num_workers: Optional[int] = 0,
pin_memory: bool = False,
cache_tensors: bool = False,
):
logger.info("Intializing datamodule")
Expand Down Expand Up @@ -401,6 +405,7 @@ def __init__(
"train_proportion",
"batch_size",
"num_workers",
"pin_memory",
"cache_tensors",
)

Expand Down Expand Up @@ -440,7 +445,7 @@ def train_dataloader(self):
cache_tensors=self.hparams.cache_tensors,
)
return DataLoader(
dataset, batch_size=None, num_workers=self.hparams.num_workers
dataset, batch_size=None, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory
)

def val_dataloader(self):
Expand All @@ -456,7 +461,7 @@ def val_dataloader(self):
cache_tensors=self.hparams.cache_tensors,
)
return DataLoader(
dataset, batch_size=None, num_workers=self.hparams.num_workers
dataset, batch_size=None, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory
)


Expand Down

0 comments on commit 6761a90

Please sign in to comment.