diff --git a/scarf/mapping_utils.py b/scarf/mapping_utils.py index 7bfd81e..49da9cd 100644 --- a/scarf/mapping_utils.py +++ b/scarf/mapping_utils.py @@ -96,16 +96,16 @@ def coral(source_data, target_data, assay, feat_key: str, cell_key: str, nthread def _order_features( - s_assay, - t_assay, - s_feat_ids: np.ndarray, + source_ids: pd.Series, + target_ids: pd.Series, + source_hvgs: np.ndarray, + target_assay: Assay, filter_null: bool, exclude_missing: bool, nthreads: int, + target_cell_key: str = "I", ) -> Tuple[np.ndarray, np.ndarray]: - s_ids = pd.Series(s_assay.feats.fetch_all("ids")) - t_ids = pd.Series(t_assay.feats.fetch_all("ids")) - t_idx = t_ids.isin(s_feat_ids) + t_idx = target_ids.isin(source_hvgs) if t_idx.sum() == 0: raise ValueError( "ERROR: None of the features from reference were found in the target data" @@ -118,8 +118,8 @@ def _order_features( else: t_idx[t_idx] = ( controlled_compute( - t_assay.rawData[:, list(t_idx[t_idx].index)][ - t_assay.cells.active_index("I"), : + target_assay.rawData[:, list(t_idx[t_idx].index)][ + target_assay.cells.active_index(target_cell_key), : ].sum(axis=0), nthreads, ) @@ -127,13 +127,13 @@ def _order_features( ) t_idx = t_idx[t_idx].index if exclude_missing: - s_idx = s_ids.isin(t_ids.values[t_idx]) + s_idx = source_ids.isin(target_ids.values[t_idx]) else: - s_idx = s_ids.isin(s_feat_ids) + s_idx = source_ids.isin(source_hvgs) s_idx = s_idx[s_idx].index - t_idx_map = {v: k for k, v in t_ids.to_dict().items()} + t_idx_map = {v: k for k, v in target_ids.to_dict().items()} t_re_idx = np.array( - [t_idx_map[x] if x in t_idx_map else -1 for x in s_ids.values[s_idx]] + [t_idx_map[x] if x in t_idx_map else -1 for x in source_ids.values[s_idx]] ) if len(s_idx) != len(t_re_idx): raise AssertionError( @@ -141,7 +141,10 @@ def _order_features( f"This is an unexpected scenario. Source has {len(s_idx)} features while target has " f"{len(t_re_idx)} features" ) - return s_idx.values, t_re_idx + missing = (t_idx == -1).sum() + total = len(t_idx) + overlap = (total - missing) / total + return s_idx.values, t_re_idx, overlap def align_features( @@ -171,18 +174,66 @@ def align_features( """ from .writers import create_zarr_dataset - source_feat_ids = source_assay.feats.fetch( - "ids", key=source_cell_key + "__" + source_feat_key - ) - s_idx, t_idx = _order_features( - source_assay, - target_assay, - source_feat_ids, - filter_null, - exclude_missing, - nthreads, - ) - logger.info(f"{(t_idx == -1).sum()} features missing in target data") + id_overlap = 0 + name_overlap = 0 + s_idx_ids = None + t_idx_ids = None + s_idx_names = None + t_idx_names = None + + try: + source_hvg_ids = source_assay.feats.fetch( + "ids", key=source_cell_key + "__" + source_feat_key + ) + s_ids = pd.Series(source_assay.feats.fetch_all("ids")) + t_ids = pd.Series(target_assay.feats.fetch_all("ids")) + s_idx_ids, t_idx_ids, id_overlap = _order_features( + s_ids, + t_ids, + source_hvg_ids, + target_assay, + filter_null, + exclude_missing, + nthreads, + target_cell_key, + ) + + except ValueError: + logger.warning("Failed to align features by IDs") + if id_overlap < 0.25: + logger.warning("Attempting feature alignment by names") + try: + source_hvg_names = source_assay.feats.fetch( + "names", key=source_cell_key + "__" + source_feat_key + ) + s_names = pd.Series(source_assay.feats.fetch_all("names")) + t_names = pd.Series(target_assay.feats.fetch_all("names")) + s_idx_names, t_idx_names, name_overlap = _order_features( + s_names, + t_names, + source_hvg_names, + target_assay, + filter_null, + exclude_missing, + nthreads, + target_cell_key, + ) + except ValueError: + logger.warning("Failed to align features by names") + + if name_overlap < 0.25 and id_overlap < 0.25: + raise ValueError( + "More than 75% of the features in the target data are missing in the source data. " + + "Please check the feature keys and try again. " + ) + if id_overlap > 0.25: + s_idx = s_idx_ids + t_idx = t_idx_ids + else: + logger.warning("Falling back to feature alignment by names") + s_idx = s_idx_names + t_idx = t_idx_names + normed_loc = f"normed__{source_cell_key}__{source_feat_key}" norm_params = source_assay.z[normed_loc].attrs["subset_params"] sorted_t_idx = np.array(sorted(t_idx[t_idx != -1]))