Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnnData spatial analysis 1 #1112

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/ark/analysis/cell_neighborhood_stats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import anndata
from functools import reduce

import numpy as np
Expand Down Expand Up @@ -193,15 +194,13 @@ def calculate_mean_distance_to_all_cell_types(


def generate_cell_distance_analysis(
cell_table, dist_mat_dir, save_path, k, cell_type_col=settings.CELL_TYPE,
anndata_dir, save_path, k, cell_type_col=settings.CELL_TYPE,
fov_col=settings.FOV_ID, cell_label_col=settings.CELL_LABEL):
""" Creates a dataframe containing the average distance between a cell and other cells of each
phenotype, based on the specified cell_type_col.
Args:
cell_table (pd.DataFrame):
dataframe containing all cells and their cell type
dist_mat_dir (str):
path to directory containing the distance matrix files
anndata_dir (str):
path where the AnnData objects are stored.
save_path (str):
path where to save the results to
k (int):
Expand All @@ -214,17 +213,23 @@ def generate_cell_distance_analysis(
column with the cell labels
"""

io_utils.validate_paths(dist_mat_dir)
fov_list = np.unique(cell_table[fov_col])
io_utils.validate_paths(anndata_dir)
fov_list = io_utils.list_folders(anndata_dir, substrs=".zarr")

cell_dists = []
with tqdm(total=len(fov_list), desc="Calculate Average Distances", unit="FOVs") \
as distance_progress:
for fov in fov_list:
distance_progress.set_postfix(FOV=fov)

fov_cell_table = cell_table[cell_table[fov_col] == fov]
fov_dist_xr = xr.load_dataarray(os.path.join(dist_mat_dir, str(fov) + '_dist_mat.xr'))
fov_adata = anndata.read_zarr(
os.path.join(anndata_dir, fov))

# extract cell table and dist mat from AnnData table
fov_cell_table = fov_adata.obs
centroid_labels = list(fov_cell_table.label)
fov_dist_xr = fov_adata.obsp["distances"]
fov_dist_xr = xr.DataArray(fov_dist_xr, coords=[centroid_labels, centroid_labels])

# get the average distances between cell types
fov_cell_dists = calculate_mean_distance_to_all_cell_types(
Expand Down
75 changes: 41 additions & 34 deletions src/ark/analysis/neighborhood_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,24 @@
import pandas as pd
import seaborn as sns
import xarray as xr
import anndata
from tqdm.notebook import tqdm
from alpineer import misc_utils
from alpineer import misc_utils, io_utils

import ark.settings as settings
from ark.analysis import spatial_analysis_utils
from ark.utils.data_utils import load_anndatas


def create_neighborhood_matrix(all_data, dist_mat_dir, included_fovs=None, distlim=50,
def create_neighborhood_matrix(anndata_dir, included_fovs=None, distlim=50,
self_neighbor=False, fov_col=settings.FOV_ID,
cell_label_col=settings.CELL_LABEL,
cell_type_col=settings.CELL_TYPE):
"""Calculates the number of neighbor phenotypes for each cell.

Args:
all_data (pandas.DataFrame):
data for all fovs. Includes the columns for fov, label, and cell phenotype.
dist_mat_dir (str):
directory containing the distance matrices
anndata_dir (str):
path where the AnnData objects are stored.
included_fovs (list):
fovs to include in analysis. If argument is none, default is all fovs used.
distlim (int):
Expand All @@ -45,33 +45,32 @@ def create_neighborhood_matrix(all_data, dist_mat_dir, included_fovs=None, distl

# Set up input and parameters
if included_fovs is None:
included_fovs = all_data[fov_col].unique()
included_fovs = io_utils.list_folders(anndata_dir, substrs=".zarr")
else:
included_fovs = [fov + '.zarr' for fov in included_fovs]

# Check if included fovs found in fov_col
misc_utils.verify_in_list(fov_names=included_fovs,
unique_fovs=all_data[fov_col].unique())
misc_utils.verify_in_list(
fov_names=included_fovs, unique_fovs=io_utils.list_folders(anndata_dir, substrs=".zarr"))

# Load AnnData Collection and extract unique cell phenotypes
fovs_ac = load_anndatas(anndata_dir=anndata_dir, join_obs="inner", join_obsm="inner")

# Subset just the fov, label, and cell phenotype columns
all_neighborhood_data = all_data[
[fov_col, cell_label_col, cell_type_col]
].reset_index(drop=True)
# Extract the cell phenotypes
cluster_names = all_neighborhood_data[cell_type_col].drop_duplicates()
# Get the total number of phenotypes
cluster_names = fovs_ac.obs[cell_type_col].unique()
cluster_num = len(cluster_names)

included_columns = [fov_col, cell_label_col, cell_type_col]

# Initialize empty matrices for cell neighborhood data
cell_neighbor_counts = pd.DataFrame(
np.zeros((all_neighborhood_data.shape[0], cluster_num + len(included_columns)))
np.zeros((fovs_ac.obs.shape[0], cluster_num + len(included_columns)))
)
# Replace the first, second (possibly third) columns of cell_neighbor_counts
cell_neighbor_counts[list(range(len(included_columns)))] = \
all_neighborhood_data[included_columns]
cols = included_columns + list(cluster_names)
cell_neighbor_counts.index = fovs_ac.obs.index

# Rename the columns to match cell phenotypes
# Replace the first three columns of cell_neighbor_counts and rename cols
cell_neighbor_counts[list(range(len(included_columns)))] = fovs_ac.obs[included_columns]
cols = included_columns + list(cluster_names)
cell_neighbor_counts.columns = cols

cell_neighbor_freqs = cell_neighbor_counts.copy(deep=True)
Expand All @@ -81,26 +80,34 @@ def create_neighborhood_matrix(all_data, dist_mat_dir, included_fovs=None, distl
for fov in included_fovs:
neighbor_mat_progress.set_postfix(FOV=fov)

# Subsetting expression matrix to only include patients with correct fov label
current_fov_idx = all_neighborhood_data.loc[:, fov_col] == fov
current_fov_neighborhood_data = all_neighborhood_data[current_fov_idx]
# load in fov AnnData table
fov_adata = anndata.read_zarr(
os.path.join(anndata_dir, fov))
fov_table = fov_adata.obs[included_columns]

# Get the subset of phenotypes included in the current fov
fov_cluster_names = current_fov_neighborhood_data[cell_type_col].drop_duplicates()
fov_cluster_names = fov_adata.obs[cell_type_col].unique()

# Retrieve fov-specific distance matrix from distance matrix dictionary
dist_matrix = xr.load_dataarray(os.path.join(dist_mat_dir, str(fov) + '_dist_mat.xr'))
# Retrieve fov-specific distance matrix from AnnData table
centroid_labels = list(fov_table.label)
dist_matrix = fov_adata.obsp["distances"]
dist_matrix = xr.DataArray(dist_matrix, coords=[centroid_labels, centroid_labels])

# Get cell_neighbor_counts and cell_neighbor_freqs for fovs
counts, freqs = spatial_analysis_utils.compute_neighbor_counts(
current_fov_neighborhood_data, dist_matrix, distlim, self_neighbor,
fov_table, dist_matrix, distlim, self_neighbor,
cell_label_col=cell_label_col, cluster_name_col=cell_type_col)

# Add to neighbor counts+freqs for only matching phenos between fov and whole dataset
cell_neighbor_counts.loc[current_fov_neighborhood_data.index, fov_cluster_names] \
= counts
cell_neighbor_freqs.loc[current_fov_neighborhood_data.index, fov_cluster_names]\
= freqs
cell_neighbor_counts.loc[fov_table.index, fov_cluster_names] = counts
cell_neighbor_freqs.loc[fov_table.index, fov_cluster_names] = freqs

# save neighbors matrix to fov AnnData table
fov_adata.obsm[f"neighbors_counts_{cell_type_col}_{distlim}"] = \
cell_neighbor_counts.loc[fov_table.index]
fov_adata.obsm[f"neighbors_freqs_{cell_type_col}_{distlim}"] = \
cell_neighbor_freqs.loc[fov_table.index]
fov_adata.write_zarr(os.path.join(anndata_dir, fov), chunks=(1000, 1000))

neighbor_mat_progress.update(1)

Expand All @@ -109,6 +116,7 @@ def create_neighborhood_matrix(all_data, dist_mat_dir, included_fovs=None, distl
keep_cells = cell_neighbor_counts.drop(included_columns, axis=1).sum(axis=1) != 0
cell_neighbor_counts = cell_neighbor_counts.loc[keep_cells].reset_index(drop=True)
cell_neighbor_freqs = cell_neighbor_freqs.loc[keep_cells].reset_index(drop=True)

# issue warning if more than 5% of cells are dropped
if (cell_neighbor_counts.shape[0] / total_cell_count) < 0.95:
warnings.warn(UserWarning("More than 5% of cells have no neighbor within the provided "
Expand Down Expand Up @@ -169,7 +177,6 @@ def generate_cluster_matrix_results(all_data, neighbor_mat, cluster_num, seed=42
cluster ids indexed row-wise and markers indexed column-wise,
indicates the mean marker expression for each cluster id
"""

# get fovs
if included_fovs is None:
included_fovs = neighbor_mat[fov_col].unique()
Expand Down Expand Up @@ -216,7 +223,7 @@ def generate_cluster_matrix_results(all_data, neighbor_mat, cluster_num, seed=42
for c in num_cell_type_per_cluster.index]

# subsets the expression matrix to only have channel columns
channel_start = np.where(all_data_clusters.columns == pre_channel_col)[0][0] + 1
channel_start = 0
channel_end = np.where(all_data_clusters.columns == post_channel_col)[0][0]
cluster_label_colnum = np.where(all_data_clusters.columns == cluster_label_col)[0][0]

Expand Down
66 changes: 25 additions & 41 deletions src/ark/analysis/spatial_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd
import scipy
import skimage.measure
import anndata
import sklearn.metrics
import xarray as xr
from alpineer import io_utils, load_utils, misc_utils
Expand All @@ -16,60 +16,44 @@
from ark.utils._bootstrapping import compute_close_num_rand


def calc_dist_matrix(label_dir, save_path, prefix='_whole_cell'):
def calc_dist_matrix(anndata_dir):
"""Generate matrix of distances between center of pairs of cells.

Saves each one individually to `save_path`.

Args:
label_dir (str):
path to segmentation masks indexed by `(fov, cell_id, cell_id, label)`
save_path (str):
path to save the distance matrices
prefix (str):
the prefix used to identify label map files in `label_dir`
anndata_dir (str):
Path where the AnnData objects are stored.
"""

# check that both label_dir and save_path exist
io_utils.validate_paths([label_dir, save_path])

# load all the file names in label_dir
fov_files = io_utils.list_files(label_dir, substrs=prefix + '.tiff')
# load fov names
fov_names = io_utils.list_folders(anndata_dir, substrs=".zarr")

# iterate for each fov
with tqdm(total=len(fov_files), desc="Distance Matrix Generation", unit="FOVs") \
with tqdm(total=len(fov_names), desc="Distance Matrix Generation", unit="FOVs") \
as dist_mat_progress:
for fov_file in fov_files:
dist_mat_progress.set_postfix(FOV=fov_file)

# retrieve the fov name
fov_name = fov_file.replace(prefix + '.tiff', '')
for fov in fov_names:

# load in the data
fov_data = load_utils.load_imgs_from_dir(
label_dir, [fov_file], match_substring=prefix,
trim_suffix=prefix, xr_channel_names=['label']
)

# keep just the middle two dimensions
fov_data = fov_data.loc[fov_name, :, :, 'label'].values
# check for previously generated distance matrices
if os.path.exists(os.path.join(anndata_dir, fov, "obsp", "distances")):
dist_mat_progress.set_postfix(FOV=fov, status="Already Computed")
dist_mat_progress.update(1)
continue
else:
dist_mat_progress.set_postfix(FOV=fov, status="Computing")
dist_mat_progress.update(1)

# extract region properties of label map, then just get centroids
props = skimage.measure.regionprops(fov_data)
centroids = [prop.centroid for prop in props]
centroid_labels = [prop.label for prop in props]
# extract cell spatial information
fov_adata = anndata.read_zarr(os.path.join(anndata_dir, fov))
centroids = fov_adata.obsm["spatial"]
centroid_list = list(centroids.itertuples(index=False, name=None))
centroid_labels = list(fov_adata.obs.label)

# generate the distance matrix, then assign centroid_labels as coords
dist_matrix = cdist(centroids, centroids).astype(np.float32)
dist_matrix = cdist(centroid_list, centroid_list).astype(np.float32)
dist_mat_xarr = xr.DataArray(dist_matrix, coords=[centroid_labels, centroid_labels])

# save the distance matrix to save_path
dist_mat_xarr.to_netcdf(
os.path.join(save_path, fov_name + '_dist_mat.xr'),
format='NETCDF3_64BIT'
)

dist_mat_progress.update(1)
# save distances to AnnData table
fov_adata.obsp["distances"] = dist_mat_xarr.data
fov_adata.write_zarr(os.path.join(anndata_dir, fov), chunks=(1000, 1000))


def append_distance_features_to_dataset(fov, dist_matrix, cell_table, distance_columns):
Expand Down
18 changes: 13 additions & 5 deletions src/ark/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import anndata
import numba as nb
import itertools
import os
Expand Down Expand Up @@ -1019,13 +1020,16 @@ class AnnCollectionKwargs(TypedDict):
indices_strict: bool


def load_anndatas(anndata_dir: os.PathLike, **anncollection_kwargs: Unpack[AnnCollectionKwargs]) -> AnnCollection:
"""Lazily loads a directory of `AnnData` objects into an `AnnCollection`. The concatination happens across the `.obs` axis.

For `AnnCollection` kwargs, see https://anndata.readthedocs.io/en/latest/generated/anndata.experimental.AnnCollection.html
def load_anndatas(anndata_dir: os.PathLike, collection=True,
**anncollection_kwargs: Unpack[AnnCollectionKwargs]) -> AnnCollection:
"""Lazily loads a directory of `AnnData` objects into an `AnnCollection`.
The concatenation happens across the `.obs` axis.
For `AnnCollection` kwargs,
see https://anndata.readthedocs.io/en/latest/generated/anndata.experimental.AnnCollection.html

Args:
anndata_dir (os.PathLike): The directory containing the `AnnData` objects.
collection (bool): Whether to return a collection or a single merged AnnData object.

Returns:
AnnCollection: The `AnnCollection` containing the `AnnData` objects.
Expand All @@ -1034,7 +1038,11 @@ def load_anndatas(anndata_dir: os.PathLike, **anncollection_kwargs: Unpack[AnnCo
anndata_dir = pathlib.Path(anndata_dir)

adata_zarr_stores = {f.stem: read_zarr(f) for f in ns.natsorted(anndata_dir.glob("*.zarr"))}
return AnnCollection(adatas=adata_zarr_stores, **anncollection_kwargs)
if collection:
return AnnCollection(adatas=adata_zarr_stores, **anncollection_kwargs)
else:
adata_zarr_list = list(adata_zarr_stores.values())
return anndata.concat(adata_zarr_list)


class AnnDataIterDataPipe(IterDataPipe):
Expand Down
30 changes: 12 additions & 18 deletions templates/Calculate_Mixing_Scores.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,8 @@
},
"outputs": [],
"source": [
"cell_table_path = os.path.join(base_dir, \"segmentation/cell_table/cell_table_size_normalized_cell_labels.csv\")\n",
"segmentation_dir = os.path.join(base_dir, \"segmentation/deepcell_output\")\n",
"\n",
"anndata_dir = os.path.join(base_dir, \"anndata\")\n",
"spatial_analysis_dir = os.path.join(base_dir, \"spatial_analysis\")\n",
"dist_mat_dir = os.path.join(spatial_analysis_dir, \"dist_mats\")\n",
"neighbors_mat_dir = os.path.join(spatial_analysis_dir, \"neighborhood_mats\")\n",
"\n",
"# new directory to store mixing score results\n",
Expand All @@ -131,18 +128,13 @@
},
"outputs": [],
"source": [
"# create the dist_mat_dir directory if it doesn't exist\n",
"if not os.path.exists(dist_mat_dir):\n",
" os.makedirs(dist_mat_dir)\n",
" spatial_analysis_utils.calc_dist_matrix(segmentation_dir, dist_mat_dir)\n",
" \n",
"# create the neighbors_mat_dir directory if it doesn't exist\n",
"if not os.path.exists(neighbors_mat_dir):\n",
" os.makedirs(neighbors_mat_dir)\n",
" \n",
"# create mixing directory\n",
"if not os.path.exists(mixing_score_dir):\n",
" os.makedirs(mixing_score_dir)"
"# generate distance matrices if needed\n",
"spatial_analysis_utils.calc_dist_matrix(anndata_dir)\n",
"\n",
"# create neighbors matrix and mixing score directories\n",
"for directory in [neighbors_mat_dir, mixing_score_dir]:\n",
" if not os.path.exists(directory):\n",
" os.makedirs(directory)"
]
},
{
Expand All @@ -164,7 +156,9 @@
},
"outputs": [],
"source": [
"all_data = pd.read_csv(cell_table_path)\n",
"# Read cell table, only fovs in the cell table will be included in the analysis\n",
"anndata_table = load_anndatas(anndata_dir=anndata_dir, collection=False, join_obs=\"inner\", join_obsm=\"inner\")\n",
"all_data = anndata_table.obs\n",
"all_fovs = all_data[settings.FOV_ID].unique()"
]
},
Expand Down Expand Up @@ -443,7 +437,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.7"
},
"vscode": {
"interpreter": {
Expand Down
Loading
Loading