From fc241fe165b4eef9954687a8f8065166f86d962d Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Wed, 4 Sep 2024 12:07:13 -0400 Subject: [PATCH] Refactor `PineconeError`: use `thiserror`/`anyhow`, ensure we support `Send + Sync` (#56) ## Problem Currently, `PineconeError` contains a lot of boilerplate for implementing `std::error::Error`. @haruska suggested we could maybe simplify some of the boilerplate using the `thiserror` and `anyhow` crates. Additionally, there was an enhancement filed (https://github.com/pinecone-io/pinecone-rust-client/issues/54) to implement `Send` + `Sync` for `PineconeError` as currently we're unable to use `PineconeError` in a multithreaded context. ## Solution Refactor PineconeError to use thiserror and anyhow to reduce some of our boilerplate for the custom error enum, make sure we can safely use PineconeError with Send and Sync, add unit test for this ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [X] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [ ] None of the above: (explain here) ## Test Plan New unit test added to verify that `PineoneError` has properly implemented the `Send` and `Sync` traits. `cargo test` -> validate CI passes as expected --- - To see the specific tasks where the Asana app for GitHub is being used, see below: - https://app.asana.com/0/0/1208161607942725 - https://app.asana.com/0/0/1208161607942720 --- Cargo.lock | 2 + Cargo.toml | 2 + src/pinecone/data.rs | 12 +-- src/pinecone/mod.rs | 2 +- src/utils/errors.rs | 177 +++++++++++++------------------------------ 5 files changed, 59 insertions(+), 136 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8518dfc..bf58a45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1477,6 +1477,7 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" name = "pinecone-sdk" version = "0.1.1" dependencies = [ + "anyhow", "httpmock", "once_cell", "prost", @@ -1489,6 +1490,7 @@ dependencies = [ "serial_test", "snafu", "temp-env", + "thiserror", "tokio", "tonic", "tonic-build", diff --git a/Cargo.toml b/Cargo.toml index ee6e01b..a959e72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,8 @@ serde = { version = "^1.0", features = ["derive"] } url = "^2.5" uuid = { version = "^1.8", features = ["serde", "v4"] } reqwest = { version = "^0.12", features = ["json", "multipart"] } +thiserror = "1.0.63" +anyhow = "1.0.86" [dev-dependencies] temp-env = "0.3" diff --git a/src/pinecone/data.rs b/src/pinecone/data.rs index cdc715f..5dab576 100644 --- a/src/pinecone/data.rs +++ b/src/pinecone/data.rs @@ -634,20 +634,14 @@ impl PineconeClient { // connect to server let endpoint = Channel::from_shared(host) - .map_err(|e| PineconeError::ConnectionError { - source: Box::new(e), - })? + .map_err(|e| PineconeError::ConnectionError { source: e.into() })? .tls_config(tls_config) - .map_err(|e| PineconeError::ConnectionError { - source: Box::new(e), - })?; + .map_err(|e| PineconeError::ConnectionError { source: e.into() })?; let channel = endpoint .connect() .await - .map_err(|e| PineconeError::ConnectionError { - source: Box::new(e), - })?; + .map_err(|e| PineconeError::ConnectionError { source: e.into() })?; // add api key in metadata through interceptor let token: TonicMetadataVal<_> = self.api_key.parse().unwrap(); diff --git a/src/pinecone/mod.rs b/src/pinecone/mod.rs index abffbc7..66b61a9 100644 --- a/src/pinecone/mod.rs +++ b/src/pinecone/mod.rs @@ -124,7 +124,7 @@ impl PineconeClientConfig { let client = reqwest::Client::builder() .default_headers(headers) .build() - .map_err(|e| PineconeError::ReqwestError { source: e })?; + .map_err(|e| PineconeError::ReqwestError { source: e.into() })?; let openapi_config = Configuration { base_path: controller_host.to_string(), diff --git a/src/utils/errors.rs b/src/utils/errors.rs index 6a1a3f7..3182742 100644 --- a/src/utils/errors.rs +++ b/src/utils/errors.rs @@ -1,11 +1,13 @@ use crate::openapi::apis::{Error as OpenApiError, ResponseContent}; - +use anyhow::Error as AnyhowError; use reqwest::{self, StatusCode}; +use thiserror::Error; /// PineconeError is the error type for all Pinecone SDK errors. -#[derive(Debug)] +#[derive(Error, Debug)] pub enum PineconeError { /// UnknownResponseError: Unknown response error. + #[error("Unknown response error: status: {status}, message: {message}")] UnknownResponseError { /// status code status: StatusCode, @@ -14,138 +16,161 @@ pub enum PineconeError { }, /// ActionForbiddenError: Action is forbidden. + #[error("Action forbidden error: {source}")] ActionForbiddenError { /// Source error source: WrappedResponseContent, }, /// APIKeyMissingError: API key is not provided as an argument nor in the environment variable `PINECONE_API_KEY`. + #[error("API key missing error: {message}")] APIKeyMissingError { /// Error message. message: String, }, /// InvalidHeadersError: Provided headers are not valid. Expects JSON. + #[error("Invalid headers error: {message}")] InvalidHeadersError { /// Error message. message: String, }, /// TimeoutError: Request timed out. + #[error("Timeout error: {message}")] TimeoutError { /// Error message. message: String, }, /// ConnectionError: Failed to establish a connection. + #[error("Connection error: {source}")] ConnectionError { - /// inner: Error object for connection error. - source: Box, + /// Source of the error. + source: AnyhowError, }, /// ReqwestError: Error caused by Reqwest + #[error("Reqwest error: {source}")] ReqwestError { - /// Source error - source: reqwest::Error, + /// Source of the error. + source: AnyhowError, }, /// SerdeError: Error caused by Serde + #[error("Serde error: {source}")] SerdeError { /// Source of the error. - source: serde_json::Error, + source: AnyhowError, }, /// IoError: Error caused by IO + #[error("IO error: {message}")] IoError { /// Error message. message: String, }, /// BadRequestError: Bad request. The request body included invalid request parameters + #[error("Bad request error: {source}")] BadRequestError { /// Source error source: WrappedResponseContent, }, /// UnauthorizedError: Unauthorized. Possibly caused by invalid API key + #[error("Unauthorized error: {source}")] UnauthorizedError { /// Source error source: WrappedResponseContent, }, /// PodQuotaExceededError: Pod quota exceeded + #[error("Pod quota exceeded error: {source}")] PodQuotaExceededError { /// Source error source: WrappedResponseContent, }, /// CollectionsQuotaExceededError: Collections quota exceeded + #[error("Collections quota exceeded error: {source}")] CollectionsQuotaExceededError { /// Source error source: WrappedResponseContent, }, /// InvalidCloudError: Provided cloud is not valid. + #[error("Invalid cloud error: {source}")] InvalidCloudError { /// Source error source: WrappedResponseContent, }, /// InvalidRegionError: Provided region is not valid. + #[error("Invalid region error: {source}")] InvalidRegionError { /// Source error source: WrappedResponseContent, }, /// InvalidConfigurationError: Provided configuration is not valid. + #[error("Invalid configuration error: {message}")] InvalidConfigurationError { /// Error message. message: String, }, /// CollectionNotFoundError: Collection of given name does not exist + #[error("Collection not found error: {source}")] CollectionNotFoundError { /// Source error source: WrappedResponseContent, }, /// IndexNotFoundError: Index of given name does not exist + #[error("Index not found error: {source}")] IndexNotFoundError { /// Source error source: WrappedResponseContent, }, /// ResourceAlreadyExistsError: Resource of given name already exists + #[error("Resource already exists error: {source}")] ResourceAlreadyExistsError { /// Source error source: WrappedResponseContent, }, /// Unprocessable entity error: The request body could not be deserialized + #[error("Unprocessable entity error: {source}")] UnprocessableEntityError { /// Source error source: WrappedResponseContent, }, /// PendingCollectionError: There is a pending collection created from this index + #[error("Pending collection error: {source}")] PendingCollectionError { /// Source error source: WrappedResponseContent, }, /// InternalServerError: Internal server error + #[error("Internal server error: {source}")] InternalServerError { /// Source error source: WrappedResponseContent, }, /// DataPlaneError: Failed to perform a data plane operation. + #[error("Data plane error: {status}")] DataPlaneError { /// Error status status: tonic::Status, }, /// InferenceError: Failed to perform an inference operation. + #[error("Inference error: {status}")] InferenceError { /// Error status status: tonic::Status, @@ -156,8 +181,12 @@ pub enum PineconeError { impl From> for PineconeError { fn from(error: OpenApiError) -> Self { match error { - OpenApiError::Reqwest(inner) => PineconeError::ReqwestError { source: inner }, - OpenApiError::Serde(inner) => PineconeError::SerdeError { source: inner }, + OpenApiError::Reqwest(inner) => PineconeError::ReqwestError { + source: inner.into(), + }, + OpenApiError::Serde(inner) => PineconeError::SerdeError { + source: inner.into(), + }, OpenApiError::Io(inner) => PineconeError::IoError { message: inner.to_string(), }, @@ -210,123 +239,6 @@ fn parse_forbidden_error(source: WrappedResponseContent, message: String) -> Pin } } -impl std::fmt::Display for PineconeError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - PineconeError::UnknownResponseError { status, message } => { - write!( - f, - "Unknown response error: status: {}, message: {}", - status, message - ) - } - PineconeError::ResourceAlreadyExistsError { source } => { - write!(f, "Resource already exists error: {}", source) - } - PineconeError::UnprocessableEntityError { source } => { - write!(f, "Unprocessable entity error: {}", source) - } - PineconeError::PendingCollectionError { source } => { - write!(f, "Pending collection error: {}", source) - } - PineconeError::InternalServerError { source } => { - write!(f, "Internal server error: {}", source) - } - PineconeError::ReqwestError { source } => { - write!(f, "Reqwest error: {}", source.to_string()) - } - PineconeError::SerdeError { source } => { - write!(f, "Serde error: {}", source.to_string()) - } - PineconeError::IoError { message } => { - write!(f, "IO error: {}", message) - } - PineconeError::BadRequestError { source } => { - write!(f, "Bad request error: {}", source) - } - PineconeError::UnauthorizedError { source } => { - write!(f, "Unauthorized error: status: {}", source) - } - PineconeError::PodQuotaExceededError { source } => { - write!(f, "Pod quota exceeded error: {}", source) - } - PineconeError::CollectionsQuotaExceededError { source } => { - write!(f, "Collections quota exceeded error: {}", source) - } - PineconeError::InvalidCloudError { source } => { - write!(f, "Invalid cloud error: status: {}", source) - } - PineconeError::InvalidRegionError { source } => { - write!(f, "Invalid region error: {}", source) - } - PineconeError::CollectionNotFoundError { source } => { - write!(f, "Collection not found error: {}", source) - } - PineconeError::IndexNotFoundError { source } => { - write!(f, "Index not found error: status: {}", source) - } - PineconeError::APIKeyMissingError { message } => { - write!(f, "API key missing error: {}", message) - } - PineconeError::InvalidHeadersError { message } => { - write!(f, "Invalid headers error: {}", message) - } - PineconeError::TimeoutError { message } => { - write!(f, "Timeout error: {}", message) - } - PineconeError::ConnectionError { source } => { - write!(f, "Connection error: {}", source) - } - PineconeError::DataPlaneError { status } => { - write!(f, "Data plane error: {}", status) - } - PineconeError::InferenceError { status } => { - write!(f, "Inference error: {}", status) - } - PineconeError::ActionForbiddenError { source } => { - write!(f, "Action forbidden error: {}", source) - } - PineconeError::InvalidConfigurationError { message } => { - write!(f, "Invalid configuration error: {}", message) - } - } - } -} - -impl std::error::Error for PineconeError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - PineconeError::UnknownResponseError { - status: _, - message: _, - } => None, - PineconeError::ReqwestError { source } => Some(source), - PineconeError::SerdeError { source } => Some(source), - PineconeError::IoError { message: _ } => None, - PineconeError::BadRequestError { source } => Some(source), - PineconeError::UnauthorizedError { source } => Some(source), - PineconeError::PodQuotaExceededError { source } => Some(source), - PineconeError::CollectionsQuotaExceededError { source } => Some(source), - PineconeError::InvalidCloudError { source } => Some(source), - PineconeError::InvalidRegionError { source } => Some(source), - PineconeError::CollectionNotFoundError { source } => Some(source), - PineconeError::IndexNotFoundError { source } => Some(source), - PineconeError::ResourceAlreadyExistsError { source } => Some(source), - PineconeError::UnprocessableEntityError { source } => Some(source), - PineconeError::PendingCollectionError { source } => Some(source), - PineconeError::InternalServerError { source } => Some(source), - PineconeError::APIKeyMissingError { message: _ } => None, - PineconeError::InvalidHeadersError { message: _ } => None, - PineconeError::TimeoutError { message: _ } => None, - PineconeError::ConnectionError { source } => Some(source.as_ref()), - PineconeError::DataPlaneError { status } => Some(status), - PineconeError::InferenceError { status } => Some(status), - PineconeError::ActionForbiddenError { source } => Some(source), - PineconeError::InvalidConfigurationError { message: _ } => None, - } - } -} - /// WrappedResponseContent is a wrapper around ResponseContent. #[derive(Debug)] pub struct WrappedResponseContent { @@ -356,3 +268,16 @@ impl std::fmt::Display for WrappedResponseContent { write!(f, "status: {} content: {}", self.status, self.content) } } + +#[cfg(test)] +mod tests { + use super::PineconeError; + use tokio; + + fn assert_send_sync() {} + + #[tokio::test] + async fn test_pinecone_error_is_send_sync() { + assert_send_sync::(); + } +}