Skip to content

Commit

Permalink
DatasetMerge fix (#137)
Browse files Browse the repository at this point in the history
* Updated merge.py to resolve a critical error of not mapping cells correctly. Used default_rng instead of np.random.seed(). Updated the calculation of rows_offsets. Updated calculation of cellOrder. Used pandas to index the combined metadata. Major: updated dump funtion to use the random order of cells. Added additional tests.

* Added explanation and comments

* added pyarrow in requirements
  • Loading branch information
Gautam8387 authored Dec 17, 2024
1 parent d1c8deb commit 2794c8f
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 24 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ setuptools
packaging
importlib_metadata
polars
pyarrow
68 changes: 44 additions & 24 deletions scarf/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import polars as pl
import zarr
from dask.array import from_array
Expand Down Expand Up @@ -176,28 +177,39 @@ def perform_randomization_rows(
seed: Seed for randomization
Returns:
"""
np.random.seed(seed)
rng = np.random.default_rng(seed=seed)
chunkSize = np.array([x.rawData.chunksize[0] for x in self.assays])
nCells = np.array([x.rawData.shape[0] for x in self.assays])
permutations = {
i: permute_into_chunks(nCells[i], chunkSize[i])
for i in range(len(self.assays))
}
} # Randomize the rows in chunks

# Create a dictionary of arrays. This is the same data in `permutations` but in a different format. We index the arrays by the chunk number.
# Example:
# permutation = {0: [array([2, 0, 1]), array([3, 4, 5]), array([8, 7, 6]), array([9])], 1: [array([2, 0, 1]), array([3, 4, 5]), array([8, 7, 6]), array([9])]}
# permutations_rows = {0: {0: array([2, 0, 1]), 1: array([3, 4, 5]), 2: array([8, 7, 6]), 3: array([9])}, 1: {0: array([2, 0, 1]), 1: array([3, 4, 5]), 2: array([8, 7, 6]), 3: array([9])}}
permutations_rows = {}
for key, arrays in permutations.items():
in_dict = {i: x for i, x in enumerate(arrays)}
permutations_rows[key] = in_dict

# Set the offset for each chunk. Offset calculated by adding the number of cells in the previous chunks. This will be helpful when we merge the cells metadata in the end.
# Example:
# {0: {0: array([2, 0, 1]), 1: array([3, 4, 5]), 2: array([8, 7, 6]), 3: array([9])}, 1: {0: array([12, 10, 11]), 1: array([13, 14, 15]), 2: array([18, 17, 16]), 3: array([19])}}
permutations_rows_offset = {}
for i in range(len(permutations)):
offset = 0
for key, val_dict in permutations_rows.items():
in__dict: dict[int, np.ndarray] = {}
last_key = i - 1 if i > 0 else 0
offset = nCells[last_key] + offset if i > 0 else 0 # noqa: F821
for j, arr in enumerate(permutations[i]):
in__dict[j] = arr + offset
permutations_rows_offset[i] = in__dict

for in_key, arrs in val_dict.items():
in__dict[in_key] = arrs + offset
permutations_rows_offset[key] = in__dict
offset += nCells[key]

# Set the random order in which the rows will be merged. The last chunk of each assay is appended at the end of the list to account for potential incomplete chunks.
# Example:
# coordinates_permutations = [[0, 0], [0, 1], [1, 2], [0, 2], [1, 1], [1, 0], [0, 3], [1, 3]]
# Here [0, 0] means the first chunk of the first assay, [0, 1] means the second chunk of the first assay, [1, 2] means the third chunk of the second assay, and so on will be the order in which the rows will be merged.
coordinates = []
extra = []
for i in range(len(self.assays)):
Expand All @@ -206,8 +218,9 @@ def perform_randomization_rows(
extra.append([i, j])
continue
coordinates.append([i, j])

coordinates_permutations = np.random.permutation(coordinates)
coordinates_permutations = rng.permutation(
coordinates
) # Randomize the order of the coordinates
if len(coordinates_permutations) > 0:
coordinates_permutations = np.concatenate(
[coordinates_permutations, extra], axis=0
Expand Down Expand Up @@ -241,21 +254,26 @@ def perform_randomization_rows(
return permutations_rows, permutations_rows_offset, coordinates_permutations

def _ref_order_cell_idx(self) -> Dict[int, Dict[int, np.ndarray]]:
"""
Calculate the order of the cells in the merged assay.
"""
# We calculate the order of the cells in the merged assay by using the permutations_rows and coordinates_permutations. This is essentially the one-to-one mapping of the cells in the assays to the cells in the merged assay.
# Example:
# cellOrder = {0: {0: array([0, 1, 2]), 1: array([3, 4, 5]), 2: array([ 9, 10, 11]), 3: array([18])}, 1: {0: array([15, 16, 17]), 1: array([12, 13, 14]), 2: array([6, 7, 8]), 3: array([19])}}
# Here we see that the cells [2, 0, 1] from the first chunk of the first assay are mapped to [0, 1, 2] in the merged assay. Similarly, the cells [2, 0, 1] from the first chunk of the second assay are mapped to [15, 16, 17] in the merged assay.
new_cells = {}
for i in range(len(self.assays)):
in_dict: dict[int, np.ndarray] = {}
for j in range(len(self.permutations_rows[i])):
in_dict[j] = np.array([])
new_cells[i] = in_dict

offset = 0
for i, (x, y) in enumerate(self.coordinates_permutations):
arr = self.permutations_rows_offset[x][y]
arr = self.permutations_rows[x][y]
arr = np.array(range(len(arr)))
arr = arr + offset
new_cells[x][y] = arr
offset = arr.max() + 1

return new_cells

def _merge_cell_table(
Expand Down Expand Up @@ -291,22 +309,18 @@ def _merge_cell_table(
a = a.with_columns(
[pl.Series("I", np.ones(len(a["ids"])).astype(bool))]
)
ret_val.append(a)
ret_val.append(a.to_pandas())

ret_val_df = pl.concat(
ret_val,
how="diagonal", # Finds a union between the column schemas and fills missing column values with null
)

# Randomize the rows in chunks
# Here we merge the cell metadata tables for each sample. We simply concatenate the tables and reset the index.
ret_val_df = pd.concat(ret_val, axis=0).reset_index(drop=True)
# Now we use the offsets stored in permutations_rows_offset along with the coordinates_permutations to reorder the cells in the merged assay. The offsets are used to bring the cells in the same order as the rows in the merged assay.
compiled_idx = [
self.permutations_rows_offset[i][j]
for i, j in self.coordinates_permutations
]
compiled_idx = np.concatenate(compiled_idx)
ret_val_df = ret_val_df[
compiled_idx
] # Polars does not support iloc so we have to use this method
# Index the merged cell metadata table with the compiled_idx to get the final randomized merged cell metadata table.
ret_val_df = ret_val_df.iloc[compiled_idx]
if sum([x.cells.N for x in self.assays]) != ret_val_df.shape[0]:
raise AssertionError(
"Unexpected number of cells in the merged table. This is unexpected, "
Expand Down Expand Up @@ -604,6 +618,7 @@ def _dask_to_coo(
for i, col_data in enumerate(computed_data.T):
consolidated_idx = consolidation_map[order[i]]
mat[:, consolidated_idx] += col_data

return coo_matrix(mat)

def dump(self, nthreads=4):
Expand All @@ -623,7 +638,12 @@ def dump(self, nthreads=4):
total=assay.rawData.numblocks[0],
desc=f"Writing data from assay {i+1}/{len(self.assays)} to merged file",
):
# Perform the inter-chunk permutation of the rows
perm_order = self.permutations_rows[i][j]
perm_order = perm_order - perm_order.min()
block = block[perm_order, :]
a = self._dask_to_coo(block, feat_order, feat_order_map, nthreads)
# Here we use the one-to-one mapping of the chunks in the assays to the chunks in the merged assay to bring the data in the same order.
row_idx = self.cellOrder[i][j]
self.assayGroup.set_coordinate_selection(
(a.row + row_idx.min(), a.col), a.data.astype(self.assayGroup.dtype)
Expand Down
26 changes: 26 additions & 0 deletions scarf/tests/test_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_assay_merge(datastore):
names=["self1", "self2"],
merge_assay_name="RNA",
prepend_text="",
overwrite=True,
)
writer.dump()
tmp = zarr.open(fn + "/RNA/counts")
Expand All @@ -37,6 +38,7 @@ def test_dataset_merge_2(datastore):
datasets=[datastore, datastore],
names=["self1", "self2"],
prepend_text="",
overwrite=True
)
writer.dump()
# Check if the merged file has the correct shape and counts
Expand Down Expand Up @@ -68,6 +70,7 @@ def test_dataset_merge_3(datastore):
datasets=[datastore, datastore, datastore],
names=["self1", "self2", "self3"],
prepend_text="",
overwrite=True
)
writer.dump()
# Check if the merged file has the correct shape and counts
Expand All @@ -84,3 +87,26 @@ def test_dataset_merge_3(datastore):
datastore.assay2.rawData.compute().sum() * 3
)
remove(fn)

def test_dataset_merge_cells(datastore):
from ..merge import DatasetMerge
from ..datastore.datastore import DataStore

fn = full_path("merged_zarr.zarr")
writer = DatasetMerge(
zarr_path=fn,
datasets=[datastore, datastore],
names=["self1", "self2"],
prepend_text="orig",
overwrite=True,
)
writer.dump()

ds = DataStore(
fn,
default_assay="RNA",
)

df = ds.cells.to_pandas_dataframe(ds.cells.columns)
df_diff = df[df['orig_RNA_nCounts'] != df['RNA_nCounts']]
assert len(df_diff) == 0

0 comments on commit 2794c8f

Please sign in to comment.