diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index c2e2abd4..5e20cc2b 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -2,23 +2,33 @@ import gc import itertools import logging -import sys import pickle +import random import shutil +import sys from pathlib import Path from pprint import pformat, pprint from typing import Dict, Optional, Tuple -import torch.nn.functional as F -import numpy as np import click import math +import numpy as np import optuna import pandas as pd import pytorch_lightning as pl +import torch.nn.functional as F +import deeprvat.deeprvat.models as deeprvat_models import torch import yaml import zarr +from deeprvat.data import DenseGTDataset +from deeprvat.metrics import ( + AveragePrecisionWithLogits, + PearsonCorr, + PearsonCorrTorch, + RSquared, +) +from deeprvat.utils import suggest_hparams from numcodecs import Blosc from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping @@ -27,15 +37,6 @@ from torch.utils.data import DataLoader, Dataset, Subset from tqdm import tqdm -import deeprvat.deeprvat.models as deeprvat_models -from deeprvat.data import DenseGTDataset -from deeprvat.metrics import ( - PearsonCorr, - PearsonCorrTorch, - RSquared, - AveragePrecisionWithLogits, -) -from deeprvat.utils import suggest_hparams logging.basicConfig( format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s", @@ -236,6 +237,8 @@ def __init__( ) self.cache_tensors = cache_tensors + if self.cache_tensors: + logger.info("Keeping all input tensors in main memory") for _, pheno_data in self.data.items(): if pheno_data["y"].shape == (pheno_data["input_tensor_zarr"].shape[0], 1): @@ -294,7 +297,7 @@ def __getitem__(self, index): annotations = ( self.data[pheno]["input_tensor"][idx] if self.cache_tensors - else self.data[pheno]["input_tensor_zarr"].oindex[idx, :, :, :] + else self.data[pheno]["input_tensor_zarr"][idx[0]:idx[-1] + 1, :, :, :] ) result[pheno] = { @@ -339,6 +342,7 @@ def __init__( upsampling_factor: int = 1, batch_size: Optional[int] = None, num_workers: Optional[int] = 0, + pin_memory: bool = False, cache_tensors: bool = False, ): logger.info("Intializing datamodule") @@ -401,6 +405,7 @@ def __init__( "train_proportion", "batch_size", "num_workers", + "pin_memory", "cache_tensors", ) @@ -440,7 +445,7 @@ def train_dataloader(self): cache_tensors=self.hparams.cache_tensors, ) return DataLoader( - dataset, batch_size=None, num_workers=self.hparams.num_workers + dataset, batch_size=None, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory ) def val_dataloader(self): @@ -456,7 +461,7 @@ def val_dataloader(self): cache_tensors=self.hparams.cache_tensors, ) return DataLoader( - dataset, batch_size=None, num_workers=self.hparams.num_workers + dataset, batch_size=None, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory )