diff --git a/ferveo/Cargo.toml b/ferveo/Cargo.toml index 2d30b924..1d957ccb 100644 --- a/ferveo/Cargo.toml +++ b/ferveo/Cargo.toml @@ -34,13 +34,11 @@ serde_with = "2.2.0" subproductdomain = { package = "subproductdomain-pre-release", path = "../subproductdomain", version = "0.1.0-alpha.0" } thiserror = "1.0" zeroize = { version = "1.6.0", default-features = false, features = ["derive"] } - -# Shared by Python and WASM bindings -derive_more = { version = "0.99", default-features = false, features = ["from", "as_ref", "into"], optional = true } +generic-array = "0.14.7" +derive_more = { version = "0.99", default-features = false, features = ["from", "as_ref", "into"] } # Python bindings pyo3 = { version = "0.18.2", features = ["macros", "multiple-pymethods"], optional = true } -generic-array = { version = "0.14.7", optional = true } # WASM bindings console_error_panic_hook = { version = "0.1.7", optional = true } @@ -60,8 +58,8 @@ serde = { version = "1.0", features = ["derive"] } wasm-bindgen = { version = "0.2.86", features = ["serde-serialize"] } [features] -bindings-python = ["pyo3", "derive_more", "generic-array"] -bindings-wasm = ["console_error_panic_hook", "derive_more", "getrandom", "js-sys", "wasm-bindgen", "wasm-bindgen-derive"] +bindings-python = ["pyo3"] +bindings-wasm = ["console_error_panic_hook", "getrandom", "js-sys", "wasm-bindgen", "wasm-bindgen-derive"] [[example]] name = "bench_primitives_size" diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs index fe8a33e9..04995dfe 100644 --- a/ferveo/src/api.rs +++ b/ferveo/src/api.rs @@ -5,6 +5,7 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::UniformRand; use bincode; use ferveo_common::serialization; +use generic_array::{typenum::U48, GenericArray}; use group_threshold_cryptography as tpke; use rand::RngCore; use serde::{Deserialize, Serialize}; @@ -75,12 +76,16 @@ pub struct DkgPublicKey( ); impl DkgPublicKey { - pub fn to_bytes(&self) -> Result> { - to_bytes(&self.0) + pub fn to_bytes(&self) -> Result> { + let as_bytes = to_bytes(&self.0)?; + Ok(GenericArray::::from_slice(&as_bytes).to_owned()) } pub fn from_bytes(bytes: &[u8]) -> Result { - from_bytes(bytes).map(DkgPublicKey) + let bytes = + GenericArray::::from_exact_iter(bytes.iter().cloned()) + .ok_or(Error::InvalidByteLength(48, bytes.len()))?; + from_bytes(&bytes).map(DkgPublicKey) } pub fn serialized_size() -> usize { diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs index c4880948..2c912e0c 100644 --- a/ferveo/src/bindings_python.rs +++ b/ferveo/src/bindings_python.rs @@ -1,7 +1,6 @@ use std::fmt::{Debug, Formatter}; use ferveo_common::serialization::{FromBytes, ToBytes}; -use generic_array::{typenum::U48, GenericArray}; use pyo3::{ basic::CompareOp, create_exception, @@ -89,6 +88,12 @@ impl From for PyErr { Error::ArkSerializeError(err) => { SerializationError::new_err(err.to_string()) } + Error::InvalidByteLength(expected, actual) => { + InvalidByteLength::new_err(format!( + "expected: {}, actual: {}", + expected, actual + )) + } }, _ => default(), } @@ -122,6 +127,7 @@ create_exception!(exceptions, InvalidTranscriptAggregate, PyValueError); create_exception!(exceptions, ValidatorsNotSorted, PyValueError); create_exception!(exceptions, ValidatorPublicKeyMismatch, PyValueError); create_exception!(exceptions, SerializationError, PyValueError); +create_exception!(exceptions, InvalidByteLength, PyValueError); fn from_py_bytes(bytes: &[u8]) -> PyResult { T::from_bytes(bytes) @@ -343,15 +349,8 @@ pub struct DkgPublicKey(api::DkgPublicKey); impl DkgPublicKey { #[staticmethod] pub fn from_bytes(bytes: &[u8]) -> PyResult { - let bytes = - GenericArray::::from_exact_iter(bytes.iter().cloned()) - .ok_or_else(|| { - FerveoPythonError::Other( - "Invalid length of bytes for DkgPublicKey".to_string(), - ) - })?; Ok(Self( - api::DkgPublicKey::from_bytes(bytes.as_slice()) + api::DkgPublicKey::from_bytes(bytes) .map_err(FerveoPythonError::FerveoError)?, )) } @@ -359,8 +358,7 @@ impl DkgPublicKey { fn __bytes__(&self) -> PyResult { let bytes = self.0.to_bytes().map_err(FerveoPythonError::FerveoError)?; - let bytes = GenericArray::::from_slice(bytes.as_slice()); - as_py_bytes(bytes) + as_py_bytes(&bytes) } #[staticmethod] diff --git a/ferveo/src/bindings_wasm.rs b/ferveo/src/bindings_wasm.rs index 3020a5c7..e4b976a3 100644 --- a/ferveo/src/bindings_wasm.rs +++ b/ferveo/src/bindings_wasm.rs @@ -215,14 +215,36 @@ pub fn decrypt_with_shared_secret( #[wasm_bindgen] pub struct DkgPublicKey(api::DkgPublicKey); -generate_common_methods!(DkgPublicKey); - #[wasm_bindgen] impl DkgPublicKey { + #[wasm_bindgen(js_name = "fromBytes")] + pub fn from_bytes(bytes: &[u8]) -> JsResult { + api::DkgPublicKey::from_bytes(bytes) + .map_err(map_js_err) + .map(Self) + } + + #[wasm_bindgen(js_name = "toBytes")] + pub fn to_bytes(&self) -> JsResult> { + let bytes = self.0.to_bytes().map_err(map_js_err)?; + let bytes: Box<[u8]> = bytes.as_slice().into(); + Ok(bytes) + } + #[wasm_bindgen] pub fn random() -> DkgPublicKey { Self(api::DkgPublicKey::random()) } + + #[wasm_bindgen(js_name = "serializedSize")] + pub fn serialized_size() -> usize { + api::DkgPublicKey::serialized_size() + } + + #[wasm_bindgen] + pub fn equals(&self, other: &DkgPublicKey) -> bool { + self.0 == other.0 + } } #[wasm_bindgen] diff --git a/ferveo/src/lib.rs b/ferveo/src/lib.rs index e3d62105..7e1b3657 100644 --- a/ferveo/src/lib.rs +++ b/ferveo/src/lib.rs @@ -98,6 +98,9 @@ pub enum Error { #[error(transparent)] ArkSerializeError(#[from] ark_serialize::SerializationError), + + #[error("Invalid byte length. Expected {0}, got {1}")] + InvalidByteLength(usize, usize), } pub type Result = std::result::Result;