Skip to content

Commit

Permalink
Updated merge.py to resolve a critical error of not mapping cells cor…
Browse files Browse the repository at this point in the history
…rectly. Used default_rng instead of np.random.seed(). Updated the calculation of rows_offsets. Updated calculation of cellOrder. Used pandas to index the combined metadata. Major: updated dump funtion to use the random order of cells. Added additional tests.
  • Loading branch information
Gautam8387 committed Dec 17, 2024
1 parent 4e93e14 commit 3461a73
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 17 deletions.
55 changes: 38 additions & 17 deletions scarf/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import polars as pl
import zarr
from dask.array import from_array
Expand Down Expand Up @@ -176,7 +177,8 @@ def perform_randomization_rows(
seed: Seed for randomization
Returns:
"""
np.random.seed(seed)
rng = np.random.default_rng(seed=seed)
# np.random.seed(seed)
chunkSize = np.array([x.rawData.chunksize[0] for x in self.assays])
nCells = np.array([x.rawData.shape[0] for x in self.assays])
permutations = {
Expand All @@ -190,13 +192,20 @@ def perform_randomization_rows(
permutations_rows[key] = in_dict

permutations_rows_offset = {}
for i in range(len(permutations)):
# for i in range(len(permutations)):
# in__dict: dict[int, np.ndarray] = {}
# last_key = i - 1 if i > 0 else 0
# offset = nCells[last_key] + offset if i > 0 else 0 # noqa: F821
# for j, arr in enumerate(permutations[i]):
# in__dict[j] = arr + offset
# permutations_rows_offset[i] = in__dict
offset = 0
for key, val_dict in permutations_rows.items():
in__dict: dict[int, np.ndarray] = {}
last_key = i - 1 if i > 0 else 0
offset = nCells[last_key] + offset if i > 0 else 0 # noqa: F821
for j, arr in enumerate(permutations[i]):
in__dict[j] = arr + offset
permutations_rows_offset[i] = in__dict
for in_key, arrs in val_dict.items():
in__dict[in_key] = arrs + offset
permutations_rows_offset[key] = in__dict
offset += nCells[key]

coordinates = []
extra = []
Expand All @@ -207,7 +216,8 @@ def perform_randomization_rows(
continue
coordinates.append([i, j])

coordinates_permutations = np.random.permutation(coordinates)
# coordinates_permutations = np.random.permutation(coordinates)
coordinates_permutations = rng.permutation(coordinates)
if len(coordinates_permutations) > 0:
coordinates_permutations = np.concatenate(
[coordinates_permutations, extra], axis=0
Expand Down Expand Up @@ -250,7 +260,8 @@ def _ref_order_cell_idx(self) -> Dict[int, Dict[int, np.ndarray]]:

offset = 0
for i, (x, y) in enumerate(self.coordinates_permutations):
arr = self.permutations_rows_offset[x][y]
# arr = self.permutations_rows_offset[x][y]
arr = self.permutations_rows[x][y]
arr = np.array(range(len(arr)))
arr = arr + offset
new_cells[x][y] = arr
Expand Down Expand Up @@ -291,22 +302,25 @@ def _merge_cell_table(
a = a.with_columns(
[pl.Series("I", np.ones(len(a["ids"])).astype(bool))]
)
ret_val.append(a)
ret_val.append(a.to_pandas())

ret_val_df = pl.concat(
ret_val,
how="diagonal", # Finds a union between the column schemas and fills missing column values with null
)
# ret_val_df = pl.concat(
# ret_val,
# how="diagonal", # Finds a union between the column schemas and fills missing column values with null
# )
ret_val_df = pd.concat(ret_val, axis=0).reset_index(drop=True)

# Randomize the rows in chunks
compiled_idx = [
self.permutations_rows_offset[i][j]
for i, j in self.coordinates_permutations
]

compiled_idx = np.concatenate(compiled_idx)
ret_val_df = ret_val_df[
compiled_idx
] # Polars does not support iloc so we have to use this method
# ret_val_df = ret_val_df[
# compiled_idx
# ] # Polars does not support iloc so we have to use this method
ret_val_df = ret_val_df.iloc[compiled_idx]
if sum([x.cells.N for x in self.assays]) != ret_val_df.shape[0]:
raise AssertionError(
"Unexpected number of cells in the merged table. This is unexpected, "
Expand Down Expand Up @@ -604,6 +618,7 @@ def _dask_to_coo(
for i, col_data in enumerate(computed_data.T):
consolidated_idx = consolidation_map[order[i]]
mat[:, consolidated_idx] += col_data

return coo_matrix(mat)

def dump(self, nthreads=4):
Expand All @@ -623,11 +638,17 @@ def dump(self, nthreads=4):
total=assay.rawData.numblocks[0],
desc=f"Writing data from assay {i+1}/{len(self.assays)} to merged file",
):
perm_order = self.permutations_rows[i][j]
perm_order = perm_order - perm_order.min()
# bring a to same order
block = block[perm_order, :]
a = self._dask_to_coo(block, feat_order, feat_order_map, nthreads)
row_idx = self.cellOrder[i][j]
# bring a to same order
self.assayGroup.set_coordinate_selection(
(a.row + row_idx.min(), a.col), a.data.astype(self.assayGroup.dtype)
)
# self.assayGroup[row_idx, :] = a
counter += a.shape[0]
try:
assert counter == self.nCells
Expand Down
26 changes: 26 additions & 0 deletions scarf/tests/test_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_assay_merge(datastore):
names=["self1", "self2"],
merge_assay_name="RNA",
prepend_text="",
overwrite=True,
)
writer.dump()
tmp = zarr.open(fn + "/RNA/counts")
Expand All @@ -37,6 +38,7 @@ def test_dataset_merge_2(datastore):
datasets=[datastore, datastore],
names=["self1", "self2"],
prepend_text="",
overwrite=True
)
writer.dump()
# Check if the merged file has the correct shape and counts
Expand Down Expand Up @@ -68,6 +70,7 @@ def test_dataset_merge_3(datastore):
datasets=[datastore, datastore, datastore],
names=["self1", "self2", "self3"],
prepend_text="",
overwrite=True
)
writer.dump()
# Check if the merged file has the correct shape and counts
Expand All @@ -84,3 +87,26 @@ def test_dataset_merge_3(datastore):
datastore.assay2.rawData.compute().sum() * 3
)
remove(fn)

def test_dataset_merge_cells(datastore):
from ..merge import DatasetMerge
from ..datastore.datastore import DataStore

fn = full_path("merged_zarr.zarr")
writer = DatasetMerge(
zarr_path=fn,
datasets=[datastore, datastore],
names=["self1", "self2"],
prepend_text="orig",
overwrite=True,
)
writer.dump()

ds = DataStore(
fn,
default_assay="RNA",
)

df = ds.cells.to_pandas_dataframe(ds.cells.columns)
df_diff = df[df['orig_RNA_nCounts'] != df['RNA_nCounts']]
assert len(df_diff) == 0

0 comments on commit 3461a73

Please sign in to comment.