diff --git a/src/lib.rs b/src/lib.rs index 29355b9..2cabd81 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ -use crate::Error::Decrypt; +use std::{error::Error as StdError, fmt}; + use byteorder::{ByteOrder, NetworkEndian, WriteBytesExt}; use pqc_kyber::{ Keypair, KyberError, PublicKey, SecretKey, KYBER_CIPHERTEXTBYTES, KYBER_SECRETKEYBYTES, @@ -52,19 +53,19 @@ pub fn encrypt_into, R: AsRef<[u8]>, V: AsRef<[u8]>, O: AsMut<[u8 let plaintext = plaintext.as_ref(); let plaintext_length = plaintext.len(); let ret = ret.as_mut(); + //, if nonce.len() != 32 { - return Err(Error::Encrypt(format!( - "Nonce must be 32 bytes, got {}", - nonce.len() - ))); + return Err(Error::new( + format!("Nonce must be 32 bytes, got {}", nonce.len()), + ErrorKind::Encrypt, + )); } - if ret.len() < ct_len(plaintext.len()) { - return Err(Error::Encrypt(format!( - "Bad output buffer len {}", - ret.len() - ))); + return Err(Error::new( + format!("Bad output buffer len {}", ret.len()), + ErrorKind::Encrypt, + )); } if plaintext_length != 0 { @@ -106,11 +107,18 @@ pub fn decrypt, R: AsRef<[u8]>>( const CIPHERTEXT_BLOCK_LEN: usize = pqc_kyber::KYBER_CIPHERTEXTBYTES; if ciphertext.len() < CIPHERTEXT_BLOCK_LEN { - return Err(Decrypt("The input ciphertext is too short".to_string())); + return Err(Error::new( + "The input ciphertext is too short".to_string(), + ErrorKind::Decrypt, + )); } - let plaintext_length = plaintext_len(ciphertext) - .ok_or_else(|| Error::Decrypt("Invalid ciphertext input length".to_string()))?; + let plaintext_length = plaintext_len(ciphertext).ok_or_else(|| { + Error::new( + "Invalid ciphertext input length".to_string(), + ErrorKind::Decrypt, + ) + })?; let split_pt = ciphertext.len().saturating_sub(8); let (concatenated_ciphertexts, _) = ciphertext.split_at(split_pt); // pt len < 32: size must be 32 @@ -147,9 +155,33 @@ pub fn kem_keypair() -> Result { } #[derive(Debug, Clone)] -pub enum Error { - Encrypt(String), - Decrypt(String), +pub enum ErrorKind { + Encrypt, + Decrypt, +} + +#[derive(Debug, Clone)] +pub struct Error { + msg: String, + err_kind: ErrorKind, +} + +impl Error { + pub fn new(msg: String, err_kind: ErrorKind) -> Self { + Self { msg, err_kind } + } +} + +impl StdError for Error { + fn description(&self) -> &str { + &self.msg + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}, ErrorKind: {:?}", self.msg, self.err_kind) + } } fn div_ceil(a: f32, b: f32) -> usize {