Skip to content

Commit

Permalink
Black code
Browse files Browse the repository at this point in the history
  • Loading branch information
endast committed Apr 25, 2024
1 parent 96af15c commit 4e09428
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,29 @@
import itertools
import logging
import math
import pickle
import os
import pickle
import re
import sys
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple

import click
import dask.dataframe as dd
import numpy as np
import pandas as pd
import pyranges as pr
import statsmodels.api as sm

import torch
import torch.nn as nn
import statsmodels.api as sm
import yaml
from bgen import BgenWriter
import zarr

from numcodecs import Blosc
from seak import scoretest

from statsmodels.tools.tools import add_constant
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm, trange
import zarr
import re

import deeprvat.deeprvat.models as deeprvat_models
from deeprvat.data import DenseGTDataset
Expand Down Expand Up @@ -303,12 +302,11 @@ def compute_burdens_(
logger.info(f"Writing chunks to {burdens_chunk_path}")

for i, batch in tqdm(
enumerate(dl),
file=sys.stdout,
total=(n_samples // batch_size + (n_samples % batch_size != 0)),
enumerate(dl),
file=sys.stdout,
total=(n_samples // batch_size + (n_samples % batch_size != 0)),
):


this_burdens, this_y, this_x, this_sampleid = get_burden(
batch, agg_models, device=device, skip_burdens=skip_burdens
)
Expand All @@ -326,7 +324,7 @@ def compute_burdens_(
burdens = zarr.open(
burdens_chunk_path / f"burdens.zarr",
mode="a",
shape=(n_total_samples,) + this_burdens.shape[1:],
shape=this_burdens.shape,
chunks=(1000, 1000, 1),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
Expand All @@ -338,23 +336,23 @@ def compute_burdens_(
y = zarr.open(
burdens_chunk_path / f"y.zarr",
mode="a",
shape=(n_total_samples,) + this_y.shape[1:],
shape=this_y.shape,
chunks=(None, None),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
)
x = zarr.open(
burdens_chunk_path / f"x.zarr",
mode="a",
shape=(n_total_samples,) + this_x.shape[1:],
shape=this_x.shape,
chunks=(None, None),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
)
sample_ids = zarr.open(
burdens_chunk_path / f"sample_ids.zarr",
mode="a",
shape=(n_total_samples),
shape=this_sampleid.shape,
chunks=(None),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
Expand All @@ -381,9 +379,9 @@ def compute_burdens_(
if not skip_burdens:
burdens[chunk_start:chunk_end] = chunk_burden

y[chunk_start:chunk_end] = chunk_y
x[chunk_start:chunk_end] = chunk_x
sample_ids[chunk_start:chunk_end] = chunk_sampleid
y = chunk_y
x = chunk_x
sample_ids = chunk_sampleid

if torch.cuda.is_available():
logger.info(
Expand Down

0 comments on commit 4e09428

Please sign in to comment.