Skip to content

Commit

Permalink
CrDirReader Update [Polars -> Pandas] (#130)
Browse files Browse the repository at this point in the history
* Revert to Pandas from Polars for improved efficiency

* Upadted Pandas and Code clean-up

* Refactor process_batch to use Polars for faster aggregation and filtering
  • Loading branch information
Gautam8387 authored Oct 23, 2024
1 parent edca2ed commit ca7d6cc
Showing 1 changed file with 73 additions and 57 deletions.
130 changes: 73 additions & 57 deletions scarf/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
- LoomReader: A class to read in data in the form of a Loom file.
"""

import math
import os
from abc import ABC, abstractmethod
from typing import Generator, Dict, List, Optional, Tuple
from typing import IO
import math
from typing import IO, Dict, Generator, List, Optional, Tuple

import h5py
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -392,38 +392,47 @@ def _read_dataset(self, key: Optional[str] = None):
vals = None
return vals

def read_header(self) -> pl.DataFrame:
header = pl.read_csv(
def read_header(self) -> pd.DataFrame:
header = pd.read_csv(
self.matFn,
comment_prefix = '%',
separator=self.sep,
has_header=False,
n_rows=1,
new_columns=["nFeatures", "nCells", "nCounts"],
comment="%",
sep=self.sep,
header=None,
nrows=1,
names=["nFeatures", "nCells", "nCounts"],
)
if header['nCells'][0] == 0 and self.nCells > 0:
raise ValueError("ERROR: Barcode count in MTX header is 0 but barcodes are present in the barcodes file")
if header['nCells'][0] > 0 and self.nCells == 0:
raise ValueError("ERROR: Barcode count in MTX header is greater than 0 but no barcodes are present in the barcodes file")
if header['nCells'][0] == 0 and self.nCells == 0:
raise ValueError("ERROR: Barcode count in MTX header and barcodes file is 0. No data to read")
if header["nCells"][0] == 0 and self.nCells > 0:
raise ValueError(
"ERROR: Barcode count in MTX header is 0 but barcodes are present in the barcodes file"
)
if header["nCells"][0] > 0 and self.nCells == 0:
raise ValueError(
"ERROR: Barcode count in MTX header is greater than 0 but no barcodes are present in the barcodes file"
)
if header["nCells"][0] == 0 and self.nCells == 0:
raise ValueError(
"ERROR: Barcode count in MTX header and barcodes file is 0. No data to read"
)
return header

def process_batch(self, dfs: pl.DataFrame, filtering_cutoff: int) -> List:
def process_batch(self, dfs: List[pd.DataFrame], filtering_cutoff: int) -> np.array:
"""Returns a list of valid barcodes after filtering out background barcodes for a given batch.
Args:
dfs: A Polar DataFrame containing a chunk of data from the MTX file.
filtering_cutoff: The cutoff value for filtering out background barcodes
"""
dfs_ = dfs.group_by('barcode').agg(pl.sum('count'))
pl_dfs = [pl.DataFrame(df) for df in dfs]
pl_dfs = pl.concat(pl_dfs)
dfs_ = pl_dfs.group_by('barcode').agg(pl.sum('count'))
dfs_ = dfs_.filter(pl.col('count') > filtering_cutoff)
return np.sort(dfs_['barcode'])

def _get_valid_barcodes(
self, filtering_cutoff: int,
batch_size: int = int(10e4),
lines_in_mem: int = int(10e6)
self,
filtering_cutoff: int,
batch_size: int = int(10e3),
lines_in_mem: int = int(10e6),
) -> np.ndarray:
"""Returns a list of valid barcodes after filtering out background barcodes.
Expand All @@ -433,48 +442,53 @@ def _get_valid_barcodes(
lines_in_mem: The number of lines to read into memory
"""
test_counter = 0
matrixIO = pl.scan_csv(
self.matFn,
comment_prefix='%',
# skip_rows=3,
skip_rows_after_header=1,
separator=self.sep,
has_header=False,
matrixIO = pd.read_csv(
self.matFn,
comment="%",
sep=self.sep,
header=0,
chunksize=lines_in_mem,
names=["gene", "barcode", "count"],
)
assert len(matrixIO.collect_schema().names()) == 3
matrixIO = matrixIO.rename({'column_1': 'gene', 'column_2': 'barcode', 'column_3': 'count'})

header = self.read_header()
nChunks = math.ceil(header["nCounts"][0] / lines_in_mem)
test_counter = 0
valid_idx = []
start = 1
dfs = pl.DataFrame()
for i in tqdmbar(
range(nChunks), desc="Filtering out background barcodes"

dfs = []
for chunk in tqdmbar(
# range(nChunks),
matrixIO,
total=nChunks,
desc="Filtering out background barcodes",
):
chunk = matrixIO.slice(i*lines_in_mem, lines_in_mem).collect()
# Check if we've reached or exceeded the current batch boundary
if (chunk[-1]['barcode'][0] - start) >= batch_size: # If the last "cell id" is greater than the start + batch size
if (
(chunk.iloc[-1]["barcode"] - start) >= batch_size
): # If the last "cell id" is greater than the start + batch size
# Filter rows in the current chunk that belong to the current batch
idx = np.array(chunk['barcode'] < (batch_size + start)) # This is the crucial line. This makes sure that if any cell ID is spread over multiple chunks, it is not missed, as any cell ID that is less than the batch size + start is included.
idx = np.array(
chunk["barcode"].values < (batch_size + start)
) # This is the crucial line. This makes sure that if any cell ID is spread over multiple chunks, it is not missed, as any cell ID that is less than the batch size + start is included.
# If no rows belong to the current batch, move to the next batch.
if idx.sum() == 0:
dfs = pl.concat([dfs, chunk])
dfs.append(chunk)
start += batch_size
test_counter += len(chunk)
continue
# Process the rows belonging to the current batch
mask_pos = np.where(idx)[0]
mask_neg = np.where(~idx)[0]
dfs = pl.concat([dfs, chunk[mask_pos]])
dfs.append(chunk.iloc[mask_pos])
valid_idx.append(self.process_batch(dfs, filtering_cutoff))
# Prepare for the next batch
del dfs
dfs = chunk[mask_neg]
dfs = [chunk.iloc[mask_neg]]
start += batch_size
else:
# If we haven't reached the batch boundary, accumulate the chunk
dfs = pl.concat([dfs, chunk])
dfs.append(chunk)
test_counter += len(chunk)
# Process any remaining data after the main loop
if len(dfs) > 0:
Expand Down Expand Up @@ -512,7 +526,7 @@ def cell_names(self) -> List[str]:

def rename_batches(self, collect: List[pl.DataFrame], batch_size: int) -> List:
df = pl.concat(collect)
barcodes = np.array(df['barcode'])
barcodes = np.array(df["barcode"])
count_hash = {}
for i, x in enumerate(np.unique(barcodes)):
count_hash[x] = i
Expand All @@ -535,14 +549,14 @@ def consume(
dtype: The data type of the matrix.
"""
matrixIO = pl.read_csv_batched(
self.matFn,
has_header=False,
self.matFn,
has_header=False,
separator=self.sep,
comment_prefix="%",
skip_rows_after_header=1,
new_columns=['gene', 'barcode', 'count'],
schema_overrides={'gene': pl.Int64, 'barcode': pl.Int64, 'count': pl.Int64},
batch_size=lines_in_mem
skip_rows_after_header=1,
new_columns=["gene", "barcode", "count"],
schema_overrides={"gene": pl.Int64, "barcode": pl.Int64, "count": pl.Int64},
batch_size=lines_in_mem,
)
unique_list = []
collect = []
Expand All @@ -551,20 +565,20 @@ def consume(
if chunk is None:
break
chunk = chunk[0]
chunk = chunk.filter(pl.col('barcode').is_in(self.validBarcodeIdx))
in_uniques = np.unique(chunk['barcode'])
chunk = chunk.filter(pl.col("barcode").is_in(self.validBarcodeIdx))
in_uniques = np.unique(chunk["barcode"])
unique_list.extend(in_uniques)
unique_list = list(set(unique_list))
if len(unique_list) > batch_size:
diff = batch_size - (len(unique_list) - len(in_uniques))
mask_pos = in_uniques[:diff]
mask_neg = in_uniques[diff:]
extra = chunk.filter(pl.col('barcode').is_in(mask_pos))
extra = chunk.filter(pl.col("barcode").is_in(mask_pos))
collect.append(extra)
collect = self.rename_batches(collect, batch_size)
mtx = self.to_sparse(np.array(collect), dtype=dtype)
yield mtx
left_out = chunk.filter(pl.col('barcode').is_in(mask_neg))
left_out = chunk.filter(pl.col("barcode").is_in(mask_neg))
collect = []
unique_list = list(mask_neg)
collect.append(left_out)
Expand Down Expand Up @@ -635,8 +649,9 @@ def __init__(
self.obsmAttrsKey: self._validate_group(self.obsmAttrsKey),
self.matrixKey: self._validate_group(self.matrixKey),
}
self.nCells, self.nFeatures = self._get_n(self.cellAttrsKey), self._get_n(
self.featureAttrsKey
self.nCells, self.nFeatures = (
self._get_n(self.cellAttrsKey),
self._get_n(self.featureAttrsKey),
)
self.cellIdsKey = self._fix_name_key(self.cellAttrsKey, cell_ids_key)
self.featIdsKey = self._fix_name_key(self.featureAttrsKey, feature_ids_key)
Expand Down Expand Up @@ -809,8 +824,9 @@ def _get_col_data(
if i in ignore_keys:
continue
if isinstance(self.h5[group][i], h5py.Dataset):
yield i, self._replace_category_values(
self.h5[group][i][:], i, group
yield (
i,
self._replace_category_values(self.h5[group][i][:], i, group),
)

def _get_obsm_data(
Expand All @@ -832,7 +848,7 @@ def _get_obsm_data(
yield f"{i}{j+1}", g[:, j]
else:
logger.warning(
f"Reading of obsm failed because it either does not exist or is not in expected format" # noqa: F541
f"Reading of obsm failed because it either does not exist or is not in expected format" # noqa: F541
)

def get_cell_columns(self) -> Generator[Tuple[str, np.ndarray], None, None]:
Expand Down

0 comments on commit ca7d6cc

Please sign in to comment.