Skip to content

Commit

Permalink
fix csc, csr matrices conversion issues
Browse files Browse the repository at this point in the history
  • Loading branch information
kaizhang committed Mar 13, 2024
1 parent 6f7402f commit f2e2c5f
Show file tree
Hide file tree
Showing 13 changed files with 289 additions and 152 deletions.
10 changes: 7 additions & 3 deletions anndata-hdf5/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ repository = "https://github.com/kaizhang/anndata-rs"
homepage = "https://github.com/kaizhang/anndata-rs"

[dependencies]
#anndata = "0.3"
anndata = { path = '../anndata' }
anndata = "0.3"
anyhow = "1.0"
hdf5 = { version = "0.8" }
hdf5-sys = { version = "0.8", features = ["static", "zlib", "threadsafe"] }
#libz-sys = { version = "1", features = ["zlib-ng"], default-features = false }
libz-sys = { version = "1", features = ["libc"], default-features = false }
ndarray = { version = "0.15" }
ndarray = { version = "0.15" }

[dev-dependencies]
tempfile = "3.2"
rand = "0.8.5"
ndarray-rand = "0.14"
2 changes: 1 addition & 1 deletion anndata/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "anndata"
version = "0.3.1"
version = "0.3.2"
edition = "2021"
rust-version = "1.70"
authors = ["Kai Zhang <kai@kzhang.org>"]
Expand Down
4 changes: 2 additions & 2 deletions anndata/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ macro_rules! impl_try_from_for_scalar {
fn try_from(data: Data) -> Result<Self> {
match data {
Data::Scalar(DynScalar::$from(data)) => Ok(data),
_ => bail!("Cannot convert data to $to"),
_ => bail!("Cannot convert data to {}", stringify!($to)),
}
}
}
Expand All @@ -107,7 +107,7 @@ macro_rules! impl_try_from_for_scalar {
fn try_from(v: Data) -> Result<Self> {
match v {
Data::ArrayData(data) => data.try_into(),
_ => bail!("Cannot convert data to $to Array"),
_ => bail!("Cannot convert data to {} Array", stringify!($to)),
}
}
}
Expand Down
18 changes: 9 additions & 9 deletions anndata/src/data/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl TryFrom<ArrayData> for DynArray {
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
match value {
ArrayData::Array(data) => Ok(data),
_ => bail!("Cannot convert {:?} to DynArray", value),
_ => bail!("Cannot convert {:?} to DynArray", value.data_type()),
}
}
}
Expand All @@ -80,7 +80,7 @@ impl TryFrom<ArrayData> for DynCsrMatrix {
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
match value {
ArrayData::CsrMatrix(data) => Ok(data),
_ => bail!("Cannot convert {:?} to DynCsrMatrix", value),
_ => bail!("Cannot convert {:?} to DynCsrMatrix", value.data_type()),
}
}
}
Expand All @@ -91,7 +91,7 @@ impl TryFrom<ArrayData> for DynCsrNonCanonical {
match value {
ArrayData::CsrNonCanonical(data) => Ok(data),
ArrayData::CsrMatrix(data) => Ok(data.into()),
_ => bail!("Cannot convert {:?} to DynCsrNonCanonical", value),
_ => bail!("Cannot convert {:?} to DynCsrNonCanonical", value.data_type()),
}
}
}
Expand All @@ -101,7 +101,7 @@ impl TryFrom<ArrayData> for DynCscMatrix {
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
match value {
ArrayData::CscMatrix(data) => Ok(data),
_ => bail!("Cannot convert {:?} to DynCscMatrix", value),
_ => bail!("Cannot convert {:?} to DynCscMatrix", value.data_type()),
}
}
}
Expand All @@ -111,7 +111,7 @@ impl TryFrom<ArrayData> for DataFrame {
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
match value {
ArrayData::DataFrame(data) => Ok(data),
_ => bail!("Cannot convert {:?} to DataFrame", value),
_ => bail!("Cannot convert {:?} to DataFrame", value.data_type()),
}
}
}
Expand Down Expand Up @@ -145,7 +145,7 @@ macro_rules! impl_into_array_data {
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
match value {
ArrayData::Array(data) => data.try_into(),
_ => bail!("Cannot convert {:?} to $ty Array", value),
_ => bail!("Cannot convert {:?} to {} Array", value.data_type(), stringify!($ty)),
}
}
}
Expand All @@ -154,7 +154,7 @@ macro_rules! impl_into_array_data {
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
match value {
ArrayData::CsrMatrix(data) => data.try_into(),
_ => bail!("Cannot convert {:?} to $ty CsrMatrix", value),
_ => bail!("Cannot convert {:?} to {} CsrMatrix", value.data_type(), stringify!($ty)),
}
}
}
Expand All @@ -164,7 +164,7 @@ macro_rules! impl_into_array_data {
match value {
ArrayData::CsrNonCanonical(data) => data.try_into(),
ArrayData::CsrMatrix(data) => DynCsrNonCanonical::from(data).try_into(),
_ => bail!("Cannot convert {:?} to $ty CsrNonCanonical", value),
_ => bail!("Cannot convert {:?} to {} CsrNonCanonical", value.data_type(), stringify!($ty)),
}
}
}
Expand All @@ -173,7 +173,7 @@ macro_rules! impl_into_array_data {
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
match value {
ArrayData::CscMatrix(data) => data.try_into(),
_ => bail!("Cannot convert {:?} to $ty CsrMatrix", value),
_ => bail!("Cannot convert {:?} to {} CsrMatrix", value.data_type(), stringify!($ty)),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions anndata/src/data/array/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl TryFrom<DynArray> for CategoricalArray {
fn try_from(v: DynArray) -> Result<Self> {
match v {
DynArray::Categorical(cat) => Ok(cat),
_ => bail!("Cannot convert {:?} to CategoricalArray", v),
_ => bail!("Cannot convert {:?} to CategoricalArray", v.data_type()),
}
}
}
Expand All @@ -67,7 +67,7 @@ macro_rules! impl_dyn_array_convert {
}
Ok(arr.into_dimensionality::<D>()?)
},
_ => bail!("Cannot convert {:?} to ArrayD<$from_type>", v),
_ => bail!("Cannot convert {:?} to {} ArrayD", v.data_type(), stringify!($from_type)),
}
}
}
Expand Down
91 changes: 48 additions & 43 deletions anndata/src/data/array/sparse/csc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ macro_rules! impl_into_dyn_csc {
fn try_from(data: DynCscMatrix) -> Result<Self> {
match data {
DynCscMatrix::$to_type(data) => Ok(data),
_ => bail!("Cannot convert to CscMatrix<$from_type>"),
_ => bail!("Cannot convert {:?} to {} CscMatrix", data.data_type(), stringify!($from_type)),
}
}
}
Expand Down Expand Up @@ -567,18 +567,18 @@ impl<T: BackendData> WriteData for CscMatrix<T> {

impl<T: BackendData> ReadData for CscMatrix<T> {
fn read<B: Backend>(container: &DataContainer<B>) -> Result<Self> {
match container.encoding_type()? {
DataType::CscMatrix(_) => {
let group = container.as_group()?;
let shape: Vec<usize> = group.read_array_attr("shape")?.to_vec();
let data = group.open_dataset("data")?.read_array::<_, Ix1>()?.into_raw_vec();
let indptr: Vec<usize> = group.open_dataset("indptr")?.read_array::<_, Ix1>()?.into_raw_vec();
let indices: Vec<usize> = group.open_dataset("indices")?.read_array::<_, Ix1>()?.into_raw_vec();
CscMatrix::try_from_csc_data(
shape[0], shape[1], indptr, indices, data
).map_err(|e| anyhow::anyhow!("{}", e))
},
ty => bail!("cannot read csc matrix from container with data type {:?}", ty),
let data_type = container.encoding_type()?;
if let DataType::CscMatrix(_) = data_type {
let group = container.as_group()?;
let shape: Vec<usize> = group.read_array_attr("shape")?.to_vec();
let data = group.open_dataset("data")?.read_array::<_, Ix1>()?.into_raw_vec();
let indptr: Vec<usize> = group.open_dataset("indptr")?.read_array::<_, Ix1>()?.into_raw_vec();
let indices: Vec<usize> = group.open_dataset("indices")?.read_array::<_, Ix1>()?.into_raw_vec();
CscMatrix::try_from_csc_data(
shape[0], shape[1], indptr, indices, data
).map_err(|e| anyhow::anyhow!("{}", e))
} else {
bail!("cannot read csc matrix from container with data type {:?}", data_type)
}
}
}
Expand All @@ -598,41 +598,46 @@ impl<T: BackendData> ReadArrayData for CscMatrix<T> {
B: Backend,
S: AsRef<SelectInfoElem>,
{
if info.as_ref().len() != 2 {
panic!("index must have length 2");
}
let data_type = container.encoding_type()?;
if let DataType::CscMatrix(_) = data_type {
if info.as_ref().len() != 2 {
panic!("index must have length 2");
}

if info.iter().all(|s| s.as_ref().is_full()) {
return Self::read(container);
}
if info.iter().all(|s| s.as_ref().is_full()) {
return Self::read(container);
}

let data = if let SelectInfoElem::Slice(s) = info[1].as_ref() {
let group = container.as_group()?;
let indptr_slice = if let Some(end) = s.end {
SelectInfoElem::from(s.start .. end + 1)
let data = if let SelectInfoElem::Slice(s) = info[1].as_ref() {
let group = container.as_group()?;
let indptr_slice = if let Some(end) = s.end {
SelectInfoElem::from(s.start .. end + 1)
} else {
SelectInfoElem::from(s.start ..)
};
let mut indptr: Vec<usize> = group
.open_dataset("indptr")?
.read_array_slice(&[indptr_slice])?
.to_vec();
let lo = indptr[0];
let slice = SelectInfoElem::from(lo .. indptr[indptr.len() - 1]);
let data: Vec<T> = group.open_dataset("data")?.read_array_slice(&[&slice])?.to_vec();
let indices: Vec<usize> = group.open_dataset("indices")?.read_array_slice(&[&slice])?.to_vec();
indptr.iter_mut().for_each(|x| *x -= lo);
CscMatrix::try_from_csc_data(
Self::get_shape(container)?[0],
indptr.len() - 1,
indptr,
indices,
data,
).unwrap().select_axis(0, info[0].as_ref()) // selct axis 1, then select axis 0 elements
} else {
SelectInfoElem::from(s.start ..)
Self::read(container)?.select(info)
};
let mut indptr: Vec<usize> = group
.open_dataset("indptr")?
.read_array_slice(&[indptr_slice])?
.to_vec();
let lo = indptr[0];
let slice = SelectInfoElem::from(lo .. indptr[indptr.len() - 1]);
let data: Vec<T> = group.open_dataset("data")?.read_array_slice(&[&slice])?.to_vec();
let indices: Vec<usize> = group.open_dataset("indices")?.read_array_slice(&[&slice])?.to_vec();
indptr.iter_mut().for_each(|x| *x -= lo);
CscMatrix::try_from_csc_data(
Self::get_shape(container)?[0],
indptr.len() - 1,
indptr,
indices,
data,
).unwrap().select_axis(0, info[0].as_ref()) // selct axis 1, then select axis 0 elements
Ok(data)
} else {
Self::read(container)?.select(info)
};
Ok(data)
bail!("cannot read csc matrix from container with data type {:?}", data_type)
}
}
}

Expand Down
91 changes: 48 additions & 43 deletions anndata/src/data/array/sparse/csr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ macro_rules! impl_into_dyn_csr {
fn try_from(data: DynCsrMatrix) -> Result<Self> {
match data {
DynCsrMatrix::$to_type(data) => Ok(data),
_ => bail!("Cannot convert to CsrMatrix<$from_type>"),
_ => bail!("Cannot convert {:?} to {} CsrMatrix", data.data_type(), stringify!($from_type)),
}
}
}
Expand Down Expand Up @@ -571,18 +571,18 @@ impl<T: BackendData> WriteData for CsrMatrix<T> {

impl<T: BackendData> ReadData for CsrMatrix<T> {
fn read<B: Backend>(container: &DataContainer<B>) -> Result<Self> {
match container.encoding_type()? {
DataType::CsrMatrix(_) => {
let group = container.as_group()?;
let shape: Vec<usize> = group.read_array_attr("shape")?.to_vec();
let data = group.open_dataset("data")?.read_array::<_, Ix1>()?.into_raw_vec();
let indptr: Vec<usize> = group.open_dataset("indptr")?.read_array::<_, Ix1>()?.into_raw_vec();
let indices: Vec<usize> = group.open_dataset("indices")?.read_array::<_, Ix1>()?.into_raw_vec();
CsrMatrix::try_from_csr_data(
shape[0], shape[1], indptr, indices, data
).map_err(|e| anyhow!("cannot read csr matrix: {}", e))
},
ty => bail!("cannot read csr matrix from container with data type {:?}", ty),
let data_type = container.encoding_type()?;
if let DataType::CsrMatrix(_) = data_type {
let group = container.as_group()?;
let shape: Vec<usize> = group.read_array_attr("shape")?.to_vec();
let data = group.open_dataset("data")?.read_array::<_, Ix1>()?.into_raw_vec();
let indptr: Vec<usize> = group.open_dataset("indptr")?.read_array::<_, Ix1>()?.into_raw_vec();
let indices: Vec<usize> = group.open_dataset("indices")?.read_array::<_, Ix1>()?.into_raw_vec();
CsrMatrix::try_from_csr_data(
shape[0], shape[1], indptr, indices, data
).map_err(|e| anyhow!("cannot read csr matrix: {}", e))
} else {
bail!("cannot read csr matrix from container with data type {:?}", data_type)
}
}
}
Expand All @@ -602,41 +602,46 @@ impl<T: BackendData> ReadArrayData for CsrMatrix<T> {
B: Backend,
S: AsRef<SelectInfoElem>,
{
if info.as_ref().len() != 2 {
panic!("index must have length 2");
}
let data_type = container.encoding_type()?;
if let DataType::CsrMatrix(_) = data_type {
if info.as_ref().len() != 2 {
panic!("index must have length 2");
}

if info.iter().all(|s| s.as_ref().is_full()) {
return Self::read(container);
}
if info.iter().all(|s| s.as_ref().is_full()) {
return Self::read(container);
}

let data = if let SelectInfoElem::Slice(s) = info[0].as_ref() {
let group = container.as_group()?;
let indptr_slice = if let Some(end) = s.end {
SelectInfoElem::from(s.start .. end + 1)
let data = if let SelectInfoElem::Slice(s) = info[0].as_ref() {
let group = container.as_group()?;
let indptr_slice = if let Some(end) = s.end {
SelectInfoElem::from(s.start .. end + 1)
} else {
SelectInfoElem::from(s.start ..)
};
let mut indptr: Vec<usize> = group
.open_dataset("indptr")?
.read_array_slice(&[indptr_slice])?
.to_vec();
let lo = indptr[0];
let slice = SelectInfoElem::from(lo .. indptr[indptr.len() - 1]);
let data: Vec<T> = group.open_dataset("data")?.read_array_slice(&[&slice])?.to_vec();
let indices: Vec<usize> = group.open_dataset("indices")?.read_array_slice(&[&slice])?.to_vec();
indptr.iter_mut().for_each(|x| *x -= lo);
CsrMatrix::try_from_csr_data(
indptr.len() - 1,
Self::get_shape(container)?[1],
indptr,
indices,
data,
).unwrap().select_axis(1, info[1].as_ref())
} else {
SelectInfoElem::from(s.start ..)
Self::read(container)?.select(info)
};
let mut indptr: Vec<usize> = group
.open_dataset("indptr")?
.read_array_slice(&[indptr_slice])?
.to_vec();
let lo = indptr[0];
let slice = SelectInfoElem::from(lo .. indptr[indptr.len() - 1]);
let data: Vec<T> = group.open_dataset("data")?.read_array_slice(&[&slice])?.to_vec();
let indices: Vec<usize> = group.open_dataset("indices")?.read_array_slice(&[&slice])?.to_vec();
indptr.iter_mut().for_each(|x| *x -= lo);
CsrMatrix::try_from_csr_data(
indptr.len() - 1,
Self::get_shape(container)?[1],
indptr,
indices,
data,
).unwrap().select_axis(1, info[1].as_ref())
Ok(data)
} else {
Self::read(container)?.select(info)
};
Ok(data)
bail!("cannot read csr matrix from container with data type {:?}", data_type)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion anndata/src/data/array/sparse/noncanonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ macro_rules! impl_into_dyn_csr {
fn try_from(data: DynCsrNonCanonical) -> Result<Self> {
match data {
DynCsrNonCanonical::$to_type(data) => Ok(data),
_ => bail!("Cannot convert to CsrNonCanonical<$from_type>"),
_ => bail!("Cannot convert {:?} to {} CsrNonCanonical", data.data_type(), stringify!($from_type)),
}
}
}
Expand Down
Loading

0 comments on commit f2e2c5f

Please sign in to comment.