From f2e2c5f8cb5ed2146e680f82263ee345851214ac Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Wed, 13 Mar 2024 10:36:00 +0800 Subject: [PATCH] fix csc, csr matrices conversion issues --- anndata-hdf5/Cargo.toml | 10 +- anndata/Cargo.toml | 2 +- anndata/src/data.rs | 4 +- anndata/src/data/array.rs | 18 +-- anndata/src/data/array/ndarray.rs | 4 +- anndata/src/data/array/sparse/csc.rs | 91 +++++------ anndata/src/data/array/sparse/csr.rs | 91 +++++------ anndata/src/data/array/sparse/noncanonical.rs | 2 +- pyanndata/Cargo.toml | 4 +- pyanndata/src/anndata/backed.rs | 144 +++++++++++++++--- pyanndata/src/data/array.rs | 16 +- pyanndata/src/data/instance.rs | 4 +- python/tests/test_subset.py | 51 ++++--- 13 files changed, 289 insertions(+), 152 deletions(-) diff --git a/anndata-hdf5/Cargo.toml b/anndata-hdf5/Cargo.toml index cfa40f3..3068f15 100644 --- a/anndata-hdf5/Cargo.toml +++ b/anndata-hdf5/Cargo.toml @@ -11,11 +11,15 @@ repository = "https://github.com/kaizhang/anndata-rs" homepage = "https://github.com/kaizhang/anndata-rs" [dependencies] -#anndata = "0.3" -anndata = { path = '../anndata' } +anndata = "0.3" anyhow = "1.0" hdf5 = { version = "0.8" } hdf5-sys = { version = "0.8", features = ["static", "zlib", "threadsafe"] } #libz-sys = { version = "1", features = ["zlib-ng"], default-features = false } libz-sys = { version = "1", features = ["libc"], default-features = false } -ndarray = { version = "0.15" } \ No newline at end of file +ndarray = { version = "0.15" } + +[dev-dependencies] +tempfile = "3.2" +rand = "0.8.5" +ndarray-rand = "0.14" \ No newline at end of file diff --git a/anndata/Cargo.toml b/anndata/Cargo.toml index 3e4ef49..aa88fe3 100644 --- a/anndata/Cargo.toml +++ b/anndata/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "anndata" -version = "0.3.1" +version = "0.3.2" edition = "2021" rust-version = "1.70" authors = ["Kai Zhang "] diff --git a/anndata/src/data.rs b/anndata/src/data.rs index fec9c22..621c48a 100644 --- a/anndata/src/data.rs +++ b/anndata/src/data.rs @@ -97,7 +97,7 @@ macro_rules! impl_try_from_for_scalar { fn try_from(data: Data) -> Result { match data { Data::Scalar(DynScalar::$from(data)) => Ok(data), - _ => bail!("Cannot convert data to $to"), + _ => bail!("Cannot convert data to {}", stringify!($to)), } } } @@ -107,7 +107,7 @@ macro_rules! impl_try_from_for_scalar { fn try_from(v: Data) -> Result { match v { Data::ArrayData(data) => data.try_into(), - _ => bail!("Cannot convert data to $to Array"), + _ => bail!("Cannot convert data to {} Array", stringify!($to)), } } } diff --git a/anndata/src/data/array.rs b/anndata/src/data/array.rs index 67cbd9b..d7ef9cd 100644 --- a/anndata/src/data/array.rs +++ b/anndata/src/data/array.rs @@ -70,7 +70,7 @@ impl TryFrom for DynArray { fn try_from(value: ArrayData) -> Result { match value { ArrayData::Array(data) => Ok(data), - _ => bail!("Cannot convert {:?} to DynArray", value), + _ => bail!("Cannot convert {:?} to DynArray", value.data_type()), } } } @@ -80,7 +80,7 @@ impl TryFrom for DynCsrMatrix { fn try_from(value: ArrayData) -> Result { match value { ArrayData::CsrMatrix(data) => Ok(data), - _ => bail!("Cannot convert {:?} to DynCsrMatrix", value), + _ => bail!("Cannot convert {:?} to DynCsrMatrix", value.data_type()), } } } @@ -91,7 +91,7 @@ impl TryFrom for DynCsrNonCanonical { match value { ArrayData::CsrNonCanonical(data) => Ok(data), ArrayData::CsrMatrix(data) => Ok(data.into()), - _ => bail!("Cannot convert {:?} to DynCsrNonCanonical", value), + _ => bail!("Cannot convert {:?} to DynCsrNonCanonical", value.data_type()), } } } @@ -101,7 +101,7 @@ impl TryFrom for DynCscMatrix { fn try_from(value: ArrayData) -> Result { match value { ArrayData::CscMatrix(data) => Ok(data), - _ => bail!("Cannot convert {:?} to DynCscMatrix", value), + _ => bail!("Cannot convert {:?} to DynCscMatrix", value.data_type()), } } } @@ -111,7 +111,7 @@ impl TryFrom for DataFrame { fn try_from(value: ArrayData) -> Result { match value { ArrayData::DataFrame(data) => Ok(data), - _ => bail!("Cannot convert {:?} to DataFrame", value), + _ => bail!("Cannot convert {:?} to DataFrame", value.data_type()), } } } @@ -145,7 +145,7 @@ macro_rules! impl_into_array_data { fn try_from(value: ArrayData) -> Result { match value { ArrayData::Array(data) => data.try_into(), - _ => bail!("Cannot convert {:?} to $ty Array", value), + _ => bail!("Cannot convert {:?} to {} Array", value.data_type(), stringify!($ty)), } } } @@ -154,7 +154,7 @@ macro_rules! impl_into_array_data { fn try_from(value: ArrayData) -> Result { match value { ArrayData::CsrMatrix(data) => data.try_into(), - _ => bail!("Cannot convert {:?} to $ty CsrMatrix", value), + _ => bail!("Cannot convert {:?} to {} CsrMatrix", value.data_type(), stringify!($ty)), } } } @@ -164,7 +164,7 @@ macro_rules! impl_into_array_data { match value { ArrayData::CsrNonCanonical(data) => data.try_into(), ArrayData::CsrMatrix(data) => DynCsrNonCanonical::from(data).try_into(), - _ => bail!("Cannot convert {:?} to $ty CsrNonCanonical", value), + _ => bail!("Cannot convert {:?} to {} CsrNonCanonical", value.data_type(), stringify!($ty)), } } } @@ -173,7 +173,7 @@ macro_rules! impl_into_array_data { fn try_from(value: ArrayData) -> Result { match value { ArrayData::CscMatrix(data) => data.try_into(), - _ => bail!("Cannot convert {:?} to $ty CsrMatrix", value), + _ => bail!("Cannot convert {:?} to {} CsrMatrix", value.data_type(), stringify!($ty)), } } } diff --git a/anndata/src/data/array/ndarray.rs b/anndata/src/data/array/ndarray.rs index 1fd0dff..4fd4218 100644 --- a/anndata/src/data/array/ndarray.rs +++ b/anndata/src/data/array/ndarray.rs @@ -43,7 +43,7 @@ impl TryFrom for CategoricalArray { fn try_from(v: DynArray) -> Result { match v { DynArray::Categorical(cat) => Ok(cat), - _ => bail!("Cannot convert {:?} to CategoricalArray", v), + _ => bail!("Cannot convert {:?} to CategoricalArray", v.data_type()), } } } @@ -67,7 +67,7 @@ macro_rules! impl_dyn_array_convert { } Ok(arr.into_dimensionality::()?) }, - _ => bail!("Cannot convert {:?} to ArrayD<$from_type>", v), + _ => bail!("Cannot convert {:?} to {} ArrayD", v.data_type(), stringify!($from_type)), } } } diff --git a/anndata/src/data/array/sparse/csc.rs b/anndata/src/data/array/sparse/csc.rs index 5408f2e..312574c 100644 --- a/anndata/src/data/array/sparse/csc.rs +++ b/anndata/src/data/array/sparse/csc.rs @@ -44,7 +44,7 @@ macro_rules! impl_into_dyn_csc { fn try_from(data: DynCscMatrix) -> Result { match data { DynCscMatrix::$to_type(data) => Ok(data), - _ => bail!("Cannot convert to CscMatrix<$from_type>"), + _ => bail!("Cannot convert {:?} to {} CscMatrix", data.data_type(), stringify!($from_type)), } } } @@ -567,18 +567,18 @@ impl WriteData for CscMatrix { impl ReadData for CscMatrix { fn read(container: &DataContainer) -> Result { - match container.encoding_type()? { - DataType::CscMatrix(_) => { - let group = container.as_group()?; - let shape: Vec = group.read_array_attr("shape")?.to_vec(); - let data = group.open_dataset("data")?.read_array::<_, Ix1>()?.into_raw_vec(); - let indptr: Vec = group.open_dataset("indptr")?.read_array::<_, Ix1>()?.into_raw_vec(); - let indices: Vec = group.open_dataset("indices")?.read_array::<_, Ix1>()?.into_raw_vec(); - CscMatrix::try_from_csc_data( - shape[0], shape[1], indptr, indices, data - ).map_err(|e| anyhow::anyhow!("{}", e)) - }, - ty => bail!("cannot read csc matrix from container with data type {:?}", ty), + let data_type = container.encoding_type()?; + if let DataType::CscMatrix(_) = data_type { + let group = container.as_group()?; + let shape: Vec = group.read_array_attr("shape")?.to_vec(); + let data = group.open_dataset("data")?.read_array::<_, Ix1>()?.into_raw_vec(); + let indptr: Vec = group.open_dataset("indptr")?.read_array::<_, Ix1>()?.into_raw_vec(); + let indices: Vec = group.open_dataset("indices")?.read_array::<_, Ix1>()?.into_raw_vec(); + CscMatrix::try_from_csc_data( + shape[0], shape[1], indptr, indices, data + ).map_err(|e| anyhow::anyhow!("{}", e)) + } else { + bail!("cannot read csc matrix from container with data type {:?}", data_type) } } } @@ -598,41 +598,46 @@ impl ReadArrayData for CscMatrix { B: Backend, S: AsRef, { - if info.as_ref().len() != 2 { - panic!("index must have length 2"); - } + let data_type = container.encoding_type()?; + if let DataType::CscMatrix(_) = data_type { + if info.as_ref().len() != 2 { + panic!("index must have length 2"); + } - if info.iter().all(|s| s.as_ref().is_full()) { - return Self::read(container); - } + if info.iter().all(|s| s.as_ref().is_full()) { + return Self::read(container); + } - let data = if let SelectInfoElem::Slice(s) = info[1].as_ref() { - let group = container.as_group()?; - let indptr_slice = if let Some(end) = s.end { - SelectInfoElem::from(s.start .. end + 1) + let data = if let SelectInfoElem::Slice(s) = info[1].as_ref() { + let group = container.as_group()?; + let indptr_slice = if let Some(end) = s.end { + SelectInfoElem::from(s.start .. end + 1) + } else { + SelectInfoElem::from(s.start ..) + }; + let mut indptr: Vec = group + .open_dataset("indptr")? + .read_array_slice(&[indptr_slice])? + .to_vec(); + let lo = indptr[0]; + let slice = SelectInfoElem::from(lo .. indptr[indptr.len() - 1]); + let data: Vec = group.open_dataset("data")?.read_array_slice(&[&slice])?.to_vec(); + let indices: Vec = group.open_dataset("indices")?.read_array_slice(&[&slice])?.to_vec(); + indptr.iter_mut().for_each(|x| *x -= lo); + CscMatrix::try_from_csc_data( + Self::get_shape(container)?[0], + indptr.len() - 1, + indptr, + indices, + data, + ).unwrap().select_axis(0, info[0].as_ref()) // selct axis 1, then select axis 0 elements } else { - SelectInfoElem::from(s.start ..) + Self::read(container)?.select(info) }; - let mut indptr: Vec = group - .open_dataset("indptr")? - .read_array_slice(&[indptr_slice])? - .to_vec(); - let lo = indptr[0]; - let slice = SelectInfoElem::from(lo .. indptr[indptr.len() - 1]); - let data: Vec = group.open_dataset("data")?.read_array_slice(&[&slice])?.to_vec(); - let indices: Vec = group.open_dataset("indices")?.read_array_slice(&[&slice])?.to_vec(); - indptr.iter_mut().for_each(|x| *x -= lo); - CscMatrix::try_from_csc_data( - Self::get_shape(container)?[0], - indptr.len() - 1, - indptr, - indices, - data, - ).unwrap().select_axis(0, info[0].as_ref()) // selct axis 1, then select axis 0 elements + Ok(data) } else { - Self::read(container)?.select(info) - }; - Ok(data) + bail!("cannot read csc matrix from container with data type {:?}", data_type) + } } } diff --git a/anndata/src/data/array/sparse/csr.rs b/anndata/src/data/array/sparse/csr.rs index 57fab51..b6bba91 100644 --- a/anndata/src/data/array/sparse/csr.rs +++ b/anndata/src/data/array/sparse/csr.rs @@ -44,7 +44,7 @@ macro_rules! impl_into_dyn_csr { fn try_from(data: DynCsrMatrix) -> Result { match data { DynCsrMatrix::$to_type(data) => Ok(data), - _ => bail!("Cannot convert to CsrMatrix<$from_type>"), + _ => bail!("Cannot convert {:?} to {} CsrMatrix", data.data_type(), stringify!($from_type)), } } } @@ -571,18 +571,18 @@ impl WriteData for CsrMatrix { impl ReadData for CsrMatrix { fn read(container: &DataContainer) -> Result { - match container.encoding_type()? { - DataType::CsrMatrix(_) => { - let group = container.as_group()?; - let shape: Vec = group.read_array_attr("shape")?.to_vec(); - let data = group.open_dataset("data")?.read_array::<_, Ix1>()?.into_raw_vec(); - let indptr: Vec = group.open_dataset("indptr")?.read_array::<_, Ix1>()?.into_raw_vec(); - let indices: Vec = group.open_dataset("indices")?.read_array::<_, Ix1>()?.into_raw_vec(); - CsrMatrix::try_from_csr_data( - shape[0], shape[1], indptr, indices, data - ).map_err(|e| anyhow!("cannot read csr matrix: {}", e)) - }, - ty => bail!("cannot read csr matrix from container with data type {:?}", ty), + let data_type = container.encoding_type()?; + if let DataType::CsrMatrix(_) = data_type { + let group = container.as_group()?; + let shape: Vec = group.read_array_attr("shape")?.to_vec(); + let data = group.open_dataset("data")?.read_array::<_, Ix1>()?.into_raw_vec(); + let indptr: Vec = group.open_dataset("indptr")?.read_array::<_, Ix1>()?.into_raw_vec(); + let indices: Vec = group.open_dataset("indices")?.read_array::<_, Ix1>()?.into_raw_vec(); + CsrMatrix::try_from_csr_data( + shape[0], shape[1], indptr, indices, data + ).map_err(|e| anyhow!("cannot read csr matrix: {}", e)) + } else { + bail!("cannot read csr matrix from container with data type {:?}", data_type) } } } @@ -602,41 +602,46 @@ impl ReadArrayData for CsrMatrix { B: Backend, S: AsRef, { - if info.as_ref().len() != 2 { - panic!("index must have length 2"); - } + let data_type = container.encoding_type()?; + if let DataType::CsrMatrix(_) = data_type { + if info.as_ref().len() != 2 { + panic!("index must have length 2"); + } - if info.iter().all(|s| s.as_ref().is_full()) { - return Self::read(container); - } + if info.iter().all(|s| s.as_ref().is_full()) { + return Self::read(container); + } - let data = if let SelectInfoElem::Slice(s) = info[0].as_ref() { - let group = container.as_group()?; - let indptr_slice = if let Some(end) = s.end { - SelectInfoElem::from(s.start .. end + 1) + let data = if let SelectInfoElem::Slice(s) = info[0].as_ref() { + let group = container.as_group()?; + let indptr_slice = if let Some(end) = s.end { + SelectInfoElem::from(s.start .. end + 1) + } else { + SelectInfoElem::from(s.start ..) + }; + let mut indptr: Vec = group + .open_dataset("indptr")? + .read_array_slice(&[indptr_slice])? + .to_vec(); + let lo = indptr[0]; + let slice = SelectInfoElem::from(lo .. indptr[indptr.len() - 1]); + let data: Vec = group.open_dataset("data")?.read_array_slice(&[&slice])?.to_vec(); + let indices: Vec = group.open_dataset("indices")?.read_array_slice(&[&slice])?.to_vec(); + indptr.iter_mut().for_each(|x| *x -= lo); + CsrMatrix::try_from_csr_data( + indptr.len() - 1, + Self::get_shape(container)?[1], + indptr, + indices, + data, + ).unwrap().select_axis(1, info[1].as_ref()) } else { - SelectInfoElem::from(s.start ..) + Self::read(container)?.select(info) }; - let mut indptr: Vec = group - .open_dataset("indptr")? - .read_array_slice(&[indptr_slice])? - .to_vec(); - let lo = indptr[0]; - let slice = SelectInfoElem::from(lo .. indptr[indptr.len() - 1]); - let data: Vec = group.open_dataset("data")?.read_array_slice(&[&slice])?.to_vec(); - let indices: Vec = group.open_dataset("indices")?.read_array_slice(&[&slice])?.to_vec(); - indptr.iter_mut().for_each(|x| *x -= lo); - CsrMatrix::try_from_csr_data( - indptr.len() - 1, - Self::get_shape(container)?[1], - indptr, - indices, - data, - ).unwrap().select_axis(1, info[1].as_ref()) + Ok(data) } else { - Self::read(container)?.select(info) - }; - Ok(data) + bail!("cannot read csr matrix from container with data type {:?}", data_type) + } } } diff --git a/anndata/src/data/array/sparse/noncanonical.rs b/anndata/src/data/array/sparse/noncanonical.rs index 25bef75..bcc34c1 100644 --- a/anndata/src/data/array/sparse/noncanonical.rs +++ b/anndata/src/data/array/sparse/noncanonical.rs @@ -64,7 +64,7 @@ macro_rules! impl_into_dyn_csr { fn try_from(data: DynCsrNonCanonical) -> Result { match data { DynCsrNonCanonical::$to_type(data) => Ok(data), - _ => bail!("Cannot convert to CsrNonCanonical<$from_type>"), + _ => bail!("Cannot convert {:?} to {} CsrNonCanonical", data.data_type(), stringify!($from_type)), } } } diff --git a/pyanndata/Cargo.toml b/pyanndata/Cargo.toml index 2e327af..82efb34 100644 --- a/pyanndata/Cargo.toml +++ b/pyanndata/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pyanndata" -version = "0.3.1" +version = "0.3.2" edition = "2021" rust-version = "1.70" authors = ["Kai Zhang "] @@ -12,7 +12,7 @@ homepage = "https://github.com/kaizhang/anndata-rs" keywords = ["data"] [dependencies] -anndata = "0.3" +anndata = "0.3.2" anndata-hdf5 = "0.2" anyhow = "1.0" downcast-rs = "1.2" diff --git a/pyanndata/src/anndata/backed.rs b/pyanndata/src/anndata/backed.rs index 74a041e..206a9b2 100644 --- a/pyanndata/src/anndata/backed.rs +++ b/pyanndata/src/anndata/backed.rs @@ -2,9 +2,9 @@ use crate::container::{PyArrayElem, PyAxisArrays, PyDataFrameElem, PyElemCollect use crate::data::{isinstance_of_pandas, to_select_elem, PyArrayData, PyData}; use crate::anndata::PyAnnData; -use anndata; +use anndata::{self, ArrayElemOp, ArrayOp, AxisArraysOp, Data, ElemCollectionOp}; use anndata::container::Slot; -use anndata::data::{DataFrameIndex, SelectInfoElem}; +use anndata::data::{DataFrameIndex, SelectInfoElem, BoundedSelectInfoElem}; use anndata::{AnnDataOp, ArrayData, Backend}; use anndata_hdf5::H5; use anyhow::{bail, Result}; @@ -352,31 +352,37 @@ impl AnnData { /// var_indices /// var indices /// out: Path | None - /// File name of the output `.h5ad` file. If provided, the result will be - /// saved to a new file and the original AnnData object remains unchanged. + /// File name of the output `.h5ad` file. When `inplace=False`, + /// the result is written to this file. If `None`, an in-memory AnnData + /// is returned. This parameter is ignored when `inplace=True`. + /// inplace: bool + /// Whether to modify the AnnData object in place or return a new AnnData object. /// backend: str | None + /// The backend to use. Currently "hdf5" is the only supported backend. /// /// Returns /// ------- /// Optional[AnnData] #[pyo3( - signature = (obs_indices=None, var_indices=None, out=None, backend=None), - text_signature = "($self, obs_indices=None, var_indices=None, out=None, backend=None)", + signature = (obs_indices=None, var_indices=None, *, out=None, inplace=true, backend=None), + text_signature = "($self, obs_indices=None, var_indices=None, *, out=None, inplace=True, backend=None)", )] pub fn subset( &self, + py: Python<'_>, obs_indices: Option<&PyAny>, var_indices: Option<&PyAny>, out: Option, + inplace: bool, backend: Option<&str>, - ) -> Result> { + ) -> Result> { let i = obs_indices .map(|x| self.select_obs(x).unwrap()) .unwrap_or(SelectInfoElem::full()); let j = var_indices .map(|x| self.select_var(x).unwrap()) .unwrap_or(SelectInfoElem::full()); - self.0.subset(&[i, j], out, backend) + self.0.subset(py, &[i, j], out, inplace, backend) } /// Return an iterator over the rows of the data matrix X. @@ -526,10 +532,12 @@ trait AnnDataTrait: Send + Downcast { fn subset( &self, + py: Python<'_>, slice: &[SelectInfoElem], - out: Option, + file: Option, + inplace: bool, backend: Option<&str>, - ) -> Result>; + ) -> Result>; fn chunked_x(&self, chunk_size: usize) -> PyChunkedArray; @@ -785,21 +793,121 @@ impl AnnDataTrait for InnerAnnData { fn subset( &self, + py: Python<'_>, slice: &[SelectInfoElem], - out: Option, + file: Option, + inplace: bool, backend: Option<&str>, - ) -> Result> { - if let Some(out) = out { + ) -> Result> { + let inner = self.adata.inner(); + if inplace { + inner.subset(slice)?; + Ok(None) + } else if let Some(file) = file { match backend.unwrap_or(H5::NAME) { H5::NAME => { - self.adata.inner().write_select::(slice, &out)?; - Ok(Some(AnnData::new_from(out, "r+", backend)?)) + inner.write_select::(slice, &file)?; + Ok(Some(AnnData::new_from(file, "r+", backend)?.into_py(py))) } x => bail!("Unsupported backend: {}", x), } } else { - self.adata.inner().subset(slice)?; - Ok(None) + let adata = PyAnnData::new(py)?; + let obs_slice = BoundedSelectInfoElem::new(&slice[0], inner.n_obs()); + let var_slice = BoundedSelectInfoElem::new(&slice[1], inner.n_vars()); + let n_obs = obs_slice.len(); + let n_vars = var_slice.len(); + adata.set_n_obs(n_obs)?; + adata.set_n_vars(n_vars)?; + + { + if let Some(x) = inner.x().slice::(slice)? { + adata.set_x(x)?; + } + } + { + // Set obs and var + let obs_idx = inner.obs_names(); + if !obs_idx.is_empty() { + adata.set_obs_names(obs_idx.select(&slice[0]))?; + adata.set_obs(inner.read_obs()?.select_axis(0, &slice[0]))?; + } + let var_idx = inner.var_names(); + if !var_idx.is_empty() { + adata.set_var_names(var_idx.select(&slice[1]))?; + adata.set_var(inner.read_var()?.select_axis(0, &slice[1]))?; + } + } + { + // Set uns + inner.uns() + .keys() + .into_iter() + .try_for_each(|k| adata.uns().add(&k, inner.uns().get_item::(&k)?.unwrap()))?; + } + { + // Set obsm + inner.obsm().keys().into_iter().try_for_each(|k| { + adata.obsm().add( + &k, + inner.obsm() + .get(&k) + .unwrap() + .slice_axis::(0, &slice[0])? + .unwrap(), + ) + })?; + } + { + // Set obsp + inner.obsp().keys().into_iter().try_for_each(|k| { + let elem = inner.obsp().get(&k).unwrap(); + let n = elem.shape().unwrap().ndim(); + let mut select = vec![SelectInfoElem::full(); n]; + select[0] = slice[0].clone(); + select[1] = slice[0].clone(); + let data = elem.slice::(select)?.unwrap(); + adata.obsp().add(&k, data) + })?; + } + { + // Set varm + inner.varm().keys().into_iter().try_for_each(|k| { + adata.varm().add( + &k, + inner.varm() + .get(&k) + .unwrap() + .slice_axis::(0, &slice[1])? + .unwrap(), + ) + })?; + } + { + // Set varp + inner.varp().keys().into_iter().try_for_each(|k| { + let elem = inner.varp().get(&k).unwrap(); + let n = elem.shape().unwrap().ndim(); + let mut select = vec![SelectInfoElem::full(); n]; + select[0] = slice[1].clone(); + select[1] = slice[1].clone(); + let data = elem.slice::(select)?.unwrap(); + adata.varp().add(&k, data) + })?; + } + { + // Set layers + inner.layers().keys().into_iter().try_for_each(|k| { + let elem = inner.layers().get(&k).unwrap(); + let n = elem.shape().unwrap().ndim(); + let mut select = vec![SelectInfoElem::full(); n]; + select[0] = slice[0].clone(); + select[1] = slice[1].clone(); + let data = elem.slice::(select)?.unwrap(); + adata.layers().add(&k, data) + })?; + } + Ok(Some(adata.to_object(py))) } } @@ -934,4 +1042,4 @@ impl StackedAnnDataTrait for Slot> { format!("{}", self.inner().deref()) } } -} \ No newline at end of file +} diff --git a/pyanndata/src/data/array.rs b/pyanndata/src/data/array.rs index a3f5227..23a3a0c 100644 --- a/pyanndata/src/data/array.rs +++ b/pyanndata/src/data/array.rs @@ -1,8 +1,8 @@ -use crate::data::{FromPython, IntoPython}; +use crate::data::{isinstance_of_csc, isinstance_of_csr, FromPython, IntoPython}; use ndarray::ArrayD; use nalgebra_sparse::{CsrMatrix, CscMatrix}; -use pyo3::prelude::*; +use pyo3::{exceptions::PyTypeError, prelude::*}; use anndata::data::{DynArray, DynCsrMatrix, DynCscMatrix, DynCsrNonCanonical, CsrNonCanonical}; use numpy::{PyReadonlyArrayDyn, IntoPyArray}; @@ -106,6 +106,10 @@ impl FromPython<'_> for DynCsrMatrix { Ok(res) } + if !isinstance_of_csr(ob.py(), ob)? { + return Err(PyTypeError::new_err("not a csr matrix")) + } + let shape: Vec = ob.getattr("shape")?.extract()?; let indices = extract_csr_indicies(ob.getattr("indices")?)?; let indptr = extract_csr_indicies(ob.getattr("indptr")?)?; @@ -151,6 +155,10 @@ impl FromPython<'_> for DynCsrNonCanonical { Ok(res) } + if !isinstance_of_csr(ob.py(), ob)? { + return Err(PyTypeError::new_err("not a csr matrix")) + } + let shape: Vec = ob.getattr("shape")?.extract()?; let indices = extract_csr_indicies(ob.getattr("indices")?)?; let indptr = extract_csr_indicies(ob.getattr("indptr")?)?; @@ -196,6 +204,10 @@ impl FromPython<'_> for DynCscMatrix { Ok(res) } + if !isinstance_of_csc(ob.py(), ob)? { + return Err(PyTypeError::new_err("not a csc matrix")) + } + let shape: Vec = ob.getattr("shape")?.extract()?; let indices = extract_csc_indicies(ob.getattr("indices")?)?; let indptr = extract_csc_indicies(ob.getattr("indptr")?)?; diff --git a/pyanndata/src/data/instance.rs b/pyanndata/src/data/instance.rs index 9849587..6b00b30 100644 --- a/pyanndata/src/data/instance.rs +++ b/pyanndata/src/data/instance.rs @@ -2,7 +2,7 @@ use pyo3::{prelude::*, types::PyType, PyResult, Python}; pub fn isinstance_of_csr<'py>(py: Python<'py>, obj: &'py PyAny) -> PyResult { obj.is_instance( - py.import("scipy.sparse.csr")? + py.import("scipy.sparse")? .getattr("csr_matrix")? .downcast::() .unwrap(), @@ -11,7 +11,7 @@ pub fn isinstance_of_csr<'py>(py: Python<'py>, obj: &'py PyAny) -> PyResult(py: Python<'py>, obj: &'py PyAny) -> PyResult { obj.is_instance( - py.import("scipy.sparse.csc")? + py.import("scipy.sparse")? .getattr("csc_matrix")? .downcast::() .unwrap(), diff --git a/python/tests/test_subset.py b/python/tests/test_subset.py index 6e54818..ed1e141 100644 --- a/python/tests/test_subset.py +++ b/python/tests/test_subset.py @@ -40,29 +40,32 @@ def test_subset(x, obs, obsm, obsp, varm, varp, indices, indices2, tmp_path): adata.varp = dict(x=varp, y=csr_matrix(varp)) adata.layers["raw"] = x - adata_subset = adata.subset(indices, indices2, out = h5ad(tmp_path)) - np.testing.assert_array_equal(adata_subset.X[:], x[np.ix_(indices, indices2)]) - np.testing.assert_array_equal(adata_subset.obs["txt"], np.array(list(obs[i] for i in indices))) - np.testing.assert_array_equal(adata_subset.obsm["x"], obsm[indices, :]) - np.testing.assert_array_equal(adata_subset.obsm["y"].todense(), obsm[indices, :]) - np.testing.assert_array_equal(adata_subset.obsp["x"], obsp[np.ix_(indices, indices)]) - np.testing.assert_array_equal(adata_subset.obsp["y"].todense(), obsp[np.ix_(indices, indices)]) - np.testing.assert_array_equal(adata_subset.varm["x"], varm[indices2, :]) - np.testing.assert_array_equal(adata_subset.varm["y"].todense(), varm[indices2, :]) - np.testing.assert_array_equal(adata_subset.varp["x"], varp[np.ix_(indices2, indices2)]) - np.testing.assert_array_equal(adata_subset.varp["y"].todense(), varp[np.ix_(indices2, indices2)]) - np.testing.assert_array_equal(adata_subset.layers["raw"], x[np.ix_(indices, indices2)]) - - adata_subset = adata.subset([str(x) for x in indices], out = h5ad(tmp_path)) - np.testing.assert_array_equal(adata_subset.X[:], x[indices, :]) - np.testing.assert_array_equal(adata_subset.obs["txt"], np.array(list(obs[i] for i in indices))) - np.testing.assert_array_equal(adata_subset.obsm["x"], obsm[indices, :]) - np.testing.assert_array_equal(adata_subset.obsm["y"].todense(), obsm[indices, :]) - np.testing.assert_array_equal(adata_subset.layers["raw"], x[indices, :]) - - adata_subset = adata.subset(pl.Series([str(x) for x in indices]), out = h5ad(tmp_path)) - np.testing.assert_array_equal(adata_subset.X[:], x[indices, :]) - np.testing.assert_array_equal(adata_subset.layers["raw"], x[indices, :]) + for adata_subset in [ adata.subset(indices, indices2, out=h5ad(tmp_path), inplace=False), + adata.subset(indices, indices2, inplace=False)]: + np.testing.assert_array_equal(adata_subset.X[:], x[np.ix_(indices, indices2)]) + np.testing.assert_array_equal(adata_subset.obs["txt"], np.array(list(obs[i] for i in indices))) + np.testing.assert_array_equal(adata_subset.obsm["x"], obsm[indices, :]) + np.testing.assert_array_equal(adata_subset.obsm["y"].todense(), obsm[indices, :]) + np.testing.assert_array_equal(adata_subset.obsp["x"], obsp[np.ix_(indices, indices)]) + np.testing.assert_array_equal(adata_subset.obsp["y"].todense(), obsp[np.ix_(indices, indices)]) + np.testing.assert_array_equal(adata_subset.varm["x"], varm[indices2, :]) + np.testing.assert_array_equal(adata_subset.varm["y"].todense(), varm[indices2, :]) + np.testing.assert_array_equal(adata_subset.varp["x"], varp[np.ix_(indices2, indices2)]) + np.testing.assert_array_equal(adata_subset.varp["y"].todense(), varp[np.ix_(indices2, indices2)]) + np.testing.assert_array_equal(adata_subset.layers["raw"], x[np.ix_(indices, indices2)]) + + for adata_subset in [ adata.subset([str(x) for x in indices], out=h5ad(tmp_path), inplace=False), + adata.subset([str(x) for x in indices], inplace=False) ]: + np.testing.assert_array_equal(adata_subset.X[:], x[indices, :]) + np.testing.assert_array_equal(adata_subset.obs["txt"], np.array(list(obs[i] for i in indices))) + np.testing.assert_array_equal(adata_subset.obsm["x"], obsm[indices, :]) + np.testing.assert_array_equal(adata_subset.obsm["y"].todense(), obsm[indices, :]) + np.testing.assert_array_equal(adata_subset.layers["raw"], x[indices, :]) + + for adata_subset in [ adata.subset(pl.Series([str(x) for x in indices]), out=h5ad(tmp_path), inplace=False), + adata.subset(pl.Series([str(x) for x in indices]), inplace=False) ]: + np.testing.assert_array_equal(adata_subset.X[:], x[indices, :]) + np.testing.assert_array_equal(adata_subset.layers["raw"], x[indices, :]) adata.subset(indices) np.testing.assert_array_equal(adata.X[:], x[indices, :]) @@ -195,4 +198,4 @@ def test_anndataset_subset(x1, x2, x3, idx1, idx2, idx3, tmp_path): # Check if open is OK os.chdir(tmp_path) - dataset_subset.subset([], out = "a_copy") \ No newline at end of file + dataset_subset.subset([], out = "a_copy")