From bb8e42f6392284f4a7a39d3eec74144a603b481c Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Sun, 15 Oct 2023 11:04:14 +0100 Subject: [PATCH 01/25] Add GetOptions::head (#4931) --- object_store/src/aws/client.rs | 9 ++------ object_store/src/aws/mod.rs | 4 ---- object_store/src/azure/client.rs | 9 ++------ object_store/src/azure/mod.rs | 4 ---- object_store/src/client/get.rs | 24 +++------------------ object_store/src/gcp/mod.rs | 13 ++--------- object_store/src/http/client.rs | 15 +++++-------- object_store/src/http/mod.rs | 4 ---- object_store/src/lib.rs | 12 ++++++++++- object_store/src/local.rs | 37 ++++---------------------------- 10 files changed, 29 insertions(+), 102 deletions(-) diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index e3ac60eca060..ac07f9ab9af3 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -554,15 +554,10 @@ impl GetClient for S3Client { const STORE: &'static str = STORE; /// Make an S3 GET request - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result { + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { let credential = self.get_credential().await?; let url = self.config.path_url(path); - let method = match head { + let method = match options.head { true => Method::HEAD, false => Method::GET, }; diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index 0028be99fa2e..285ee2f59deb 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -307,10 +307,6 @@ impl ObjectStore for AmazonS3 { self.client.get_opts(location, options).await } - async fn head(&self, location: &Path) -> Result { - self.client.head(location).await - } - async fn delete(&self, location: &Path) -> Result<()> { self.client.delete_request(location, &()).await } diff --git a/object_store/src/azure/client.rs b/object_store/src/azure/client.rs index cd1a3a10fcc7..f65388b61a80 100644 --- a/object_store/src/azure/client.rs +++ b/object_store/src/azure/client.rs @@ -264,15 +264,10 @@ impl GetClient for AzureClient { /// Make an Azure GET request /// /// - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result { + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { let credential = self.get_credential().await?; let url = self.config.path_url(path); - let method = match head { + let method = match options.head { true => Method::HEAD, false => Method::GET, }; diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index b210d486d9bf..9017634c42da 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -202,10 +202,6 @@ impl ObjectStore for MicrosoftAzure { self.client.get_opts(location, options).await } - async fn head(&self, location: &Path) -> Result { - self.client.head(location).await - } - async fn delete(&self, location: &Path) -> Result<()> { self.client.delete_request(location, &()).await } diff --git a/object_store/src/client/get.rs b/object_store/src/client/get.rs index 333f6fe58475..7f68b6d1225f 100644 --- a/object_store/src/client/get.rs +++ b/object_store/src/client/get.rs @@ -17,7 +17,7 @@ use crate::client::header::{header_meta, HeaderConfig}; use crate::path::Path; -use crate::{Error, GetOptions, GetResult, ObjectMeta}; +use crate::{Error, GetOptions, GetResult}; use crate::{GetResultPayload, Result}; use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; @@ -34,27 +34,20 @@ pub trait GetClient: Send + Sync + 'static { last_modified_required: true, }; - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result; + async fn get_request(&self, path: &Path, options: GetOptions) -> Result; } /// Extension trait for [`GetClient`] that adds common retrieval functionality #[async_trait] pub trait GetClientExt { async fn get_opts(&self, location: &Path, options: GetOptions) -> Result; - - async fn head(&self, location: &Path) -> Result; } #[async_trait] impl GetClientExt for T { async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { let range = options.range.clone(); - let response = self.get_request(location, options, false).await?; + let response = self.get_request(location, options).await?; let meta = header_meta(location, response.headers(), T::HEADER_CONFIG).map_err(|e| { Error::Generic { @@ -77,15 +70,4 @@ impl GetClientExt for T { meta, }) } - - async fn head(&self, location: &Path) -> Result { - let options = GetOptions::default(); - let response = self.get_request(location, options, true).await?; - header_meta(location, response.headers(), T::HEADER_CONFIG).map_err(|e| { - Error::Generic { - store: T::STORE, - source: Box::new(e), - } - }) - } } diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index a0a60f27a6aa..f80704b91765 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -389,16 +389,11 @@ impl GetClient for GoogleCloudStorageClient { const STORE: &'static str = STORE; /// Perform a get request - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result { + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { let credential = self.get_credential().await?; let url = self.object_url(path); - let method = match head { + let method = match options.head { true => Method::HEAD, false => Method::GET, }; @@ -604,10 +599,6 @@ impl ObjectStore for GoogleCloudStorage { self.client.get_opts(location, options).await } - async fn head(&self, location: &Path) -> Result { - self.client.head(location).await - } - async fn delete(&self, location: &Path) -> Result<()> { self.client.delete_request(location).await } diff --git a/object_store/src/http/client.rs b/object_store/src/http/client.rs index 0bd2e5639cb5..b2a6ac0aa34a 100644 --- a/object_store/src/http/client.rs +++ b/object_store/src/http/client.rs @@ -288,14 +288,9 @@ impl GetClient for Client { last_modified_required: false, }; - async fn get_request( - &self, - location: &Path, - options: GetOptions, - head: bool, - ) -> Result { - let url = self.path_url(location); - let method = match head { + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { + let url = self.path_url(path); + let method = match options.head { true => Method::HEAD, false => Method::GET, }; @@ -311,7 +306,7 @@ impl GetClient for Client { Some(StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED) => { crate::Error::NotFound { source: Box::new(source), - path: location.to_string(), + path: path.to_string(), } } _ => Error::Request { source }.into(), @@ -322,7 +317,7 @@ impl GetClient for Client { if has_range && res.status() != StatusCode::PARTIAL_CONTENT { return Err(crate::Error::NotSupported { source: Box::new(Error::RangeNotSupported { - href: location.to_string(), + href: path.to_string(), }), }); } diff --git a/object_store/src/http/mod.rs b/object_store/src/http/mod.rs index e9ed5902d8f5..6ffb62358941 100644 --- a/object_store/src/http/mod.rs +++ b/object_store/src/http/mod.rs @@ -118,10 +118,6 @@ impl ObjectStore for HttpStore { self.client.get_opts(location, options).await } - async fn head(&self, location: &Path) -> Result { - self.client.head(location).await - } - async fn delete(&self, location: &Path) -> Result<()> { self.client.delete(location).await } diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 68e785b3a31e..ff0a46533dda 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -410,7 +410,13 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { } /// Return the metadata for the specified location - async fn head(&self, location: &Path) -> Result; + async fn head(&self, location: &Path) -> Result { + let options = GetOptions { + head: true, + ..Default::default() + }; + Ok(self.get_opts(location, options).await?.meta) + } /// Delete the object at the specified location. async fn delete(&self, location: &Path) -> Result<()>; @@ -716,6 +722,10 @@ pub struct GetOptions { /// /// pub range: Option>, + /// Request transfer of no content + /// + /// + pub head: bool, } impl GetOptions { diff --git a/object_store/src/local.rs b/object_store/src/local.rs index 69da170b0872..3ed63a410815 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -419,35 +419,6 @@ impl ObjectStore for LocalFileSystem { .await } - async fn head(&self, location: &Path) -> Result { - let path = self.config.path_to_filesystem(location)?; - let location = location.clone(); - - maybe_spawn_blocking(move || { - let metadata = match metadata(&path) { - Err(e) => Err(match e.kind() { - ErrorKind::NotFound => Error::NotFound { - path: path.clone(), - source: e, - }, - _ => Error::Metadata { - source: e.into(), - path: location.to_string(), - }, - }), - Ok(m) => match !m.is_dir() { - true => Ok(m), - false => Err(Error::NotFound { - path, - source: io::Error::new(ErrorKind::NotFound, "is directory"), - }), - }, - }?; - convert_metadata(metadata, location) - }) - .await - } - async fn delete(&self, location: &Path) -> Result<()> { let path = self.config.path_to_filesystem(location)?; maybe_spawn_blocking(move || match std::fs::remove_file(&path) { @@ -1604,15 +1575,15 @@ mod unix_test { let path = root.path().join(filename); unistd::mkfifo(&path, stat::Mode::S_IRWXU).unwrap(); - let location = Path::from(filename); - integration.head(&location).await.unwrap(); - // Need to open read and write side in parallel let spawned = tokio::task::spawn_blocking(|| { - OpenOptions::new().write(true).open(path).unwrap(); + OpenOptions::new().write(true).open(path).unwrap() }); + let location = Path::from(filename); + integration.head(&location).await.unwrap(); integration.get(&location).await.unwrap(); + spawned.await.unwrap(); } } From 57cd0945db863059d30d31a890b692a6844038fd Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Mon, 16 Oct 2023 10:56:10 +0100 Subject: [PATCH 02/25] Allow opting out of request signing (#4927) (#4929) --- object_store/src/aws/client.rs | 24 +++++++++------- object_store/src/aws/credential.rs | 21 ++++++++------ object_store/src/aws/mod.rs | 44 ++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 18 deletions(-) diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index ac07f9ab9af3..8a45a9f3ac47 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -207,6 +207,7 @@ pub struct S3Config { pub retry_config: RetryConfig, pub client_options: ClientOptions, pub sign_payload: bool, + pub skip_signature: bool, pub checksum: Option, pub copy_if_not_exists: Option, } @@ -234,8 +235,11 @@ impl S3Client { &self.config } - async fn get_credential(&self) -> Result> { - self.config.credentials.get_credential().await + async fn get_credential(&self) -> Result>> { + Ok(match self.config.skip_signature { + false => Some(self.config.credentials.get_credential().await?), + true => None, + }) } /// Make an S3 PUT request @@ -271,7 +275,7 @@ impl S3Client { let response = builder .query(query) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -299,7 +303,7 @@ impl S3Client { .request(Method::DELETE, url) .query(query) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -390,7 +394,7 @@ impl S3Client { .header(CONTENT_TYPE, "application/xml") .body(body) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -459,7 +463,7 @@ impl S3Client { builder .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -490,7 +494,7 @@ impl S3Client { .client .request(Method::POST, url) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -535,7 +539,7 @@ impl S3Client { .query(&[("uploadId", upload_id)]) .body(body) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -567,7 +571,7 @@ impl GetClient for S3Client { let response = builder .with_get_options(options) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -621,7 +625,7 @@ impl ListClient for S3Client { .request(Method::GET, &url) .query(&query) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs index e27b71f7c411..e0c5de5fe784 100644 --- a/object_store/src/aws/credential.rs +++ b/object_store/src/aws/credential.rs @@ -291,7 +291,7 @@ pub trait CredentialExt { /// Sign a request fn with_aws_sigv4( self, - credential: &AwsCredential, + credential: Option<&AwsCredential>, region: &str, service: &str, sign_payload: bool, @@ -302,20 +302,25 @@ pub trait CredentialExt { impl CredentialExt for RequestBuilder { fn with_aws_sigv4( self, - credential: &AwsCredential, + credential: Option<&AwsCredential>, region: &str, service: &str, sign_payload: bool, payload_sha256: Option<&[u8]>, ) -> Self { - let (client, request) = self.build_split(); - let mut request = request.expect("request valid"); + match credential { + Some(credential) => { + let (client, request) = self.build_split(); + let mut request = request.expect("request valid"); - AwsAuthorizer::new(credential, service, region) - .with_sign_payload(sign_payload) - .authorize(&mut request, payload_sha256); + AwsAuthorizer::new(credential, service, region) + .with_sign_payload(sign_payload) + .authorize(&mut request, payload_sha256); - Self::from_parts(client, request) + Self::from_parts(client, request) + } + None => self, + } } } diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index 285ee2f59deb..70170a3cf48a 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -448,6 +448,8 @@ pub struct AmazonS3Builder { client_options: ClientOptions, /// Credentials credentials: Option, + /// Skip signing requests + skip_signature: ConfigValue, /// Copy if not exists copy_if_not_exists: Option>, } @@ -586,6 +588,9 @@ pub enum AmazonS3ConfigKey { /// See [`S3CopyIfNotExists`] CopyIfNotExists, + /// Skip signing request + SkipSignature, + /// Client options Client(ClientConfigKey), } @@ -608,6 +613,7 @@ impl AsRef for AmazonS3ConfigKey { Self::ContainerCredentialsRelativeUri => { "aws_container_credentials_relative_uri" } + Self::SkipSignature => "aws_skip_signature", Self::CopyIfNotExists => "copy_if_not_exists", Self::Client(opt) => opt.as_ref(), } @@ -642,6 +648,7 @@ impl FromStr for AmazonS3ConfigKey { "aws_container_credentials_relative_uri" => { Ok(Self::ContainerCredentialsRelativeUri) } + "aws_skip_signature" | "skip_signature" => Ok(Self::SkipSignature), "copy_if_not_exists" => Ok(Self::CopyIfNotExists), // Backwards compatibility "aws_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), @@ -753,6 +760,7 @@ impl AmazonS3Builder { AmazonS3ConfigKey::Client(key) => { self.client_options = self.client_options.with_config(key, value) } + AmazonS3ConfigKey::SkipSignature => self.skip_signature.parse(value), AmazonS3ConfigKey::CopyIfNotExists => { self.copy_if_not_exists = Some(ConfigValue::Deferred(value.into())) } @@ -823,6 +831,7 @@ impl AmazonS3Builder { AmazonS3ConfigKey::ContainerCredentialsRelativeUri => { self.container_credentials_relative_uri.clone() } + AmazonS3ConfigKey::SkipSignature => Some(self.skip_signature.to_string()), AmazonS3ConfigKey::CopyIfNotExists => { self.copy_if_not_exists.as_ref().map(ToString::to_string) } @@ -977,6 +986,14 @@ impl AmazonS3Builder { self } + /// If enabled, [`AmazonS3`] will not fetch credentials and will not sign requests + /// + /// This can be useful when interacting with public S3 buckets that deny authorized requests + pub fn with_skip_signature(mut self, skip_signature: bool) -> Self { + self.skip_signature = skip_signature.into(); + self + } + /// Sets the [checksum algorithm] which has to be used for object integrity check during upload. /// /// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html @@ -1146,6 +1163,7 @@ impl AmazonS3Builder { retry_config: self.retry_config, client_options: self.client_options, sign_payload: !self.unsigned_payload.get()?, + skip_signature: self.skip_signature.get()?, checksum, copy_if_not_exists, }; @@ -1505,4 +1523,30 @@ mod s3_resolve_bucket_region_tests { assert!(result.is_err()); } + + #[tokio::test] + #[ignore = "Tests shouldn't call use remote services by default"] + async fn test_disable_creds() { + // https://registry.opendata.aws/daylight-osm/ + let v1 = AmazonS3Builder::new() + .with_bucket_name("daylight-map-distribution") + .with_region("us-west-1") + .with_access_key_id("local") + .with_secret_access_key("development") + .build() + .unwrap(); + + let prefix = Path::from("release"); + + v1.list_with_delimiter(Some(&prefix)).await.unwrap_err(); + + let v2 = AmazonS3Builder::new() + .with_bucket_name("daylight-map-distribution") + .with_region("us-west-1") + .with_skip_signature(true) + .build() + .unwrap(); + + v2.list_with_delimiter(Some(&prefix)).await.unwrap(); + } } From 31bc84c91e7d6c509443f6e73bda0df32a0a5cba Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Mon, 16 Oct 2023 10:56:25 +0100 Subject: [PATCH 03/25] Default connection and request timeouts of 5 seconds (#4928) * Default connection and request timeouts of 5 seconds * Clippy * Allow disabling timeouts --- object_store/src/aws/mod.rs | 3 +- object_store/src/azure/mod.rs | 2 +- object_store/src/client/mod.rs | 66 ++++++++++++++++++++++++++++++++-- object_store/src/gcp/mod.rs | 2 +- 4 files changed, 67 insertions(+), 6 deletions(-) diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index 70170a3cf48a..3ddce08002c4 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -1130,8 +1130,7 @@ impl AmazonS3Builder { Arc::new(TokenCredentialProvider::new( token, - // The instance metadata endpoint is access over HTTP - self.client_options.clone().with_allow_http(true).client()?, + self.client_options.metadata_client()?, self.retry_config.clone(), )) as _ }; diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 9017634c42da..190b73bf9490 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -1070,7 +1070,7 @@ impl MicrosoftAzureBuilder { ); Arc::new(TokenCredentialProvider::new( msi_credential, - self.client_options.clone().with_allow_http(true).client()?, + self.client_options.metadata_client()?, self.retry_config.clone(), )) as _ }; diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs index ee9d62a44f0c..137da2b37594 100644 --- a/object_store/src/client/mod.rs +++ b/object_store/src/client/mod.rs @@ -166,7 +166,7 @@ impl FromStr for ClientConfigKey { } /// HTTP client configuration for remote object stores -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct ClientOptions { user_agent: Option>, content_type_map: HashMap, @@ -188,6 +188,35 @@ pub struct ClientOptions { http2_only: ConfigValue, } +impl Default for ClientOptions { + fn default() -> Self { + // Defaults based on + // + // + // Which recommend a connection timeout of 3.1s and a request timeout of 2s + Self { + user_agent: None, + content_type_map: Default::default(), + default_content_type: None, + default_headers: None, + proxy_url: None, + proxy_ca_certificate: None, + proxy_excludes: None, + allow_http: Default::default(), + allow_insecure: Default::default(), + timeout: Some(Duration::from_secs(5).into()), + connect_timeout: Some(Duration::from_secs(5).into()), + pool_idle_timeout: None, + pool_max_idle_per_host: None, + http2_keep_alive_interval: None, + http2_keep_alive_timeout: None, + http2_keep_alive_while_idle: Default::default(), + http1_only: Default::default(), + http2_only: Default::default(), + } + } +} + impl ClientOptions { /// Create a new [`ClientOptions`] with default values pub fn new() -> Self { @@ -367,17 +396,37 @@ impl ClientOptions { /// /// The timeout is applied from when the request starts connecting until the /// response body has finished + /// + /// Default is 5 seconds pub fn with_timeout(mut self, timeout: Duration) -> Self { self.timeout = Some(ConfigValue::Parsed(timeout)); self } + /// Disables the request timeout + /// + /// See [`Self::with_timeout`] + pub fn with_timeout_disabled(mut self) -> Self { + self.timeout = None; + self + } + /// Set a timeout for only the connect phase of a Client + /// + /// Default is 5 seconds pub fn with_connect_timeout(mut self, timeout: Duration) -> Self { self.connect_timeout = Some(ConfigValue::Parsed(timeout)); self } + /// Disables the connection timeout + /// + /// See [`Self::with_connect_timeout`] + pub fn with_connect_timeout_disabled(mut self) -> Self { + self.timeout = None; + self + } + /// Set the pool max idle timeout /// /// This is the length of time an idle connection will be kept alive @@ -444,7 +493,20 @@ impl ClientOptions { } } - pub(crate) fn client(&self) -> super::Result { + /// Create a [`Client`] with overrides optimised for metadata endpoint access + /// + /// In particular: + /// * Allows HTTP as metadata endpoints do not use TLS + /// * Configures a low connection timeout to provide quick feedback if not present + #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] + pub(crate) fn metadata_client(&self) -> Result { + self.clone() + .with_allow_http(true) + .with_connect_timeout(Duration::from_secs(1)) + .client() + } + + pub(crate) fn client(&self) -> Result { let mut builder = ClientBuilder::new(); match &self.user_agent { diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index f80704b91765..f8a16310dd1e 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -1071,7 +1071,7 @@ impl GoogleCloudStorageBuilder { } else { Arc::new(TokenCredentialProvider::new( InstanceCredentialProvider::new(audience), - self.client_options.clone().with_allow_http(true).client()?, + self.client_options.metadata_client()?, self.retry_config.clone(), )) as _ }; From 4a23ab93336fbdbc96b9e9f29fe46c44e40b57d6 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Mon, 16 Oct 2023 15:16:45 +0200 Subject: [PATCH 04/25] Update pyo3 requirement from 0.19 to 0.20 (#4941) Updates the requirements on [pyo3](https://github.com/pyo3/pyo3) to permit the latest version. - [Release notes](https://github.com/pyo3/pyo3/releases) - [Changelog](https://github.com/PyO3/pyo3/blob/main/CHANGELOG.md) - [Commits](https://github.com/pyo3/pyo3/compare/v0.19.0...v0.20.0) --- updated-dependencies: - dependency-name: pyo3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- arrow-pyarrow-integration-testing/Cargo.toml | 2 +- arrow-pyarrow-integration-testing/pyproject.toml | 2 +- arrow/Cargo.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/arrow-pyarrow-integration-testing/Cargo.toml b/arrow-pyarrow-integration-testing/Cargo.toml index 50987b03ca9e..8c60c086c29a 100644 --- a/arrow-pyarrow-integration-testing/Cargo.toml +++ b/arrow-pyarrow-integration-testing/Cargo.toml @@ -34,4 +34,4 @@ crate-type = ["cdylib"] [dependencies] arrow = { path = "../arrow", features = ["pyarrow"] } -pyo3 = { version = "0.19", features = ["extension-module"] } +pyo3 = { version = "0.20", features = ["extension-module"] } diff --git a/arrow-pyarrow-integration-testing/pyproject.toml b/arrow-pyarrow-integration-testing/pyproject.toml index d75f8de1ac4c..d85db24c2e18 100644 --- a/arrow-pyarrow-integration-testing/pyproject.toml +++ b/arrow-pyarrow-integration-testing/pyproject.toml @@ -16,7 +16,7 @@ # under the License. [build-system] -requires = ["maturin"] +requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" dependencies = ["pyarrow>=1"] diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 8abb4f73a384..37f03a05b3fa 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -60,7 +60,7 @@ arrow-select = { workspace = true } arrow-string = { workspace = true } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } -pyo3 = { version = "0.19", default-features = false, optional = true } +pyo3 = { version = "0.20", default-features = false, optional = true } [package.metadata.docs.rs] features = ["prettyprint", "ipc_compression", "ffi", "pyarrow"] From 69c937565f7404dc1576bc22d153ce79bf107cfb Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Mon, 16 Oct 2023 14:18:53 +0100 Subject: [PATCH 05/25] Support service_account in ApplicationDefaultCredentials and Use SelfSignedJwt (#4926) * Support service_account in ApplicationDefaultCredentials * Use SelfSignedJwt for Service Accounts * Update CI * Apply suggestions from code review Co-authored-by: Marco Neumann --------- Co-authored-by: Marco Neumann --- .github/workflows/object_store.yml | 2 +- object_store/src/gcp/credential.rs | 219 +++++++++++------------------ object_store/src/gcp/mod.rs | 45 +++--- 3 files changed, 108 insertions(+), 158 deletions(-) diff --git a/.github/workflows/object_store.yml b/.github/workflows/object_store.yml index c28f8037a307..1b991e33c097 100644 --- a/.github/workflows/object_store.yml +++ b/.github/workflows/object_store.yml @@ -126,7 +126,7 @@ jobs: # Give the container a moment to start up prior to configuring it sleep 1 curl -v -X POST --data-binary '{"name":"test-bucket"}' -H "Content-Type: application/json" "http://localhost:4443/storage/v1/b" - echo '{"gcs_base_url": "http://localhost:4443", "disable_oauth": true, "client_email": "", "private_key": ""}' > "$GOOGLE_SERVICE_ACCOUNT" + echo '{"gcs_base_url": "http://localhost:4443", "disable_oauth": true, "client_email": "", "private_key": "", "private_key_id": ""}' > "$GOOGLE_SERVICE_ACCOUNT" - name: Setup WebDav run: docker run -d -p 8080:80 rclone/rclone serve webdav /data --addr :80 diff --git a/object_store/src/gcp/credential.rs b/object_store/src/gcp/credential.rs index ad21c33b8b9d..87f8e244f21c 100644 --- a/object_store/src/gcp/credential.rs +++ b/object_store/src/gcp/credential.rs @@ -17,10 +17,8 @@ use crate::client::retry::RetryExt; use crate::client::token::TemporaryToken; -use crate::client::{TokenCredentialProvider, TokenProvider}; -use crate::gcp::credential::Error::UnsupportedCredentialsType; -use crate::gcp::{GcpCredentialProvider, STORE}; -use crate::ClientOptions; +use crate::client::TokenProvider; +use crate::gcp::STORE; use crate::RetryConfig; use async_trait::async_trait; use base64::prelude::BASE64_URL_SAFE_NO_PAD; @@ -28,6 +26,7 @@ use base64::Engine; use futures::TryFutureExt; use reqwest::{Client, Method}; use ring::signature::RsaKeyPair; +use serde::Deserialize; use snafu::{ResultExt, Snafu}; use std::env; use std::fs::File; @@ -37,6 +36,10 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tracing::info; +pub const DEFAULT_SCOPE: &str = "https://www.googleapis.com/auth/devstorage.full_control"; + +pub const DEFAULT_GCS_BASE_URL: &str = "https://storage.googleapis.com"; + #[derive(Debug, Snafu)] pub enum Error { #[snafu(display("Unable to open service account file from {}: {}", path.display(), source))] @@ -68,9 +71,6 @@ pub enum Error { #[snafu(display("Error getting token response body: {}", source))] TokenResponseBody { source: reqwest::Error }, - - #[snafu(display("Unsupported ApplicationCredentials type: {}", type_))] - UnsupportedCredentialsType { type_: String }, } impl From for crate::Error { @@ -92,48 +92,48 @@ pub struct GcpCredential { pub type Result = std::result::Result; #[derive(Debug, Default, serde::Serialize)] -pub struct JwtHeader { +pub struct JwtHeader<'a> { /// The type of JWS: it can only be "JWT" here /// /// Defined in [RFC7515#4.1.9](https://tools.ietf.org/html/rfc7515#section-4.1.9). #[serde(skip_serializing_if = "Option::is_none")] - pub typ: Option, + pub typ: Option<&'a str>, /// The algorithm used /// /// Defined in [RFC7515#4.1.1](https://tools.ietf.org/html/rfc7515#section-4.1.1). - pub alg: String, + pub alg: &'a str, /// Content type /// /// Defined in [RFC7519#5.2](https://tools.ietf.org/html/rfc7519#section-5.2). #[serde(skip_serializing_if = "Option::is_none")] - pub cty: Option, + pub cty: Option<&'a str>, /// JSON Key URL /// /// Defined in [RFC7515#4.1.2](https://tools.ietf.org/html/rfc7515#section-4.1.2). #[serde(skip_serializing_if = "Option::is_none")] - pub jku: Option, + pub jku: Option<&'a str>, /// Key ID /// /// Defined in [RFC7515#4.1.4](https://tools.ietf.org/html/rfc7515#section-4.1.4). #[serde(skip_serializing_if = "Option::is_none")] - pub kid: Option, + pub kid: Option<&'a str>, /// X.509 URL /// /// Defined in [RFC7515#4.1.5](https://tools.ietf.org/html/rfc7515#section-4.1.5). #[serde(skip_serializing_if = "Option::is_none")] - pub x5u: Option, + pub x5u: Option<&'a str>, /// X.509 certificate thumbprint /// /// Defined in [RFC7515#4.1.7](https://tools.ietf.org/html/rfc7515#section-4.1.7). #[serde(skip_serializing_if = "Option::is_none")] - pub x5t: Option, + pub x5t: Option<&'a str>, } #[derive(serde::Serialize)] struct TokenClaims<'a> { iss: &'a str, + sub: &'a str, scope: &'a str, - aud: &'a str, exp: u64, iat: u64, } @@ -144,28 +144,32 @@ struct TokenResponse { expires_in: u64, } -/// Encapsulates the logic to perform an OAuth token challenge +/// Self-signed JWT (JSON Web Token). +/// +/// # References +/// - #[derive(Debug)] -pub struct OAuthProvider { +pub struct SelfSignedJwt { issuer: String, scope: String, - audience: String, key_pair: RsaKeyPair, jwt_header: String, random: ring::rand::SystemRandom, } -impl OAuthProvider { - /// Create a new [`OAuthProvider`] +impl SelfSignedJwt { + /// Create a new [`SelfSignedJwt`] pub fn new( + key_id: String, issuer: String, private_key_pem: String, scope: String, - audience: String, ) -> Result { let key_pair = decode_first_rsa_key(private_key_pem)?; let jwt_header = b64_encode_obj(&JwtHeader { - alg: "RS256".to_string(), + alg: "RS256", + typ: Some("JWT"), + kid: Some(&key_id), ..Default::default() })?; @@ -173,7 +177,6 @@ impl OAuthProvider { issuer, key_pair, scope, - audience, jwt_header, random: ring::rand::SystemRandom::new(), }) @@ -181,24 +184,24 @@ impl OAuthProvider { } #[async_trait] -impl TokenProvider for OAuthProvider { +impl TokenProvider for SelfSignedJwt { type Credential = GcpCredential; /// Fetch a fresh token async fn fetch_token( &self, - client: &Client, - retry: &RetryConfig, + _client: &Client, + _retry: &RetryConfig, ) -> crate::Result>> { let now = seconds_since_epoch(); let exp = now + 3600; let claims = TokenClaims { iss: &self.issuer, + sub: &self.issuer, scope: &self.scope, - aud: &self.audience, - exp, iat: now, + exp, }; let claim_str = b64_encode_obj(&claims)?; @@ -214,28 +217,11 @@ impl TokenProvider for OAuthProvider { .context(SignSnafu)?; let signature = BASE64_URL_SAFE_NO_PAD.encode(sig_bytes); - let jwt = [message, signature].join("."); - - let body = [ - ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), - ("assertion", &jwt), - ]; - - let response: TokenResponse = client - .request(Method::POST, &self.audience) - .form(&body) - .send_retry(retry) - .await - .context(TokenRequestSnafu)? - .json() - .await - .context(TokenResponseBodySnafu)?; + let bearer = [message, signature].join("."); Ok(TemporaryToken { - token: Arc::new(GcpCredential { - bearer: response.access_token, - }), - expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), + token: Arc::new(GcpCredential { bearer }), + expiry: Some(Instant::now() + Duration::from_secs(3600)), }) } } @@ -259,29 +245,24 @@ pub struct ServiceAccountCredentials { /// The private key in RSA format. pub private_key: String, + /// The private key ID + pub private_key_id: String, + /// The email address associated with the service account. pub client_email: String, /// Base URL for GCS - #[serde(default = "default_gcs_base_url")] - pub gcs_base_url: String, + #[serde(default)] + pub gcs_base_url: Option, /// Disable oauth and use empty tokens. - #[serde(default = "default_disable_oauth")] + #[serde(default)] pub disable_oauth: bool, } -pub fn default_gcs_base_url() -> String { - "https://storage.googleapis.com".to_owned() -} - -pub fn default_disable_oauth() -> bool { - false -} - impl ServiceAccountCredentials { /// Create a new [`ServiceAccountCredentials`] from a file. - pub fn from_file>(path: P) -> Result { + pub fn from_file>(path: P) -> Result { read_credentials_file(path) } @@ -290,17 +271,20 @@ impl ServiceAccountCredentials { serde_json::from_str(key).context(DecodeCredentialsSnafu) } - /// Create an [`OAuthProvider`] from this credentials struct. - pub fn oauth_provider( - self, - scope: &str, - audience: &str, - ) -> crate::Result { - Ok(OAuthProvider::new( + /// Create a [`SelfSignedJwt`] from this credentials struct. + /// + /// We use a scope of [`DEFAULT_SCOPE`] as opposed to an audience + /// as GCS appears to not support audience + /// + /// # References + /// - + /// - + pub fn token_provider(self) -> crate::Result { + Ok(SelfSignedJwt::new( + self.private_key_id, self.client_email, self.private_key, - scope.to_string(), - audience.to_string(), + DEFAULT_SCOPE.to_string(), )?) } } @@ -337,25 +321,13 @@ fn b64_encode_obj(obj: &T) -> Result { /// /// #[derive(Debug, Default)] -pub struct InstanceCredentialProvider { - audience: String, -} - -impl InstanceCredentialProvider { - /// Create a new [`InstanceCredentialProvider`], we need to control the client in order to enable http access so save the options. - pub fn new>(audience: T) -> Self { - Self { - audience: audience.into(), - } - } -} +pub struct InstanceCredentialProvider {} /// Make a request to the metadata server to fetch a token, using a a given hostname. async fn make_metadata_request( client: &Client, hostname: &str, retry: &RetryConfig, - audience: &str, ) -> crate::Result { let url = format!( "http://{hostname}/computeMetadata/v1/instance/service-accounts/default/token" @@ -363,7 +335,7 @@ async fn make_metadata_request( let response: TokenResponse = client .request(Method::GET, url) .header("Metadata-Flavor", "Google") - .query(&[("audience", audience)]) + .query(&[("audience", "https://www.googleapis.com/oauth2/v4/token")]) .send_retry(retry) .await .context(TokenRequestSnafu)? @@ -388,12 +360,9 @@ impl TokenProvider for InstanceCredentialProvider { const METADATA_HOST: &str = "metadata"; info!("fetching token from metadata server"); - let response = - make_metadata_request(client, METADATA_HOST, retry, &self.audience) - .or_else(|_| { - make_metadata_request(client, METADATA_IP, retry, &self.audience) - }) - .await?; + let response = make_metadata_request(client, METADATA_HOST, retry) + .or_else(|_| make_metadata_request(client, METADATA_IP, retry)) + .await?; let token = TemporaryToken { token: Arc::new(GcpCredential { bearer: response.access_token, @@ -404,62 +373,36 @@ impl TokenProvider for InstanceCredentialProvider { } } -/// ApplicationDefaultCredentials -/// -pub fn application_default_credentials( - path: Option<&str>, - client: &ClientOptions, - retry: &RetryConfig, -) -> crate::Result> { - let file = match ApplicationDefaultCredentialsFile::read(path)? { - Some(x) => x, - None => return Ok(None), - }; - - match file.type_.as_str() { - // - "authorized_user" => { - let token = AuthorizedUserCredentials { - client_id: file.client_id, - client_secret: file.client_secret, - refresh_token: file.refresh_token, - }; - - Ok(Some(Arc::new(TokenCredentialProvider::new( - token, - client.client()?, - retry.clone(), - )))) - } - type_ => Err(UnsupportedCredentialsType { - type_: type_.to_string(), - } - .into()), - } -} - /// A deserialized `application_default_credentials.json`-file. -/// +/// +/// # References +/// - +/// - #[derive(serde::Deserialize)] -struct ApplicationDefaultCredentialsFile { - #[serde(default)] - client_id: String, - #[serde(default)] - client_secret: String, - #[serde(default)] - refresh_token: String, - #[serde(rename = "type")] - type_: String, +#[serde(tag = "type")] +pub enum ApplicationDefaultCredentials { + /// Service Account. + /// + /// # References + /// - + #[serde(rename = "service_account")] + ServiceAccount(ServiceAccountCredentials), + /// Authorized user via "gcloud CLI Integration". + /// + /// # References + /// - + #[serde(rename = "authorized_user")] + AuthorizedUser(AuthorizedUserCredentials), } -impl ApplicationDefaultCredentialsFile { +impl ApplicationDefaultCredentials { const CREDENTIALS_PATH: &'static str = ".config/gcloud/application_default_credentials.json"; // Create a new application default credential in the following situations: // 1. a file is passed in and the type matches. // 2. without argument if the well-known configuration file is present. - fn read(path: Option<&str>) -> Result, Error> { + pub fn read(path: Option<&str>) -> Result, Error> { if let Some(path) = path { return read_credentials_file::(path).map(Some); } @@ -478,8 +421,8 @@ impl ApplicationDefaultCredentialsFile { const DEFAULT_TOKEN_GCP_URI: &str = "https://accounts.google.com/o/oauth2/token"; /// -#[derive(Debug)] -struct AuthorizedUserCredentials { +#[derive(Debug, Deserialize)] +pub struct AuthorizedUserCredentials { client_id: String, client_secret: String, refresh_token: String, diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index f8a16310dd1e..a75527fe7b9f 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -57,10 +57,7 @@ use crate::{ ObjectStore, Result, RetryConfig, }; -use credential::{ - application_default_credentials, default_gcs_base_url, InstanceCredentialProvider, - ServiceAccountCredentials, -}; +use credential::{InstanceCredentialProvider, ServiceAccountCredentials}; mod credential; @@ -68,6 +65,7 @@ const STORE: &str = "GCS"; /// [`CredentialProvider`] for [`GoogleCloudStorage`] pub type GcpCredentialProvider = Arc>; +use crate::gcp::credential::{ApplicationDefaultCredentials, DEFAULT_GCS_BASE_URL}; pub use credential::GcpCredential; #[derive(Debug, Snafu)] @@ -1034,10 +1032,8 @@ impl GoogleCloudStorageBuilder { }; // Then try to initialize from the application credentials file, or the environment. - let application_default_credentials = application_default_credentials( + let application_default_credentials = ApplicationDefaultCredentials::read( self.application_credentials_path.as_deref(), - &self.client_options, - &self.retry_config, )?; let disable_oauth = service_account_credentials @@ -1045,14 +1041,10 @@ impl GoogleCloudStorageBuilder { .map(|c| c.disable_oauth) .unwrap_or(false); - let gcs_base_url = service_account_credentials + let gcs_base_url: String = service_account_credentials .as_ref() - .map(|c| c.gcs_base_url.clone()) - .unwrap_or_else(default_gcs_base_url); - - // TODO: https://cloud.google.com/storage/docs/authentication#oauth-scopes - let scope = "https://www.googleapis.com/auth/devstorage.full_control"; - let audience = "https://www.googleapis.com/oauth2/v4/token"; + .and_then(|c| c.gcs_base_url.clone()) + .unwrap_or_else(|| DEFAULT_GCS_BASE_URL.to_string()); let credentials = if let Some(credentials) = self.credentials { credentials @@ -1062,15 +1054,30 @@ impl GoogleCloudStorageBuilder { })) as _ } else if let Some(credentials) = service_account_credentials { Arc::new(TokenCredentialProvider::new( - credentials.oauth_provider(scope, audience)?, + credentials.token_provider()?, self.client_options.client()?, self.retry_config.clone(), )) as _ } else if let Some(credentials) = application_default_credentials { - credentials + match credentials { + ApplicationDefaultCredentials::AuthorizedUser(token) => { + Arc::new(TokenCredentialProvider::new( + token, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ + } + ApplicationDefaultCredentials::ServiceAccount(token) => { + Arc::new(TokenCredentialProvider::new( + token.token_provider()?, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ + } + } } else { Arc::new(TokenCredentialProvider::new( - InstanceCredentialProvider::new(audience), + InstanceCredentialProvider::default(), self.client_options.metadata_client()?, self.retry_config.clone(), )) as _ @@ -1105,7 +1112,7 @@ mod test { use super::*; - const FAKE_KEY: &str = r#"{"private_key": "private_key", "client_email":"client_email", "disable_oauth":true}"#; + const FAKE_KEY: &str = r#"{"private_key": "private_key", "private_key_id": "private_key_id", "client_email":"client_email", "disable_oauth":true}"#; const NON_EXISTENT_NAME: &str = "nonexistentname"; #[tokio::test] @@ -1117,7 +1124,7 @@ mod test { list_uses_directories_correctly(&integration).await; list_with_delimiter(&integration).await; rename_and_copy(&integration).await; - if integration.client.base_url == default_gcs_base_url() { + if integration.client.base_url == DEFAULT_GCS_BASE_URL { // Fake GCS server doesn't currently honor ifGenerationMatch // https://github.com/fsouza/fake-gcs-server/issues/994 copy_if_not_exists(&integration).await; From ce2a9580556c33261eba39a96d597db9600cc682 Mon Sep 17 00:00:00 2001 From: Haixuan Xavier Tao Date: Mon, 16 Oct 2023 23:04:55 +0800 Subject: [PATCH 06/25] Add `FileWriter` schema getter (#4940) * Add `FileWriter` schema getter * Make schema getter consistent with the parquet implementation --- arrow-ipc/src/writer.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 0e01e51231d6..567fa2e94171 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -23,6 +23,7 @@ use std::cmp::min; use std::collections::HashMap; use std::io::{BufWriter, Write}; +use std::sync::Arc; use flatbuffers::FlatBufferBuilder; @@ -696,7 +697,7 @@ pub struct FileWriter { /// IPC write options write_options: IpcWriteOptions, /// A reference to the schema, used in validating record batches - schema: Schema, + schema: SchemaRef, /// The number of bytes between each block of bytes, as an offset for random access block_offsets: usize, /// Dictionary blocks that will be written as part of the IPC footer @@ -739,7 +740,7 @@ impl FileWriter { Ok(Self { writer, write_options, - schema: schema.clone(), + schema: Arc::new(schema.clone()), block_offsets: meta + data + header_size, dictionary_blocks: vec![], record_blocks: vec![], @@ -832,6 +833,11 @@ impl FileWriter { Ok(()) } + /// Returns the arrow [`SchemaRef`] for this arrow file. + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + /// Gets a reference to the underlying writer. pub fn get_ref(&self) -> &W { self.writer.get_ref() From 95b015cf7b5d57c7fe66a8feada4f48a987cb020 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Tue, 17 Oct 2023 01:52:27 +0800 Subject: [PATCH 07/25] Evaluate null_regex for string type in csv (now such values will be parsed as `Null` rather than `""`) (#4942) * fix: add null_regex for string type in csv * Update arrow-csv/src/reader/mod.rs Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --------- Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --- arrow-csv/src/reader/mod.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 2ba49cadc73f..1106b16bc46f 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -791,7 +791,10 @@ fn parse( } DataType::Utf8 => Ok(Arc::new( rows.iter() - .map(|row| Some(row.get(i))) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::(), ) as ArrayRef), DataType::Dictionary(key_type, value_type) @@ -1495,7 +1498,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![ Field::new("c_int", DataType::UInt64, false), Field::new("c_float", DataType::Float32, true), - Field::new("c_string", DataType::Utf8, false), + Field::new("c_string", DataType::Utf8, true), Field::new("c_bool", DataType::Boolean, false), ])); @@ -1596,8 +1599,7 @@ mod tests { assert!(batch.column(0).is_null(1)); assert!(batch.column(1).is_null(2)); assert!(batch.column(3).is_null(4)); - // String won't be empty - assert!(!batch.column(2).is_null(3)); + assert!(batch.column(2).is_null(3)); assert!(!batch.column(2).is_null(4)); } @@ -2237,8 +2239,8 @@ mod tests { fn err_test(csv: &[u8], expected: &str) { let schema = Arc::new(Schema::new(vec![ - Field::new("text1", DataType::Utf8, false), - Field::new("text2", DataType::Utf8, false), + Field::new("text1", DataType::Utf8, true), + Field::new("text2", DataType::Utf8, true), ])); let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv)); let b = ReaderBuilder::new(schema) From ab87abdd69ab787fdf247cf36f04abc1fbfa6266 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Tue, 17 Oct 2023 12:39:34 +0100 Subject: [PATCH 08/25] Generate `ETag`s for `InMemory` and `LocalFileSystem` (#4879) (#4922) * Support ETag in InMemory (#4879) * Add LocalFileSystem Etag * Review feedback * Review feedback --- object_store/src/lib.rs | 206 ++++++++++++++++++++++++++++--------- object_store/src/local.rs | 37 ++++--- object_store/src/memory.rs | 149 ++++++++++++++++----------- 3 files changed, 268 insertions(+), 124 deletions(-) diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index ff0a46533dda..b79042e3cda8 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -698,12 +698,28 @@ pub struct GetOptions { /// Request will succeed if the `ObjectMeta::e_tag` matches /// otherwise returning [`Error::Precondition`] /// - /// + /// See + /// + /// Examples: + /// + /// ```text + /// If-Match: "xyzzy" + /// If-Match: "xyzzy", "r2d2xxxx", "c3piozzzz" + /// If-Match: * + /// ``` pub if_match: Option, /// Request will succeed if the `ObjectMeta::e_tag` does not match /// otherwise returning [`Error::NotModified`] /// - /// + /// See + /// + /// Examples: + /// + /// ```text + /// If-None-Match: "xyzzy" + /// If-None-Match: "xyzzy", "r2d2xxxx", "c3piozzzz" + /// If-None-Match: * + /// ``` pub if_none_match: Option, /// Request will succeed if the object has been modified since /// @@ -730,25 +746,41 @@ pub struct GetOptions { impl GetOptions { /// Returns an error if the modification conditions on this request are not satisfied - fn check_modified( - &self, - location: &Path, - last_modified: DateTime, - ) -> Result<()> { - if let Some(date) = self.if_modified_since { - if last_modified <= date { - return Err(Error::NotModified { - path: location.to_string(), - source: format!("{} >= {}", date, last_modified).into(), + /// + /// + fn check_preconditions(&self, meta: &ObjectMeta) -> Result<()> { + // The use of the invalid etag "*" means no ETag is equivalent to never matching + let etag = meta.e_tag.as_deref().unwrap_or("*"); + let last_modified = meta.last_modified; + + if let Some(m) = &self.if_match { + if m != "*" && m.split(',').map(str::trim).all(|x| x != etag) { + return Err(Error::Precondition { + path: meta.location.to_string(), + source: format!("{etag} does not match {m}").into(), }); } - } - - if let Some(date) = self.if_unmodified_since { + } else if let Some(date) = self.if_unmodified_since { if last_modified > date { return Err(Error::Precondition { - path: location.to_string(), - source: format!("{} < {}", date, last_modified).into(), + path: meta.location.to_string(), + source: format!("{date} < {last_modified}").into(), + }); + } + } + + if let Some(m) = &self.if_none_match { + if m == "*" || m.split(',').map(str::trim).any(|x| x == etag) { + return Err(Error::NotModified { + path: meta.location.to_string(), + source: format!("{etag} matches {m}").into(), + }); + } + } else if let Some(date) = self.if_modified_since { + if last_modified <= date { + return Err(Error::NotModified { + path: meta.location.to_string(), + source: format!("{date} >= {last_modified}").into(), }); } } @@ -952,6 +984,7 @@ mod test_util { mod tests { use super::*; use crate::test_util::flatten_list_stream; + use chrono::TimeZone; use rand::{thread_rng, Rng}; use tokio::io::AsyncWriteExt; @@ -1359,33 +1392,32 @@ mod tests { Err(e) => panic!("{e}"), } - if let Some(tag) = meta.e_tag { - let options = GetOptions { - if_match: Some(tag.clone()), - ..GetOptions::default() - }; - storage.get_opts(&path, options).await.unwrap(); - - let options = GetOptions { - if_match: Some("invalid".to_string()), - ..GetOptions::default() - }; - let err = storage.get_opts(&path, options).await.unwrap_err(); - assert!(matches!(err, Error::Precondition { .. }), "{err}"); - - let options = GetOptions { - if_none_match: Some(tag.clone()), - ..GetOptions::default() - }; - let err = storage.get_opts(&path, options).await.unwrap_err(); - assert!(matches!(err, Error::NotModified { .. }), "{err}"); - - let options = GetOptions { - if_none_match: Some("invalid".to_string()), - ..GetOptions::default() - }; - storage.get_opts(&path, options).await.unwrap(); - } + let tag = meta.e_tag.unwrap(); + let options = GetOptions { + if_match: Some(tag.clone()), + ..GetOptions::default() + }; + storage.get_opts(&path, options).await.unwrap(); + + let options = GetOptions { + if_match: Some("invalid".to_string()), + ..GetOptions::default() + }; + let err = storage.get_opts(&path, options).await.unwrap_err(); + assert!(matches!(err, Error::Precondition { .. }), "{err}"); + + let options = GetOptions { + if_none_match: Some(tag.clone()), + ..GetOptions::default() + }; + let err = storage.get_opts(&path, options).await.unwrap_err(); + assert!(matches!(err, Error::NotModified { .. }), "{err}"); + + let options = GetOptions { + if_none_match: Some("invalid".to_string()), + ..GetOptions::default() + }; + storage.get_opts(&path, options).await.unwrap(); } /// Returns a chunk of length `chunk_length` @@ -1697,8 +1729,86 @@ mod tests { assert!(stream.next().await.is_none()); } - // Tests TODO: - // GET nonexisting location (in_memory/file) - // DELETE nonexisting location - // PUT overwriting + #[test] + fn test_preconditions() { + let mut meta = ObjectMeta { + location: Path::from("test"), + last_modified: Utc.timestamp_nanos(100), + size: 100, + e_tag: Some("123".to_string()), + }; + + let mut options = GetOptions::default(); + options.check_preconditions(&meta).unwrap(); + + options.if_modified_since = Some(Utc.timestamp_nanos(50)); + options.check_preconditions(&meta).unwrap(); + + options.if_modified_since = Some(Utc.timestamp_nanos(100)); + options.check_preconditions(&meta).unwrap_err(); + + options.if_modified_since = Some(Utc.timestamp_nanos(101)); + options.check_preconditions(&meta).unwrap_err(); + + options = GetOptions::default(); + + options.if_unmodified_since = Some(Utc.timestamp_nanos(50)); + options.check_preconditions(&meta).unwrap_err(); + + options.if_unmodified_since = Some(Utc.timestamp_nanos(100)); + options.check_preconditions(&meta).unwrap(); + + options.if_unmodified_since = Some(Utc.timestamp_nanos(101)); + options.check_preconditions(&meta).unwrap(); + + options = GetOptions::default(); + + options.if_match = Some("123".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_match = Some("123,354".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_match = Some("354, 123,".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_match = Some("354".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + options.if_match = Some("*".to_string()); + options.check_preconditions(&meta).unwrap(); + + // If-Match takes precedence + options.if_unmodified_since = Some(Utc.timestamp_nanos(200)); + options.check_preconditions(&meta).unwrap(); + + options = GetOptions::default(); + + options.if_none_match = Some("123".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + options.if_none_match = Some("*".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + options.if_none_match = Some("1232".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_none_match = Some("23, 123".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + // If-None-Match takes precedence + options.if_modified_since = Some(Utc.timestamp_nanos(10)); + options.check_preconditions(&meta).unwrap_err(); + + // Check missing ETag + meta.e_tag = None; + options = GetOptions::default(); + + options.if_none_match = Some("*".to_string()); // Fails if any file exists + options.check_preconditions(&meta).unwrap_err(); + + options = GetOptions::default(); + options.if_match = Some("*".to_string()); // Passes if file exists + options.check_preconditions(&meta).unwrap(); + } } diff --git a/object_store/src/local.rs b/object_store/src/local.rs index 3ed63a410815..3d4a02a1e9e9 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -365,23 +365,12 @@ impl ObjectStore for LocalFileSystem { } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { - if options.if_match.is_some() || options.if_none_match.is_some() { - return Err(super::Error::NotSupported { - source: "ETags not supported by LocalFileSystem".to_string().into(), - }); - } - let location = location.clone(); let path = self.config.path_to_filesystem(&location)?; maybe_spawn_blocking(move || { let (file, metadata) = open_file(&path)?; - if options.if_unmodified_since.is_some() - || options.if_modified_since.is_some() - { - options.check_modified(&location, last_modified(&metadata))?; - } - let meta = convert_metadata(metadata, location)?; + options.check_preconditions(&meta)?; Ok(GetResult { payload: GetResultPayload::File(file, path), @@ -965,7 +954,7 @@ fn convert_entry(entry: DirEntry, location: Path) -> Result { convert_metadata(metadata, location) } -fn last_modified(metadata: &std::fs::Metadata) -> DateTime { +fn last_modified(metadata: &Metadata) -> DateTime { metadata .modified() .expect("Modified file time should be supported on this platform") @@ -977,15 +966,35 @@ fn convert_metadata(metadata: Metadata, location: Path) -> Result { let size = usize::try_from(metadata.len()).context(FileSizeOverflowedUsizeSnafu { path: location.as_ref(), })?; + let inode = get_inode(&metadata); + let mtime = last_modified.timestamp_micros(); + + // Use an ETag scheme based on that used by many popular HTTP servers + // + // + let etag = format!("{inode:x}-{mtime:x}-{size:x}"); Ok(ObjectMeta { location, last_modified, size, - e_tag: None, + e_tag: Some(etag), }) } +#[cfg(unix)] +/// We include the inode when available to yield an ETag more resistant to collisions +/// and as used by popular web servers such as [Apache](https://httpd.apache.org/docs/2.2/mod/core.html#fileetag) +fn get_inode(metadata: &Metadata) -> u64 { + std::os::unix::fs::MetadataExt::ino(metadata) +} + +#[cfg(not(unix))] +/// On platforms where an inode isn't available, fallback to just relying on size and mtime +fn get_inode(metadata: &Metadata) -> u64 { + 0 +} + /// Convert walkdir results and converts not-found errors into `None`. /// Convert broken symlinks to `None`. fn convert_walkdir_result( diff --git a/object_store/src/memory.rs b/object_store/src/memory.rs index 0e229885b006..f638ed6d7a55 100644 --- a/object_store/src/memory.rs +++ b/object_store/src/memory.rs @@ -35,9 +35,6 @@ use std::sync::Arc; use std::task::Poll; use tokio::io::AsyncWrite; -type Entry = (Bytes, DateTime); -type StorageType = Arc>>; - /// A specialized `Error` for in-memory object store-related errors #[derive(Debug, Snafu)] #[allow(missing_docs)] @@ -80,7 +77,41 @@ impl From for super::Error { /// storage provider. #[derive(Debug, Default)] pub struct InMemory { - storage: StorageType, + storage: SharedStorage, +} + +#[derive(Debug, Clone)] +struct Entry { + data: Bytes, + last_modified: DateTime, + e_tag: usize, +} + +impl Entry { + fn new(data: Bytes, last_modified: DateTime, e_tag: usize) -> Self { + Self { + data, + last_modified, + e_tag, + } + } +} + +#[derive(Debug, Default, Clone)] +struct Storage { + next_etag: usize, + map: BTreeMap, +} + +type SharedStorage = Arc>; + +impl Storage { + fn insert(&mut self, location: &Path, bytes: Bytes) { + let etag = self.next_etag; + self.next_etag += 1; + let entry = Entry::new(bytes, Utc::now(), etag); + self.map.insert(location.clone(), entry); + } } impl std::fmt::Display for InMemory { @@ -92,9 +123,7 @@ impl std::fmt::Display for InMemory { #[async_trait] impl ObjectStore for InMemory { async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - self.storage - .write() - .insert(location.clone(), (bytes, Utc::now())); + self.storage.write().insert(location, bytes); Ok(()) } @@ -128,33 +157,30 @@ impl ObjectStore for InMemory { Ok(Box::new(InMemoryAppend { location: location.clone(), data: Vec::::new(), - storage: StorageType::clone(&self.storage), + storage: SharedStorage::clone(&self.storage), })) } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { - if options.if_match.is_some() || options.if_none_match.is_some() { - return Err(super::Error::NotSupported { - source: "ETags not supported by InMemory".to_string().into(), - }); - } - let (data, last_modified) = self.entry(location).await?; - options.check_modified(location, last_modified)?; + let entry = self.entry(location).await?; + let e_tag = entry.e_tag.to_string(); + let meta = ObjectMeta { location: location.clone(), - last_modified, - size: data.len(), - e_tag: None, + last_modified: entry.last_modified, + size: entry.data.len(), + e_tag: Some(e_tag), }; + options.check_preconditions(&meta)?; let (range, data) = match options.range { Some(range) => { - let len = data.len(); + let len = entry.data.len(); ensure!(range.end <= len, OutOfRangeSnafu { range, len }); ensure!(range.start <= range.end, BadRangeSnafu { range }); - (range.clone(), data.slice(range)) + (range.clone(), entry.data.slice(range)) } - None => (0..data.len(), data), + None => (0..entry.data.len(), entry.data), }; let stream = futures::stream::once(futures::future::ready(Ok(data))); @@ -170,15 +196,18 @@ impl ObjectStore for InMemory { location: &Path, ranges: &[Range], ) -> Result> { - let data = self.entry(location).await?; + let entry = self.entry(location).await?; ranges .iter() .map(|range| { let range = range.clone(); - let len = data.0.len(); - ensure!(range.end <= data.0.len(), OutOfRangeSnafu { range, len }); + let len = entry.data.len(); + ensure!( + range.end <= entry.data.len(), + OutOfRangeSnafu { range, len } + ); ensure!(range.start <= range.end, BadRangeSnafu { range }); - Ok(data.0.slice(range)) + Ok(entry.data.slice(range)) }) .collect() } @@ -188,14 +217,14 @@ impl ObjectStore for InMemory { Ok(ObjectMeta { location: location.clone(), - last_modified: entry.1, - size: entry.0.len(), - e_tag: None, + last_modified: entry.last_modified, + size: entry.data.len(), + e_tag: Some(entry.e_tag.to_string()), }) } async fn delete(&self, location: &Path) -> Result<()> { - self.storage.write().remove(location); + self.storage.write().map.remove(location); Ok(()) } @@ -208,6 +237,7 @@ impl ObjectStore for InMemory { let storage = self.storage.read(); let values: Vec<_> = storage + .map .range((prefix)..) .take_while(|(key, _)| key.as_ref().starts_with(prefix.as_ref())) .filter(|(key, _)| { @@ -219,9 +249,9 @@ impl ObjectStore for InMemory { .map(|(key, value)| { Ok(ObjectMeta { location: key.clone(), - last_modified: value.1, - size: value.0.len(), - e_tag: None, + last_modified: value.last_modified, + size: value.data.len(), + e_tag: Some(value.e_tag.to_string()), }) }) .collect(); @@ -241,7 +271,7 @@ impl ObjectStore for InMemory { // Only objects in this base level should be returned in the // response. Otherwise, we just collect the common prefixes. let mut objects = vec![]; - for (k, v) in self.storage.read().range((prefix)..) { + for (k, v) in self.storage.read().map.range((prefix)..) { if !k.as_ref().starts_with(prefix.as_ref()) { break; } @@ -263,9 +293,9 @@ impl ObjectStore for InMemory { } else { let object = ObjectMeta { location: k.clone(), - last_modified: v.1, - size: v.0.len(), - e_tag: None, + last_modified: v.last_modified, + size: v.data.len(), + e_tag: Some(v.e_tag.to_string()), }; objects.push(object); } @@ -278,23 +308,21 @@ impl ObjectStore for InMemory { } async fn copy(&self, from: &Path, to: &Path) -> Result<()> { - let data = self.entry(from).await?; - self.storage - .write() - .insert(to.clone(), (data.0, Utc::now())); + let entry = self.entry(from).await?; + self.storage.write().insert(to, entry.data); Ok(()) } async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { - let data = self.entry(from).await?; + let entry = self.entry(from).await?; let mut storage = self.storage.write(); - if storage.contains_key(to) { + if storage.map.contains_key(to) { return Err(Error::AlreadyExists { path: to.to_string(), } .into()); } - storage.insert(to.clone(), (data.0, Utc::now())); + storage.insert(to, entry.data); Ok(()) } } @@ -319,9 +347,10 @@ impl InMemory { self.fork() } - async fn entry(&self, location: &Path) -> Result<(Bytes, DateTime)> { + async fn entry(&self, location: &Path) -> Result { let storage = self.storage.read(); let value = storage + .map .get(location) .cloned() .context(NoDataInMemorySnafu { @@ -335,7 +364,7 @@ impl InMemory { struct InMemoryUpload { location: Path, data: Vec, - storage: StorageType, + storage: Arc>, } impl AsyncWrite for InMemoryUpload { @@ -343,7 +372,7 @@ impl AsyncWrite for InMemoryUpload { mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> std::task::Poll> { + ) -> Poll> { self.data.extend_from_slice(buf); Poll::Ready(Ok(buf.len())) } @@ -351,18 +380,16 @@ impl AsyncWrite for InMemoryUpload { fn poll_flush( self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { Poll::Ready(Ok(())) } fn poll_shutdown( mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { let data = Bytes::from(std::mem::take(&mut self.data)); - self.storage - .write() - .insert(self.location.clone(), (data, Utc::now())); + self.storage.write().insert(&self.location, data); Poll::Ready(Ok(())) } } @@ -370,7 +397,7 @@ impl AsyncWrite for InMemoryUpload { struct InMemoryAppend { location: Path, data: Vec, - storage: StorageType, + storage: Arc>, } impl AsyncWrite for InMemoryAppend { @@ -378,7 +405,7 @@ impl AsyncWrite for InMemoryAppend { mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> std::task::Poll> { + ) -> Poll> { self.data.extend_from_slice(buf); Poll::Ready(Ok(buf.len())) } @@ -386,20 +413,18 @@ impl AsyncWrite for InMemoryAppend { fn poll_flush( mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let storage = StorageType::clone(&self.storage); + ) -> Poll> { + let storage = Arc::clone(&self.storage); let mut writer = storage.write(); - if let Some((bytes, _)) = writer.remove(&self.location) { + if let Some(entry) = writer.map.remove(&self.location) { let buf = std::mem::take(&mut self.data); - let concat = Bytes::from_iter(bytes.into_iter().chain(buf)); - writer.insert(self.location.clone(), (concat, Utc::now())); + let concat = Bytes::from_iter(entry.data.into_iter().chain(buf)); + writer.insert(&self.location, concat); } else { - writer.insert( - self.location.clone(), - (Bytes::from(std::mem::take(&mut self.data)), Utc::now()), - ); + let data = Bytes::from(std::mem::take(&mut self.data)); + writer.insert(&self.location, data); }; Poll::Ready(Ok(())) } From d4d11fe7a47b529429020848f2ac0f63659500d6 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Tue, 17 Oct 2023 22:09:12 +0100 Subject: [PATCH 09/25] Assume Pages Delimit Records When Offset Index Loaded (#4921) (#4943) * Assume records not split across pages (#4921) * More test * Add PageReader::at_record_boundary * Fix flush partial --- parquet/src/arrow/array_reader/mod.rs | 2 +- parquet/src/arrow/async_reader/mod.rs | 96 ++++++++++++++++++++++++++- parquet/src/column/page.rs | 14 ++++ parquet/src/column/reader.rs | 8 +-- parquet/src/column/reader/decoder.rs | 7 ++ parquet/src/file/serialized_reader.rs | 9 +++ 6 files changed, 129 insertions(+), 7 deletions(-) diff --git a/parquet/src/arrow/array_reader/mod.rs b/parquet/src/arrow/array_reader/mod.rs index 625ac034ef47..a4ee5040590e 100644 --- a/parquet/src/arrow/array_reader/mod.rs +++ b/parquet/src/arrow/array_reader/mod.rs @@ -152,7 +152,7 @@ where Ok(records_read) } -/// Uses `record_reader` to skip up to `batch_size` records from`pages` +/// Uses `record_reader` to skip up to `batch_size` records from `pages` /// /// Returns the number of records skipped, which can be less than `batch_size` if /// pages is exhausted diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index 4b3eebf2e67e..875fff4dac57 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -878,12 +878,17 @@ mod tests { use crate::file::properties::WriterProperties; use arrow::compute::kernels::cmp::eq; use arrow::error::Result as ArrowResult; + use arrow_array::builder::{ListBuilder, StringBuilder}; use arrow_array::cast::AsArray; use arrow_array::types::Int32Type; - use arrow_array::{Array, ArrayRef, Int32Array, Int8Array, Scalar, StringArray}; - use futures::TryStreamExt; + use arrow_array::{ + Array, ArrayRef, Int32Array, Int8Array, Scalar, StringArray, UInt64Array, + }; + use arrow_schema::{DataType, Field, Schema}; + use futures::{StreamExt, TryStreamExt}; use rand::{thread_rng, Rng}; use std::sync::Mutex; + use tempfile::tempfile; #[derive(Clone)] struct TestReader { @@ -1677,4 +1682,91 @@ mod tests { assert!(sbbf.check(&"Hello")); assert!(!sbbf.check(&"Hello_Not_Exists")); } + + #[tokio::test] + async fn test_nested_skip() { + let schema = Arc::new(Schema::new(vec![ + Field::new("col_1", DataType::UInt64, false), + Field::new_list("col_2", Field::new("item", DataType::Utf8, true), true), + ])); + + // Default writer properties + let props = WriterProperties::builder() + .set_data_page_row_count_limit(256) + .set_write_batch_size(256) + .set_max_row_group_size(1024); + + // Write data + let mut file = tempfile().unwrap(); + let mut writer = + ArrowWriter::try_new(&mut file, schema.clone(), Some(props.build())).unwrap(); + + let mut builder = ListBuilder::new(StringBuilder::new()); + for id in 0..1024 { + match id % 3 { + 0 => builder + .append_value([Some("val_1".to_string()), Some(format!("id_{id}"))]), + 1 => builder.append_value([Some(format!("id_{id}"))]), + _ => builder.append_null(), + } + } + let refs = vec![ + Arc::new(UInt64Array::from_iter_values(0..1024)) as ArrayRef, + Arc::new(builder.finish()) as ArrayRef, + ]; + + let batch = RecordBatch::try_new(schema.clone(), refs).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let selections = [ + RowSelection::from(vec![ + RowSelector::skip(313), + RowSelector::select(1), + RowSelector::skip(709), + RowSelector::select(1), + ]), + RowSelection::from(vec![ + RowSelector::skip(255), + RowSelector::select(1), + RowSelector::skip(767), + RowSelector::select(1), + ]), + RowSelection::from(vec![ + RowSelector::select(255), + RowSelector::skip(1), + RowSelector::select(767), + RowSelector::skip(1), + ]), + RowSelection::from(vec![ + RowSelector::skip(254), + RowSelector::select(1), + RowSelector::select(1), + RowSelector::skip(767), + RowSelector::select(1), + ]), + ]; + + for selection in selections { + let expected = selection.row_count(); + // Read data + let mut reader = ParquetRecordBatchStreamBuilder::new_with_options( + tokio::fs::File::from_std(file.try_clone().unwrap()), + ArrowReaderOptions::new().with_page_index(true), + ) + .await + .unwrap(); + + reader = reader.with_row_selection(selection); + + let mut stream = reader.build().unwrap(); + + let mut total_rows = 0; + while let Some(rb) = stream.next().await { + let rb = rb.unwrap(); + total_rows += rb.num_rows(); + } + assert_eq!(total_rows, expected); + } + } } diff --git a/parquet/src/column/page.rs b/parquet/src/column/page.rs index ec9af2aa271a..933e42386272 100644 --- a/parquet/src/column/page.rs +++ b/parquet/src/column/page.rs @@ -320,6 +320,20 @@ pub trait PageReader: Iterator> + Send { /// Skips reading the next page, returns an error if no /// column index information fn skip_next_page(&mut self) -> Result<()>; + + /// Returns `true` if the next page can be assumed to contain the start of a new record + /// + /// Prior to parquet V2 the specification was ambiguous as to whether a single record + /// could be split across multiple pages, and prior to [(#4327)] the Rust writer would do + /// this in certain situations. However, correctly interpreting the offset index relies on + /// this assumption holding [(#4943)], and so this mechanism is provided for a [`PageReader`] + /// to signal this to the calling context + /// + /// [(#4327)]: https://github.com/apache/arrow-rs/pull/4327 + /// [(#4943)]: https://github.com/apache/arrow-rs/pull/4943 + fn at_record_boundary(&mut self) -> Result { + Ok(self.peek_next_page()?.is_none()) + } } /// API for writing pages in a column chunk. diff --git a/parquet/src/column/reader.rs b/parquet/src/column/reader.rs index 3ce00622e953..52ad4d644c95 100644 --- a/parquet/src/column/reader.rs +++ b/parquet/src/column/reader.rs @@ -269,7 +269,7 @@ where // Reached end of page, which implies records_read < remaining_records // as otherwise would have stopped reading before reaching the end assert!(records_read < remaining_records); // Sanity check - records_read += 1; + records_read += reader.flush_partial() as usize; } (records_read, levels_read) } @@ -380,7 +380,7 @@ where // Reached end of page, which implies records_read < remaining_records // as otherwise would have stopped reading before reaching the end assert!(records_read < remaining_records); // Sanity check - records_read += 1; + records_read += decoder.flush_partial() as usize; } (records_read, levels_read) @@ -491,7 +491,7 @@ where offset += bytes_read; self.has_record_delimiter = - self.page_reader.peek_next_page()?.is_none(); + self.page_reader.at_record_boundary()?; self.rep_level_decoder .as_mut() @@ -548,7 +548,7 @@ where // across multiple pages, however, the parquet writer // used to do this so we preserve backwards compatibility self.has_record_delimiter = - self.page_reader.peek_next_page()?.is_none(); + self.page_reader.at_record_boundary()?; self.rep_level_decoder.as_mut().unwrap().set_data( Encoding::RLE, diff --git a/parquet/src/column/reader/decoder.rs b/parquet/src/column/reader/decoder.rs index 369b335dc98f..27ffb7637e18 100644 --- a/parquet/src/column/reader/decoder.rs +++ b/parquet/src/column/reader/decoder.rs @@ -102,6 +102,9 @@ pub trait RepetitionLevelDecoder: ColumnLevelDecoder { num_records: usize, num_levels: usize, ) -> Result<(usize, usize)>; + + /// Flush any partially read or skipped record + fn flush_partial(&mut self) -> bool; } pub trait DefinitionLevelDecoder: ColumnLevelDecoder { @@ -519,6 +522,10 @@ impl RepetitionLevelDecoder for RepetitionLevelDecoderImpl { } Ok((total_records_read, total_levels_read)) } + + fn flush_partial(&mut self) -> bool { + std::mem::take(&mut self.has_partial) + } } #[cfg(test)] diff --git a/parquet/src/file/serialized_reader.rs b/parquet/src/file/serialized_reader.rs index 4bc484144a81..b60d30ffea23 100644 --- a/parquet/src/file/serialized_reader.rs +++ b/parquet/src/file/serialized_reader.rs @@ -770,6 +770,15 @@ impl PageReader for SerializedPageReader { } } } + + fn at_record_boundary(&mut self) -> Result { + match &mut self.state { + SerializedPageReaderState::Values { .. } => { + Ok(self.peek_next_page()?.is_none()) + } + SerializedPageReaderState::Pages { .. } => Ok(true), + } + } } #[cfg(test)] From fa7a61a4b074ca4ec9bf429cc84b6c325057d96e Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Tue, 17 Oct 2023 22:10:31 +0100 Subject: [PATCH 10/25] Remove Nested async and Fallibility from ObjectStore::list (#4930) * Remove nested async and fallibility from ObjectStore::list * Clippy * Update limit test * Update docs --- object_store/src/aws/mod.rs | 13 +- object_store/src/azure/mod.rs | 7 +- object_store/src/chunked.rs | 13 +- object_store/src/client/list.rs | 32 ++--- object_store/src/gcp/mod.rs | 7 +- object_store/src/http/mod.rs | 24 ++-- object_store/src/lib.rs | 178 +++++++++++---------------- object_store/src/limit.rs | 44 ++++--- object_store/src/local.rs | 82 +++++------- object_store/src/memory.rs | 7 +- object_store/src/prefix.rs | 17 ++- object_store/src/throttle.rs | 47 +++---- object_store/tests/get_range_file.rs | 5 +- 13 files changed, 197 insertions(+), 279 deletions(-) diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index 3ddce08002c4..d3c50861c122 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -331,19 +331,16 @@ impl ObjectStore for AmazonS3 { .boxed() } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - self.client.list(prefix).await + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.client.list(prefix) } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { - self.client.list_with_offset(prefix, offset).await + ) -> BoxStream<'_, Result> { + self.client.list_with_offset(prefix, offset) } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 190b73bf9490..2a08c6775807 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -206,11 +206,8 @@ impl ObjectStore for MicrosoftAzure { self.client.delete_request(location, &()).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - self.client.list(prefix).await + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.client.list(prefix) } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { diff --git a/object_store/src/chunked.rs b/object_store/src/chunked.rs index 008dec679413..d3e02b412725 100644 --- a/object_store/src/chunked.rs +++ b/object_store/src/chunked.rs @@ -147,19 +147,16 @@ impl ObjectStore for ChunkedStore { self.inner.delete(location).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - self.inner.list(prefix).await + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.inner.list(prefix) } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { - self.inner.list_with_offset(prefix, offset).await + ) -> BoxStream<'_, Result> { + self.inner.list_with_offset(prefix, offset) } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { diff --git a/object_store/src/client/list.rs b/object_store/src/client/list.rs index b2dbee27f14d..371894dfeb71 100644 --- a/object_store/src/client/list.rs +++ b/object_store/src/client/list.rs @@ -46,16 +46,13 @@ pub trait ListClientExt { offset: Option<&Path>, ) -> BoxStream<'_, Result>; - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>>; + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result>; - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>>; + ) -> BoxStream<'_, Result>; async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result; } @@ -90,31 +87,22 @@ impl ListClientExt for T { .boxed() } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - let stream = self - .list_paginated(prefix, false, None) + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.list_paginated(prefix, false, None) .map_ok(|r| futures::stream::iter(r.objects.into_iter().map(Ok))) .try_flatten() - .boxed(); - - Ok(stream) + .boxed() } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { - let stream = self - .list_paginated(prefix, false, Some(offset)) + ) -> BoxStream<'_, Result> { + self.list_paginated(prefix, false, Some(offset)) .map_ok(|r| futures::stream::iter(r.objects.into_iter().map(Ok))) .try_flatten() - .boxed(); - - Ok(stream) + .boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index a75527fe7b9f..513e396cbae6 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -601,11 +601,8 @@ impl ObjectStore for GoogleCloudStorage { self.client.delete_request(location).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - self.client.list(prefix).await + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.client.list(prefix) } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { diff --git a/object_store/src/http/mod.rs b/object_store/src/http/mod.rs index 6ffb62358941..2fd7850b6bbf 100644 --- a/object_store/src/http/mod.rs +++ b/object_store/src/http/mod.rs @@ -34,7 +34,7 @@ use async_trait::async_trait; use bytes::Bytes; use futures::stream::BoxStream; -use futures::StreamExt; +use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; use snafu::{OptionExt, ResultExt, Snafu}; use tokio::io::AsyncWrite; @@ -122,14 +122,13 @@ impl ObjectStore for HttpStore { self.client.delete(location).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { let prefix_len = prefix.map(|p| p.as_ref().len()).unwrap_or_default(); - let status = self.client.list(prefix, "infinity").await?; - Ok(futures::stream::iter( - status + let prefix = prefix.cloned(); + futures::stream::once(async move { + let status = self.client.list(prefix.as_ref(), "infinity").await?; + + let iter = status .response .into_iter() .filter(|r| !r.is_dir()) @@ -138,9 +137,12 @@ impl ObjectStore for HttpStore { response.object_meta(self.client.base_url()) }) // Filter out exact prefix matches - .filter_ok(move |r| r.location.as_ref().len() > prefix_len), - ) - .boxed()) + .filter_ok(move |r| r.location.as_ref().len() > prefix_len); + + Ok::<_, crate::Error>(futures::stream::iter(iter)) + }) + .try_flatten() + .boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index b79042e3cda8..9b396444fa0d 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -95,18 +95,18 @@ //! //! ``` //! # use object_store::local::LocalFileSystem; +//! # use std::sync::Arc; +//! # use object_store::{path::Path, ObjectStore}; +//! # use futures::stream::StreamExt; //! # // use LocalFileSystem for example -//! # fn get_object_store() -> LocalFileSystem { -//! # LocalFileSystem::new_with_prefix("/tmp").unwrap() +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) //! # } -//! +//! # //! # async fn example() { -//! use std::sync::Arc; -//! use object_store::{path::Path, ObjectStore}; -//! use futures::stream::StreamExt; -//! +//! # //! // create an ObjectStore -//! let object_store: Arc = Arc::new(get_object_store()); +//! let object_store: Arc = get_object_store(); //! //! // Recursively list all files below the 'data' path. //! // 1. On AWS S3 this would be the 'data/' prefix @@ -114,21 +114,12 @@ //! let prefix: Path = "data".try_into().unwrap(); //! //! // Get an `async` stream of Metadata objects: -//! let list_stream = object_store -//! .list(Some(&prefix)) -//! .await -//! .expect("Error listing files"); +//! let mut list_stream = object_store.list(Some(&prefix)); //! -//! // Print a line about each object based on its metadata -//! // using for_each from `StreamExt` trait. -//! list_stream -//! .for_each(move |meta| { -//! async { -//! let meta = meta.expect("Error listing"); -//! println!("Name: {}, size: {}", meta.location, meta.size); -//! } -//! }) -//! .await; +//! // Print a line about each object +//! while let Some(meta) = list_stream.next().await.transpose().unwrap() { +//! println!("Name: {}, size: {}", meta.location, meta.size); +//! } //! # } //! ``` //! @@ -147,19 +138,18 @@ //! from remote storage or files in the local filesystem as a stream. //! //! ``` +//! # use futures::TryStreamExt; //! # use object_store::local::LocalFileSystem; -//! # // use LocalFileSystem for example -//! # fn get_object_store() -> LocalFileSystem { -//! # LocalFileSystem::new_with_prefix("/tmp").unwrap() +//! # use std::sync::Arc; +//! # use object_store::{path::Path, ObjectStore}; +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) //! # } -//! +//! # //! # async fn example() { -//! use std::sync::Arc; -//! use object_store::{path::Path, ObjectStore}; -//! use futures::stream::StreamExt; -//! +//! # //! // create an ObjectStore -//! let object_store: Arc = Arc::new(get_object_store()); +//! let object_store: Arc = get_object_store(); //! //! // Retrieve a specific file //! let path: Path = "data/file01.parquet".try_into().unwrap(); @@ -171,16 +161,11 @@ //! .unwrap() //! .into_stream(); //! -//! // Count the '0's using `map` from `StreamExt` trait +//! // Count the '0's using `try_fold` from `TryStreamExt` trait //! let num_zeros = stream -//! .map(|bytes| { -//! let bytes = bytes.unwrap(); -//! bytes.iter().filter(|b| **b == 0).count() -//! }) -//! .collect::>() -//! .await -//! .into_iter() -//! .sum::(); +//! .try_fold(0, |acc, bytes| async move { +//! Ok(acc + bytes.iter().filter(|b| **b == 0).count()) +//! }).await.unwrap(); //! //! println!("Num zeros in {} is {}", path, num_zeros); //! # } @@ -196,22 +181,19 @@ //! //! ``` //! # use object_store::local::LocalFileSystem; -//! # fn get_object_store() -> LocalFileSystem { -//! # LocalFileSystem::new_with_prefix("/tmp").unwrap() +//! # use object_store::ObjectStore; +//! # use std::sync::Arc; +//! # use bytes::Bytes; +//! # use object_store::path::Path; +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) //! # } //! # async fn put() { -//! use object_store::ObjectStore; -//! use std::sync::Arc; -//! use bytes::Bytes; -//! use object_store::path::Path; -//! -//! let object_store: Arc = Arc::new(get_object_store()); +//! # +//! let object_store: Arc = get_object_store(); //! let path: Path = "data/file1".try_into().unwrap(); //! let bytes = Bytes::from_static(b"hello"); -//! object_store -//! .put(&path, bytes) -//! .await -//! .unwrap(); +//! object_store.put(&path, bytes).await.unwrap(); //! # } //! ``` //! @@ -220,22 +202,20 @@ //! //! ``` //! # use object_store::local::LocalFileSystem; -//! # fn get_object_store() -> LocalFileSystem { -//! # LocalFileSystem::new_with_prefix("/tmp").unwrap() +//! # use object_store::ObjectStore; +//! # use std::sync::Arc; +//! # use bytes::Bytes; +//! # use tokio::io::AsyncWriteExt; +//! # use object_store::path::Path; +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) //! # } //! # async fn multi_upload() { -//! use object_store::ObjectStore; -//! use std::sync::Arc; -//! use bytes::Bytes; -//! use tokio::io::AsyncWriteExt; -//! use object_store::path::Path; -//! -//! let object_store: Arc = Arc::new(get_object_store()); +//! # +//! let object_store: Arc = get_object_store(); //! let path: Path = "data/large_file".try_into().unwrap(); -//! let (_id, mut writer) = object_store -//! .put_multipart(&path) -//! .await -//! .unwrap(); +//! let (_id, mut writer) = object_store.put_multipart(&path).await.unwrap(); +//! //! let bytes = Bytes::from_static(b"hello"); //! writer.write_all(&bytes).await.unwrap(); //! writer.flush().await.unwrap(); @@ -439,23 +419,22 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// return Ok. If it is an error, it will be [`Error::NotFound`]. /// /// ``` + /// # use futures::{StreamExt, TryStreamExt}; /// # use object_store::local::LocalFileSystem; /// # async fn example() -> Result<(), Box> { /// # let root = tempfile::TempDir::new().unwrap(); /// # let store = LocalFileSystem::new_with_prefix(root.path()).unwrap(); - /// use object_store::{ObjectStore, ObjectMeta}; - /// use object_store::path::Path; - /// use futures::{StreamExt, TryStreamExt}; - /// use bytes::Bytes; - /// + /// # use object_store::{ObjectStore, ObjectMeta}; + /// # use object_store::path::Path; + /// # use futures::{StreamExt, TryStreamExt}; + /// # use bytes::Bytes; + /// # /// // Create two objects /// store.put(&Path::from("foo"), Bytes::from("foo")).await?; /// store.put(&Path::from("bar"), Bytes::from("bar")).await?; /// /// // List object - /// let locations = store.list(None).await? - /// .map(|meta: Result| meta.map(|m| m.location)) - /// .boxed(); + /// let locations = store.list(None).map_ok(|m| m.location).boxed(); /// /// // Delete them /// store.delete_stream(locations).try_collect::>().await?; @@ -484,10 +463,7 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// `foo/bar_baz/x`. /// /// Note: the order of returned [`ObjectMeta`] is not guaranteed - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>>; + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result>; /// List all the objects with the given prefix and a location greater than `offset` /// @@ -495,18 +471,15 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// the number of network requests required /// /// Note: the order of returned [`ObjectMeta`] is not guaranteed - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { + ) -> BoxStream<'_, Result> { let offset = offset.clone(); - let stream = self - .list(prefix) - .await? + self.list(prefix) .try_filter(move |f| futures::future::ready(f.location > offset)) - .boxed(); - Ok(stream) + .boxed() } /// List objects with the given prefix and an implementation specific @@ -624,19 +597,16 @@ macro_rules! as_ref_impl { self.as_ref().delete_stream(locations) } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - self.as_ref().list(prefix).await + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.as_ref().list(prefix) } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { - self.as_ref().list_with_offset(prefix, offset).await + ) -> BoxStream<'_, Result> { + self.as_ref().list_with_offset(prefix, offset) } async fn list_with_delimiter( @@ -973,7 +943,6 @@ mod test_util { ) -> Result> { storage .list(prefix) - .await? .map_ok(|meta| meta.location) .try_collect::>() .await @@ -1264,11 +1233,7 @@ mod tests { ]; for (prefix, offset) in cases { - let s = storage - .list_with_offset(prefix.as_ref(), &offset) - .await - .unwrap(); - + let s = storage.list_with_offset(prefix.as_ref(), &offset); let mut actual: Vec<_> = s.map_ok(|x| x.location).try_collect().await.unwrap(); @@ -1700,12 +1665,7 @@ mod tests { } async fn delete_fixtures(storage: &DynObjectStore) { - let paths = storage - .list(None) - .await - .unwrap() - .map_ok(|meta| meta.location) - .boxed(); + let paths = storage.list(None).map_ok(|meta| meta.location).boxed(); storage .delete_stream(paths) .try_collect::>() @@ -1714,18 +1674,18 @@ mod tests { } /// Test that the returned stream does not borrow the lifetime of Path - async fn list_store<'a, 'b>( + fn list_store<'a>( store: &'a dyn ObjectStore, - path_str: &'b str, - ) -> super::Result>> { + path_str: &str, + ) -> BoxStream<'a, Result> { let path = Path::from(path_str); - store.list(Some(&path)).await + store.list(Some(&path)) } #[tokio::test] async fn test_list_lifetimes() { let store = memory::InMemory::new(); - let mut stream = list_store(&store, "path").await.unwrap(); + let mut stream = list_store(&store, "path"); assert!(stream.next().await.is_none()); } diff --git a/object_store/src/limit.rs b/object_store/src/limit.rs index a9b8c4b05020..00cbce023c3d 100644 --- a/object_store/src/limit.rs +++ b/object_store/src/limit.rs @@ -23,7 +23,7 @@ use crate::{ }; use async_trait::async_trait; use bytes::Bytes; -use futures::Stream; +use futures::{FutureExt, Stream}; use std::io::{Error, IoSlice}; use std::ops::Range; use std::pin::Pin; @@ -147,23 +147,31 @@ impl ObjectStore for LimitStore { self.inner.delete_stream(locations) } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); - let s = self.inner.list(prefix).await?; - Ok(PermitWrapper::new(s, permit).boxed()) + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + let prefix = prefix.cloned(); + let fut = Arc::clone(&self.semaphore) + .acquire_owned() + .map(move |permit| { + let s = self.inner.list(prefix.as_ref()); + PermitWrapper::new(s, permit.unwrap()) + }); + fut.into_stream().flatten().boxed() } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { - let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); - let s = self.inner.list_with_offset(prefix, offset).await?; - Ok(PermitWrapper::new(s, permit).boxed()) + ) -> BoxStream<'_, Result> { + let prefix = prefix.cloned(); + let offset = offset.clone(); + let fut = Arc::clone(&self.semaphore) + .acquire_owned() + .map(move |permit| { + let s = self.inner.list_with_offset(prefix.as_ref(), &offset); + PermitWrapper::new(s, permit.unwrap()) + }); + fut.into_stream().flatten().boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { @@ -272,6 +280,8 @@ mod tests { use crate::memory::InMemory; use crate::tests::*; use crate::ObjectStore; + use futures::stream::StreamExt; + use std::pin::Pin; use std::time::Duration; use tokio::time::timeout; @@ -290,19 +300,21 @@ mod tests { let mut streams = Vec::with_capacity(max_requests); for _ in 0..max_requests { - let stream = integration.list(None).await.unwrap(); + let mut stream = integration.list(None).peekable(); + Pin::new(&mut stream).peek().await; // Ensure semaphore is acquired streams.push(stream); } let t = Duration::from_millis(20); // Expect to not be able to make another request - assert!(timeout(t, integration.list(None)).await.is_err()); + let fut = integration.list(None).collect::>(); + assert!(timeout(t, fut).await.is_err()); // Drop one of the streams streams.pop(); // Can now make another request - integration.list(None).await.unwrap(); + integration.list(None).collect::>().await; } } diff --git a/object_store/src/local.rs b/object_store/src/local.rs index 3d4a02a1e9e9..38467c3a9e7c 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -420,14 +420,14 @@ impl ObjectStore for LocalFileSystem { .await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { let config = Arc::clone(&self.config); let root_path = match prefix { - Some(prefix) => config.path_to_filesystem(prefix)?, + Some(prefix) => match config.path_to_filesystem(prefix) { + Ok(path) => path, + Err(e) => return futures::future::ready(Err(e)).into_stream().boxed(), + }, None => self.config.root.to_file_path().unwrap(), }; @@ -457,36 +457,34 @@ impl ObjectStore for LocalFileSystem { // If no tokio context, return iterator directly as no // need to perform chunked spawn_blocking reads if tokio::runtime::Handle::try_current().is_err() { - return Ok(futures::stream::iter(s).boxed()); + return futures::stream::iter(s).boxed(); } // Otherwise list in batches of CHUNK_SIZE const CHUNK_SIZE: usize = 1024; let buffer = VecDeque::with_capacity(CHUNK_SIZE); - let stream = - futures::stream::try_unfold((s, buffer), |(mut s, mut buffer)| async move { - if buffer.is_empty() { - (s, buffer) = tokio::task::spawn_blocking(move || { - for _ in 0..CHUNK_SIZE { - match s.next() { - Some(r) => buffer.push_back(r), - None => break, - } + futures::stream::try_unfold((s, buffer), |(mut s, mut buffer)| async move { + if buffer.is_empty() { + (s, buffer) = tokio::task::spawn_blocking(move || { + for _ in 0..CHUNK_SIZE { + match s.next() { + Some(r) => buffer.push_back(r), + None => break, } - (s, buffer) - }) - .await?; - } - - match buffer.pop_front() { - Some(Err(e)) => Err(e), - Some(Ok(meta)) => Ok(Some((meta, (s, buffer)))), - None => Ok(None), - } - }); + } + (s, buffer) + }) + .await?; + } - Ok(stream.boxed()) + match buffer.pop_front() { + Some(Err(e)) => Err(e), + Some(Ok(meta)) => Ok(Some((meta, (s, buffer)))), + None => Ok(None), + } + }) + .boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { @@ -1138,21 +1136,14 @@ mod tests { let store = LocalFileSystem::new_with_prefix(root.path()).unwrap(); - // `list` must fail - match store.list(None).await { - Err(_) => { - // ok, error found - } - Ok(mut stream) => { - let mut any_err = false; - while let Some(res) = stream.next().await { - if res.is_err() { - any_err = true; - } - } - assert!(any_err); + let mut stream = store.list(None); + let mut any_err = false; + while let Some(res) = stream.next().await { + if res.is_err() { + any_err = true; } } + assert!(any_err); // `list_with_delimiter assert!(store.list_with_delimiter(None).await.is_err()); @@ -1226,13 +1217,7 @@ mod tests { prefix: Option<&Path>, expected: &[&str], ) { - let result: Vec<_> = integration - .list(prefix) - .await - .unwrap() - .try_collect() - .await - .unwrap(); + let result: Vec<_> = integration.list(prefix).try_collect().await.unwrap(); let mut strings: Vec<_> = result.iter().map(|x| x.location.as_ref()).collect(); strings.sort_unstable(); @@ -1428,8 +1413,7 @@ mod tests { std::fs::write(temp_dir.path().join(filename), "foo").unwrap(); - let list_stream = integration.list(None).await.unwrap(); - let res: Vec<_> = list_stream.try_collect().await.unwrap(); + let res: Vec<_> = integration.list(None).try_collect().await.unwrap(); assert_eq!(res.len(), 1); assert_eq!(res[0].location.as_ref(), filename); diff --git a/object_store/src/memory.rs b/object_store/src/memory.rs index f638ed6d7a55..00b330b5eb94 100644 --- a/object_store/src/memory.rs +++ b/object_store/src/memory.rs @@ -228,10 +228,7 @@ impl ObjectStore for InMemory { Ok(()) } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { let root = Path::default(); let prefix = prefix.unwrap_or(&root); @@ -256,7 +253,7 @@ impl ObjectStore for InMemory { }) .collect(); - Ok(futures::stream::iter(values).boxed()) + futures::stream::iter(values).boxed() } /// The memory implementation returns all results, as opposed to the cloud diff --git a/object_store/src/prefix.rs b/object_store/src/prefix.rs index 39585f73b692..3776dec2e872 100644 --- a/object_store/src/prefix.rs +++ b/object_store/src/prefix.rs @@ -144,24 +144,21 @@ impl ObjectStore for PrefixStore { self.inner.delete(&full_path).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { let prefix = self.full_path(prefix.unwrap_or(&Path::default())); - let s = self.inner.list(Some(&prefix)).await?; - Ok(s.map_ok(|meta| self.strip_meta(meta)).boxed()) + let s = self.inner.list(Some(&prefix)); + s.map_ok(|meta| self.strip_meta(meta)).boxed() } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { + ) -> BoxStream<'_, Result> { let offset = self.full_path(offset); let prefix = self.full_path(prefix.unwrap_or(&Path::default())); - let s = self.inner.list_with_offset(Some(&prefix), &offset).await?; - Ok(s.map_ok(|meta| self.strip_meta(meta)).boxed()) + let s = self.inner.list_with_offset(Some(&prefix), &offset); + s.map_ok(|meta| self.strip_meta(meta)).boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs index 58c476ab4530..f716a11f8a05 100644 --- a/object_store/src/throttle.rs +++ b/object_store/src/throttle.rs @@ -233,29 +233,30 @@ impl ObjectStore for ThrottledStore { self.inner.delete(location).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - sleep(self.config().wait_list_per_call).await; - - // need to copy to avoid moving / referencing `self` - let wait_list_per_entry = self.config().wait_list_per_entry; - let stream = self.inner.list(prefix).await?; - Ok(throttle_stream(stream, move |_| wait_list_per_entry)) + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + let stream = self.inner.list(prefix); + futures::stream::once(async move { + let wait_list_per_entry = self.config().wait_list_per_entry; + sleep(self.config().wait_list_per_call).await; + throttle_stream(stream, move |_| wait_list_per_entry) + }) + .flatten() + .boxed() } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { - sleep(self.config().wait_list_per_call).await; - - // need to copy to avoid moving / referencing `self` - let wait_list_per_entry = self.config().wait_list_per_entry; - let stream = self.inner.list_with_offset(prefix, offset).await?; - Ok(throttle_stream(stream, move |_| wait_list_per_entry)) + ) -> BoxStream<'_, Result> { + let stream = self.inner.list_with_offset(prefix, offset); + futures::stream::once(async move { + let wait_list_per_entry = self.config().wait_list_per_entry; + sleep(self.config().wait_list_per_call).await; + throttle_stream(stream, move |_| wait_list_per_entry) + }) + .flatten() + .boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { @@ -511,13 +512,7 @@ mod tests { let prefix = Path::from("foo"); // clean up store - let entries: Vec<_> = store - .list(Some(&prefix)) - .await - .unwrap() - .try_collect() - .await - .unwrap(); + let entries: Vec<_> = store.list(Some(&prefix)).try_collect().await.unwrap(); for entry in entries { store.delete(&entry.location).await.unwrap(); @@ -583,8 +578,6 @@ mod tests { let t0 = Instant::now(); store .list(Some(&prefix)) - .await - .unwrap() .try_collect::>() .await .unwrap(); diff --git a/object_store/tests/get_range_file.rs b/object_store/tests/get_range_file.rs index f926e3b07f2a..25c469260675 100644 --- a/object_store/tests/get_range_file.rs +++ b/object_store/tests/get_range_file.rs @@ -75,10 +75,7 @@ impl ObjectStore for MyStore { todo!() } - async fn list( - &self, - _: Option<&Path>, - ) -> object_store::Result>> { + fn list(&self, _: Option<&Path>) -> BoxStream<'_, object_store::Result> { todo!() } From 511ac44cf94ffe3f35e4efd3d1e816a8657a5061 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Tue, 17 Oct 2023 22:27:16 +0100 Subject: [PATCH 11/25] Fix object_store docs (#4947) --- object_store/src/parse.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/object_store/src/parse.rs b/object_store/src/parse.rs index 1159e9a1af17..2e72a710ac75 100644 --- a/object_store/src/parse.rs +++ b/object_store/src/parse.rs @@ -47,12 +47,12 @@ impl From for super::Error { } } -/// Recognises various URL formats, identifying the relevant [`ObjectStore`](crate::ObjectStore) +/// Recognises various URL formats, identifying the relevant [`ObjectStore`] #[derive(Debug, Eq, PartialEq)] enum ObjectStoreScheme { - /// Url corresponding to [`LocalFileSystem`](crate::local::LocalFileSystem) + /// Url corresponding to [`LocalFileSystem`] Local, - /// Url corresponding to [`InMemory`](crate::memory::InMemory) + /// Url corresponding to [`InMemory`] Memory, /// Url corresponding to [`AmazonS3`](crate::aws::AmazonS3) AmazonS3, From 952cd2efcb787385c6368acc8c582ffc5a7dfd95 Mon Sep 17 00:00:00 2001 From: Andre Martins <38951957+amartins23@users.noreply.github.com> Date: Tue, 17 Oct 2023 22:31:23 +0100 Subject: [PATCH 12/25] Expose SubstraitPlan structure in arrow_flight::sql (#4932) (#4933) --- arrow-flight/src/sql/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 4bb8ce8b36e5..4042ce8efc46 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -93,6 +93,7 @@ pub use gen::SqlSupportedTransactions; pub use gen::SqlSupportedUnions; pub use gen::SqlSupportsConvert; pub use gen::SqlTransactionIsolationLevel; +pub use gen::SubstraitPlan; pub use gen::SupportedSqlGrammar; pub use gen::TicketStatementQuery; pub use gen::UpdateDeleteRules; From a94ccff9deac04ca075f6f05f81a5755af81348e Mon Sep 17 00:00:00 2001 From: fan <75058860+fansehep@users.noreply.github.com> Date: Wed, 18 Oct 2023 17:36:43 +0800 Subject: [PATCH 13/25] feat: support parsing for parquet writer option (#4938) * feat: support parsing for parquet writer option Signed-off-by: fan * fix clippy warning Signed-off-by: fan * add tests Signed-off-by: fan * follow reviews Signed-off-by: fan * fix only support lower and uppercase Signed-off-by: fan --------- Signed-off-by: fan --- parquet/src/basic.rs | 185 +++++++++++++++++++++++++++++++++ parquet/src/file/properties.rs | 68 ++++++++++++ 2 files changed, 253 insertions(+) diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs index cc8d033f42a4..cdad3597ffef 100644 --- a/parquet/src/basic.rs +++ b/parquet/src/basic.rs @@ -18,6 +18,7 @@ //! Contains Rust mappings for Thrift definition. //! Refer to [`parquet.thrift`](https://github.com/apache/parquet-format/blob/master/src/main/thrift/parquet.thrift) file to see raw definitions. +use std::str::FromStr; use std::{fmt, str}; pub use crate::compression::{BrotliLevel, GzipLevel, ZstdLevel}; @@ -278,6 +279,29 @@ pub enum Encoding { BYTE_STREAM_SPLIT, } +impl FromStr for Encoding { + type Err = ParquetError; + + fn from_str(s: &str) -> Result { + match s { + "PLAIN" | "plain" => Ok(Encoding::PLAIN), + "PLAIN_DICTIONARY" | "plain_dictionary" => Ok(Encoding::PLAIN_DICTIONARY), + "RLE" | "rle" => Ok(Encoding::RLE), + "BIT_PACKED" | "bit_packed" => Ok(Encoding::BIT_PACKED), + "DELTA_BINARY_PACKED" | "delta_binary_packed" => { + Ok(Encoding::DELTA_BINARY_PACKED) + } + "DELTA_LENGTH_BYTE_ARRAY" | "delta_length_byte_array" => { + Ok(Encoding::DELTA_LENGTH_BYTE_ARRAY) + } + "DELTA_BYTE_ARRAY" | "delta_byte_array" => Ok(Encoding::DELTA_BYTE_ARRAY), + "RLE_DICTIONARY" | "rle_dictionary" => Ok(Encoding::RLE_DICTIONARY), + "BYTE_STREAM_SPLIT" | "byte_stream_split" => Ok(Encoding::BYTE_STREAM_SPLIT), + _ => Err(general_err!("unknown encoding: {}", s)), + } + } +} + // ---------------------------------------------------------------------- // Mirrors `parquet::CompressionCodec` @@ -295,6 +319,90 @@ pub enum Compression { LZ4_RAW, } +fn split_compression_string( + str_setting: &str, +) -> Result<(&str, Option), ParquetError> { + let split_setting = str_setting.split_once('('); + + match split_setting { + Some((codec, level_str)) => { + let level = + &level_str[..level_str.len() - 1] + .parse::() + .map_err(|_| { + ParquetError::General(format!( + "invalid compression level: {}", + level_str + )) + })?; + Ok((codec, Some(*level))) + } + None => Ok((str_setting, None)), + } +} + +fn check_level_is_none(level: &Option) -> Result<(), ParquetError> { + if level.is_some() { + return Err(ParquetError::General("level is not support".to_string())); + } + + Ok(()) +} + +fn require_level(codec: &str, level: Option) -> Result { + level.ok_or(ParquetError::General(format!("{} require level", codec))) +} + +impl FromStr for Compression { + type Err = ParquetError; + + fn from_str(s: &str) -> std::result::Result { + let (codec, level) = split_compression_string(s)?; + + let c = match codec { + "UNCOMPRESSED" | "uncompressed" => { + check_level_is_none(&level)?; + Compression::UNCOMPRESSED + } + "SNAPPY" | "snappy" => { + check_level_is_none(&level)?; + Compression::SNAPPY + } + "GZIP" | "gzip" => { + let level = require_level(codec, level)?; + Compression::GZIP(GzipLevel::try_new(level)?) + } + "LZO" | "lzo" => { + check_level_is_none(&level)?; + Compression::LZO + } + "BROTLI" | "brotli" => { + let level = require_level(codec, level)?; + Compression::BROTLI(BrotliLevel::try_new(level)?) + } + "LZ4" | "lz4" => { + check_level_is_none(&level)?; + Compression::LZ4 + } + "ZSTD" | "zstd" => { + let level = require_level(codec, level)?; + Compression::ZSTD(ZstdLevel::try_new(level as i32)?) + } + "LZ4_RAW" | "lz4_raw" => { + check_level_is_none(&level)?; + Compression::LZ4_RAW + } + _ => { + return Err(ParquetError::General(format!( + "unsupport compression {codec}" + ))); + } + }; + + Ok(c) + } +} + // ---------------------------------------------------------------------- // Mirrors `parquet::PageType` @@ -2130,4 +2238,81 @@ mod tests { ); assert_eq!(ColumnOrder::UNDEFINED.sort_order(), SortOrder::SIGNED); } + + #[test] + fn test_parse_encoding() { + let mut encoding: Encoding = "PLAIN".parse().unwrap(); + assert_eq!(encoding, Encoding::PLAIN); + encoding = "PLAIN_DICTIONARY".parse().unwrap(); + assert_eq!(encoding, Encoding::PLAIN_DICTIONARY); + encoding = "RLE".parse().unwrap(); + assert_eq!(encoding, Encoding::RLE); + encoding = "BIT_PACKED".parse().unwrap(); + assert_eq!(encoding, Encoding::BIT_PACKED); + encoding = "DELTA_BINARY_PACKED".parse().unwrap(); + assert_eq!(encoding, Encoding::DELTA_BINARY_PACKED); + encoding = "DELTA_LENGTH_BYTE_ARRAY".parse().unwrap(); + assert_eq!(encoding, Encoding::DELTA_LENGTH_BYTE_ARRAY); + encoding = "DELTA_BYTE_ARRAY".parse().unwrap(); + assert_eq!(encoding, Encoding::DELTA_BYTE_ARRAY); + encoding = "RLE_DICTIONARY".parse().unwrap(); + assert_eq!(encoding, Encoding::RLE_DICTIONARY); + encoding = "BYTE_STREAM_SPLIT".parse().unwrap(); + assert_eq!(encoding, Encoding::BYTE_STREAM_SPLIT); + + // test lowercase + encoding = "byte_stream_split".parse().unwrap(); + assert_eq!(encoding, Encoding::BYTE_STREAM_SPLIT); + + // test unknown string + match "plain_xxx".parse::() { + Ok(e) => { + panic!("Should not be able to parse {:?}", e); + } + Err(e) => { + assert_eq!(e.to_string(), "Parquet error: unknown encoding: plain_xxx"); + } + } + } + + #[test] + fn test_parse_compression() { + let mut compress: Compression = "snappy".parse().unwrap(); + assert_eq!(compress, Compression::SNAPPY); + compress = "lzo".parse().unwrap(); + assert_eq!(compress, Compression::LZO); + compress = "zstd(3)".parse().unwrap(); + assert_eq!(compress, Compression::ZSTD(ZstdLevel::try_new(3).unwrap())); + compress = "LZ4_RAW".parse().unwrap(); + assert_eq!(compress, Compression::LZ4_RAW); + compress = "uncompressed".parse().unwrap(); + assert_eq!(compress, Compression::UNCOMPRESSED); + compress = "snappy".parse().unwrap(); + assert_eq!(compress, Compression::SNAPPY); + compress = "gzip(9)".parse().unwrap(); + assert_eq!(compress, Compression::GZIP(GzipLevel::try_new(9).unwrap())); + compress = "lzo".parse().unwrap(); + assert_eq!(compress, Compression::LZO); + compress = "brotli(3)".parse().unwrap(); + assert_eq!( + compress, + Compression::BROTLI(BrotliLevel::try_new(3).unwrap()) + ); + compress = "lz4".parse().unwrap(); + assert_eq!(compress, Compression::LZ4); + + // test unknown compression + let mut err = "plain_xxx".parse::().unwrap_err(); + assert_eq!( + err.to_string(), + "Parquet error: unknown encoding: plain_xxx" + ); + + // test invalid compress level + err = "gzip(-10)".parse::().unwrap_err(); + assert_eq!( + err.to_string(), + "Parquet error: unknown encoding: gzip(-10)" + ); + } } diff --git a/parquet/src/file/properties.rs b/parquet/src/file/properties.rs index c83fea3f9b92..93b034cf4f60 100644 --- a/parquet/src/file/properties.rs +++ b/parquet/src/file/properties.rs @@ -16,6 +16,7 @@ // under the License. //! Configuration via [`WriterProperties`] and [`ReaderProperties`] +use std::str::FromStr; use std::{collections::HashMap, sync::Arc}; use crate::basic::{Compression, Encoding}; @@ -72,6 +73,18 @@ impl WriterVersion { } } +impl FromStr for WriterVersion { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "PARQUET_1_0" | "parquet_1_0" => Ok(WriterVersion::PARQUET_1_0), + "PARQUET_2_0" | "parquet_2_0" => Ok(WriterVersion::PARQUET_2_0), + _ => Err(format!("Invalid writer version: {}", s)), + } + } +} + /// Reference counted writer properties. pub type WriterPropertiesPtr = Arc; @@ -655,6 +668,19 @@ pub enum EnabledStatistics { Page, } +impl FromStr for EnabledStatistics { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "NONE" | "none" => Ok(EnabledStatistics::None), + "CHUNK" | "chunk" => Ok(EnabledStatistics::Chunk), + "PAGE" | "page" => Ok(EnabledStatistics::Page), + _ => Err(format!("Invalid statistics arg: {}", s)), + } + } +} + impl Default for EnabledStatistics { fn default() -> Self { DEFAULT_STATISTICS_ENABLED @@ -1182,4 +1208,46 @@ mod tests { assert_eq!(props.codec_options(), &codec_options); } + + #[test] + fn test_parse_writerversion() { + let mut writer_version = "PARQUET_1_0".parse::().unwrap(); + assert_eq!(writer_version, WriterVersion::PARQUET_1_0); + writer_version = "PARQUET_2_0".parse::().unwrap(); + assert_eq!(writer_version, WriterVersion::PARQUET_2_0); + + // test lowercase + writer_version = "parquet_1_0".parse::().unwrap(); + assert_eq!(writer_version, WriterVersion::PARQUET_1_0); + + // test invalid version + match "PARQUET_-1_0".parse::() { + Ok(_) => panic!("Should not be able to parse PARQUET_-1_0"), + Err(e) => { + assert_eq!(e, "Invalid writer version: PARQUET_-1_0"); + } + } + } + + #[test] + fn test_parse_enabledstatistics() { + let mut enabled_statistics = "NONE".parse::().unwrap(); + assert_eq!(enabled_statistics, EnabledStatistics::None); + enabled_statistics = "CHUNK".parse::().unwrap(); + assert_eq!(enabled_statistics, EnabledStatistics::Chunk); + enabled_statistics = "PAGE".parse::().unwrap(); + assert_eq!(enabled_statistics, EnabledStatistics::Page); + + // test lowercase + enabled_statistics = "none".parse::().unwrap(); + assert_eq!(enabled_statistics, EnabledStatistics::None); + + //test invalid statistics + match "ChunkAndPage".parse::() { + Ok(_) => panic!("Should not be able to parse ChunkAndPage"), + Err(e) => { + assert_eq!(e, "Invalid statistics arg: ChunkAndPage"); + } + } + } } From 4964d844313d5e62cf102616d26864dca6fe286e Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:18:52 +0100 Subject: [PATCH 14/25] Add `ReaderBuilder::with_header` for csv reader (#4949) * Add ReaderBuilder::with_header * Update test --- arrow-csv/examples/csv_calculation.rs | 2 +- arrow-csv/src/reader/mod.rs | 48 ++++++++++++++++----------- arrow/benches/csv_reader.rs | 2 +- parquet/src/bin/parquet-fromcsv.rs | 6 ++-- 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/arrow-csv/examples/csv_calculation.rs b/arrow-csv/examples/csv_calculation.rs index 12aaadde4415..6ce963e2b012 100644 --- a/arrow-csv/examples/csv_calculation.rs +++ b/arrow-csv/examples/csv_calculation.rs @@ -33,7 +33,7 @@ fn main() { Field::new("c4", DataType::Boolean, true), ]); let mut reader = ReaderBuilder::new(Arc::new(csv_schema)) - .has_header(true) + .with_header(true) .build(file) .unwrap(); diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 1106b16bc46f..a194b35ffa46 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -225,7 +225,7 @@ impl InferredDataType { /// The format specification for the CSV file #[derive(Debug, Clone, Default)] pub struct Format { - has_header: bool, + header: bool, delimiter: Option, escape: Option, quote: Option, @@ -235,7 +235,7 @@ pub struct Format { impl Format { pub fn with_header(mut self, has_header: bool) -> Self { - self.has_header = has_header; + self.header = has_header; self } @@ -280,7 +280,7 @@ impl Format { // get or create header names // when has_header is false, creates default column names with column_ prefix - let headers: Vec = if self.has_header { + let headers: Vec = if self.header { let headers = &csv_reader.headers().map_err(map_csv_error)?.clone(); headers.iter().map(|s| s.to_string()).collect() } else { @@ -331,7 +331,7 @@ impl Format { /// Build a [`csv::Reader`] for this [`Format`] fn build_reader(&self, reader: R) -> csv::Reader { let mut builder = csv::ReaderBuilder::new(); - builder.has_headers(self.has_header); + builder.has_headers(self.header); if let Some(c) = self.delimiter { builder.delimiter(c); @@ -403,7 +403,7 @@ pub fn infer_reader_schema( ) -> Result<(Schema, usize), ArrowError> { let format = Format { delimiter: Some(delimiter), - has_header, + header: has_header, ..Default::default() }; format.infer_schema(reader, max_read_records) @@ -425,7 +425,7 @@ pub fn infer_schema_from_files( let mut records_to_read = max_read_records.unwrap_or(usize::MAX); let format = Format { delimiter: Some(delimiter), - has_header, + header: has_header, ..Default::default() }; @@ -1095,8 +1095,16 @@ impl ReaderBuilder { } /// Set whether the CSV file has headers + #[deprecated(note = "Use with_header")] + #[doc(hidden)] pub fn has_header(mut self, has_header: bool) -> Self { - self.format.has_header = has_header; + self.format.header = has_header; + self + } + + /// Set whether the CSV file has a header + pub fn with_header(mut self, has_header: bool) -> Self { + self.format.header = has_header; self } @@ -1176,7 +1184,7 @@ impl ReaderBuilder { let delimiter = self.format.build_parser(); let record_decoder = RecordDecoder::new(delimiter, self.schema.fields().len()); - let header = self.format.has_header as usize; + let header = self.format.header as usize; let (start, end) = match self.bounds { Some((start, end)) => (start + header, end + header), @@ -1317,7 +1325,7 @@ mod tests { .chain(Cursor::new("\n".to_string())) .chain(file_without_headers); let mut csv = ReaderBuilder::new(Arc::new(schema)) - .has_header(true) + .with_header(true) .build(both_files) .unwrap(); let batch = csv.next().unwrap().unwrap(); @@ -1335,7 +1343,7 @@ mod tests { .unwrap(); file.rewind().unwrap(); - let builder = ReaderBuilder::new(Arc::new(schema)).has_header(true); + let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true); let mut csv = builder.build(file).unwrap(); let expected_schema = Schema::new(vec![ @@ -1505,7 +1513,7 @@ mod tests { let file = File::open("test/data/null_test.csv").unwrap(); let mut csv = ReaderBuilder::new(schema) - .has_header(true) + .with_header(true) .build(file) .unwrap(); @@ -1530,7 +1538,7 @@ mod tests { let file = File::open("test/data/init_null_test.csv").unwrap(); let mut csv = ReaderBuilder::new(schema) - .has_header(true) + .with_header(true) .build(file) .unwrap(); @@ -1588,7 +1596,7 @@ mod tests { let null_regex = Regex::new("^nil$").unwrap(); let mut csv = ReaderBuilder::new(schema) - .has_header(true) + .with_header(true) .with_null_regex(null_regex) .build(file) .unwrap(); @@ -1710,7 +1718,7 @@ mod tests { ]); let builder = ReaderBuilder::new(Arc::new(schema)) - .has_header(true) + .with_header(true) .with_delimiter(b'|') .with_batch_size(512) .with_projection(vec![0, 1, 2, 3]); @@ -2037,7 +2045,7 @@ mod tests { Field::new("text2", DataType::Utf8, false), ]); let builder = ReaderBuilder::new(Arc::new(schema)) - .has_header(false) + .with_header(false) .with_quote(b'~'); // default is ", change to ~ let mut csv_text = Vec::new(); @@ -2069,7 +2077,7 @@ mod tests { Field::new("text2", DataType::Utf8, false), ]); let builder = ReaderBuilder::new(Arc::new(schema)) - .has_header(false) + .with_header(false) .with_escape(b'\\'); // default is None, change to \ let mut csv_text = Vec::new(); @@ -2101,7 +2109,7 @@ mod tests { Field::new("text2", DataType::Utf8, false), ]); let builder = ReaderBuilder::new(Arc::new(schema)) - .has_header(false) + .with_header(false) .with_terminator(b'\n'); // default is CRLF, change to LF let mut csv_text = Vec::new(); @@ -2143,7 +2151,7 @@ mod tests { ])); for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() { - let mut reader = ReaderBuilder::new(schema.clone()).has_header(has_header); + let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header); if let Some((start, end)) = bounds { reader = reader.with_bounds(start, end); } @@ -2208,7 +2216,7 @@ mod tests { for capacity in [1, 3, 7, 100] { let reader = ReaderBuilder::new(schema.clone()) .with_batch_size(batch_size) - .has_header(has_header) + .with_header(has_header) .build(File::open(path).unwrap()) .unwrap(); @@ -2226,7 +2234,7 @@ mod tests { let reader = ReaderBuilder::new(schema.clone()) .with_batch_size(batch_size) - .has_header(has_header) + .with_header(has_header) .build_buffered(buffered) .unwrap(); diff --git a/arrow/benches/csv_reader.rs b/arrow/benches/csv_reader.rs index 4c3f663bf741..5a91dfe0a6ff 100644 --- a/arrow/benches/csv_reader.rs +++ b/arrow/benches/csv_reader.rs @@ -45,7 +45,7 @@ fn do_bench(c: &mut Criterion, name: &str, cols: Vec) { let cursor = Cursor::new(buf.as_slice()); let reader = csv::ReaderBuilder::new(batch.schema()) .with_batch_size(batch_size) - .has_header(true) + .with_header(true) .build_buffered(cursor) .unwrap(); diff --git a/parquet/src/bin/parquet-fromcsv.rs b/parquet/src/bin/parquet-fromcsv.rs index 548bbdbfb8f1..1f5d0a62bbfa 100644 --- a/parquet/src/bin/parquet-fromcsv.rs +++ b/parquet/src/bin/parquet-fromcsv.rs @@ -321,7 +321,7 @@ fn configure_reader_builder(args: &Args, arrow_schema: Arc) -> ReaderBui let mut builder = ReaderBuilder::new(arrow_schema) .with_batch_size(args.batch_size) - .has_header(args.has_header) + .with_header(args.has_header) .with_delimiter(args.get_delimiter()); builder = configure_reader( @@ -606,7 +606,7 @@ mod tests { let reader_builder = configure_reader_builder(&args, arrow_schema); let builder_debug = format!("{reader_builder:?}"); - assert_debug_text(&builder_debug, "has_header", "false"); + assert_debug_text(&builder_debug, "header", "false"); assert_debug_text(&builder_debug, "delimiter", "Some(44)"); assert_debug_text(&builder_debug, "quote", "Some(34)"); assert_debug_text(&builder_debug, "terminator", "None"); @@ -641,7 +641,7 @@ mod tests { ])); let reader_builder = configure_reader_builder(&args, arrow_schema); let builder_debug = format!("{reader_builder:?}"); - assert_debug_text(&builder_debug, "has_header", "true"); + assert_debug_text(&builder_debug, "header", "true"); assert_debug_text(&builder_debug, "delimiter", "Some(9)"); assert_debug_text(&builder_debug, "quote", "None"); assert_debug_text(&builder_debug, "terminator", "Some(10)"); From 6e332b8f570d53bdc906159a97b2c5f95db670e5 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:20:57 +0100 Subject: [PATCH 15/25] Prepare arrow 48.0.0 (#4948) --- CHANGELOG-old.md | 69 ++++++++++++++++++ CHANGELOG.md | 117 +++++++++++++++++-------------- Cargo.toml | 32 ++++----- dev/release/update_change_log.sh | 4 +- 4 files changed, 151 insertions(+), 71 deletions(-) diff --git a/CHANGELOG-old.md b/CHANGELOG-old.md index bac7847bdac5..cde9b8f3b521 100644 --- a/CHANGELOG-old.md +++ b/CHANGELOG-old.md @@ -19,6 +19,75 @@ # Historical Changelog +## [47.0.0](https://github.com/apache/arrow-rs/tree/47.0.0) (2023-09-19) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/46.0.0...47.0.0) + +**Breaking changes:** + +- Make FixedSizeBinaryArray value\_data return a reference [\#4820](https://github.com/apache/arrow-rs/issues/4820) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update prost to v0.12.1 [\#4825](https://github.com/apache/arrow-rs/pull/4825) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- feat: FixedSizeBinaryArray::value\_data return reference [\#4821](https://github.com/apache/arrow-rs/pull/4821) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Stateless Row Encoding / Don't Preserve Dictionaries in `RowConverter` \(\#4811\) [\#4819](https://github.com/apache/arrow-rs/pull/4819) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- fix: entries field is non-nullable [\#4808](https://github.com/apache/arrow-rs/pull/4808) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Fix flight sql do put handling, add bind parameter support to FlightSQL cli client [\#4797](https://github.com/apache/arrow-rs/pull/4797) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([suremarc](https://github.com/suremarc)) +- Remove unused dyn\_cmp\_dict feature [\#4766](https://github.com/apache/arrow-rs/pull/4766) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add underlying `std::io::Error` to `IoError` and add `IpcError` variant [\#4726](https://github.com/apache/arrow-rs/pull/4726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alexandreyc](https://github.com/alexandreyc)) + +**Implemented enhancements:** + +- Row Format Adapative Block Size [\#4812](https://github.com/apache/arrow-rs/issues/4812) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Stateless Row Conversion [\#4811](https://github.com/apache/arrow-rs/issues/4811) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Add option to specify custom null values for CSV reader [\#4794](https://github.com/apache/arrow-rs/issues/4794) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet::record::RowIter cannot be customized with batch\_size and defaults to 1024 [\#4782](https://github.com/apache/arrow-rs/issues/4782) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `DynScalar` abstraction \(something that makes it easy to create scalar `Datum`s\) [\#4781](https://github.com/apache/arrow-rs/issues/4781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `Datum` is not exported as part of `arrow` \(it is only exported in `arrow_array`\) [\#4780](https://github.com/apache/arrow-rs/issues/4780) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `Scalar` is not exported as part of `arrow` \(it is only exported in `arrow_array`\) [\#4779](https://github.com/apache/arrow-rs/issues/4779) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support IntoPyArrow for impl RecordBatchReader [\#4730](https://github.com/apache/arrow-rs/issues/4730) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Datum Based String Kernels [\#4595](https://github.com/apache/arrow-rs/issues/4595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- MapArray::new\_from\_strings creates nullable entries field [\#4807](https://github.com/apache/arrow-rs/issues/4807) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- pyarrow module can't roundtrip tensor arrays [\#4805](https://github.com/apache/arrow-rs/issues/4805) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `concat_batches` errors with "schema mismatch" error when only metadata differs [\#4799](https://github.com/apache/arrow-rs/issues/4799) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- panic in `cmp` kernels with DictionaryArrays: `Option::unwrap()` on a `None` value' [\#4788](https://github.com/apache/arrow-rs/issues/4788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- stream ffi panics if schema metadata values aren't valid utf8 [\#4750](https://github.com/apache/arrow-rs/issues/4750) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Regression: Incorrect Sorting of `*ListArray` in 46.0.0 [\#4746](https://github.com/apache/arrow-rs/issues/4746) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Row is no longer comparable after reuse [\#4741](https://github.com/apache/arrow-rs/issues/4741) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- DoPut FlightSQL handler inadvertently consumes schema at start of Request\\> [\#4658](https://github.com/apache/arrow-rs/issues/4658) +- Return error when converting schema [\#4752](https://github.com/apache/arrow-rs/pull/4752) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Implement PyArrowType for `Box` [\#4751](https://github.com/apache/arrow-rs/pull/4751) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) + +**Closed issues:** + +- Building arrow-rust for target wasm32-wasi falied to compile packed\_simd\_2 [\#4717](https://github.com/apache/arrow-rs/issues/4717) + +**Merged pull requests:** + +- Respect FormatOption::nulls for NullArray [\#4836](https://github.com/apache/arrow-rs/pull/4836) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix merge\_dictionary\_values in selection kernels [\#4833](https://github.com/apache/arrow-rs/pull/4833) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix like scalar null [\#4832](https://github.com/apache/arrow-rs/pull/4832) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- More chrono deprecations [\#4822](https://github.com/apache/arrow-rs/pull/4822) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Adaptive Row Block Size \(\#4812\) [\#4818](https://github.com/apache/arrow-rs/pull/4818) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.66 to =1.0.67 [\#4816](https://github.com/apache/arrow-rs/pull/4816) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Do not check schema for equality in concat\_batches [\#4815](https://github.com/apache/arrow-rs/pull/4815) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: export record batch through stream [\#4806](https://github.com/apache/arrow-rs/pull/4806) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Improve CSV Reader Benchmark Coverage of Small Primitives [\#4803](https://github.com/apache/arrow-rs/pull/4803) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- csv: Add option to specify custom null values [\#4795](https://github.com/apache/arrow-rs/pull/4795) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([vrongmeal](https://github.com/vrongmeal)) +- Expand docstring and add example to `Scalar` [\#4793](https://github.com/apache/arrow-rs/pull/4793) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Re-export array crate root \(\#4780\) \(\#4779\) [\#4791](https://github.com/apache/arrow-rs/pull/4791) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix DictionaryArray::normalized\_keys \(\#4788\) [\#4789](https://github.com/apache/arrow-rs/pull/4789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Allow custom tree builder for parquet::record::RowIter [\#4783](https://github.com/apache/arrow-rs/pull/4783) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([YuraKotov](https://github.com/YuraKotov)) +- Bump actions/checkout from 3 to 4 [\#4767](https://github.com/apache/arrow-rs/pull/4767) ([dependabot[bot]](https://github.com/apps/dependabot)) +- fix: avoid panic if offset index not exists. [\#4761](https://github.com/apache/arrow-rs/pull/4761) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([RinChanNOWWW](https://github.com/RinChanNOWWW)) +- Relax constraints on PyArrowType [\#4757](https://github.com/apache/arrow-rs/pull/4757) ([tustvold](https://github.com/tustvold)) +- Chrono deprecations [\#4748](https://github.com/apache/arrow-rs/pull/4748) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix List Sorting, Revert Removal of Rank Kernels [\#4747](https://github.com/apache/arrow-rs/pull/4747) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Clear row buffer before reuse [\#4742](https://github.com/apache/arrow-rs/pull/4742) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) +- Datum based like kernels \(\#4595\) [\#4732](https://github.com/apache/arrow-rs/pull/4732) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- feat: expose DoGet response headers & trailers [\#4727](https://github.com/apache/arrow-rs/pull/4727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Cleanup length and bit\_length kernels [\#4718](https://github.com/apache/arrow-rs/pull/4718) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) ## [46.0.0](https://github.com/apache/arrow-rs/tree/46.0.0) (2023-08-21) [Full Changelog](https://github.com/apache/arrow-rs/compare/45.0.0...46.0.0) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f97055a9c0d..8c5351708c0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,75 +19,86 @@ # Changelog -## [47.0.0](https://github.com/apache/arrow-rs/tree/47.0.0) (2023-09-19) +## [48.0.0](https://github.com/apache/arrow-rs/tree/48.0.0) (2023-10-18) -[Full Changelog](https://github.com/apache/arrow-rs/compare/46.0.0...47.0.0) +[Full Changelog](https://github.com/apache/arrow-rs/compare/47.0.0...48.0.0) **Breaking changes:** -- Make FixedSizeBinaryArray value\_data return a reference [\#4820](https://github.com/apache/arrow-rs/issues/4820) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Update prost to v0.12.1 [\#4825](https://github.com/apache/arrow-rs/pull/4825) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) -- feat: FixedSizeBinaryArray::value\_data return reference [\#4821](https://github.com/apache/arrow-rs/pull/4821) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) -- Stateless Row Encoding / Don't Preserve Dictionaries in `RowConverter` \(\#4811\) [\#4819](https://github.com/apache/arrow-rs/pull/4819) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) -- fix: entries field is non-nullable [\#4808](https://github.com/apache/arrow-rs/pull/4808) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) -- Fix flight sql do put handling, add bind parameter support to FlightSQL cli client [\#4797](https://github.com/apache/arrow-rs/pull/4797) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([suremarc](https://github.com/suremarc)) -- Remove unused dyn\_cmp\_dict feature [\#4766](https://github.com/apache/arrow-rs/pull/4766) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Add underlying `std::io::Error` to `IoError` and add `IpcError` variant [\#4726](https://github.com/apache/arrow-rs/pull/4726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alexandreyc](https://github.com/alexandreyc)) +- Evaluate null\_regex for string type in csv \(now such values will be parsed as `Null` rather than `""`\) [\#4942](https://github.com/apache/arrow-rs/pull/4942) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([haohuaijin](https://github.com/haohuaijin)) +- fix\(csv\)!: infer null for empty column. [\#4910](https://github.com/apache/arrow-rs/pull/4910) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- feat: log headers/trailers in flight CLI \(+ minor fixes\) [\#4898](https://github.com/apache/arrow-rs/pull/4898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- fix\(arrow-json\)!: include null fields in schema inference with a type of Null [\#4894](https://github.com/apache/arrow-rs/pull/4894) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- Mark OnCloseRowGroup Send [\#4893](https://github.com/apache/arrow-rs/pull/4893) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([devinjdangelo](https://github.com/devinjdangelo)) +- Specialize Thrift Decoding \(~40% Faster\) \(\#4891\) [\#4892](https://github.com/apache/arrow-rs/pull/4892) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make ArrowRowGroupWriter Public and SerializedRowGroupWriter Send [\#4850](https://github.com/apache/arrow-rs/pull/4850) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([devinjdangelo](https://github.com/devinjdangelo)) **Implemented enhancements:** -- Row Format Adapative Block Size [\#4812](https://github.com/apache/arrow-rs/issues/4812) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Stateless Row Conversion [\#4811](https://github.com/apache/arrow-rs/issues/4811) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] -- Add option to specify custom null values for CSV reader [\#4794](https://github.com/apache/arrow-rs/issues/4794) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- parquet::record::RowIter cannot be customized with batch\_size and defaults to 1024 [\#4782](https://github.com/apache/arrow-rs/issues/4782) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `DynScalar` abstraction \(something that makes it easy to create scalar `Datum`s\) [\#4781](https://github.com/apache/arrow-rs/issues/4781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- `Datum` is not exported as part of `arrow` \(it is only exported in `arrow_array`\) [\#4780](https://github.com/apache/arrow-rs/issues/4780) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- `Scalar` is not exported as part of `arrow` \(it is only exported in `arrow_array`\) [\#4779](https://github.com/apache/arrow-rs/issues/4779) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support IntoPyArrow for impl RecordBatchReader [\#4730](https://github.com/apache/arrow-rs/issues/4730) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Datum Based String Kernels [\#4595](https://github.com/apache/arrow-rs/issues/4595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Allow schema fields to merge with `Null` datatype [\#4901](https://github.com/apache/arrow-rs/issues/4901) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add option to FlightDataEncoder to always send dictionaries [\#4895](https://github.com/apache/arrow-rs/issues/4895) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Rework Thrift Encoding / Decoding of Parquet Metadata [\#4891](https://github.com/apache/arrow-rs/issues/4891) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Plans for supporting Extension Array to support Fixed shape tensor Array [\#4890](https://github.com/apache/arrow-rs/issues/4890) +- Implement Take for UnionArray [\#4882](https://github.com/apache/arrow-rs/issues/4882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Check precision overflow for casting floating to decimal [\#4865](https://github.com/apache/arrow-rs/issues/4865) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace lexical [\#4774](https://github.com/apache/arrow-rs/issues/4774) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add read access to settings in `csv::WriterBuilder` [\#4735](https://github.com/apache/arrow-rs/issues/4735) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve the performance of "DictionaryValue" row encoding [\#4712](https://github.com/apache/arrow-rs/issues/4712) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] **Fixed bugs:** -- MapArray::new\_from\_strings creates nullable entries field [\#4807](https://github.com/apache/arrow-rs/issues/4807) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- pyarrow module can't roundtrip tensor arrays [\#4805](https://github.com/apache/arrow-rs/issues/4805) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- `concat_batches` errors with "schema mismatch" error when only metadata differs [\#4799](https://github.com/apache/arrow-rs/issues/4799) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- panic in `cmp` kernels with DictionaryArrays: `Option::unwrap()` on a `None` value' [\#4788](https://github.com/apache/arrow-rs/issues/4788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- stream ffi panics if schema metadata values aren't valid utf8 [\#4750](https://github.com/apache/arrow-rs/issues/4750) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Regression: Incorrect Sorting of `*ListArray` in 46.0.0 [\#4746](https://github.com/apache/arrow-rs/issues/4746) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Row is no longer comparable after reuse [\#4741](https://github.com/apache/arrow-rs/issues/4741) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- DoPut FlightSQL handler inadvertently consumes schema at start of Request\\> [\#4658](https://github.com/apache/arrow-rs/issues/4658) -- Return error when converting schema [\#4752](https://github.com/apache/arrow-rs/pull/4752) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) -- Implement PyArrowType for `Box` [\#4751](https://github.com/apache/arrow-rs/pull/4751) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Should we make blank values and empty string to `None` in csv? [\#4939](https://github.com/apache/arrow-rs/issues/4939) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[FlightSQL\] SubstraitPlan structure is not exported [\#4932](https://github.com/apache/arrow-rs/issues/4932) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Loading page index breaks skipping of pages with nested types [\#4921](https://github.com/apache/arrow-rs/issues/4921) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- CSV schema inference assumes `Utf8` for empty columns [\#4903](https://github.com/apache/arrow-rs/issues/4903) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet: Field Ids are not read from a Parquet file without serialized arrow schema [\#4877](https://github.com/apache/arrow-rs/issues/4877) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- make\_primitive\_scalar function loses DataType Internal information [\#4851](https://github.com/apache/arrow-rs/issues/4851) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- StructBuilder doesn't handle nulls correctly for empty structs [\#4842](https://github.com/apache/arrow-rs/issues/4842) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `NullArray::is_null()` returns `false` incorrectly [\#4835](https://github.com/apache/arrow-rs/issues/4835) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- cast\_string\_to\_decimal should check precision overflow [\#4829](https://github.com/apache/arrow-rs/issues/4829) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Null fields are omitted by `infer_json_schema_from_seekable` [\#4814](https://github.com/apache/arrow-rs/issues/4814) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Closed issues:** -- Building arrow-rust for target wasm32-wasi falied to compile packed\_simd\_2 [\#4717](https://github.com/apache/arrow-rs/issues/4717) +- Support for reading JSON Array to Arrow [\#4905](https://github.com/apache/arrow-rs/issues/4905) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Merged pull requests:** -- Respect FormatOption::nulls for NullArray [\#4836](https://github.com/apache/arrow-rs/pull/4836) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Fix merge\_dictionary\_values in selection kernels [\#4833](https://github.com/apache/arrow-rs/pull/4833) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Fix like scalar null [\#4832](https://github.com/apache/arrow-rs/pull/4832) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- More chrono deprecations [\#4822](https://github.com/apache/arrow-rs/pull/4822) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Adaptive Row Block Size \(\#4812\) [\#4818](https://github.com/apache/arrow-rs/pull/4818) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Update proc-macro2 requirement from =1.0.66 to =1.0.67 [\#4816](https://github.com/apache/arrow-rs/pull/4816) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) -- Do not check schema for equality in concat\_batches [\#4815](https://github.com/apache/arrow-rs/pull/4815) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- fix: export record batch through stream [\#4806](https://github.com/apache/arrow-rs/pull/4806) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) -- Improve CSV Reader Benchmark Coverage of Small Primitives [\#4803](https://github.com/apache/arrow-rs/pull/4803) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- csv: Add option to specify custom null values [\#4795](https://github.com/apache/arrow-rs/pull/4795) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([vrongmeal](https://github.com/vrongmeal)) -- Expand docstring and add example to `Scalar` [\#4793](https://github.com/apache/arrow-rs/pull/4793) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Re-export array crate root \(\#4780\) \(\#4779\) [\#4791](https://github.com/apache/arrow-rs/pull/4791) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Fix DictionaryArray::normalized\_keys \(\#4788\) [\#4789](https://github.com/apache/arrow-rs/pull/4789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Allow custom tree builder for parquet::record::RowIter [\#4783](https://github.com/apache/arrow-rs/pull/4783) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([YuraKotov](https://github.com/YuraKotov)) -- Bump actions/checkout from 3 to 4 [\#4767](https://github.com/apache/arrow-rs/pull/4767) ([dependabot[bot]](https://github.com/apps/dependabot)) -- fix: avoid panic if offset index not exists. [\#4761](https://github.com/apache/arrow-rs/pull/4761) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([RinChanNOWWW](https://github.com/RinChanNOWWW)) -- Relax constraints on PyArrowType [\#4757](https://github.com/apache/arrow-rs/pull/4757) ([tustvold](https://github.com/tustvold)) -- Chrono deprecations [\#4748](https://github.com/apache/arrow-rs/pull/4748) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Fix List Sorting, Revert Removal of Rank Kernels [\#4747](https://github.com/apache/arrow-rs/pull/4747) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Clear row buffer before reuse [\#4742](https://github.com/apache/arrow-rs/pull/4742) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) -- Datum based like kernels \(\#4595\) [\#4732](https://github.com/apache/arrow-rs/pull/4732) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) -- feat: expose DoGet response headers & trailers [\#4727](https://github.com/apache/arrow-rs/pull/4727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) -- Cleanup length and bit\_length kernels [\#4718](https://github.com/apache/arrow-rs/pull/4718) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Assume Pages Delimit Records When Offset Index Loaded \(\#4921\) [\#4943](https://github.com/apache/arrow-rs/pull/4943) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update pyo3 requirement from 0.19 to 0.20 [\#4941](https://github.com/apache/arrow-rs/pull/4941) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Add `FileWriter` schema getter [\#4940](https://github.com/apache/arrow-rs/pull/4940) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([haixuanTao](https://github.com/haixuanTao)) +- feat: support parsing for parquet writer option [\#4938](https://github.com/apache/arrow-rs/pull/4938) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([fansehep](https://github.com/fansehep)) +- Export `SubstraitPlan` structure in arrow\_flight::sql \(\#4932\) [\#4933](https://github.com/apache/arrow-rs/pull/4933) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([amartins23](https://github.com/amartins23)) +- Update zstd requirement from 0.12.0 to 0.13.0 [\#4923](https://github.com/apache/arrow-rs/pull/4923) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: add method for async read bloom filter [\#4917](https://github.com/apache/arrow-rs/pull/4917) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([hengfeiyang](https://github.com/hengfeiyang)) +- Minor: Clarify rationale for `FlightDataEncoder` API, add examples [\#4916](https://github.com/apache/arrow-rs/pull/4916) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Update regex-syntax requirement from 0.7.1 to 0.8.0 [\#4914](https://github.com/apache/arrow-rs/pull/4914) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: document & streamline flight SQL CLI [\#4912](https://github.com/apache/arrow-rs/pull/4912) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Support Arbitrary JSON values in JSON Reader \(\#4905\) [\#4911](https://github.com/apache/arrow-rs/pull/4911) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup CSV WriterBuilder, Default to AutoSI Second Precision \(\#4735\) [\#4909](https://github.com/apache/arrow-rs/pull/4909) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.68 to =1.0.69 [\#4907](https://github.com/apache/arrow-rs/pull/4907) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- chore: add csv example [\#4904](https://github.com/apache/arrow-rs/pull/4904) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fansehep](https://github.com/fansehep)) +- feat\(schema\): allow null fields to be merged with other datatypes [\#4902](https://github.com/apache/arrow-rs/pull/4902) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- Update proc-macro2 requirement from =1.0.67 to =1.0.68 [\#4900](https://github.com/apache/arrow-rs/pull/4900) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add option to `FlightDataEncoder` to always resend batch dictionaries [\#4896](https://github.com/apache/arrow-rs/pull/4896) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) +- Fix integration tests [\#4889](https://github.com/apache/arrow-rs/pull/4889) ([tustvold](https://github.com/tustvold)) +- Support Parsing Avro File Headers [\#4888](https://github.com/apache/arrow-rs/pull/4888) ([tustvold](https://github.com/tustvold)) +- Support parquet bloom filter length [\#4885](https://github.com/apache/arrow-rs/pull/4885) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([letian-jiang](https://github.com/letian-jiang)) +- Replace lz4 with lz4\_flex Allowing Compilation for WASM [\#4884](https://github.com/apache/arrow-rs/pull/4884) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement Take for UnionArray [\#4883](https://github.com/apache/arrow-rs/pull/4883) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Update tonic-build requirement from =0.10.1 to =0.10.2 [\#4881](https://github.com/apache/arrow-rs/pull/4881) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- parquet: Read field IDs from Parquet Schema [\#4878](https://github.com/apache/arrow-rs/pull/4878) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Samrose-Ahmed](https://github.com/Samrose-Ahmed)) +- feat: improve flight CLI error handling [\#4873](https://github.com/apache/arrow-rs/pull/4873) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Support Encoding Parquet Columns in Parallel [\#4871](https://github.com/apache/arrow-rs/pull/4871) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Check precision overflow for casting floating to decimal [\#4866](https://github.com/apache/arrow-rs/pull/4866) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Make align\_buffers as public API [\#4863](https://github.com/apache/arrow-rs/pull/4863) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Enable new integration tests \(\#4828\) [\#4862](https://github.com/apache/arrow-rs/pull/4862) ([tustvold](https://github.com/tustvold)) +- Faster Serde Integration \(~80% faster\) [\#4861](https://github.com/apache/arrow-rs/pull/4861) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix: make\_primitive\_scalar bug [\#4852](https://github.com/apache/arrow-rs/pull/4852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JasonLi-cn](https://github.com/JasonLi-cn)) +- Update tonic-build requirement from =0.10.0 to =0.10.1 [\#4846](https://github.com/apache/arrow-rs/pull/4846) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Allow Constructing Non-Empty StructArray with no Fields \(\#4842\) [\#4845](https://github.com/apache/arrow-rs/pull/4845) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Refine documentation to `Array::is_null` [\#4838](https://github.com/apache/arrow-rs/pull/4838) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: add missing precision overflow checking for `cast_string_to_decimal` [\#4830](https://github.com/apache/arrow-rs/pull/4830) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jonahgao](https://github.com/jonahgao)) diff --git a/Cargo.toml b/Cargo.toml index d874e335eeae..d59a5af68a19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,7 +62,7 @@ exclude = [ ] [workspace.package] -version = "47.0.0" +version = "48.0.0" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] @@ -77,20 +77,20 @@ edition = "2021" rust-version = "1.62" [workspace.dependencies] -arrow = { version = "47.0.0", path = "./arrow", default-features = false } -arrow-arith = { version = "47.0.0", path = "./arrow-arith" } -arrow-array = { version = "47.0.0", path = "./arrow-array" } -arrow-buffer = { version = "47.0.0", path = "./arrow-buffer" } -arrow-cast = { version = "47.0.0", path = "./arrow-cast" } -arrow-csv = { version = "47.0.0", path = "./arrow-csv" } -arrow-data = { version = "47.0.0", path = "./arrow-data" } -arrow-ipc = { version = "47.0.0", path = "./arrow-ipc" } -arrow-json = { version = "47.0.0", path = "./arrow-json" } -arrow-ord = { version = "47.0.0", path = "./arrow-ord" } -arrow-row = { version = "47.0.0", path = "./arrow-row" } -arrow-schema = { version = "47.0.0", path = "./arrow-schema" } -arrow-select = { version = "47.0.0", path = "./arrow-select" } -arrow-string = { version = "47.0.0", path = "./arrow-string" } -parquet = { version = "47.0.0", path = "./parquet", default-features = false } +arrow = { version = "48.0.0", path = "./arrow", default-features = false } +arrow-arith = { version = "48.0.0", path = "./arrow-arith" } +arrow-array = { version = "48.0.0", path = "./arrow-array" } +arrow-buffer = { version = "48.0.0", path = "./arrow-buffer" } +arrow-cast = { version = "48.0.0", path = "./arrow-cast" } +arrow-csv = { version = "48.0.0", path = "./arrow-csv" } +arrow-data = { version = "48.0.0", path = "./arrow-data" } +arrow-ipc = { version = "48.0.0", path = "./arrow-ipc" } +arrow-json = { version = "48.0.0", path = "./arrow-json" } +arrow-ord = { version = "48.0.0", path = "./arrow-ord" } +arrow-row = { version = "48.0.0", path = "./arrow-row" } +arrow-schema = { version = "48.0.0", path = "./arrow-schema" } +arrow-select = { version = "48.0.0", path = "./arrow-select" } +arrow-string = { version = "48.0.0", path = "./arrow-string" } +parquet = { version = "48.0.0", path = "./parquet", default-features = false } chrono = { version = "0.4.31", default-features = false, features = ["clock"] } diff --git a/dev/release/update_change_log.sh b/dev/release/update_change_log.sh index 74bbb4ac1e8d..c1627ebb8cf2 100755 --- a/dev/release/update_change_log.sh +++ b/dev/release/update_change_log.sh @@ -29,8 +29,8 @@ set -e -SINCE_TAG="46.0.0" -FUTURE_RELEASE="47.0.0" +SINCE_TAG="47.0.0" +FUTURE_RELEASE="48.0.0" SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" From 51ac6fec8755147cd6b1dfe7d76bfdcfacad0463 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:52:36 +0100 Subject: [PATCH 16/25] Respect ARROW_TEST_DATA in apache-avro tests (#4950) --- arrow-avro/src/lib.rs | 10 ++++++++++ arrow-avro/src/reader/header.rs | 5 +++-- arrow-avro/src/reader/mod.rs | 3 ++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index e134d9d798f2..c76ecb399a45 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -26,3 +26,13 @@ pub mod reader; mod schema; mod compression; + +#[cfg(test)] +mod test_util { + pub fn arrow_test_data(path: &str) -> String { + match std::env::var("ARROW_TEST_DATA") { + Ok(dir) => format!("{dir}/{path}"), + Err(_) => format!("../testing/data/{path}"), + } + } +} diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index 92db8b1dc76d..2d443175a7aa 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -240,6 +240,7 @@ mod test { use super::*; use crate::reader::read_header; use crate::schema::SCHEMA_METADATA_KEY; + use crate::test_util::arrow_test_data; use std::fs::File; use std::io::{BufRead, BufReader}; @@ -266,7 +267,7 @@ mod test { #[test] fn test_header() { - let header = decode_file("../testing/data/avro/alltypes_plain.avro"); + let header = decode_file(&arrow_test_data("avro/alltypes_plain.avro")); let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap(); let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"id","type":["int","null"]},{"name":"bool_col","type":["boolean","null"]},{"name":"tinyint_col","type":["int","null"]},{"name":"smallint_col","type":["int","null"]},{"name":"int_col","type":["int","null"]},{"name":"bigint_col","type":["long","null"]},{"name":"float_col","type":["float","null"]},{"name":"double_col","type":["double","null"]},{"name":"date_string_col","type":["bytes","null"]},{"name":"string_col","type":["bytes","null"]},{"name":"timestamp_col","type":[{"type":"long","logicalType":"timestamp-micros"},"null"]}]}"#; assert_eq!(schema_json, expected); @@ -276,7 +277,7 @@ mod test { 226966037233754408753420635932530907102 ); - let header = decode_file("../testing/data/avro/fixed_length_decimal.avro"); + let header = decode_file(&arrow_test_data("avro/fixed_length_decimal.avro")); let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap(); let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"value","type":[{"type":"fixed","name":"fixed","namespace":"topLevelRecord.value","size":11,"logicalType":"decimal","precision":25,"scale":2},"null"]}]}"#; assert_eq!(schema_json, expected); diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index a42011e3b2ad..91e2dbf9835b 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -76,12 +76,13 @@ fn read_blocks( #[cfg(test)] mod test { use crate::reader::{read_blocks, read_header}; + use crate::test_util::arrow_test_data; use std::fs::File; use std::io::BufReader; #[test] fn test_mux() { - let file = File::open("../testing/data/avro/alltypes_plain.avro").unwrap(); + let file = File::open(arrow_test_data("avro/alltypes_plain.avro")).unwrap(); let mut reader = BufReader::new(file); let header = read_header(&mut reader).unwrap(); for result in read_blocks(reader) { From 4cca0291441fe622f13db6724f8bc3efb1a31b5b Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Thu, 19 Oct 2023 09:44:46 +0100 Subject: [PATCH 17/25] Return `PutResult` with an ETag from ObjectStore::put (#4934) (#4944) * Return ETag from ObjectStore::put (#4934) * Further tests * Clippy * Review feedback --- object_store/src/aws/client.rs | 12 +++- object_store/src/aws/mod.rs | 25 ++------ object_store/src/azure/mod.rs | 20 ++++--- object_store/src/chunked.rs | 3 +- object_store/src/client/header.rs | 17 +++--- object_store/src/gcp/mod.rs | 87 +++++++++++----------------- object_store/src/http/client.rs | 4 +- object_store/src/http/mod.rs | 13 ++++- object_store/src/lib.rs | 35 ++++++++++- object_store/src/limit.rs | 4 +- object_store/src/local.rs | 43 ++++++++++---- object_store/src/memory.rs | 14 +++-- object_store/src/prefix.rs | 5 +- object_store/src/throttle.rs | 5 +- object_store/tests/get_range_file.rs | 4 +- 15 files changed, 169 insertions(+), 122 deletions(-) diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index 8a45a9f3ac47..eb81e92fb932 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -21,6 +21,7 @@ use crate::aws::{ AwsCredentialProvider, S3CopyIfNotExists, STORE, STRICT_PATH_ENCODE_SET, }; use crate::client::get::GetClient; +use crate::client::header::get_etag; use crate::client::list::ListClient; use crate::client::list_response::ListResponse; use crate::client::retry::RetryExt; @@ -122,6 +123,11 @@ pub(crate) enum Error { #[snafu(display("Got invalid multipart response: {}", source))] InvalidMultipartResponse { source: quick_xml::de::DeError }, + + #[snafu(display("Unable to extract metadata from headers: {}", source))] + Metadata { + source: crate::client::header::Error, + }, } impl From for crate::Error { @@ -243,12 +249,14 @@ impl S3Client { } /// Make an S3 PUT request + /// + /// Returns the ETag pub async fn put_request( &self, path: &Path, bytes: Bytes, query: &T, - ) -> Result { + ) -> Result { let credential = self.get_credential().await?; let url = self.config.path_url(path); let mut builder = self.client.request(Method::PUT, url); @@ -287,7 +295,7 @@ impl S3Client { path: path.as_ref(), })?; - Ok(response) + Ok(get_etag(response.headers()).context(MetadataSnafu)?) } /// Make an S3 Delete request diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index d3c50861c122..6d5aecea2d17 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -59,7 +59,7 @@ use crate::multipart::{PartId, PutPart, WriteMultiPart}; use crate::signer::Signer; use crate::{ ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, - ObjectStore, Path, Result, RetryConfig, + ObjectStore, Path, PutResult, Result, RetryConfig, }; mod checksum; @@ -109,12 +109,6 @@ enum Error { #[snafu(display("Missing SecretAccessKey"))] MissingSecretAccessKey, - #[snafu(display("ETag Header missing from response"))] - MissingEtag, - - #[snafu(display("Received header containing non-ASCII data"))] - BadHeader { source: reqwest::header::ToStrError }, - #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] UnableToParseUrl { source: url::ParseError, @@ -273,9 +267,9 @@ impl Signer for AmazonS3 { #[async_trait] impl ObjectStore for AmazonS3 { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - self.client.put_request(location, bytes, &()).await?; - Ok(()) + async fn put(&self, location: &Path, bytes: Bytes) -> Result { + let e_tag = self.client.put_request(location, bytes, &()).await?; + Ok(PutResult { e_tag: Some(e_tag) }) } async fn put_multipart( @@ -365,10 +359,9 @@ struct S3MultiPartUpload { #[async_trait] impl PutPart for S3MultiPartUpload { async fn put_part(&self, buf: Vec, part_idx: usize) -> Result { - use reqwest::header::ETAG; let part = (part_idx + 1).to_string(); - let response = self + let content_id = self .client .put_request( &self.location, @@ -377,13 +370,7 @@ impl PutPart for S3MultiPartUpload { ) .await?; - let etag = response.headers().get(ETAG).context(MissingEtagSnafu)?; - - let etag = etag.to_str().context(BadHeaderSnafu)?; - - Ok(PartId { - content_id: etag.to_string(), - }) + Ok(PartId { content_id }) } async fn complete(&self, completed_parts: Vec) -> Result<()> { diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 2a08c6775807..0e638efc399f 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -31,7 +31,7 @@ use crate::{ multipart::{PartId, PutPart, WriteMultiPart}, path::Path, ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, - ObjectStore, Result, RetryConfig, + ObjectStore, PutResult, Result, RetryConfig, }; use async_trait::async_trait; use base64::prelude::BASE64_STANDARD; @@ -62,6 +62,7 @@ mod credential; /// [`CredentialProvider`] for [`MicrosoftAzure`] pub type AzureCredentialProvider = Arc>; +use crate::client::header::get_etag; pub use credential::AzureCredential; const STORE: &str = "MicrosoftAzure"; @@ -81,9 +82,6 @@ const MSI_ENDPOINT_ENV_KEY: &str = "IDENTITY_ENDPOINT"; #[derive(Debug, Snafu)] #[allow(missing_docs)] enum Error { - #[snafu(display("Received header containing non-ASCII data"))] - BadHeader { source: reqwest::header::ToStrError }, - #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] UnableToParseUrl { source: url::ParseError, @@ -126,8 +124,10 @@ enum Error { #[snafu(display("Configuration key: '{}' is not known.", key))] UnknownConfigurationKey { key: String }, - #[snafu(display("ETag Header missing from response"))] - MissingEtag, + #[snafu(display("Unable to extract metadata from headers: {}", source))] + Metadata { + source: crate::client::header::Error, + }, } impl From for super::Error { @@ -170,11 +170,13 @@ impl std::fmt::Display for MicrosoftAzure { #[async_trait] impl ObjectStore for MicrosoftAzure { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - self.client + async fn put(&self, location: &Path, bytes: Bytes) -> Result { + let response = self + .client .put_request(location, Some(bytes), false, &()) .await?; - Ok(()) + let e_tag = Some(get_etag(response.headers()).context(MetadataSnafu)?); + Ok(PutResult { e_tag }) } async fn put_multipart( diff --git a/object_store/src/chunked.rs b/object_store/src/chunked.rs index d3e02b412725..5694c55d787f 100644 --- a/object_store/src/chunked.rs +++ b/object_store/src/chunked.rs @@ -30,6 +30,7 @@ use tokio::io::AsyncWrite; use crate::path::Path; use crate::{ GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, + PutResult, }; use crate::{MultipartId, Result}; @@ -62,7 +63,7 @@ impl Display for ChunkedStore { #[async_trait] impl ObjectStore for ChunkedStore { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + async fn put(&self, location: &Path, bytes: Bytes) -> Result { self.inner.put(location, bytes).await } diff --git a/object_store/src/client/header.rs b/object_store/src/client/header.rs index 6499eff5aebe..17f83a2ba8c8 100644 --- a/object_store/src/client/header.rs +++ b/object_store/src/client/header.rs @@ -64,6 +64,12 @@ pub enum Error { }, } +/// Extracts an etag from the provided [`HeaderMap`] +pub fn get_etag(headers: &HeaderMap) -> Result { + let e_tag = headers.get(ETAG).ok_or(Error::MissingEtag)?; + Ok(e_tag.to_str().context(BadHeaderSnafu)?.to_string()) +} + /// Extracts [`ObjectMeta`] from the provided [`HeaderMap`] pub fn header_meta( location: &Path, @@ -81,13 +87,10 @@ pub fn header_meta( None => Utc.timestamp_nanos(0), }; - let e_tag = match headers.get(ETAG) { - Some(e_tag) => { - let e_tag = e_tag.to_str().context(BadHeaderSnafu)?; - Some(e_tag.to_string()) - } - None if cfg.etag_required => return Err(Error::MissingEtag), - None => None, + let e_tag = match get_etag(headers) { + Ok(e_tag) => Some(e_tag), + Err(Error::MissingEtag) if !cfg.etag_required => None, + Err(e) => return Err(e), }; let content_length = headers diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index 513e396cbae6..97755c07c671 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -54,7 +54,7 @@ use crate::{ multipart::{PartId, PutPart, WriteMultiPart}, path::{Path, DELIMITER}, ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, - ObjectStore, Result, RetryConfig, + ObjectStore, PutResult, Result, RetryConfig, }; use credential::{InstanceCredentialProvider, ServiceAccountCredentials}; @@ -65,6 +65,7 @@ const STORE: &str = "GCS"; /// [`CredentialProvider`] for [`GoogleCloudStorage`] pub type GcpCredentialProvider = Arc>; +use crate::client::header::get_etag; use crate::gcp::credential::{ApplicationDefaultCredentials, DEFAULT_GCS_BASE_URL}; pub use credential::GcpCredential; @@ -155,11 +156,10 @@ enum Error { #[snafu(display("Configuration key: '{}' is not known.", key))] UnknownConfigurationKey { key: String }, - #[snafu(display("ETag Header missing from response"))] - MissingEtag, - - #[snafu(display("Received header containing non-ASCII data"))] - BadHeader { source: header::ToStrError }, + #[snafu(display("Unable to extract metadata from headers: {}", source))] + Metadata { + source: crate::client::header::Error, + }, } impl From for super::Error { @@ -247,7 +247,14 @@ impl GoogleCloudStorageClient { } /// Perform a put request - async fn put_request(&self, path: &Path, payload: Bytes) -> Result<()> { + /// + /// Returns the new ETag + async fn put_request( + &self, + path: &Path, + payload: Bytes, + query: &T, + ) -> Result { let credential = self.get_credential().await?; let url = self.object_url(path); @@ -256,8 +263,10 @@ impl GoogleCloudStorageClient { .get_content_type(path) .unwrap_or("application/octet-stream"); - self.client + let response = self + .client .request(Method::PUT, url) + .query(query) .bearer_auth(&credential.bearer) .header(header::CONTENT_TYPE, content_type) .header(header::CONTENT_LENGTH, payload.len()) @@ -268,7 +277,7 @@ impl GoogleCloudStorageClient { path: path.as_ref(), })?; - Ok(()) + Ok(get_etag(response.headers()).context(MetadataSnafu)?) } /// Initiate a multi-part upload @@ -469,7 +478,7 @@ impl ListClient for GoogleCloudStorageClient { struct GCSMultipartUpload { client: Arc, - encoded_path: String, + path: Path, multipart_id: MultipartId, } @@ -478,38 +487,17 @@ impl PutPart for GCSMultipartUpload { /// Upload an object part async fn put_part(&self, buf: Vec, part_idx: usize) -> Result { let upload_id = self.multipart_id.clone(); - let url = format!( - "{}/{}/{}", - self.client.base_url, self.client.bucket_name_encoded, self.encoded_path - ); - - let credential = self.client.get_credential().await?; - - let response = self + let content_id = self .client - .client - .request(Method::PUT, &url) - .bearer_auth(&credential.bearer) - .query(&[ - ("partNumber", format!("{}", part_idx + 1)), - ("uploadId", upload_id), - ]) - .header(header::CONTENT_TYPE, "application/octet-stream") - .header(header::CONTENT_LENGTH, format!("{}", buf.len())) - .body(buf) - .send_retry(&self.client.retry_config) - .await - .context(PutRequestSnafu { - path: &self.encoded_path, - })?; - - let content_id = response - .headers() - .get("ETag") - .context(MissingEtagSnafu)? - .to_str() - .context(BadHeaderSnafu)? - .to_string(); + .put_request( + &self.path, + buf.into(), + &[ + ("partNumber", format!("{}", part_idx + 1)), + ("uploadId", upload_id), + ], + ) + .await?; Ok(PartId { content_id }) } @@ -517,10 +505,7 @@ impl PutPart for GCSMultipartUpload { /// Complete a multipart upload async fn complete(&self, completed_parts: Vec) -> Result<()> { let upload_id = self.multipart_id.clone(); - let url = format!( - "{}/{}/{}", - self.client.base_url, self.client.bucket_name_encoded, self.encoded_path - ); + let url = self.client.object_url(&self.path); let parts = completed_parts .into_iter() @@ -550,7 +535,7 @@ impl PutPart for GCSMultipartUpload { .send_retry(&self.client.retry_config) .await .context(PostRequestSnafu { - path: &self.encoded_path, + path: self.path.as_ref(), })?; Ok(()) @@ -559,8 +544,9 @@ impl PutPart for GCSMultipartUpload { #[async_trait] impl ObjectStore for GoogleCloudStorage { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - self.client.put_request(location, bytes).await + async fn put(&self, location: &Path, bytes: Bytes) -> Result { + let e_tag = self.client.put_request(location, bytes, &()).await?; + Ok(PutResult { e_tag: Some(e_tag) }) } async fn put_multipart( @@ -569,12 +555,9 @@ impl ObjectStore for GoogleCloudStorage { ) -> Result<(MultipartId, Box)> { let upload_id = self.client.multipart_initiate(location).await?; - let encoded_path = - percent_encode(location.to_string().as_bytes(), NON_ALPHANUMERIC).to_string(); - let inner = GCSMultipartUpload { client: Arc::clone(&self.client), - encoded_path, + path: location.clone(), multipart_id: upload_id.clone(), }; diff --git a/object_store/src/http/client.rs b/object_store/src/http/client.rs index b2a6ac0aa34a..4c2a7fcf8db3 100644 --- a/object_store/src/http/client.rs +++ b/object_store/src/http/client.rs @@ -160,7 +160,7 @@ impl Client { Ok(()) } - pub async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + pub async fn put(&self, location: &Path, bytes: Bytes) -> Result { let mut retry = false; loop { let url = self.path_url(location); @@ -170,7 +170,7 @@ impl Client { } match builder.send_retry(&self.retry_config).await { - Ok(_) => return Ok(()), + Ok(response) => return Ok(response), Err(source) => match source.status() { // Some implementations return 404 instead of 409 Some(StatusCode::CONFLICT | StatusCode::NOT_FOUND) if !retry => { diff --git a/object_store/src/http/mod.rs b/object_store/src/http/mod.rs index 2fd7850b6bbf..e41e4f990110 100644 --- a/object_store/src/http/mod.rs +++ b/object_store/src/http/mod.rs @@ -41,11 +41,12 @@ use tokio::io::AsyncWrite; use url::Url; use crate::client::get::GetClientExt; +use crate::client::header::get_etag; use crate::http::client::Client; use crate::path::Path; use crate::{ ClientConfigKey, ClientOptions, GetOptions, GetResult, ListResult, MultipartId, - ObjectMeta, ObjectStore, Result, RetryConfig, + ObjectMeta, ObjectStore, PutResult, Result, RetryConfig, }; mod client; @@ -95,8 +96,14 @@ impl std::fmt::Display for HttpStore { #[async_trait] impl ObjectStore for HttpStore { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - self.client.put(location, bytes).await + async fn put(&self, location: &Path, bytes: Bytes) -> Result { + let response = self.client.put(location, bytes).await?; + let e_tag = match get_etag(response.headers()) { + Ok(e_tag) => Some(e_tag), + Err(crate::client::header::Error::MissingEtag) => None, + Err(source) => return Err(Error::Metadata { source }.into()), + }; + Ok(PutResult { e_tag }) } async fn put_multipart( diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 9b396444fa0d..018f0f5e8dec 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -300,7 +300,7 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// The operation is guaranteed to be atomic, it will either successfully /// write the entirety of `bytes` to `location`, or fail. No clients /// should be able to observe a partially written object - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()>; + async fn put(&self, location: &Path, bytes: Bytes) -> Result; /// Get a multi-part upload that allows writing data in chunks /// @@ -528,7 +528,7 @@ macro_rules! as_ref_impl { ($type:ty) => { #[async_trait] impl ObjectStore for $type { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + async fn put(&self, location: &Path, bytes: Bytes) -> Result { self.as_ref().put(location, bytes).await } @@ -659,6 +659,8 @@ pub struct ObjectMeta { /// The size in bytes of the object pub size: usize, /// The unique identifier for the object + /// + /// pub e_tag: Option, } @@ -850,6 +852,15 @@ impl GetResult { } } +/// Result for a put request +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PutResult { + /// The unique identifier for the object + /// + /// + pub e_tag: Option, +} + /// A specialized `Result` for object store-related errors pub type Result = std::result::Result; @@ -1383,6 +1394,26 @@ mod tests { ..GetOptions::default() }; storage.get_opts(&path, options).await.unwrap(); + + let result = storage.put(&path, "test".into()).await.unwrap(); + let new_tag = result.e_tag.unwrap(); + assert_ne!(tag, new_tag); + + let meta = storage.head(&path).await.unwrap(); + assert_eq!(meta.e_tag.unwrap(), new_tag); + + let options = GetOptions { + if_match: Some(new_tag), + ..GetOptions::default() + }; + storage.get_opts(&path, options).await.unwrap(); + + let options = GetOptions { + if_match: Some(tag), + ..GetOptions::default() + }; + let err = storage.get_opts(&path, options).await.unwrap_err(); + assert!(matches!(err, Error::Precondition { .. }), "{err}"); } /// Returns a chunk of length `chunk_length` diff --git a/object_store/src/limit.rs b/object_store/src/limit.rs index 00cbce023c3d..8a453813c24e 100644 --- a/object_store/src/limit.rs +++ b/object_store/src/limit.rs @@ -19,7 +19,7 @@ use crate::{ BoxStream, GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, - ObjectMeta, ObjectStore, Path, Result, StreamExt, + ObjectMeta, ObjectStore, Path, PutResult, Result, StreamExt, }; use async_trait::async_trait; use bytes::Bytes; @@ -72,7 +72,7 @@ impl std::fmt::Display for LimitStore { #[async_trait] impl ObjectStore for LimitStore { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + async fn put(&self, location: &Path, bytes: Bytes) -> Result { let _permit = self.semaphore.acquire().await.unwrap(); self.inner.put(location, bytes).await } diff --git a/object_store/src/local.rs b/object_store/src/local.rs index 38467c3a9e7c..4b7c96346e4d 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -20,7 +20,7 @@ use crate::{ maybe_spawn_blocking, path::{absolute_path_to_url, Path}, GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, ObjectMeta, - ObjectStore, Result, + ObjectStore, PutResult, Result, }; use async_trait::async_trait; use bytes::Bytes; @@ -36,6 +36,7 @@ use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::Poll; +use std::time::SystemTime; use std::{collections::BTreeSet, convert::TryFrom, io}; use std::{collections::VecDeque, path::PathBuf}; use tokio::io::AsyncWrite; @@ -270,7 +271,7 @@ impl Config { #[async_trait] impl ObjectStore for LocalFileSystem { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + async fn put(&self, location: &Path, bytes: Bytes) -> Result { let path = self.config.path_to_filesystem(location)?; maybe_spawn_blocking(move || { let (mut file, suffix) = new_staged_upload(&path)?; @@ -282,8 +283,17 @@ impl ObjectStore for LocalFileSystem { }) .map_err(|e| { let _ = std::fs::remove_file(&staging_path); // Attempt to cleanup - e.into() - }) + e + })?; + + let metadata = file.metadata().map_err(|e| Error::Metadata { + source: e.into(), + path: path.to_string_lossy().to_string(), + })?; + + Ok(PutResult { + e_tag: Some(get_etag(&metadata)), + }) }) .await } @@ -959,24 +969,33 @@ fn last_modified(metadata: &Metadata) -> DateTime { .into() } +fn get_etag(metadata: &Metadata) -> String { + let inode = get_inode(metadata); + let size = metadata.len(); + let mtime = metadata + .modified() + .ok() + .and_then(|mtime| mtime.duration_since(SystemTime::UNIX_EPOCH).ok()) + .unwrap_or_default() + .as_micros(); + + // Use an ETag scheme based on that used by many popular HTTP servers + // + // + format!("{inode:x}-{mtime:x}-{size:x}") +} + fn convert_metadata(metadata: Metadata, location: Path) -> Result { let last_modified = last_modified(&metadata); let size = usize::try_from(metadata.len()).context(FileSizeOverflowedUsizeSnafu { path: location.as_ref(), })?; - let inode = get_inode(&metadata); - let mtime = last_modified.timestamp_micros(); - - // Use an ETag scheme based on that used by many popular HTTP servers - // - // - let etag = format!("{inode:x}-{mtime:x}-{size:x}"); Ok(ObjectMeta { location, last_modified, size, - e_tag: Some(etag), + e_tag: Some(get_etag(&metadata)), }) } diff --git a/object_store/src/memory.rs b/object_store/src/memory.rs index 00b330b5eb94..952b45739759 100644 --- a/object_store/src/memory.rs +++ b/object_store/src/memory.rs @@ -17,7 +17,8 @@ //! An in-memory object store implementation use crate::{ - path::Path, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, Result, + path::Path, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, + PutResult, Result, }; use crate::{GetOptions, MultipartId}; use async_trait::async_trait; @@ -106,11 +107,12 @@ struct Storage { type SharedStorage = Arc>; impl Storage { - fn insert(&mut self, location: &Path, bytes: Bytes) { + fn insert(&mut self, location: &Path, bytes: Bytes) -> usize { let etag = self.next_etag; self.next_etag += 1; let entry = Entry::new(bytes, Utc::now(), etag); self.map.insert(location.clone(), entry); + etag } } @@ -122,9 +124,11 @@ impl std::fmt::Display for InMemory { #[async_trait] impl ObjectStore for InMemory { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - self.storage.write().insert(location, bytes); - Ok(()) + async fn put(&self, location: &Path, bytes: Bytes) -> Result { + let etag = self.storage.write().insert(location, bytes); + Ok(PutResult { + e_tag: Some(etag.to_string()), + }) } async fn put_multipart( diff --git a/object_store/src/prefix.rs b/object_store/src/prefix.rs index 3776dec2e872..21f6c1d99dc9 100644 --- a/object_store/src/prefix.rs +++ b/object_store/src/prefix.rs @@ -23,7 +23,8 @@ use tokio::io::AsyncWrite; use crate::path::Path; use crate::{ - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, + GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutResult, + Result, }; #[doc(hidden)] @@ -79,7 +80,7 @@ impl PrefixStore { #[async_trait::async_trait] impl ObjectStore for PrefixStore { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + async fn put(&self, location: &Path, bytes: Bytes) -> Result { let full_path = self.full_path(location); self.inner.put(&full_path, bytes).await } diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs index f716a11f8a05..d6f191baf82e 100644 --- a/object_store/src/throttle.rs +++ b/object_store/src/throttle.rs @@ -21,7 +21,8 @@ use std::ops::Range; use std::{convert::TryInto, sync::Arc}; use crate::{ - path::Path, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, Result, + path::Path, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, + PutResult, Result, }; use crate::{GetOptions, MultipartId}; use async_trait::async_trait; @@ -147,7 +148,7 @@ impl std::fmt::Display for ThrottledStore { #[async_trait] impl ObjectStore for ThrottledStore { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + async fn put(&self, location: &Path, bytes: Bytes) -> Result { sleep(self.config().wait_put_per_call).await; self.inner.put(location, bytes).await diff --git a/object_store/tests/get_range_file.rs b/object_store/tests/get_range_file.rs index 25c469260675..5703d7f24844 100644 --- a/object_store/tests/get_range_file.rs +++ b/object_store/tests/get_range_file.rs @@ -23,7 +23,7 @@ use futures::stream::BoxStream; use object_store::local::LocalFileSystem; use object_store::path::Path; use object_store::{ - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, + GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutResult, }; use std::fmt::Formatter; use tempfile::tempdir; @@ -40,7 +40,7 @@ impl std::fmt::Display for MyStore { #[async_trait] impl ObjectStore for MyStore { - async fn put(&self, path: &Path, data: Bytes) -> object_store::Result<()> { + async fn put(&self, path: &Path, data: Bytes) -> object_store::Result { self.0.put(path, data).await } From 62ca5f37d143db172a73b3f0365f48f8bc3e2c72 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:40:49 +0100 Subject: [PATCH 18/25] Split aws Module (#4953) * Split aws module * Clippy * Fix doc --- object_store/src/aws/builder.rs | 1098 +++++++++++++++++++++++++++++ object_store/src/aws/mod.rs | 1169 +------------------------------ object_store/src/aws/resolve.rs | 106 +++ 3 files changed, 1225 insertions(+), 1148 deletions(-) create mode 100644 object_store/src/aws/builder.rs create mode 100644 object_store/src/aws/resolve.rs diff --git a/object_store/src/aws/builder.rs b/object_store/src/aws/builder.rs new file mode 100644 index 000000000000..422ba15efa52 --- /dev/null +++ b/object_store/src/aws/builder.rs @@ -0,0 +1,1098 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aws::client::{S3Client, S3Config}; +use crate::aws::credential::{ + InstanceCredentialProvider, TaskCredentialProvider, WebIdentityProvider, +}; +use crate::aws::{ + AmazonS3, AwsCredential, AwsCredentialProvider, Checksum, S3CopyIfNotExists, STORE, +}; +use crate::client::TokenCredentialProvider; +use crate::config::ConfigValue; +use crate::{ + ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider, +}; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use snafu::{OptionExt, ResultExt, Snafu}; +use std::str::FromStr; +use std::sync::Arc; +use tracing::info; +use url::Url; + +/// Default metadata endpoint +static DEFAULT_METADATA_ENDPOINT: &str = "http://169.254.169.254"; + +/// A specialized `Error` for object store-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +enum Error { + #[snafu(display("Missing region"))] + MissingRegion, + + #[snafu(display("Missing bucket name"))] + MissingBucketName, + + #[snafu(display("Missing AccessKeyId"))] + MissingAccessKeyId, + + #[snafu(display("Missing SecretAccessKey"))] + MissingSecretAccessKey, + + #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] + UnableToParseUrl { + source: url::ParseError, + url: String, + }, + + #[snafu(display( + "Unknown url scheme cannot be parsed into storage location: {}", + scheme + ))] + UnknownUrlScheme { scheme: String }, + + #[snafu(display("URL did not match any known pattern for scheme: {}", url))] + UrlNotRecognised { url: String }, + + #[snafu(display("Configuration key: '{}' is not known.", key))] + UnknownConfigurationKey { key: String }, + + #[snafu(display("Bucket '{}' not found", bucket))] + BucketNotFound { bucket: String }, + + #[snafu(display("Failed to resolve region for bucket '{}'", bucket))] + ResolveRegion { + bucket: String, + source: reqwest::Error, + }, + + #[snafu(display("Failed to parse the region for bucket '{}'", bucket))] + RegionParse { bucket: String }, +} + +impl From for crate::Error { + fn from(source: Error) -> Self { + match source { + Error::UnknownConfigurationKey { key } => { + Self::UnknownConfigurationKey { store: STORE, key } + } + _ => Self::Generic { + store: STORE, + source: Box::new(source), + }, + } + } +} + +/// Configure a connection to Amazon S3 using the specified credentials in +/// the specified Amazon region and bucket. +/// +/// # Example +/// ``` +/// # let REGION = "foo"; +/// # let BUCKET_NAME = "foo"; +/// # let ACCESS_KEY_ID = "foo"; +/// # let SECRET_KEY = "foo"; +/// # use object_store::aws::AmazonS3Builder; +/// let s3 = AmazonS3Builder::new() +/// .with_region(REGION) +/// .with_bucket_name(BUCKET_NAME) +/// .with_access_key_id(ACCESS_KEY_ID) +/// .with_secret_access_key(SECRET_KEY) +/// .build(); +/// ``` +#[derive(Debug, Default, Clone)] +pub struct AmazonS3Builder { + /// Access key id + access_key_id: Option, + /// Secret access_key + secret_access_key: Option, + /// Region + region: Option, + /// Bucket name + bucket_name: Option, + /// Endpoint for communicating with AWS S3 + endpoint: Option, + /// Token to use for requests + token: Option, + /// Url + url: Option, + /// Retry config + retry_config: RetryConfig, + /// When set to true, fallback to IMDSv1 + imdsv1_fallback: ConfigValue, + /// When set to true, virtual hosted style request has to be used + virtual_hosted_style_request: ConfigValue, + /// When set to true, unsigned payload option has to be used + unsigned_payload: ConfigValue, + /// Checksum algorithm which has to be used for object integrity check during upload + checksum_algorithm: Option>, + /// Metadata endpoint, see + metadata_endpoint: Option, + /// Container credentials URL, see + container_credentials_relative_uri: Option, + /// Client options + client_options: ClientOptions, + /// Credentials + credentials: Option, + /// Skip signing requests + skip_signature: ConfigValue, + /// Copy if not exists + copy_if_not_exists: Option>, +} + +/// Configuration keys for [`AmazonS3Builder`] +/// +/// Configuration via keys can be done via [`AmazonS3Builder::with_config`] +/// +/// # Example +/// ``` +/// # use object_store::aws::{AmazonS3Builder, AmazonS3ConfigKey}; +/// let builder = AmazonS3Builder::new() +/// .with_config("aws_access_key_id".parse().unwrap(), "my-access-key-id") +/// .with_config(AmazonS3ConfigKey::DefaultRegion, "my-default-region"); +/// ``` +#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Serialize, Deserialize)] +#[non_exhaustive] +pub enum AmazonS3ConfigKey { + /// AWS Access Key + /// + /// See [`AmazonS3Builder::with_access_key_id`] for details. + /// + /// Supported keys: + /// - `aws_access_key_id` + /// - `access_key_id` + AccessKeyId, + + /// Secret Access Key + /// + /// See [`AmazonS3Builder::with_secret_access_key`] for details. + /// + /// Supported keys: + /// - `aws_secret_access_key` + /// - `secret_access_key` + SecretAccessKey, + + /// Region + /// + /// See [`AmazonS3Builder::with_region`] for details. + /// + /// Supported keys: + /// - `aws_region` + /// - `region` + Region, + + /// Default region + /// + /// See [`AmazonS3Builder::with_region`] for details. + /// + /// Supported keys: + /// - `aws_default_region` + /// - `default_region` + DefaultRegion, + + /// Bucket name + /// + /// See [`AmazonS3Builder::with_bucket_name`] for details. + /// + /// Supported keys: + /// - `aws_bucket` + /// - `aws_bucket_name` + /// - `bucket` + /// - `bucket_name` + Bucket, + + /// Sets custom endpoint for communicating with AWS S3. + /// + /// See [`AmazonS3Builder::with_endpoint`] for details. + /// + /// Supported keys: + /// - `aws_endpoint` + /// - `aws_endpoint_url` + /// - `endpoint` + /// - `endpoint_url` + Endpoint, + + /// Token to use for requests (passed to underlying provider) + /// + /// See [`AmazonS3Builder::with_token`] for details. + /// + /// Supported keys: + /// - `aws_session_token` + /// - `aws_token` + /// - `session_token` + /// - `token` + Token, + + /// Fall back to ImdsV1 + /// + /// See [`AmazonS3Builder::with_imdsv1_fallback`] for details. + /// + /// Supported keys: + /// - `aws_imdsv1_fallback` + /// - `imdsv1_fallback` + ImdsV1Fallback, + + /// If virtual hosted style request has to be used + /// + /// See [`AmazonS3Builder::with_virtual_hosted_style_request`] for details. + /// + /// Supported keys: + /// - `aws_virtual_hosted_style_request` + /// - `virtual_hosted_style_request` + VirtualHostedStyleRequest, + + /// Avoid computing payload checksum when calculating signature. + /// + /// See [`AmazonS3Builder::with_unsigned_payload`] for details. + /// + /// Supported keys: + /// - `aws_unsigned_payload` + /// - `unsigned_payload` + UnsignedPayload, + + /// Set the checksum algorithm for this client + /// + /// See [`AmazonS3Builder::with_checksum_algorithm`] + Checksum, + + /// Set the instance metadata endpoint + /// + /// See [`AmazonS3Builder::with_metadata_endpoint`] for details. + /// + /// Supported keys: + /// - `aws_metadata_endpoint` + /// - `metadata_endpoint` + MetadataEndpoint, + + /// Set the container credentials relative URI + /// + /// + ContainerCredentialsRelativeUri, + + /// Configure how to provide `copy_if_not_exists` + /// + /// See [`S3CopyIfNotExists`] + CopyIfNotExists, + + /// Skip signing request + SkipSignature, + + /// Client options + Client(ClientConfigKey), +} + +impl AsRef for AmazonS3ConfigKey { + fn as_ref(&self) -> &str { + match self { + Self::AccessKeyId => "aws_access_key_id", + Self::SecretAccessKey => "aws_secret_access_key", + Self::Region => "aws_region", + Self::Bucket => "aws_bucket", + Self::Endpoint => "aws_endpoint", + Self::Token => "aws_session_token", + Self::ImdsV1Fallback => "aws_imdsv1_fallback", + Self::VirtualHostedStyleRequest => "aws_virtual_hosted_style_request", + Self::DefaultRegion => "aws_default_region", + Self::MetadataEndpoint => "aws_metadata_endpoint", + Self::UnsignedPayload => "aws_unsigned_payload", + Self::Checksum => "aws_checksum_algorithm", + Self::ContainerCredentialsRelativeUri => { + "aws_container_credentials_relative_uri" + } + Self::SkipSignature => "aws_skip_signature", + Self::CopyIfNotExists => "copy_if_not_exists", + Self::Client(opt) => opt.as_ref(), + } + } +} + +impl FromStr for AmazonS3ConfigKey { + type Err = crate::Error; + + fn from_str(s: &str) -> Result { + match s { + "aws_access_key_id" | "access_key_id" => Ok(Self::AccessKeyId), + "aws_secret_access_key" | "secret_access_key" => Ok(Self::SecretAccessKey), + "aws_default_region" | "default_region" => Ok(Self::DefaultRegion), + "aws_region" | "region" => Ok(Self::Region), + "aws_bucket" | "aws_bucket_name" | "bucket_name" | "bucket" => { + Ok(Self::Bucket) + } + "aws_endpoint_url" | "aws_endpoint" | "endpoint_url" | "endpoint" => { + Ok(Self::Endpoint) + } + "aws_session_token" | "aws_token" | "session_token" | "token" => { + Ok(Self::Token) + } + "aws_virtual_hosted_style_request" | "virtual_hosted_style_request" => { + Ok(Self::VirtualHostedStyleRequest) + } + "aws_imdsv1_fallback" | "imdsv1_fallback" => Ok(Self::ImdsV1Fallback), + "aws_metadata_endpoint" | "metadata_endpoint" => Ok(Self::MetadataEndpoint), + "aws_unsigned_payload" | "unsigned_payload" => Ok(Self::UnsignedPayload), + "aws_checksum_algorithm" | "checksum_algorithm" => Ok(Self::Checksum), + "aws_container_credentials_relative_uri" => { + Ok(Self::ContainerCredentialsRelativeUri) + } + "aws_skip_signature" | "skip_signature" => Ok(Self::SkipSignature), + "copy_if_not_exists" => Ok(Self::CopyIfNotExists), + // Backwards compatibility + "aws_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), + _ => match s.parse() { + Ok(key) => Ok(Self::Client(key)), + Err(_) => Err(Error::UnknownConfigurationKey { key: s.into() }.into()), + }, + } + } +} + +impl AmazonS3Builder { + /// Create a new [`AmazonS3Builder`] with default values. + pub fn new() -> Self { + Default::default() + } + + /// Fill the [`AmazonS3Builder`] with regular AWS environment variables + /// + /// Variables extracted from environment: + /// * `AWS_ACCESS_KEY_ID` -> access_key_id + /// * `AWS_SECRET_ACCESS_KEY` -> secret_access_key + /// * `AWS_DEFAULT_REGION` -> region + /// * `AWS_ENDPOINT` -> endpoint + /// * `AWS_SESSION_TOKEN` -> token + /// * `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI` -> + /// * `AWS_ALLOW_HTTP` -> set to "true" to permit HTTP connections without TLS + /// # Example + /// ``` + /// use object_store::aws::AmazonS3Builder; + /// + /// let s3 = AmazonS3Builder::from_env() + /// .with_bucket_name("foo") + /// .build(); + /// ``` + pub fn from_env() -> Self { + let mut builder: Self = Default::default(); + + for (os_key, os_value) in std::env::vars_os() { + if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { + if key.starts_with("AWS_") { + if let Ok(config_key) = key.to_ascii_lowercase().parse() { + builder = builder.with_config(config_key, value); + } + } + } + } + + builder + } + + /// Parse available connection info form a well-known storage URL. + /// + /// The supported url schemes are: + /// + /// - `s3:///` + /// - `s3a:///` + /// - `https://s3..amazonaws.com/` + /// - `https://.s3..amazonaws.com` + /// - `https://ACCOUNT_ID.r2.cloudflarestorage.com/bucket` + /// + /// Note: Settings derived from the URL will override any others set on this builder + /// + /// # Example + /// ``` + /// use object_store::aws::AmazonS3Builder; + /// + /// let s3 = AmazonS3Builder::from_env() + /// .with_url("s3://bucket/path") + /// .build(); + /// ``` + pub fn with_url(mut self, url: impl Into) -> Self { + self.url = Some(url.into()); + self + } + + /// Set an option on the builder via a key - value pair. + pub fn with_config( + mut self, + key: AmazonS3ConfigKey, + value: impl Into, + ) -> Self { + match key { + AmazonS3ConfigKey::AccessKeyId => self.access_key_id = Some(value.into()), + AmazonS3ConfigKey::SecretAccessKey => { + self.secret_access_key = Some(value.into()) + } + AmazonS3ConfigKey::Region => self.region = Some(value.into()), + AmazonS3ConfigKey::Bucket => self.bucket_name = Some(value.into()), + AmazonS3ConfigKey::Endpoint => self.endpoint = Some(value.into()), + AmazonS3ConfigKey::Token => self.token = Some(value.into()), + AmazonS3ConfigKey::ImdsV1Fallback => self.imdsv1_fallback.parse(value), + AmazonS3ConfigKey::VirtualHostedStyleRequest => { + self.virtual_hosted_style_request.parse(value) + } + AmazonS3ConfigKey::DefaultRegion => { + self.region = self.region.or_else(|| Some(value.into())) + } + AmazonS3ConfigKey::MetadataEndpoint => { + self.metadata_endpoint = Some(value.into()) + } + AmazonS3ConfigKey::UnsignedPayload => self.unsigned_payload.parse(value), + AmazonS3ConfigKey::Checksum => { + self.checksum_algorithm = Some(ConfigValue::Deferred(value.into())) + } + AmazonS3ConfigKey::ContainerCredentialsRelativeUri => { + self.container_credentials_relative_uri = Some(value.into()) + } + AmazonS3ConfigKey::Client(key) => { + self.client_options = self.client_options.with_config(key, value) + } + AmazonS3ConfigKey::SkipSignature => self.skip_signature.parse(value), + AmazonS3ConfigKey::CopyIfNotExists => { + self.copy_if_not_exists = Some(ConfigValue::Deferred(value.into())) + } + }; + self + } + + /// Set an option on the builder via a key - value pair. + /// + /// This method will return an `UnknownConfigKey` error if key cannot be parsed into [`AmazonS3ConfigKey`]. + #[deprecated(note = "Use with_config")] + pub fn try_with_option( + self, + key: impl AsRef, + value: impl Into, + ) -> Result { + Ok(self.with_config(key.as_ref().parse()?, value)) + } + + /// Hydrate builder from key value pairs + /// + /// This method will return an `UnknownConfigKey` error if any key cannot be parsed into [`AmazonS3ConfigKey`]. + #[deprecated(note = "Use with_config")] + #[allow(deprecated)] + pub fn try_with_options< + I: IntoIterator, impl Into)>, + >( + mut self, + options: I, + ) -> Result { + for (key, value) in options { + self = self.try_with_option(key, value)?; + } + Ok(self) + } + + /// Get config value via a [`AmazonS3ConfigKey`]. + /// + /// # Example + /// ``` + /// use object_store::aws::{AmazonS3Builder, AmazonS3ConfigKey}; + /// + /// let builder = AmazonS3Builder::from_env() + /// .with_bucket_name("foo"); + /// let bucket_name = builder.get_config_value(&AmazonS3ConfigKey::Bucket).unwrap_or_default(); + /// assert_eq!("foo", &bucket_name); + /// ``` + pub fn get_config_value(&self, key: &AmazonS3ConfigKey) -> Option { + match key { + AmazonS3ConfigKey::AccessKeyId => self.access_key_id.clone(), + AmazonS3ConfigKey::SecretAccessKey => self.secret_access_key.clone(), + AmazonS3ConfigKey::Region | AmazonS3ConfigKey::DefaultRegion => { + self.region.clone() + } + AmazonS3ConfigKey::Bucket => self.bucket_name.clone(), + AmazonS3ConfigKey::Endpoint => self.endpoint.clone(), + AmazonS3ConfigKey::Token => self.token.clone(), + AmazonS3ConfigKey::ImdsV1Fallback => Some(self.imdsv1_fallback.to_string()), + AmazonS3ConfigKey::VirtualHostedStyleRequest => { + Some(self.virtual_hosted_style_request.to_string()) + } + AmazonS3ConfigKey::MetadataEndpoint => self.metadata_endpoint.clone(), + AmazonS3ConfigKey::UnsignedPayload => Some(self.unsigned_payload.to_string()), + AmazonS3ConfigKey::Checksum => { + self.checksum_algorithm.as_ref().map(ToString::to_string) + } + AmazonS3ConfigKey::Client(key) => self.client_options.get_config_value(key), + AmazonS3ConfigKey::ContainerCredentialsRelativeUri => { + self.container_credentials_relative_uri.clone() + } + AmazonS3ConfigKey::SkipSignature => Some(self.skip_signature.to_string()), + AmazonS3ConfigKey::CopyIfNotExists => { + self.copy_if_not_exists.as_ref().map(ToString::to_string) + } + } + } + + /// Sets properties on this builder based on a URL + /// + /// This is a separate member function to allow fallible computation to + /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] + fn parse_url(&mut self, url: &str) -> Result<()> { + let parsed = Url::parse(url).context(UnableToParseUrlSnafu { url })?; + let host = parsed.host_str().context(UrlNotRecognisedSnafu { url })?; + match parsed.scheme() { + "s3" | "s3a" => self.bucket_name = Some(host.to_string()), + "https" => match host.splitn(4, '.').collect_tuple() { + Some(("s3", region, "amazonaws", "com")) => { + self.region = Some(region.to_string()); + let bucket = parsed.path_segments().into_iter().flatten().next(); + if let Some(bucket) = bucket { + self.bucket_name = Some(bucket.into()); + } + } + Some((bucket, "s3", region, "amazonaws.com")) => { + self.bucket_name = Some(bucket.to_string()); + self.region = Some(region.to_string()); + self.virtual_hosted_style_request = true.into(); + } + Some((account, "r2", "cloudflarestorage", "com")) => { + self.region = Some("auto".to_string()); + let endpoint = format!("https://{account}.r2.cloudflarestorage.com"); + self.endpoint = Some(endpoint); + + let bucket = parsed.path_segments().into_iter().flatten().next(); + if let Some(bucket) = bucket { + self.bucket_name = Some(bucket.into()); + } + } + _ => return Err(UrlNotRecognisedSnafu { url }.build().into()), + }, + scheme => return Err(UnknownUrlSchemeSnafu { scheme }.build().into()), + }; + Ok(()) + } + + /// Set the AWS Access Key (required) + pub fn with_access_key_id(mut self, access_key_id: impl Into) -> Self { + self.access_key_id = Some(access_key_id.into()); + self + } + + /// Set the AWS Secret Access Key (required) + pub fn with_secret_access_key( + mut self, + secret_access_key: impl Into, + ) -> Self { + self.secret_access_key = Some(secret_access_key.into()); + self + } + + /// Set the region (e.g. `us-east-1`) (required) + pub fn with_region(mut self, region: impl Into) -> Self { + self.region = Some(region.into()); + self + } + + /// Set the bucket_name (required) + pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { + self.bucket_name = Some(bucket_name.into()); + self + } + + /// Sets the endpoint for communicating with AWS S3. Default value + /// is based on region. The `endpoint` field should be consistent with + /// the field `virtual_hosted_style_request'. + /// + /// For example, this might be set to `"http://localhost:4566:` + /// for testing against a localstack instance. + /// If `virtual_hosted_style_request` is set to true then `endpoint` + /// should have bucket name included. + pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { + self.endpoint = Some(endpoint.into()); + self + } + + /// Set the token to use for requests (passed to underlying provider) + pub fn with_token(mut self, token: impl Into) -> Self { + self.token = Some(token.into()); + self + } + + /// Set the credential provider overriding any other options + pub fn with_credentials(mut self, credentials: AwsCredentialProvider) -> Self { + self.credentials = Some(credentials); + self + } + + /// Sets what protocol is allowed. If `allow_http` is : + /// * false (default): Only HTTPS are allowed + /// * true: HTTP and HTTPS are allowed + pub fn with_allow_http(mut self, allow_http: bool) -> Self { + self.client_options = self.client_options.with_allow_http(allow_http); + self + } + + /// Sets if virtual hosted style request has to be used. + /// If `virtual_hosted_style_request` is : + /// * false (default): Path style request is used + /// * true: Virtual hosted style request is used + /// + /// If the `endpoint` is provided then it should be + /// consistent with `virtual_hosted_style_request`. + /// i.e. if `virtual_hosted_style_request` is set to true + /// then `endpoint` should have bucket name included. + pub fn with_virtual_hosted_style_request( + mut self, + virtual_hosted_style_request: bool, + ) -> Self { + self.virtual_hosted_style_request = virtual_hosted_style_request.into(); + self + } + + /// Set the retry configuration + pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { + self.retry_config = retry_config; + self + } + + /// By default instance credentials will only be fetched over [IMDSv2], as AWS recommends + /// against having IMDSv1 enabled on EC2 instances as it is vulnerable to [SSRF attack] + /// + /// However, certain deployment environments, such as those running old versions of kube2iam, + /// may not support IMDSv2. This option will enable automatic fallback to using IMDSv1 + /// if the token endpoint returns a 403 error indicating that IMDSv2 is not supported. + /// + /// This option has no effect if not using instance credentials + /// + /// [IMDSv2]: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html + /// [SSRF attack]: https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/ + /// + pub fn with_imdsv1_fallback(mut self) -> Self { + self.imdsv1_fallback = true.into(); + self + } + + /// Sets if unsigned payload option has to be used. + /// See [unsigned payload option](https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html) + /// * false (default): Signed payload option is used, where the checksum for the request body is computed and included when constructing a canonical request. + /// * true: Unsigned payload option is used. `UNSIGNED-PAYLOAD` literal is included when constructing a canonical request, + pub fn with_unsigned_payload(mut self, unsigned_payload: bool) -> Self { + self.unsigned_payload = unsigned_payload.into(); + self + } + + /// If enabled, [`AmazonS3`] will not fetch credentials and will not sign requests + /// + /// This can be useful when interacting with public S3 buckets that deny authorized requests + pub fn with_skip_signature(mut self, skip_signature: bool) -> Self { + self.skip_signature = skip_signature.into(); + self + } + + /// Sets the [checksum algorithm] which has to be used for object integrity check during upload. + /// + /// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html + pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) -> Self { + // Convert to String to enable deferred parsing of config + self.checksum_algorithm = Some(checksum_algorithm.into()); + self + } + + /// Set the [instance metadata endpoint](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html), + /// used primarily within AWS EC2. + /// + /// This defaults to the IPv4 endpoint: http://169.254.169.254. One can alternatively use the IPv6 + /// endpoint http://fd00:ec2::254. + pub fn with_metadata_endpoint(mut self, endpoint: impl Into) -> Self { + self.metadata_endpoint = Some(endpoint.into()); + self + } + + /// Set the proxy_url to be used by the underlying client + pub fn with_proxy_url(mut self, proxy_url: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_url(proxy_url); + self + } + + /// Set a trusted proxy CA certificate + pub fn with_proxy_ca_certificate( + mut self, + proxy_ca_certificate: impl Into, + ) -> Self { + self.client_options = self + .client_options + .with_proxy_ca_certificate(proxy_ca_certificate); + self + } + + /// Set a list of hosts to exclude from proxy connections + pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); + self + } + + /// Sets the client options, overriding any already set + pub fn with_client_options(mut self, options: ClientOptions) -> Self { + self.client_options = options; + self + } + + /// Configure how to provide `copy_if_not_exists` + pub fn with_copy_if_not_exists(mut self, config: S3CopyIfNotExists) -> Self { + self.copy_if_not_exists = Some(config.into()); + self + } + + /// Create a [`AmazonS3`] instance from the provided values, + /// consuming `self`. + pub fn build(mut self) -> Result { + if let Some(url) = self.url.take() { + self.parse_url(&url)?; + } + + let bucket = self.bucket_name.context(MissingBucketNameSnafu)?; + let region = self.region.context(MissingRegionSnafu)?; + let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?; + let copy_if_not_exists = self.copy_if_not_exists.map(|x| x.get()).transpose()?; + + let credentials = if let Some(credentials) = self.credentials { + credentials + } else if self.access_key_id.is_some() || self.secret_access_key.is_some() { + match (self.access_key_id, self.secret_access_key, self.token) { + (Some(key_id), Some(secret_key), token) => { + info!("Using Static credential provider"); + let credential = AwsCredential { + key_id, + secret_key, + token, + }; + Arc::new(StaticCredentialProvider::new(credential)) as _ + } + (None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()), + (Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()), + (None, None, _) => unreachable!(), + } + } else if let (Ok(token_path), Ok(role_arn)) = ( + std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE"), + std::env::var("AWS_ROLE_ARN"), + ) { + // TODO: Replace with `AmazonS3Builder::credentials_from_env` + info!("Using WebIdentity credential provider"); + + let session_name = std::env::var("AWS_ROLE_SESSION_NAME") + .unwrap_or_else(|_| "WebIdentitySession".to_string()); + + let endpoint = format!("https://sts.{region}.amazonaws.com"); + + // Disallow non-HTTPs requests + let client = self + .client_options + .clone() + .with_allow_http(false) + .client()?; + + let token = WebIdentityProvider { + token_path, + session_name, + role_arn, + endpoint, + }; + + Arc::new(TokenCredentialProvider::new( + token, + client, + self.retry_config.clone(), + )) as _ + } else if let Some(uri) = self.container_credentials_relative_uri { + info!("Using Task credential provider"); + Arc::new(TaskCredentialProvider { + url: format!("http://169.254.170.2{uri}"), + retry: self.retry_config.clone(), + // The instance metadata endpoint is access over HTTP + client: self.client_options.clone().with_allow_http(true).client()?, + cache: Default::default(), + }) as _ + } else { + info!("Using Instance credential provider"); + + let token = InstanceCredentialProvider { + cache: Default::default(), + imdsv1_fallback: self.imdsv1_fallback.get()?, + metadata_endpoint: self + .metadata_endpoint + .unwrap_or_else(|| DEFAULT_METADATA_ENDPOINT.into()), + }; + + Arc::new(TokenCredentialProvider::new( + token, + self.client_options.metadata_client()?, + self.retry_config.clone(), + )) as _ + }; + + let endpoint: String; + let bucket_endpoint: String; + + // If `endpoint` is provided then its assumed to be consistent with + // `virtual_hosted_style_request`. i.e. if `virtual_hosted_style_request` is true then + // `endpoint` should have bucket name included. + if self.virtual_hosted_style_request.get()? { + endpoint = self + .endpoint + .unwrap_or_else(|| format!("https://{bucket}.s3.{region}.amazonaws.com")); + bucket_endpoint = endpoint.clone(); + } else { + endpoint = self + .endpoint + .unwrap_or_else(|| format!("https://s3.{region}.amazonaws.com")); + bucket_endpoint = format!("{endpoint}/{bucket}"); + } + + let config = S3Config { + region, + endpoint, + bucket, + bucket_endpoint, + credentials, + retry_config: self.retry_config, + client_options: self.client_options, + sign_payload: !self.unsigned_payload.get()?, + skip_signature: self.skip_signature.get()?, + checksum, + copy_if_not_exists, + }; + + let client = Arc::new(S3Client::new(config)?); + + Ok(AmazonS3 { client }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn s3_test_config_from_map() { + let aws_access_key_id = "object_store:fake_access_key_id".to_string(); + let aws_secret_access_key = "object_store:fake_secret_key".to_string(); + let aws_default_region = "object_store:fake_default_region".to_string(); + let aws_endpoint = "object_store:fake_endpoint".to_string(); + let aws_session_token = "object_store:fake_session_token".to_string(); + let options = HashMap::from([ + ("aws_access_key_id", aws_access_key_id.clone()), + ("aws_secret_access_key", aws_secret_access_key), + ("aws_default_region", aws_default_region.clone()), + ("aws_endpoint", aws_endpoint.clone()), + ("aws_session_token", aws_session_token.clone()), + ("aws_unsigned_payload", "true".to_string()), + ("aws_checksum_algorithm", "sha256".to_string()), + ]); + + let builder = options + .into_iter() + .fold(AmazonS3Builder::new(), |builder, (key, value)| { + builder.with_config(key.parse().unwrap(), value) + }) + .with_config(AmazonS3ConfigKey::SecretAccessKey, "new-secret-key"); + + assert_eq!(builder.access_key_id.unwrap(), aws_access_key_id.as_str()); + assert_eq!(builder.secret_access_key.unwrap(), "new-secret-key"); + assert_eq!(builder.region.unwrap(), aws_default_region); + assert_eq!(builder.endpoint.unwrap(), aws_endpoint); + assert_eq!(builder.token.unwrap(), aws_session_token); + assert_eq!( + builder.checksum_algorithm.unwrap().get().unwrap(), + Checksum::SHA256 + ); + assert!(builder.unsigned_payload.get().unwrap()); + } + + #[test] + fn s3_test_config_get_value() { + let aws_access_key_id = "object_store:fake_access_key_id".to_string(); + let aws_secret_access_key = "object_store:fake_secret_key".to_string(); + let aws_default_region = "object_store:fake_default_region".to_string(); + let aws_endpoint = "object_store:fake_endpoint".to_string(); + let aws_session_token = "object_store:fake_session_token".to_string(); + + let builder = AmazonS3Builder::new() + .with_config(AmazonS3ConfigKey::AccessKeyId, &aws_access_key_id) + .with_config(AmazonS3ConfigKey::SecretAccessKey, &aws_secret_access_key) + .with_config(AmazonS3ConfigKey::DefaultRegion, &aws_default_region) + .with_config(AmazonS3ConfigKey::Endpoint, &aws_endpoint) + .with_config(AmazonS3ConfigKey::Token, &aws_session_token) + .with_config(AmazonS3ConfigKey::UnsignedPayload, "true"); + + assert_eq!( + builder + .get_config_value(&AmazonS3ConfigKey::AccessKeyId) + .unwrap(), + aws_access_key_id + ); + assert_eq!( + builder + .get_config_value(&AmazonS3ConfigKey::SecretAccessKey) + .unwrap(), + aws_secret_access_key + ); + assert_eq!( + builder + .get_config_value(&AmazonS3ConfigKey::DefaultRegion) + .unwrap(), + aws_default_region + ); + assert_eq!( + builder + .get_config_value(&AmazonS3ConfigKey::Endpoint) + .unwrap(), + aws_endpoint + ); + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Token).unwrap(), + aws_session_token + ); + assert_eq!( + builder + .get_config_value(&AmazonS3ConfigKey::UnsignedPayload) + .unwrap(), + "true" + ); + } + + #[test] + fn s3_test_urls() { + let mut builder = AmazonS3Builder::new(); + builder.parse_url("s3://bucket/path").unwrap(); + assert_eq!(builder.bucket_name, Some("bucket".to_string())); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("s3://buckets.can.have.dots/path") + .unwrap(); + assert_eq!( + builder.bucket_name, + Some("buckets.can.have.dots".to_string()) + ); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("https://s3.region.amazonaws.com") + .unwrap(); + assert_eq!(builder.region, Some("region".to_string())); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("https://s3.region.amazonaws.com/bucket") + .unwrap(); + assert_eq!(builder.region, Some("region".to_string())); + assert_eq!(builder.bucket_name, Some("bucket".to_string())); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("https://s3.region.amazonaws.com/bucket.with.dot/path") + .unwrap(); + assert_eq!(builder.region, Some("region".to_string())); + assert_eq!(builder.bucket_name, Some("bucket.with.dot".to_string())); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("https://bucket.s3.region.amazonaws.com") + .unwrap(); + assert_eq!(builder.bucket_name, Some("bucket".to_string())); + assert_eq!(builder.region, Some("region".to_string())); + assert!(builder.virtual_hosted_style_request.get().unwrap()); + + let mut builder = AmazonS3Builder::new(); + builder + .parse_url("https://account123.r2.cloudflarestorage.com/bucket-123") + .unwrap(); + + assert_eq!(builder.bucket_name, Some("bucket-123".to_string())); + assert_eq!(builder.region, Some("auto".to_string())); + assert_eq!( + builder.endpoint, + Some("https://account123.r2.cloudflarestorage.com".to_string()) + ); + + let err_cases = [ + "mailto://bucket/path", + "https://s3.bucket.mydomain.com", + "https://s3.bucket.foo.amazonaws.com", + "https://bucket.mydomain.region.amazonaws.com", + "https://bucket.s3.region.bar.amazonaws.com", + "https://bucket.foo.s3.amazonaws.com", + ]; + let mut builder = AmazonS3Builder::new(); + for case in err_cases { + builder.parse_url(case).unwrap_err(); + } + } + + #[tokio::test] + async fn s3_test_proxy_url() { + let s3 = AmazonS3Builder::new() + .with_access_key_id("access_key_id") + .with_secret_access_key("secret_access_key") + .with_region("region") + .with_bucket_name("bucket_name") + .with_allow_http(true) + .with_proxy_url("https://example.com") + .build(); + + assert!(s3.is_ok()); + + let err = AmazonS3Builder::new() + .with_access_key_id("access_key_id") + .with_secret_access_key("secret_access_key") + .with_region("region") + .with_bucket_name("bucket_name") + .with_allow_http(true) + .with_proxy_url("asdf://example.com") + .build() + .unwrap_err() + .to_string(); + + assert_eq!( + "Generic HTTP client error: builder error: unknown proxy scheme", + err + ); + } + + #[test] + fn test_invalid_config() { + let err = AmazonS3Builder::new() + .with_config(AmazonS3ConfigKey::ImdsV1Fallback, "enabled") + .with_bucket_name("bucket") + .with_region("region") + .build() + .unwrap_err() + .to_string(); + + assert_eq!( + err, + "Generic Config error: failed to parse \"enabled\" as boolean" + ); + + let err = AmazonS3Builder::new() + .with_config(AmazonS3ConfigKey::Checksum, "md5") + .with_bucket_name("bucket") + .with_region("region") + .build() + .unwrap_err() + .to_string(); + + assert_eq!( + err, + "Generic Config error: \"md5\" is not a valid checksum algorithm" + ); + } +} diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index 6d5aecea2d17..a4e39c3b88dd 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -35,40 +35,33 @@ use async_trait::async_trait; use bytes::Bytes; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; -use itertools::Itertools; use reqwest::Method; -use serde::{Deserialize, Serialize}; -use snafu::{ensure, OptionExt, ResultExt, Snafu}; -use std::{str::FromStr, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use tokio::io::AsyncWrite; -use tracing::info; use url::Url; -use crate::aws::client::{S3Client, S3Config}; -use crate::aws::credential::{ - InstanceCredentialProvider, TaskCredentialProvider, WebIdentityProvider, -}; +use crate::aws::client::S3Client; use crate::client::get::GetClientExt; use crate::client::list::ListClientExt; -use crate::client::{ - ClientConfigKey, CredentialProvider, StaticCredentialProvider, - TokenCredentialProvider, -}; -use crate::config::ConfigValue; +use crate::client::CredentialProvider; use crate::multipart::{PartId, PutPart, WriteMultiPart}; use crate::signer::Signer; use crate::{ - ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, - ObjectStore, Path, PutResult, Result, RetryConfig, + GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, + PutResult, Result, }; +mod builder; mod checksum; mod client; mod copy; mod credential; +mod resolve; +pub use builder::{AmazonS3Builder, AmazonS3ConfigKey}; pub use checksum::Checksum; pub use copy::S3CopyIfNotExists; +pub use resolve::resolve_bucket_region; // http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html // @@ -90,103 +83,6 @@ const STORE: &str = "S3"; pub type AwsCredentialProvider = Arc>; pub use credential::{AwsAuthorizer, AwsCredential}; -/// Default metadata endpoint -static DEFAULT_METADATA_ENDPOINT: &str = "http://169.254.169.254"; - -/// A specialized `Error` for object store-related errors -#[derive(Debug, Snafu)] -#[allow(missing_docs)] -enum Error { - #[snafu(display("Missing region"))] - MissingRegion, - - #[snafu(display("Missing bucket name"))] - MissingBucketName, - - #[snafu(display("Missing AccessKeyId"))] - MissingAccessKeyId, - - #[snafu(display("Missing SecretAccessKey"))] - MissingSecretAccessKey, - - #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] - UnableToParseUrl { - source: url::ParseError, - url: String, - }, - - #[snafu(display( - "Unknown url scheme cannot be parsed into storage location: {}", - scheme - ))] - UnknownUrlScheme { scheme: String }, - - #[snafu(display("URL did not match any known pattern for scheme: {}", url))] - UrlNotRecognised { url: String }, - - #[snafu(display("Configuration key: '{}' is not known.", key))] - UnknownConfigurationKey { key: String }, - - #[snafu(display("Bucket '{}' not found", bucket))] - BucketNotFound { bucket: String }, - - #[snafu(display("Failed to resolve region for bucket '{}'", bucket))] - ResolveRegion { - bucket: String, - source: reqwest::Error, - }, - - #[snafu(display("Failed to parse the region for bucket '{}'", bucket))] - RegionParse { bucket: String }, -} - -impl From for super::Error { - fn from(source: Error) -> Self { - match source { - Error::UnknownConfigurationKey { key } => { - Self::UnknownConfigurationKey { store: STORE, key } - } - _ => Self::Generic { - store: STORE, - source: Box::new(source), - }, - } - } -} - -/// Get the bucket region using the [HeadBucket API]. This will fail if the bucket does not exist. -/// -/// [HeadBucket API]: https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadBucket.html -pub async fn resolve_bucket_region( - bucket: &str, - client_options: &ClientOptions, -) -> Result { - use reqwest::StatusCode; - - let endpoint = format!("https://{}.s3.amazonaws.com", bucket); - - let client = client_options.client()?; - - let response = client - .head(&endpoint) - .send() - .await - .context(ResolveRegionSnafu { bucket })?; - - ensure!( - response.status() != StatusCode::NOT_FOUND, - BucketNotFoundSnafu { bucket } - ); - - let region = response - .headers() - .get("x-amz-bucket-region") - .and_then(|x| x.to_str().ok()) - .context(RegionParseSnafu { bucket })?; - - Ok(region.to_string()) -} - /// Interface for [Amazon S3](https://aws.amazon.com/s3/). #[derive(Debug)] pub struct AmazonS3 { @@ -256,8 +152,10 @@ impl Signer for AmazonS3 { AwsAuthorizer::new(&credential, "s3", &self.client.config().region); let path_url = self.path_url(path); - let mut url = - Url::parse(&path_url).context(UnableToParseUrlSnafu { url: path_url })?; + let mut url = Url::parse(&path_url).map_err(|e| crate::Error::Generic { + store: STORE, + source: format!("Unable to parse url {path_url}: {e}").into(), + })?; authorizer.sign(method, &mut url, expires_in); @@ -381,891 +279,23 @@ impl PutPart for S3MultiPartUpload { } } -/// Configure a connection to Amazon S3 using the specified credentials in -/// the specified Amazon region and bucket. -/// -/// # Example -/// ``` -/// # let REGION = "foo"; -/// # let BUCKET_NAME = "foo"; -/// # let ACCESS_KEY_ID = "foo"; -/// # let SECRET_KEY = "foo"; -/// # use object_store::aws::AmazonS3Builder; -/// let s3 = AmazonS3Builder::new() -/// .with_region(REGION) -/// .with_bucket_name(BUCKET_NAME) -/// .with_access_key_id(ACCESS_KEY_ID) -/// .with_secret_access_key(SECRET_KEY) -/// .build(); -/// ``` -#[derive(Debug, Default, Clone)] -pub struct AmazonS3Builder { - /// Access key id - access_key_id: Option, - /// Secret access_key - secret_access_key: Option, - /// Region - region: Option, - /// Bucket name - bucket_name: Option, - /// Endpoint for communicating with AWS S3 - endpoint: Option, - /// Token to use for requests - token: Option, - /// Url - url: Option, - /// Retry config - retry_config: RetryConfig, - /// When set to true, fallback to IMDSv1 - imdsv1_fallback: ConfigValue, - /// When set to true, virtual hosted style request has to be used - virtual_hosted_style_request: ConfigValue, - /// When set to true, unsigned payload option has to be used - unsigned_payload: ConfigValue, - /// Checksum algorithm which has to be used for object integrity check during upload - checksum_algorithm: Option>, - /// Metadata endpoint, see - metadata_endpoint: Option, - /// Container credentials URL, see - container_credentials_relative_uri: Option, - /// Client options - client_options: ClientOptions, - /// Credentials - credentials: Option, - /// Skip signing requests - skip_signature: ConfigValue, - /// Copy if not exists - copy_if_not_exists: Option>, -} - -/// Configuration keys for [`AmazonS3Builder`] -/// -/// Configuration via keys can be done via [`AmazonS3Builder::with_config`] -/// -/// # Example -/// ``` -/// # use object_store::aws::{AmazonS3Builder, AmazonS3ConfigKey}; -/// let builder = AmazonS3Builder::new() -/// .with_config("aws_access_key_id".parse().unwrap(), "my-access-key-id") -/// .with_config(AmazonS3ConfigKey::DefaultRegion, "my-default-region"); -/// ``` -#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Serialize, Deserialize)] -#[non_exhaustive] -pub enum AmazonS3ConfigKey { - /// AWS Access Key - /// - /// See [`AmazonS3Builder::with_access_key_id`] for details. - /// - /// Supported keys: - /// - `aws_access_key_id` - /// - `access_key_id` - AccessKeyId, - - /// Secret Access Key - /// - /// See [`AmazonS3Builder::with_secret_access_key`] for details. - /// - /// Supported keys: - /// - `aws_secret_access_key` - /// - `secret_access_key` - SecretAccessKey, - - /// Region - /// - /// See [`AmazonS3Builder::with_region`] for details. - /// - /// Supported keys: - /// - `aws_region` - /// - `region` - Region, - - /// Default region - /// - /// See [`AmazonS3Builder::with_region`] for details. - /// - /// Supported keys: - /// - `aws_default_region` - /// - `default_region` - DefaultRegion, - - /// Bucket name - /// - /// See [`AmazonS3Builder::with_bucket_name`] for details. - /// - /// Supported keys: - /// - `aws_bucket` - /// - `aws_bucket_name` - /// - `bucket` - /// - `bucket_name` - Bucket, - - /// Sets custom endpoint for communicating with AWS S3. - /// - /// See [`AmazonS3Builder::with_endpoint`] for details. - /// - /// Supported keys: - /// - `aws_endpoint` - /// - `aws_endpoint_url` - /// - `endpoint` - /// - `endpoint_url` - Endpoint, - - /// Token to use for requests (passed to underlying provider) - /// - /// See [`AmazonS3Builder::with_token`] for details. - /// - /// Supported keys: - /// - `aws_session_token` - /// - `aws_token` - /// - `session_token` - /// - `token` - Token, - - /// Fall back to ImdsV1 - /// - /// See [`AmazonS3Builder::with_imdsv1_fallback`] for details. - /// - /// Supported keys: - /// - `aws_imdsv1_fallback` - /// - `imdsv1_fallback` - ImdsV1Fallback, - - /// If virtual hosted style request has to be used - /// - /// See [`AmazonS3Builder::with_virtual_hosted_style_request`] for details. - /// - /// Supported keys: - /// - `aws_virtual_hosted_style_request` - /// - `virtual_hosted_style_request` - VirtualHostedStyleRequest, - - /// Avoid computing payload checksum when calculating signature. - /// - /// See [`AmazonS3Builder::with_unsigned_payload`] for details. - /// - /// Supported keys: - /// - `aws_unsigned_payload` - /// - `unsigned_payload` - UnsignedPayload, - - /// Set the checksum algorithm for this client - /// - /// See [`AmazonS3Builder::with_checksum_algorithm`] - Checksum, - - /// Set the instance metadata endpoint - /// - /// See [`AmazonS3Builder::with_metadata_endpoint`] for details. - /// - /// Supported keys: - /// - `aws_metadata_endpoint` - /// - `metadata_endpoint` - MetadataEndpoint, - - /// Set the container credentials relative URI - /// - /// - ContainerCredentialsRelativeUri, - - /// Configure how to provide [`ObjectStore::copy_if_not_exists`] - /// - /// See [`S3CopyIfNotExists`] - CopyIfNotExists, - - /// Skip signing request - SkipSignature, - - /// Client options - Client(ClientConfigKey), -} - -impl AsRef for AmazonS3ConfigKey { - fn as_ref(&self) -> &str { - match self { - Self::AccessKeyId => "aws_access_key_id", - Self::SecretAccessKey => "aws_secret_access_key", - Self::Region => "aws_region", - Self::Bucket => "aws_bucket", - Self::Endpoint => "aws_endpoint", - Self::Token => "aws_session_token", - Self::ImdsV1Fallback => "aws_imdsv1_fallback", - Self::VirtualHostedStyleRequest => "aws_virtual_hosted_style_request", - Self::DefaultRegion => "aws_default_region", - Self::MetadataEndpoint => "aws_metadata_endpoint", - Self::UnsignedPayload => "aws_unsigned_payload", - Self::Checksum => "aws_checksum_algorithm", - Self::ContainerCredentialsRelativeUri => { - "aws_container_credentials_relative_uri" - } - Self::SkipSignature => "aws_skip_signature", - Self::CopyIfNotExists => "copy_if_not_exists", - Self::Client(opt) => opt.as_ref(), - } - } -} - -impl FromStr for AmazonS3ConfigKey { - type Err = super::Error; - - fn from_str(s: &str) -> Result { - match s { - "aws_access_key_id" | "access_key_id" => Ok(Self::AccessKeyId), - "aws_secret_access_key" | "secret_access_key" => Ok(Self::SecretAccessKey), - "aws_default_region" | "default_region" => Ok(Self::DefaultRegion), - "aws_region" | "region" => Ok(Self::Region), - "aws_bucket" | "aws_bucket_name" | "bucket_name" | "bucket" => { - Ok(Self::Bucket) - } - "aws_endpoint_url" | "aws_endpoint" | "endpoint_url" | "endpoint" => { - Ok(Self::Endpoint) - } - "aws_session_token" | "aws_token" | "session_token" | "token" => { - Ok(Self::Token) - } - "aws_virtual_hosted_style_request" | "virtual_hosted_style_request" => { - Ok(Self::VirtualHostedStyleRequest) - } - "aws_imdsv1_fallback" | "imdsv1_fallback" => Ok(Self::ImdsV1Fallback), - "aws_metadata_endpoint" | "metadata_endpoint" => Ok(Self::MetadataEndpoint), - "aws_unsigned_payload" | "unsigned_payload" => Ok(Self::UnsignedPayload), - "aws_checksum_algorithm" | "checksum_algorithm" => Ok(Self::Checksum), - "aws_container_credentials_relative_uri" => { - Ok(Self::ContainerCredentialsRelativeUri) - } - "aws_skip_signature" | "skip_signature" => Ok(Self::SkipSignature), - "copy_if_not_exists" => Ok(Self::CopyIfNotExists), - // Backwards compatibility - "aws_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), - _ => match s.parse() { - Ok(key) => Ok(Self::Client(key)), - Err(_) => Err(Error::UnknownConfigurationKey { key: s.into() }.into()), - }, - } - } -} - -impl AmazonS3Builder { - /// Create a new [`AmazonS3Builder`] with default values. - pub fn new() -> Self { - Default::default() - } - - /// Fill the [`AmazonS3Builder`] with regular AWS environment variables - /// - /// Variables extracted from environment: - /// * `AWS_ACCESS_KEY_ID` -> access_key_id - /// * `AWS_SECRET_ACCESS_KEY` -> secret_access_key - /// * `AWS_DEFAULT_REGION` -> region - /// * `AWS_ENDPOINT` -> endpoint - /// * `AWS_SESSION_TOKEN` -> token - /// * `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI` -> - /// * `AWS_ALLOW_HTTP` -> set to "true" to permit HTTP connections without TLS - /// # Example - /// ``` - /// use object_store::aws::AmazonS3Builder; - /// - /// let s3 = AmazonS3Builder::from_env() - /// .with_bucket_name("foo") - /// .build(); - /// ``` - pub fn from_env() -> Self { - let mut builder: Self = Default::default(); - - for (os_key, os_value) in std::env::vars_os() { - if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { - if key.starts_with("AWS_") { - if let Ok(config_key) = key.to_ascii_lowercase().parse() { - builder = builder.with_config(config_key, value); - } - } - } - } - - builder - } - - /// Parse available connection info form a well-known storage URL. - /// - /// The supported url schemes are: - /// - /// - `s3:///` - /// - `s3a:///` - /// - `https://s3..amazonaws.com/` - /// - `https://.s3..amazonaws.com` - /// - `https://ACCOUNT_ID.r2.cloudflarestorage.com/bucket` - /// - /// Note: Settings derived from the URL will override any others set on this builder - /// - /// # Example - /// ``` - /// use object_store::aws::AmazonS3Builder; - /// - /// let s3 = AmazonS3Builder::from_env() - /// .with_url("s3://bucket/path") - /// .build(); - /// ``` - pub fn with_url(mut self, url: impl Into) -> Self { - self.url = Some(url.into()); - self - } - - /// Set an option on the builder via a key - value pair. - pub fn with_config( - mut self, - key: AmazonS3ConfigKey, - value: impl Into, - ) -> Self { - match key { - AmazonS3ConfigKey::AccessKeyId => self.access_key_id = Some(value.into()), - AmazonS3ConfigKey::SecretAccessKey => { - self.secret_access_key = Some(value.into()) - } - AmazonS3ConfigKey::Region => self.region = Some(value.into()), - AmazonS3ConfigKey::Bucket => self.bucket_name = Some(value.into()), - AmazonS3ConfigKey::Endpoint => self.endpoint = Some(value.into()), - AmazonS3ConfigKey::Token => self.token = Some(value.into()), - AmazonS3ConfigKey::ImdsV1Fallback => self.imdsv1_fallback.parse(value), - AmazonS3ConfigKey::VirtualHostedStyleRequest => { - self.virtual_hosted_style_request.parse(value) - } - AmazonS3ConfigKey::DefaultRegion => { - self.region = self.region.or_else(|| Some(value.into())) - } - AmazonS3ConfigKey::MetadataEndpoint => { - self.metadata_endpoint = Some(value.into()) - } - AmazonS3ConfigKey::UnsignedPayload => self.unsigned_payload.parse(value), - AmazonS3ConfigKey::Checksum => { - self.checksum_algorithm = Some(ConfigValue::Deferred(value.into())) - } - AmazonS3ConfigKey::ContainerCredentialsRelativeUri => { - self.container_credentials_relative_uri = Some(value.into()) - } - AmazonS3ConfigKey::Client(key) => { - self.client_options = self.client_options.with_config(key, value) - } - AmazonS3ConfigKey::SkipSignature => self.skip_signature.parse(value), - AmazonS3ConfigKey::CopyIfNotExists => { - self.copy_if_not_exists = Some(ConfigValue::Deferred(value.into())) - } - }; - self - } - - /// Set an option on the builder via a key - value pair. - /// - /// This method will return an `UnknownConfigKey` error if key cannot be parsed into [`AmazonS3ConfigKey`]. - #[deprecated(note = "Use with_config")] - pub fn try_with_option( - self, - key: impl AsRef, - value: impl Into, - ) -> Result { - Ok(self.with_config(key.as_ref().parse()?, value)) - } - - /// Hydrate builder from key value pairs - /// - /// This method will return an `UnknownConfigKey` error if any key cannot be parsed into [`AmazonS3ConfigKey`]. - #[deprecated(note = "Use with_config")] - #[allow(deprecated)] - pub fn try_with_options< - I: IntoIterator, impl Into)>, - >( - mut self, - options: I, - ) -> Result { - for (key, value) in options { - self = self.try_with_option(key, value)?; - } - Ok(self) - } - - /// Get config value via a [`AmazonS3ConfigKey`]. - /// - /// # Example - /// ``` - /// use object_store::aws::{AmazonS3Builder, AmazonS3ConfigKey}; - /// - /// let builder = AmazonS3Builder::from_env() - /// .with_bucket_name("foo"); - /// let bucket_name = builder.get_config_value(&AmazonS3ConfigKey::Bucket).unwrap_or_default(); - /// assert_eq!("foo", &bucket_name); - /// ``` - pub fn get_config_value(&self, key: &AmazonS3ConfigKey) -> Option { - match key { - AmazonS3ConfigKey::AccessKeyId => self.access_key_id.clone(), - AmazonS3ConfigKey::SecretAccessKey => self.secret_access_key.clone(), - AmazonS3ConfigKey::Region | AmazonS3ConfigKey::DefaultRegion => { - self.region.clone() - } - AmazonS3ConfigKey::Bucket => self.bucket_name.clone(), - AmazonS3ConfigKey::Endpoint => self.endpoint.clone(), - AmazonS3ConfigKey::Token => self.token.clone(), - AmazonS3ConfigKey::ImdsV1Fallback => Some(self.imdsv1_fallback.to_string()), - AmazonS3ConfigKey::VirtualHostedStyleRequest => { - Some(self.virtual_hosted_style_request.to_string()) - } - AmazonS3ConfigKey::MetadataEndpoint => self.metadata_endpoint.clone(), - AmazonS3ConfigKey::UnsignedPayload => Some(self.unsigned_payload.to_string()), - AmazonS3ConfigKey::Checksum => { - self.checksum_algorithm.as_ref().map(ToString::to_string) - } - AmazonS3ConfigKey::Client(key) => self.client_options.get_config_value(key), - AmazonS3ConfigKey::ContainerCredentialsRelativeUri => { - self.container_credentials_relative_uri.clone() - } - AmazonS3ConfigKey::SkipSignature => Some(self.skip_signature.to_string()), - AmazonS3ConfigKey::CopyIfNotExists => { - self.copy_if_not_exists.as_ref().map(ToString::to_string) - } - } - } - - /// Sets properties on this builder based on a URL - /// - /// This is a separate member function to allow fallible computation to - /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] - fn parse_url(&mut self, url: &str) -> Result<()> { - let parsed = Url::parse(url).context(UnableToParseUrlSnafu { url })?; - let host = parsed.host_str().context(UrlNotRecognisedSnafu { url })?; - match parsed.scheme() { - "s3" | "s3a" => self.bucket_name = Some(host.to_string()), - "https" => match host.splitn(4, '.').collect_tuple() { - Some(("s3", region, "amazonaws", "com")) => { - self.region = Some(region.to_string()); - let bucket = parsed.path_segments().into_iter().flatten().next(); - if let Some(bucket) = bucket { - self.bucket_name = Some(bucket.into()); - } - } - Some((bucket, "s3", region, "amazonaws.com")) => { - self.bucket_name = Some(bucket.to_string()); - self.region = Some(region.to_string()); - self.virtual_hosted_style_request = true.into(); - } - Some((account, "r2", "cloudflarestorage", "com")) => { - self.region = Some("auto".to_string()); - let endpoint = format!("https://{account}.r2.cloudflarestorage.com"); - self.endpoint = Some(endpoint); - - let bucket = parsed.path_segments().into_iter().flatten().next(); - if let Some(bucket) = bucket { - self.bucket_name = Some(bucket.into()); - } - } - _ => return Err(UrlNotRecognisedSnafu { url }.build().into()), - }, - scheme => return Err(UnknownUrlSchemeSnafu { scheme }.build().into()), - }; - Ok(()) - } - - /// Set the AWS Access Key (required) - pub fn with_access_key_id(mut self, access_key_id: impl Into) -> Self { - self.access_key_id = Some(access_key_id.into()); - self - } - - /// Set the AWS Secret Access Key (required) - pub fn with_secret_access_key( - mut self, - secret_access_key: impl Into, - ) -> Self { - self.secret_access_key = Some(secret_access_key.into()); - self - } - - /// Set the region (e.g. `us-east-1`) (required) - pub fn with_region(mut self, region: impl Into) -> Self { - self.region = Some(region.into()); - self - } - - /// Set the bucket_name (required) - pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { - self.bucket_name = Some(bucket_name.into()); - self - } - - /// Sets the endpoint for communicating with AWS S3. Default value - /// is based on region. The `endpoint` field should be consistent with - /// the field `virtual_hosted_style_request'. - /// - /// For example, this might be set to `"http://localhost:4566:` - /// for testing against a localstack instance. - /// If `virtual_hosted_style_request` is set to true then `endpoint` - /// should have bucket name included. - pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { - self.endpoint = Some(endpoint.into()); - self - } - - /// Set the token to use for requests (passed to underlying provider) - pub fn with_token(mut self, token: impl Into) -> Self { - self.token = Some(token.into()); - self - } - - /// Set the credential provider overriding any other options - pub fn with_credentials(mut self, credentials: AwsCredentialProvider) -> Self { - self.credentials = Some(credentials); - self - } - - /// Sets what protocol is allowed. If `allow_http` is : - /// * false (default): Only HTTPS are allowed - /// * true: HTTP and HTTPS are allowed - pub fn with_allow_http(mut self, allow_http: bool) -> Self { - self.client_options = self.client_options.with_allow_http(allow_http); - self - } - - /// Sets if virtual hosted style request has to be used. - /// If `virtual_hosted_style_request` is : - /// * false (default): Path style request is used - /// * true: Virtual hosted style request is used - /// - /// If the `endpoint` is provided then it should be - /// consistent with `virtual_hosted_style_request`. - /// i.e. if `virtual_hosted_style_request` is set to true - /// then `endpoint` should have bucket name included. - pub fn with_virtual_hosted_style_request( - mut self, - virtual_hosted_style_request: bool, - ) -> Self { - self.virtual_hosted_style_request = virtual_hosted_style_request.into(); - self - } - - /// Set the retry configuration - pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { - self.retry_config = retry_config; - self - } - - /// By default instance credentials will only be fetched over [IMDSv2], as AWS recommends - /// against having IMDSv1 enabled on EC2 instances as it is vulnerable to [SSRF attack] - /// - /// However, certain deployment environments, such as those running old versions of kube2iam, - /// may not support IMDSv2. This option will enable automatic fallback to using IMDSv1 - /// if the token endpoint returns a 403 error indicating that IMDSv2 is not supported. - /// - /// This option has no effect if not using instance credentials - /// - /// [IMDSv2]: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html - /// [SSRF attack]: https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/ - /// - pub fn with_imdsv1_fallback(mut self) -> Self { - self.imdsv1_fallback = true.into(); - self - } - - /// Sets if unsigned payload option has to be used. - /// See [unsigned payload option](https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html) - /// * false (default): Signed payload option is used, where the checksum for the request body is computed and included when constructing a canonical request. - /// * true: Unsigned payload option is used. `UNSIGNED-PAYLOAD` literal is included when constructing a canonical request, - pub fn with_unsigned_payload(mut self, unsigned_payload: bool) -> Self { - self.unsigned_payload = unsigned_payload.into(); - self - } - - /// If enabled, [`AmazonS3`] will not fetch credentials and will not sign requests - /// - /// This can be useful when interacting with public S3 buckets that deny authorized requests - pub fn with_skip_signature(mut self, skip_signature: bool) -> Self { - self.skip_signature = skip_signature.into(); - self - } - - /// Sets the [checksum algorithm] which has to be used for object integrity check during upload. - /// - /// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html - pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) -> Self { - // Convert to String to enable deferred parsing of config - self.checksum_algorithm = Some(checksum_algorithm.into()); - self - } - - /// Set the [instance metadata endpoint](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html), - /// used primarily within AWS EC2. - /// - /// This defaults to the IPv4 endpoint: http://169.254.169.254. One can alternatively use the IPv6 - /// endpoint http://fd00:ec2::254. - pub fn with_metadata_endpoint(mut self, endpoint: impl Into) -> Self { - self.metadata_endpoint = Some(endpoint.into()); - self - } - - /// Set the proxy_url to be used by the underlying client - pub fn with_proxy_url(mut self, proxy_url: impl Into) -> Self { - self.client_options = self.client_options.with_proxy_url(proxy_url); - self - } - - /// Set a trusted proxy CA certificate - pub fn with_proxy_ca_certificate( - mut self, - proxy_ca_certificate: impl Into, - ) -> Self { - self.client_options = self - .client_options - .with_proxy_ca_certificate(proxy_ca_certificate); - self - } - - /// Set a list of hosts to exclude from proxy connections - pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { - self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); - self - } - - /// Sets the client options, overriding any already set - pub fn with_client_options(mut self, options: ClientOptions) -> Self { - self.client_options = options; - self - } - - /// Configure how to provide [`ObjectStore::copy_if_not_exists`] - pub fn with_copy_if_not_exists(mut self, config: S3CopyIfNotExists) -> Self { - self.copy_if_not_exists = Some(config.into()); - self - } - - /// Create a [`AmazonS3`] instance from the provided values, - /// consuming `self`. - pub fn build(mut self) -> Result { - if let Some(url) = self.url.take() { - self.parse_url(&url)?; - } - - let bucket = self.bucket_name.context(MissingBucketNameSnafu)?; - let region = self.region.context(MissingRegionSnafu)?; - let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?; - let copy_if_not_exists = self.copy_if_not_exists.map(|x| x.get()).transpose()?; - - let credentials = if let Some(credentials) = self.credentials { - credentials - } else if self.access_key_id.is_some() || self.secret_access_key.is_some() { - match (self.access_key_id, self.secret_access_key, self.token) { - (Some(key_id), Some(secret_key), token) => { - info!("Using Static credential provider"); - let credential = AwsCredential { - key_id, - secret_key, - token, - }; - Arc::new(StaticCredentialProvider::new(credential)) as _ - } - (None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()), - (Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()), - (None, None, _) => unreachable!(), - } - } else if let (Ok(token_path), Ok(role_arn)) = ( - std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE"), - std::env::var("AWS_ROLE_ARN"), - ) { - // TODO: Replace with `AmazonS3Builder::credentials_from_env` - info!("Using WebIdentity credential provider"); - - let session_name = std::env::var("AWS_ROLE_SESSION_NAME") - .unwrap_or_else(|_| "WebIdentitySession".to_string()); - - let endpoint = format!("https://sts.{region}.amazonaws.com"); - - // Disallow non-HTTPs requests - let client = self - .client_options - .clone() - .with_allow_http(false) - .client()?; - - let token = WebIdentityProvider { - token_path, - session_name, - role_arn, - endpoint, - }; - - Arc::new(TokenCredentialProvider::new( - token, - client, - self.retry_config.clone(), - )) as _ - } else if let Some(uri) = self.container_credentials_relative_uri { - info!("Using Task credential provider"); - Arc::new(TaskCredentialProvider { - url: format!("http://169.254.170.2{uri}"), - retry: self.retry_config.clone(), - // The instance metadata endpoint is access over HTTP - client: self.client_options.clone().with_allow_http(true).client()?, - cache: Default::default(), - }) as _ - } else { - info!("Using Instance credential provider"); - - let token = InstanceCredentialProvider { - cache: Default::default(), - imdsv1_fallback: self.imdsv1_fallback.get()?, - metadata_endpoint: self - .metadata_endpoint - .unwrap_or_else(|| DEFAULT_METADATA_ENDPOINT.into()), - }; - - Arc::new(TokenCredentialProvider::new( - token, - self.client_options.metadata_client()?, - self.retry_config.clone(), - )) as _ - }; - - let endpoint: String; - let bucket_endpoint: String; - - // If `endpoint` is provided then its assumed to be consistent with - // `virtual_hosted_style_request`. i.e. if `virtual_hosted_style_request` is true then - // `endpoint` should have bucket name included. - if self.virtual_hosted_style_request.get()? { - endpoint = self - .endpoint - .unwrap_or_else(|| format!("https://{bucket}.s3.{region}.amazonaws.com")); - bucket_endpoint = endpoint.clone(); - } else { - endpoint = self - .endpoint - .unwrap_or_else(|| format!("https://s3.{region}.amazonaws.com")); - bucket_endpoint = format!("{endpoint}/{bucket}"); - } - - let config = S3Config { - region, - endpoint, - bucket, - bucket_endpoint, - credentials, - retry_config: self.retry_config, - client_options: self.client_options, - sign_payload: !self.unsigned_payload.get()?, - skip_signature: self.skip_signature.get()?, - checksum, - copy_if_not_exists, - }; - - let client = Arc::new(S3Client::new(config)?); - - Ok(AmazonS3 { client }) - } -} - #[cfg(test)] mod tests { use super::*; - use crate::tests::{ - copy_if_not_exists, get_nonexistent_object, get_opts, - list_uses_directories_correctly, list_with_delimiter, put_get_delete_list_opts, - rename_and_copy, stream_get, - }; + use crate::tests::*; use bytes::Bytes; - use std::collections::HashMap; const NON_EXISTENT_NAME: &str = "nonexistentname"; - #[test] - fn s3_test_config_from_map() { - let aws_access_key_id = "object_store:fake_access_key_id".to_string(); - let aws_secret_access_key = "object_store:fake_secret_key".to_string(); - let aws_default_region = "object_store:fake_default_region".to_string(); - let aws_endpoint = "object_store:fake_endpoint".to_string(); - let aws_session_token = "object_store:fake_session_token".to_string(); - let options = HashMap::from([ - ("aws_access_key_id", aws_access_key_id.clone()), - ("aws_secret_access_key", aws_secret_access_key), - ("aws_default_region", aws_default_region.clone()), - ("aws_endpoint", aws_endpoint.clone()), - ("aws_session_token", aws_session_token.clone()), - ("aws_unsigned_payload", "true".to_string()), - ("aws_checksum_algorithm", "sha256".to_string()), - ]); - - let builder = options - .into_iter() - .fold(AmazonS3Builder::new(), |builder, (key, value)| { - builder.with_config(key.parse().unwrap(), value) - }) - .with_config(AmazonS3ConfigKey::SecretAccessKey, "new-secret-key"); - - assert_eq!(builder.access_key_id.unwrap(), aws_access_key_id.as_str()); - assert_eq!(builder.secret_access_key.unwrap(), "new-secret-key"); - assert_eq!(builder.region.unwrap(), aws_default_region); - assert_eq!(builder.endpoint.unwrap(), aws_endpoint); - assert_eq!(builder.token.unwrap(), aws_session_token); - assert_eq!( - builder.checksum_algorithm.unwrap().get().unwrap(), - Checksum::SHA256 - ); - assert!(builder.unsigned_payload.get().unwrap()); - } - - #[test] - fn s3_test_config_get_value() { - let aws_access_key_id = "object_store:fake_access_key_id".to_string(); - let aws_secret_access_key = "object_store:fake_secret_key".to_string(); - let aws_default_region = "object_store:fake_default_region".to_string(); - let aws_endpoint = "object_store:fake_endpoint".to_string(); - let aws_session_token = "object_store:fake_session_token".to_string(); - - let builder = AmazonS3Builder::new() - .with_config(AmazonS3ConfigKey::AccessKeyId, &aws_access_key_id) - .with_config(AmazonS3ConfigKey::SecretAccessKey, &aws_secret_access_key) - .with_config(AmazonS3ConfigKey::DefaultRegion, &aws_default_region) - .with_config(AmazonS3ConfigKey::Endpoint, &aws_endpoint) - .with_config(AmazonS3ConfigKey::Token, &aws_session_token) - .with_config(AmazonS3ConfigKey::UnsignedPayload, "true"); - - assert_eq!( - builder - .get_config_value(&AmazonS3ConfigKey::AccessKeyId) - .unwrap(), - aws_access_key_id - ); - assert_eq!( - builder - .get_config_value(&AmazonS3ConfigKey::SecretAccessKey) - .unwrap(), - aws_secret_access_key - ); - assert_eq!( - builder - .get_config_value(&AmazonS3ConfigKey::DefaultRegion) - .unwrap(), - aws_default_region - ); - assert_eq!( - builder - .get_config_value(&AmazonS3ConfigKey::Endpoint) - .unwrap(), - aws_endpoint - ); - assert_eq!( - builder.get_config_value(&AmazonS3ConfigKey::Token).unwrap(), - aws_session_token - ); - assert_eq!( - builder - .get_config_value(&AmazonS3ConfigKey::UnsignedPayload) - .unwrap(), - "true" - ); - } - #[tokio::test] async fn s3_test() { crate::test_util::maybe_skip_integration!(); let config = AmazonS3Builder::from_env(); - let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://")); - let test_not_exists = config.copy_if_not_exists.is_some(); let integration = config.build().unwrap(); + let config = integration.client.config(); + let is_local = config.endpoint.starts_with("http://"); + let test_not_exists = config.copy_if_not_exists.is_some(); // Localstack doesn't support listing with spaces https://github.com/localstack/localstack/issues/6328 put_get_delete_list_opts(&integration, is_local).await; @@ -1279,16 +309,14 @@ mod tests { } // run integration test with unsigned payload enabled - let config = AmazonS3Builder::from_env().with_unsigned_payload(true); - let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://")); - let integration = config.build().unwrap(); + let builder = AmazonS3Builder::from_env().with_unsigned_payload(true); + let integration = builder.build().unwrap(); put_get_delete_list_opts(&integration, is_local).await; // run integration test with checksum set to sha256 - let config = + let builder = AmazonS3Builder::from_env().with_checksum_algorithm(Checksum::SHA256); - let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://")); - let integration = config.build().unwrap(); + let integration = builder.build().unwrap(); put_get_delete_list_opts(&integration, is_local).await; } @@ -1352,161 +380,6 @@ mod tests { assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); } - #[tokio::test] - async fn s3_test_proxy_url() { - let s3 = AmazonS3Builder::new() - .with_access_key_id("access_key_id") - .with_secret_access_key("secret_access_key") - .with_region("region") - .with_bucket_name("bucket_name") - .with_allow_http(true) - .with_proxy_url("https://example.com") - .build(); - - assert!(s3.is_ok()); - - let err = AmazonS3Builder::new() - .with_access_key_id("access_key_id") - .with_secret_access_key("secret_access_key") - .with_region("region") - .with_bucket_name("bucket_name") - .with_allow_http(true) - .with_proxy_url("asdf://example.com") - .build() - .unwrap_err() - .to_string(); - - assert_eq!( - "Generic HTTP client error: builder error: unknown proxy scheme", - err - ); - } - - #[test] - fn s3_test_urls() { - let mut builder = AmazonS3Builder::new(); - builder.parse_url("s3://bucket/path").unwrap(); - assert_eq!(builder.bucket_name, Some("bucket".to_string())); - - let mut builder = AmazonS3Builder::new(); - builder - .parse_url("s3://buckets.can.have.dots/path") - .unwrap(); - assert_eq!( - builder.bucket_name, - Some("buckets.can.have.dots".to_string()) - ); - - let mut builder = AmazonS3Builder::new(); - builder - .parse_url("https://s3.region.amazonaws.com") - .unwrap(); - assert_eq!(builder.region, Some("region".to_string())); - - let mut builder = AmazonS3Builder::new(); - builder - .parse_url("https://s3.region.amazonaws.com/bucket") - .unwrap(); - assert_eq!(builder.region, Some("region".to_string())); - assert_eq!(builder.bucket_name, Some("bucket".to_string())); - - let mut builder = AmazonS3Builder::new(); - builder - .parse_url("https://s3.region.amazonaws.com/bucket.with.dot/path") - .unwrap(); - assert_eq!(builder.region, Some("region".to_string())); - assert_eq!(builder.bucket_name, Some("bucket.with.dot".to_string())); - - let mut builder = AmazonS3Builder::new(); - builder - .parse_url("https://bucket.s3.region.amazonaws.com") - .unwrap(); - assert_eq!(builder.bucket_name, Some("bucket".to_string())); - assert_eq!(builder.region, Some("region".to_string())); - assert!(builder.virtual_hosted_style_request.get().unwrap()); - - let mut builder = AmazonS3Builder::new(); - builder - .parse_url("https://account123.r2.cloudflarestorage.com/bucket-123") - .unwrap(); - - assert_eq!(builder.bucket_name, Some("bucket-123".to_string())); - assert_eq!(builder.region, Some("auto".to_string())); - assert_eq!( - builder.endpoint, - Some("https://account123.r2.cloudflarestorage.com".to_string()) - ); - - let err_cases = [ - "mailto://bucket/path", - "https://s3.bucket.mydomain.com", - "https://s3.bucket.foo.amazonaws.com", - "https://bucket.mydomain.region.amazonaws.com", - "https://bucket.s3.region.bar.amazonaws.com", - "https://bucket.foo.s3.amazonaws.com", - ]; - let mut builder = AmazonS3Builder::new(); - for case in err_cases { - builder.parse_url(case).unwrap_err(); - } - } - - #[test] - fn test_invalid_config() { - let err = AmazonS3Builder::new() - .with_config(AmazonS3ConfigKey::ImdsV1Fallback, "enabled") - .with_bucket_name("bucket") - .with_region("region") - .build() - .unwrap_err() - .to_string(); - - assert_eq!( - err, - "Generic Config error: failed to parse \"enabled\" as boolean" - ); - - let err = AmazonS3Builder::new() - .with_config(AmazonS3ConfigKey::Checksum, "md5") - .with_bucket_name("bucket") - .with_region("region") - .build() - .unwrap_err() - .to_string(); - - assert_eq!( - err, - "Generic Config error: \"md5\" is not a valid checksum algorithm" - ); - } -} - -#[cfg(test)] -mod s3_resolve_bucket_region_tests { - use super::*; - - #[tokio::test] - async fn test_private_bucket() { - let bucket = "bloxbender"; - - let region = resolve_bucket_region(bucket, &ClientOptions::new()) - .await - .unwrap(); - - let expected = "us-west-2".to_string(); - - assert_eq!(region, expected); - } - - #[tokio::test] - async fn test_bucket_does_not_exist() { - let bucket = "please-dont-exist"; - - let result = resolve_bucket_region(bucket, &ClientOptions::new()).await; - - assert!(result.is_err()); - } - #[tokio::test] #[ignore = "Tests shouldn't call use remote services by default"] async fn test_disable_creds() { diff --git a/object_store/src/aws/resolve.rs b/object_store/src/aws/resolve.rs new file mode 100644 index 000000000000..2b21fabd34ab --- /dev/null +++ b/object_store/src/aws/resolve.rs @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aws::STORE; +use crate::{ClientOptions, Result}; +use snafu::{ensure, OptionExt, ResultExt, Snafu}; + +/// A specialized `Error` for object store-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +enum Error { + #[snafu(display("Bucket '{}' not found", bucket))] + BucketNotFound { bucket: String }, + + #[snafu(display("Failed to resolve region for bucket '{}'", bucket))] + ResolveRegion { + bucket: String, + source: reqwest::Error, + }, + + #[snafu(display("Failed to parse the region for bucket '{}'", bucket))] + RegionParse { bucket: String }, +} + +impl From for crate::Error { + fn from(source: Error) -> Self { + Self::Generic { + store: STORE, + source: Box::new(source), + } + } +} + +/// Get the bucket region using the [HeadBucket API]. This will fail if the bucket does not exist. +/// +/// [HeadBucket API]: https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadBucket.html +pub async fn resolve_bucket_region( + bucket: &str, + client_options: &ClientOptions, +) -> Result { + use reqwest::StatusCode; + + let endpoint = format!("https://{}.s3.amazonaws.com", bucket); + + let client = client_options.client()?; + + let response = client + .head(&endpoint) + .send() + .await + .context(ResolveRegionSnafu { bucket })?; + + ensure!( + response.status() != StatusCode::NOT_FOUND, + BucketNotFoundSnafu { bucket } + ); + + let region = response + .headers() + .get("x-amz-bucket-region") + .and_then(|x| x.to_str().ok()) + .context(RegionParseSnafu { bucket })?; + + Ok(region.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_private_bucket() { + let bucket = "bloxbender"; + + let region = resolve_bucket_region(bucket, &ClientOptions::new()) + .await + .unwrap(); + + let expected = "us-west-2".to_string(); + + assert_eq!(region, expected); + } + + #[tokio::test] + async fn test_bucket_does_not_exist() { + let bucket = "please-dont-exist"; + + let result = resolve_bucket_region(bucket, &ClientOptions::new()).await; + + assert!(result.is_err()); + } +} From a425e7e7faf82032abc85cf570c863974db5bb66 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:41:03 +0100 Subject: [PATCH 19/25] Split azure Module (#4954) * Split azure module * Format * Docs --- object_store/src/azure/builder.rs | 1101 +++++++++++++++++++++++++++++ object_store/src/azure/mod.rs | 1081 +--------------------------- 2 files changed, 1112 insertions(+), 1070 deletions(-) create mode 100644 object_store/src/azure/builder.rs diff --git a/object_store/src/azure/builder.rs b/object_store/src/azure/builder.rs new file mode 100644 index 000000000000..eb2de147f3ad --- /dev/null +++ b/object_store/src/azure/builder.rs @@ -0,0 +1,1101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::azure::client::{AzureClient, AzureConfig}; +use crate::azure::credential::{ + AzureCliCredential, ClientSecretOAuthProvider, ImdsManagedIdentityProvider, + WorkloadIdentityOAuthProvider, +}; +use crate::azure::{AzureCredential, AzureCredentialProvider, MicrosoftAzure, STORE}; +use crate::client::TokenCredentialProvider; +use crate::config::ConfigValue; +use crate::{ + ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider, +}; +use percent_encoding::percent_decode_str; +use serde::{Deserialize, Serialize}; +use snafu::{OptionExt, ResultExt, Snafu}; +use std::str::FromStr; +use std::sync::Arc; +use url::Url; + +/// The well-known account used by Azurite and the legacy Azure Storage Emulator. +/// +/// +const EMULATOR_ACCOUNT: &str = "devstoreaccount1"; + +/// The well-known account key used by Azurite and the legacy Azure Storage Emulator. +/// +/// +const EMULATOR_ACCOUNT_KEY: &str = + "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="; + +const MSI_ENDPOINT_ENV_KEY: &str = "IDENTITY_ENDPOINT"; + +/// A specialized `Error` for Azure builder-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +enum Error { + #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] + UnableToParseUrl { + source: url::ParseError, + url: String, + }, + + #[snafu(display( + "Unable parse emulator url {}={}, Error: {}", + env_name, + env_value, + source + ))] + UnableToParseEmulatorUrl { + env_name: String, + env_value: String, + source: url::ParseError, + }, + + #[snafu(display("Account must be specified"))] + MissingAccount {}, + + #[snafu(display("Container name must be specified"))] + MissingContainerName {}, + + #[snafu(display( + "Unknown url scheme cannot be parsed into storage location: {}", + scheme + ))] + UnknownUrlScheme { scheme: String }, + + #[snafu(display("URL did not match any known pattern for scheme: {}", url))] + UrlNotRecognised { url: String }, + + #[snafu(display("Failed parsing an SAS key"))] + DecodeSasKey { source: std::str::Utf8Error }, + + #[snafu(display("Missing component in SAS query pair"))] + MissingSasComponent {}, + + #[snafu(display("Configuration key: '{}' is not known.", key))] + UnknownConfigurationKey { key: String }, + + #[snafu(display("Unable to extract metadata from headers: {}", source))] + Metadata { + source: crate::client::header::Error, + }, +} + +impl From for crate::Error { + fn from(source: Error) -> Self { + match source { + Error::UnknownConfigurationKey { key } => { + Self::UnknownConfigurationKey { store: STORE, key } + } + _ => Self::Generic { + store: STORE, + source: Box::new(source), + }, + } + } +} + +/// Configure a connection to Microsoft Azure Blob Storage container using +/// the specified credentials. +/// +/// # Example +/// ``` +/// # let ACCOUNT = "foo"; +/// # let BUCKET_NAME = "foo"; +/// # let ACCESS_KEY = "foo"; +/// # use object_store::azure::MicrosoftAzureBuilder; +/// let azure = MicrosoftAzureBuilder::new() +/// .with_account(ACCOUNT) +/// .with_access_key(ACCESS_KEY) +/// .with_container_name(BUCKET_NAME) +/// .build(); +/// ``` +#[derive(Default, Clone)] +pub struct MicrosoftAzureBuilder { + /// Account name + account_name: Option, + /// Access key + access_key: Option, + /// Container name + container_name: Option, + /// Bearer token + bearer_token: Option, + /// Client id + client_id: Option, + /// Client secret + client_secret: Option, + /// Tenant id + tenant_id: Option, + /// Query pairs for shared access signature authorization + sas_query_pairs: Option>, + /// Shared access signature + sas_key: Option, + /// Authority host + authority_host: Option, + /// Url + url: Option, + /// When set to true, azurite storage emulator has to be used + use_emulator: ConfigValue, + /// Storage endpoint + endpoint: Option, + /// Msi endpoint for acquiring managed identity token + msi_endpoint: Option, + /// Object id for use with managed identity authentication + object_id: Option, + /// Msi resource id for use with managed identity authentication + msi_resource_id: Option, + /// File containing token for Azure AD workload identity federation + federated_token_file: Option, + /// When set to true, azure cli has to be used for acquiring access token + use_azure_cli: ConfigValue, + /// Retry config + retry_config: RetryConfig, + /// Client options + client_options: ClientOptions, + /// Credentials + credentials: Option, + /// When set to true, fabric url scheme will be used + /// + /// i.e. https://{account_name}.dfs.fabric.microsoft.com + use_fabric_endpoint: ConfigValue, +} + +/// Configuration keys for [`MicrosoftAzureBuilder`] +/// +/// Configuration via keys can be done via [`MicrosoftAzureBuilder::with_config`] +/// +/// # Example +/// ``` +/// # use object_store::azure::{MicrosoftAzureBuilder, AzureConfigKey}; +/// let builder = MicrosoftAzureBuilder::new() +/// .with_config("azure_client_id".parse().unwrap(), "my-client-id") +/// .with_config(AzureConfigKey::AuthorityId, "my-tenant-id"); +/// ``` +#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Deserialize, Serialize)] +#[non_exhaustive] +pub enum AzureConfigKey { + /// The name of the azure storage account + /// + /// Supported keys: + /// - `azure_storage_account_name` + /// - `account_name` + AccountName, + + /// Master key for accessing storage account + /// + /// Supported keys: + /// - `azure_storage_account_key` + /// - `azure_storage_access_key` + /// - `azure_storage_master_key` + /// - `access_key` + /// - `account_key` + /// - `master_key` + AccessKey, + + /// Service principal client id for authorizing requests + /// + /// Supported keys: + /// - `azure_storage_client_id` + /// - `azure_client_id` + /// - `client_id` + ClientId, + + /// Service principal client secret for authorizing requests + /// + /// Supported keys: + /// - `azure_storage_client_secret` + /// - `azure_client_secret` + /// - `client_secret` + ClientSecret, + + /// Tenant id used in oauth flows + /// + /// Supported keys: + /// - `azure_storage_tenant_id` + /// - `azure_storage_authority_id` + /// - `azure_tenant_id` + /// - `azure_authority_id` + /// - `tenant_id` + /// - `authority_id` + AuthorityId, + + /// Shared access signature. + /// + /// The signature is expected to be percent-encoded, much like they are provided + /// in the azure storage explorer or azure portal. + /// + /// Supported keys: + /// - `azure_storage_sas_key` + /// - `azure_storage_sas_token` + /// - `sas_key` + /// - `sas_token` + SasKey, + + /// Bearer token + /// + /// Supported keys: + /// - `azure_storage_token` + /// - `bearer_token` + /// - `token` + Token, + + /// Use object store with azurite storage emulator + /// + /// Supported keys: + /// - `azure_storage_use_emulator` + /// - `object_store_use_emulator` + /// - `use_emulator` + UseEmulator, + + /// Override the endpoint used to communicate with blob storage + /// + /// Supported keys: + /// - `azure_storage_endpoint` + /// - `azure_endpoint` + /// - `endpoint` + Endpoint, + + /// Use object store with url scheme account.dfs.fabric.microsoft.com + /// + /// Supported keys: + /// - `azure_use_fabric_endpoint` + /// - `use_fabric_endpoint` + UseFabricEndpoint, + + /// Endpoint to request a imds managed identity token + /// + /// Supported keys: + /// - `azure_msi_endpoint` + /// - `azure_identity_endpoint` + /// - `identity_endpoint` + /// - `msi_endpoint` + MsiEndpoint, + + /// Object id for use with managed identity authentication + /// + /// Supported keys: + /// - `azure_object_id` + /// - `object_id` + ObjectId, + + /// Msi resource id for use with managed identity authentication + /// + /// Supported keys: + /// - `azure_msi_resource_id` + /// - `msi_resource_id` + MsiResourceId, + + /// File containing token for Azure AD workload identity federation + /// + /// Supported keys: + /// - `azure_federated_token_file` + /// - `federated_token_file` + FederatedTokenFile, + + /// Use azure cli for acquiring access token + /// + /// Supported keys: + /// - `azure_use_azure_cli` + /// - `use_azure_cli` + UseAzureCli, + + /// Container name + /// + /// Supported keys: + /// - `azure_container_name` + /// - `container_name` + ContainerName, + + /// Client options + Client(ClientConfigKey), +} + +impl AsRef for AzureConfigKey { + fn as_ref(&self) -> &str { + match self { + Self::AccountName => "azure_storage_account_name", + Self::AccessKey => "azure_storage_account_key", + Self::ClientId => "azure_storage_client_id", + Self::ClientSecret => "azure_storage_client_secret", + Self::AuthorityId => "azure_storage_tenant_id", + Self::SasKey => "azure_storage_sas_key", + Self::Token => "azure_storage_token", + Self::UseEmulator => "azure_storage_use_emulator", + Self::UseFabricEndpoint => "azure_use_fabric_endpoint", + Self::Endpoint => "azure_storage_endpoint", + Self::MsiEndpoint => "azure_msi_endpoint", + Self::ObjectId => "azure_object_id", + Self::MsiResourceId => "azure_msi_resource_id", + Self::FederatedTokenFile => "azure_federated_token_file", + Self::UseAzureCli => "azure_use_azure_cli", + Self::ContainerName => "azure_container_name", + Self::Client(key) => key.as_ref(), + } + } +} + +impl FromStr for AzureConfigKey { + type Err = crate::Error; + + fn from_str(s: &str) -> Result { + match s { + "azure_storage_account_key" + | "azure_storage_access_key" + | "azure_storage_master_key" + | "master_key" + | "account_key" + | "access_key" => Ok(Self::AccessKey), + "azure_storage_account_name" | "account_name" => Ok(Self::AccountName), + "azure_storage_client_id" | "azure_client_id" | "client_id" => { + Ok(Self::ClientId) + } + "azure_storage_client_secret" | "azure_client_secret" | "client_secret" => { + Ok(Self::ClientSecret) + } + "azure_storage_tenant_id" + | "azure_storage_authority_id" + | "azure_tenant_id" + | "azure_authority_id" + | "tenant_id" + | "authority_id" => Ok(Self::AuthorityId), + "azure_storage_sas_key" + | "azure_storage_sas_token" + | "sas_key" + | "sas_token" => Ok(Self::SasKey), + "azure_storage_token" | "bearer_token" | "token" => Ok(Self::Token), + "azure_storage_use_emulator" | "use_emulator" => Ok(Self::UseEmulator), + "azure_storage_endpoint" | "azure_endpoint" | "endpoint" => { + Ok(Self::Endpoint) + } + "azure_msi_endpoint" + | "azure_identity_endpoint" + | "identity_endpoint" + | "msi_endpoint" => Ok(Self::MsiEndpoint), + "azure_object_id" | "object_id" => Ok(Self::ObjectId), + "azure_msi_resource_id" | "msi_resource_id" => Ok(Self::MsiResourceId), + "azure_federated_token_file" | "federated_token_file" => { + Ok(Self::FederatedTokenFile) + } + "azure_use_fabric_endpoint" | "use_fabric_endpoint" => { + Ok(Self::UseFabricEndpoint) + } + "azure_use_azure_cli" | "use_azure_cli" => Ok(Self::UseAzureCli), + "azure_container_name" | "container_name" => Ok(Self::ContainerName), + // Backwards compatibility + "azure_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), + _ => match s.parse() { + Ok(key) => Ok(Self::Client(key)), + Err(_) => Err(Error::UnknownConfigurationKey { key: s.into() }.into()), + }, + } + } +} + +impl std::fmt::Debug for MicrosoftAzureBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "MicrosoftAzureBuilder {{ account: {:?}, container_name: {:?} }}", + self.account_name, self.container_name + ) + } +} + +impl MicrosoftAzureBuilder { + /// Create a new [`MicrosoftAzureBuilder`] with default values. + pub fn new() -> Self { + Default::default() + } + + /// Create an instance of [`MicrosoftAzureBuilder`] with values pre-populated from environment variables. + /// + /// Variables extracted from environment: + /// * AZURE_STORAGE_ACCOUNT_NAME: storage account name + /// * AZURE_STORAGE_ACCOUNT_KEY: storage account master key + /// * AZURE_STORAGE_ACCESS_KEY: alias for AZURE_STORAGE_ACCOUNT_KEY + /// * AZURE_STORAGE_CLIENT_ID -> client id for service principal authorization + /// * AZURE_STORAGE_CLIENT_SECRET -> client secret for service principal authorization + /// * AZURE_STORAGE_TENANT_ID -> tenant id used in oauth flows + /// # Example + /// ``` + /// use object_store::azure::MicrosoftAzureBuilder; + /// + /// let azure = MicrosoftAzureBuilder::from_env() + /// .with_container_name("foo") + /// .build(); + /// ``` + pub fn from_env() -> Self { + let mut builder = Self::default(); + for (os_key, os_value) in std::env::vars_os() { + if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { + if key.starts_with("AZURE_") { + if let Ok(config_key) = key.to_ascii_lowercase().parse() { + builder = builder.with_config(config_key, value); + } + } + } + } + + if let Ok(text) = std::env::var(MSI_ENDPOINT_ENV_KEY) { + builder = builder.with_msi_endpoint(text); + } + + builder + } + + /// Parse available connection info form a well-known storage URL. + /// + /// The supported url schemes are: + /// + /// - `abfs[s]:///` (according to [fsspec](https://github.com/fsspec/adlfs)) + /// - `abfs[s]://@.dfs.core.windows.net/` + /// - `abfs[s]://@.dfs.fabric.microsoft.com/` + /// - `az:///` (according to [fsspec](https://github.com/fsspec/adlfs)) + /// - `adl:///` (according to [fsspec](https://github.com/fsspec/adlfs)) + /// - `azure:///` (custom) + /// - `https://.dfs.core.windows.net` + /// - `https://.blob.core.windows.net` + /// - `https://.dfs.fabric.microsoft.com` + /// - `https://.dfs.fabric.microsoft.com/` + /// - `https://.blob.fabric.microsoft.com` + /// - `https://.blob.fabric.microsoft.com/` + /// + /// Note: Settings derived from the URL will override any others set on this builder + /// + /// # Example + /// ``` + /// use object_store::azure::MicrosoftAzureBuilder; + /// + /// let azure = MicrosoftAzureBuilder::from_env() + /// .with_url("abfss://file_system@account.dfs.core.windows.net/") + /// .build(); + /// ``` + pub fn with_url(mut self, url: impl Into) -> Self { + self.url = Some(url.into()); + self + } + + /// Set an option on the builder via a key - value pair. + pub fn with_config(mut self, key: AzureConfigKey, value: impl Into) -> Self { + match key { + AzureConfigKey::AccessKey => self.access_key = Some(value.into()), + AzureConfigKey::AccountName => self.account_name = Some(value.into()), + AzureConfigKey::ClientId => self.client_id = Some(value.into()), + AzureConfigKey::ClientSecret => self.client_secret = Some(value.into()), + AzureConfigKey::AuthorityId => self.tenant_id = Some(value.into()), + AzureConfigKey::SasKey => self.sas_key = Some(value.into()), + AzureConfigKey::Token => self.bearer_token = Some(value.into()), + AzureConfigKey::MsiEndpoint => self.msi_endpoint = Some(value.into()), + AzureConfigKey::ObjectId => self.object_id = Some(value.into()), + AzureConfigKey::MsiResourceId => self.msi_resource_id = Some(value.into()), + AzureConfigKey::FederatedTokenFile => { + self.federated_token_file = Some(value.into()) + } + AzureConfigKey::UseAzureCli => self.use_azure_cli.parse(value), + AzureConfigKey::UseEmulator => self.use_emulator.parse(value), + AzureConfigKey::Endpoint => self.endpoint = Some(value.into()), + AzureConfigKey::UseFabricEndpoint => self.use_fabric_endpoint.parse(value), + AzureConfigKey::Client(key) => { + self.client_options = self.client_options.with_config(key, value) + } + AzureConfigKey::ContainerName => self.container_name = Some(value.into()), + }; + self + } + + /// Set an option on the builder via a key - value pair. + #[deprecated(note = "Use with_config")] + pub fn try_with_option( + self, + key: impl AsRef, + value: impl Into, + ) -> Result { + Ok(self.with_config(key.as_ref().parse()?, value)) + } + + /// Hydrate builder from key value pairs + #[deprecated(note = "Use with_config")] + #[allow(deprecated)] + pub fn try_with_options< + I: IntoIterator, impl Into)>, + >( + mut self, + options: I, + ) -> Result { + for (key, value) in options { + self = self.try_with_option(key, value)?; + } + Ok(self) + } + + /// Get config value via a [`AzureConfigKey`]. + /// + /// # Example + /// ``` + /// use object_store::azure::{MicrosoftAzureBuilder, AzureConfigKey}; + /// + /// let builder = MicrosoftAzureBuilder::from_env() + /// .with_account("foo"); + /// let account_name = builder.get_config_value(&AzureConfigKey::AccountName).unwrap_or_default(); + /// assert_eq!("foo", &account_name); + /// ``` + pub fn get_config_value(&self, key: &AzureConfigKey) -> Option { + match key { + AzureConfigKey::AccountName => self.account_name.clone(), + AzureConfigKey::AccessKey => self.access_key.clone(), + AzureConfigKey::ClientId => self.client_id.clone(), + AzureConfigKey::ClientSecret => self.client_secret.clone(), + AzureConfigKey::AuthorityId => self.tenant_id.clone(), + AzureConfigKey::SasKey => self.sas_key.clone(), + AzureConfigKey::Token => self.bearer_token.clone(), + AzureConfigKey::UseEmulator => Some(self.use_emulator.to_string()), + AzureConfigKey::UseFabricEndpoint => { + Some(self.use_fabric_endpoint.to_string()) + } + AzureConfigKey::Endpoint => self.endpoint.clone(), + AzureConfigKey::MsiEndpoint => self.msi_endpoint.clone(), + AzureConfigKey::ObjectId => self.object_id.clone(), + AzureConfigKey::MsiResourceId => self.msi_resource_id.clone(), + AzureConfigKey::FederatedTokenFile => self.federated_token_file.clone(), + AzureConfigKey::UseAzureCli => Some(self.use_azure_cli.to_string()), + AzureConfigKey::Client(key) => self.client_options.get_config_value(key), + AzureConfigKey::ContainerName => self.container_name.clone(), + } + } + + /// Sets properties on this builder based on a URL + /// + /// This is a separate member function to allow fallible computation to + /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] + fn parse_url(&mut self, url: &str) -> Result<()> { + let parsed = Url::parse(url).context(UnableToParseUrlSnafu { url })?; + let host = parsed.host_str().context(UrlNotRecognisedSnafu { url })?; + + let validate = |s: &str| match s.contains('.') { + true => Err(UrlNotRecognisedSnafu { url }.build()), + false => Ok(s.to_string()), + }; + + match parsed.scheme() { + "az" | "adl" | "azure" => self.container_name = Some(validate(host)?), + "abfs" | "abfss" => { + // abfs(s) might refer to the fsspec convention abfs:/// + // or the convention for the hadoop driver abfs[s]://@.dfs.core.windows.net/ + if parsed.username().is_empty() { + self.container_name = Some(validate(host)?); + } else if let Some(a) = host.strip_suffix(".dfs.core.windows.net") { + self.container_name = Some(validate(parsed.username())?); + self.account_name = Some(validate(a)?); + } else if let Some(a) = host.strip_suffix(".dfs.fabric.microsoft.com") { + self.container_name = Some(validate(parsed.username())?); + self.account_name = Some(validate(a)?); + self.use_fabric_endpoint = true.into(); + } else { + return Err(UrlNotRecognisedSnafu { url }.build().into()); + } + } + "https" => match host.split_once('.') { + Some((a, "dfs.core.windows.net")) + | Some((a, "blob.core.windows.net")) => { + self.account_name = Some(validate(a)?); + } + Some((a, "dfs.fabric.microsoft.com")) + | Some((a, "blob.fabric.microsoft.com")) => { + self.account_name = Some(validate(a)?); + // Attempt to infer the container name from the URL + // - https://onelake.dfs.fabric.microsoft.com///Files/test.csv + // - https://onelake.dfs.fabric.microsoft.com//.// + // + // See + if let Some(workspace) = parsed.path_segments().unwrap().next() { + if !workspace.is_empty() { + self.container_name = Some(workspace.to_string()) + } + } + self.use_fabric_endpoint = true.into(); + } + _ => return Err(UrlNotRecognisedSnafu { url }.build().into()), + }, + scheme => return Err(UnknownUrlSchemeSnafu { scheme }.build().into()), + } + Ok(()) + } + + /// Set the Azure Account (required) + pub fn with_account(mut self, account: impl Into) -> Self { + self.account_name = Some(account.into()); + self + } + + /// Set the Azure Container Name (required) + pub fn with_container_name(mut self, container_name: impl Into) -> Self { + self.container_name = Some(container_name.into()); + self + } + + /// Set the Azure Access Key (required - one of access key, bearer token, or client credentials) + pub fn with_access_key(mut self, access_key: impl Into) -> Self { + self.access_key = Some(access_key.into()); + self + } + + /// Set a static bearer token to be used for authorizing requests + pub fn with_bearer_token_authorization( + mut self, + bearer_token: impl Into, + ) -> Self { + self.bearer_token = Some(bearer_token.into()); + self + } + + /// Set a client secret used for client secret authorization + pub fn with_client_secret_authorization( + mut self, + client_id: impl Into, + client_secret: impl Into, + tenant_id: impl Into, + ) -> Self { + self.client_id = Some(client_id.into()); + self.client_secret = Some(client_secret.into()); + self.tenant_id = Some(tenant_id.into()); + self + } + + /// Sets the client id for use in client secret or k8s federated credential flow + pub fn with_client_id(mut self, client_id: impl Into) -> Self { + self.client_id = Some(client_id.into()); + self + } + + /// Sets the client secret for use in client secret flow + pub fn with_client_secret(mut self, client_secret: impl Into) -> Self { + self.client_secret = Some(client_secret.into()); + self + } + + /// Sets the tenant id for use in client secret or k8s federated credential flow + pub fn with_tenant_id(mut self, tenant_id: impl Into) -> Self { + self.tenant_id = Some(tenant_id.into()); + self + } + + /// Set query pairs appended to the url for shared access signature authorization + pub fn with_sas_authorization( + mut self, + query_pairs: impl Into>, + ) -> Self { + self.sas_query_pairs = Some(query_pairs.into()); + self + } + + /// Set the credential provider overriding any other options + pub fn with_credentials(mut self, credentials: AzureCredentialProvider) -> Self { + self.credentials = Some(credentials); + self + } + + /// Set if the Azure emulator should be used (defaults to false) + pub fn with_use_emulator(mut self, use_emulator: bool) -> Self { + self.use_emulator = use_emulator.into(); + self + } + + /// Override the endpoint used to communicate with blob storage + /// + /// Defaults to `https://{account}.blob.core.windows.net` + pub fn with_endpoint(mut self, endpoint: String) -> Self { + self.endpoint = Some(endpoint); + self + } + + /// Set if Microsoft Fabric url scheme should be used (defaults to false) + /// + /// When disabled the url scheme used is `https://{account}.blob.core.windows.net` + /// When enabled the url scheme used is `https://{account}.dfs.fabric.microsoft.com` + /// + /// Note: [`Self::with_endpoint`] will take precedence over this option + pub fn with_use_fabric_endpoint(mut self, use_fabric_endpoint: bool) -> Self { + self.use_fabric_endpoint = use_fabric_endpoint.into(); + self + } + + /// Sets what protocol is allowed + /// + /// If `allow_http` is : + /// * false (default): Only HTTPS are allowed + /// * true: HTTP and HTTPS are allowed + pub fn with_allow_http(mut self, allow_http: bool) -> Self { + self.client_options = self.client_options.with_allow_http(allow_http); + self + } + + /// Sets an alternative authority host for OAuth based authorization + /// + /// Common hosts for azure clouds are defined in [authority_hosts](crate::azure::authority_hosts). + /// + /// Defaults to + pub fn with_authority_host(mut self, authority_host: impl Into) -> Self { + self.authority_host = Some(authority_host.into()); + self + } + + /// Set the retry configuration + pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { + self.retry_config = retry_config; + self + } + + /// Set the proxy_url to be used by the underlying client + pub fn with_proxy_url(mut self, proxy_url: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_url(proxy_url); + self + } + + /// Set a trusted proxy CA certificate + pub fn with_proxy_ca_certificate( + mut self, + proxy_ca_certificate: impl Into, + ) -> Self { + self.client_options = self + .client_options + .with_proxy_ca_certificate(proxy_ca_certificate); + self + } + + /// Set a list of hosts to exclude from proxy connections + pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); + self + } + + /// Sets the client options, overriding any already set + pub fn with_client_options(mut self, options: ClientOptions) -> Self { + self.client_options = options; + self + } + + /// Sets the endpoint for acquiring managed identity token + pub fn with_msi_endpoint(mut self, msi_endpoint: impl Into) -> Self { + self.msi_endpoint = Some(msi_endpoint.into()); + self + } + + /// Sets a file path for acquiring azure federated identity token in k8s + /// + /// requires `client_id` and `tenant_id` to be set + pub fn with_federated_token_file( + mut self, + federated_token_file: impl Into, + ) -> Self { + self.federated_token_file = Some(federated_token_file.into()); + self + } + + /// Set if the Azure Cli should be used for acquiring access token + /// + /// + pub fn with_use_azure_cli(mut self, use_azure_cli: bool) -> Self { + self.use_azure_cli = use_azure_cli.into(); + self + } + + /// Configure a connection to container with given name on Microsoft Azure Blob store. + pub fn build(mut self) -> Result { + if let Some(url) = self.url.take() { + self.parse_url(&url)?; + } + + let container = self.container_name.ok_or(Error::MissingContainerName {})?; + + let static_creds = |credential: AzureCredential| -> AzureCredentialProvider { + Arc::new(StaticCredentialProvider::new(credential)) + }; + + let (is_emulator, storage_url, auth, account) = if self.use_emulator.get()? { + let account_name = self + .account_name + .unwrap_or_else(|| EMULATOR_ACCOUNT.to_string()); + // Allow overriding defaults. Values taken from + // from https://docs.rs/azure_storage/0.2.0/src/azure_storage/core/clients/storage_account_client.rs.html#129-141 + let url = url_from_env("AZURITE_BLOB_STORAGE_URL", "http://127.0.0.1:10000")?; + let account_key = self + .access_key + .unwrap_or_else(|| EMULATOR_ACCOUNT_KEY.to_string()); + + let credential = static_creds(AzureCredential::AccessKey(account_key)); + + self.client_options = self.client_options.with_allow_http(true); + (true, url, credential, account_name) + } else { + let account_name = self.account_name.ok_or(Error::MissingAccount {})?; + let account_url = match self.endpoint { + Some(account_url) => account_url, + None => match self.use_fabric_endpoint.get()? { + true => { + format!("https://{}.blob.fabric.microsoft.com", &account_name) + } + false => format!("https://{}.blob.core.windows.net", &account_name), + }, + }; + + let url = Url::parse(&account_url) + .context(UnableToParseUrlSnafu { url: account_url })?; + + let credential = if let Some(credential) = self.credentials { + credential + } else if let Some(bearer_token) = self.bearer_token { + static_creds(AzureCredential::BearerToken(bearer_token)) + } else if let Some(access_key) = self.access_key { + static_creds(AzureCredential::AccessKey(access_key)) + } else if let (Some(client_id), Some(tenant_id), Some(federated_token_file)) = + (&self.client_id, &self.tenant_id, self.federated_token_file) + { + let client_credential = WorkloadIdentityOAuthProvider::new( + client_id, + federated_token_file, + tenant_id, + self.authority_host, + ); + Arc::new(TokenCredentialProvider::new( + client_credential, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ + } else if let (Some(client_id), Some(client_secret), Some(tenant_id)) = + (&self.client_id, self.client_secret, &self.tenant_id) + { + let client_credential = ClientSecretOAuthProvider::new( + client_id.clone(), + client_secret, + tenant_id, + self.authority_host, + ); + Arc::new(TokenCredentialProvider::new( + client_credential, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ + } else if let Some(query_pairs) = self.sas_query_pairs { + static_creds(AzureCredential::SASToken(query_pairs)) + } else if let Some(sas) = self.sas_key { + static_creds(AzureCredential::SASToken(split_sas(&sas)?)) + } else if self.use_azure_cli.get()? { + Arc::new(AzureCliCredential::new()) as _ + } else { + let msi_credential = ImdsManagedIdentityProvider::new( + self.client_id, + self.object_id, + self.msi_resource_id, + self.msi_endpoint, + ); + Arc::new(TokenCredentialProvider::new( + msi_credential, + self.client_options.metadata_client()?, + self.retry_config.clone(), + )) as _ + }; + (false, url, credential, account_name) + }; + + let config = AzureConfig { + account, + is_emulator, + container, + retry_config: self.retry_config, + client_options: self.client_options, + service: storage_url, + credentials: auth, + }; + + let client = Arc::new(AzureClient::new(config)?); + + Ok(MicrosoftAzure { client }) + } +} + +/// Parses the contents of the environment variable `env_name` as a URL +/// if present, otherwise falls back to default_url +fn url_from_env(env_name: &str, default_url: &str) -> Result { + let url = match std::env::var(env_name) { + Ok(env_value) => { + Url::parse(&env_value).context(UnableToParseEmulatorUrlSnafu { + env_name, + env_value, + })? + } + Err(_) => Url::parse(default_url).expect("Failed to parse default URL"), + }; + Ok(url) +} + +fn split_sas(sas: &str) -> Result, Error> { + let sas = percent_decode_str(sas) + .decode_utf8() + .context(DecodeSasKeySnafu {})?; + let kv_str_pairs = sas + .trim_start_matches('?') + .split('&') + .filter(|s| !s.chars().all(char::is_whitespace)); + let mut pairs = Vec::new(); + for kv_pair_str in kv_str_pairs { + let (k, v) = kv_pair_str + .trim() + .split_once('=') + .ok_or(Error::MissingSasComponent {})?; + pairs.push((k.into(), v.into())) + } + Ok(pairs) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn azure_blob_test_urls() { + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("abfss://file_system@account.dfs.core.windows.net/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, Some("file_system".to_string())); + assert!(!builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("abfss://file_system@account.dfs.fabric.microsoft.com/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, Some("file_system".to_string())); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder.parse_url("abfs://container/path").unwrap(); + assert_eq!(builder.container_name, Some("container".to_string())); + + let mut builder = MicrosoftAzureBuilder::new(); + builder.parse_url("az://container").unwrap(); + assert_eq!(builder.container_name, Some("container".to_string())); + + let mut builder = MicrosoftAzureBuilder::new(); + builder.parse_url("az://container/path").unwrap(); + assert_eq!(builder.container_name, Some("container".to_string())); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.dfs.core.windows.net/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert!(!builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.blob.core.windows.net/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert!(!builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.dfs.fabric.microsoft.com/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, None); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.dfs.fabric.microsoft.com/container") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name.as_deref(), Some("container")); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.blob.fabric.microsoft.com/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, None); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.blob.fabric.microsoft.com/container") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name.as_deref(), Some("container")); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let err_cases = [ + "mailto://account.blob.core.windows.net/", + "az://blob.mydomain/", + "abfs://container.foo/path", + "abfss://file_system@account.foo.dfs.core.windows.net/", + "abfss://file_system.bar@account.dfs.core.windows.net/", + "https://blob.mydomain/", + "https://blob.foo.dfs.core.windows.net/", + ]; + let mut builder = MicrosoftAzureBuilder::new(); + for case in err_cases { + builder.parse_url(case).unwrap_err(); + } + } + + #[test] + fn azure_test_config_from_map() { + let azure_client_id = "object_store:fake_access_key_id"; + let azure_storage_account_name = "object_store:fake_secret_key"; + let azure_storage_token = "object_store:fake_default_region"; + let options = HashMap::from([ + ("azure_client_id", azure_client_id), + ("azure_storage_account_name", azure_storage_account_name), + ("azure_storage_token", azure_storage_token), + ]); + + let builder = options + .into_iter() + .fold(MicrosoftAzureBuilder::new(), |builder, (key, value)| { + builder.with_config(key.parse().unwrap(), value) + }); + assert_eq!(builder.client_id.unwrap(), azure_client_id); + assert_eq!(builder.account_name.unwrap(), azure_storage_account_name); + assert_eq!(builder.bearer_token.unwrap(), azure_storage_token); + } + + #[test] + fn azure_test_split_sas() { + let raw_sas = "?sv=2021-10-04&st=2023-01-04T17%3A48%3A57Z&se=2023-01-04T18%3A15%3A00Z&sr=c&sp=rcwl&sig=C7%2BZeEOWbrxPA3R0Cw%2Fw1EZz0%2B4KBvQexeKZKe%2BB6h0%3D"; + let expected = vec![ + ("sv".to_string(), "2021-10-04".to_string()), + ("st".to_string(), "2023-01-04T17:48:57Z".to_string()), + ("se".to_string(), "2023-01-04T18:15:00Z".to_string()), + ("sr".to_string(), "c".to_string()), + ("sp".to_string(), "rcwl".to_string()), + ( + "sig".to_string(), + "C7+ZeEOWbrxPA3R0Cw/w1EZz0+4KBvQexeKZKe+B6h0=".to_string(), + ), + ]; + let pairs = split_sas(raw_sas).unwrap(); + assert_eq!(expected, pairs); + } +} diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 0e638efc399f..7e1db5bc8c1c 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -30,32 +30,24 @@ use self::client::{BlockId, BlockList}; use crate::{ multipart::{PartId, PutPart, WriteMultiPart}, path::Path, - ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, - ObjectStore, PutResult, Result, RetryConfig, + GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutResult, + Result, }; use async_trait::async_trait; use base64::prelude::BASE64_STANDARD; use base64::Engine; use bytes::Bytes; use futures::stream::BoxStream; -use percent_encoding::percent_decode_str; -use serde::{Deserialize, Serialize}; -use snafu::{OptionExt, ResultExt, Snafu}; -use std::fmt::{Debug, Formatter}; -use std::str::FromStr; +use std::fmt::Debug; use std::sync::Arc; use tokio::io::AsyncWrite; -use url::Url; use crate::client::get::GetClientExt; use crate::client::list::ListClientExt; -use crate::client::{ - ClientConfigKey, CredentialProvider, StaticCredentialProvider, - TokenCredentialProvider, -}; -use crate::config::ConfigValue; +use crate::client::CredentialProvider; pub use credential::authority_hosts; +mod builder; mod client; mod credential; @@ -63,87 +55,11 @@ mod credential; pub type AzureCredentialProvider = Arc>; use crate::client::header::get_etag; +pub use builder::{AzureConfigKey, MicrosoftAzureBuilder}; pub use credential::AzureCredential; const STORE: &str = "MicrosoftAzure"; -/// The well-known account used by Azurite and the legacy Azure Storage Emulator. -/// -const EMULATOR_ACCOUNT: &str = "devstoreaccount1"; - -/// The well-known account key used by Azurite and the legacy Azure Storage Emulator. -/// -const EMULATOR_ACCOUNT_KEY: &str = - "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="; - -const MSI_ENDPOINT_ENV_KEY: &str = "IDENTITY_ENDPOINT"; - -/// A specialized `Error` for Azure object store-related errors -#[derive(Debug, Snafu)] -#[allow(missing_docs)] -enum Error { - #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] - UnableToParseUrl { - source: url::ParseError, - url: String, - }, - - #[snafu(display( - "Unable parse emulator url {}={}, Error: {}", - env_name, - env_value, - source - ))] - UnableToParseEmulatorUrl { - env_name: String, - env_value: String, - source: url::ParseError, - }, - - #[snafu(display("Account must be specified"))] - MissingAccount {}, - - #[snafu(display("Container name must be specified"))] - MissingContainerName {}, - - #[snafu(display( - "Unknown url scheme cannot be parsed into storage location: {}", - scheme - ))] - UnknownUrlScheme { scheme: String }, - - #[snafu(display("URL did not match any known pattern for scheme: {}", url))] - UrlNotRecognised { url: String }, - - #[snafu(display("Failed parsing an SAS key"))] - DecodeSasKey { source: std::str::Utf8Error }, - - #[snafu(display("Missing component in SAS query pair"))] - MissingSasComponent {}, - - #[snafu(display("Configuration key: '{}' is not known.", key))] - UnknownConfigurationKey { key: String }, - - #[snafu(display("Unable to extract metadata from headers: {}", source))] - Metadata { - source: crate::client::header::Error, - }, -} - -impl From for super::Error { - fn from(source: Error) -> Self { - match source { - Error::UnknownConfigurationKey { key } => { - Self::UnknownConfigurationKey { store: STORE, key } - } - _ => Self::Generic { - store: STORE, - source: Box::new(source), - }, - } - } -} - /// Interface for [Microsoft Azure Blob Storage](https://azure.microsoft.com/en-us/services/storage/blobs/). #[derive(Debug)] pub struct MicrosoftAzure { @@ -175,8 +91,11 @@ impl ObjectStore for MicrosoftAzure { .client .put_request(location, Some(bytes), false, &()) .await?; - let e_tag = Some(get_etag(response.headers()).context(MetadataSnafu)?); - Ok(PutResult { e_tag }) + let e_tag = get_etag(response.headers()).map_err(|e| crate::Error::Generic { + store: STORE, + source: Box::new(e), + })?; + Ok(PutResult { e_tag: Some(e_tag) }) } async fn put_multipart( @@ -279,853 +198,6 @@ impl PutPart for AzureMultiPartUpload { } } -/// Configure a connection to Microsoft Azure Blob Storage container using -/// the specified credentials. -/// -/// # Example -/// ``` -/// # let ACCOUNT = "foo"; -/// # let BUCKET_NAME = "foo"; -/// # let ACCESS_KEY = "foo"; -/// # use object_store::azure::MicrosoftAzureBuilder; -/// let azure = MicrosoftAzureBuilder::new() -/// .with_account(ACCOUNT) -/// .with_access_key(ACCESS_KEY) -/// .with_container_name(BUCKET_NAME) -/// .build(); -/// ``` -#[derive(Default, Clone)] -pub struct MicrosoftAzureBuilder { - /// Account name - account_name: Option, - /// Access key - access_key: Option, - /// Container name - container_name: Option, - /// Bearer token - bearer_token: Option, - /// Client id - client_id: Option, - /// Client secret - client_secret: Option, - /// Tenant id - tenant_id: Option, - /// Query pairs for shared access signature authorization - sas_query_pairs: Option>, - /// Shared access signature - sas_key: Option, - /// Authority host - authority_host: Option, - /// Url - url: Option, - /// When set to true, azurite storage emulator has to be used - use_emulator: ConfigValue, - /// Storage endpoint - endpoint: Option, - /// Msi endpoint for acquiring managed identity token - msi_endpoint: Option, - /// Object id for use with managed identity authentication - object_id: Option, - /// Msi resource id for use with managed identity authentication - msi_resource_id: Option, - /// File containing token for Azure AD workload identity federation - federated_token_file: Option, - /// When set to true, azure cli has to be used for acquiring access token - use_azure_cli: ConfigValue, - /// Retry config - retry_config: RetryConfig, - /// Client options - client_options: ClientOptions, - /// Credentials - credentials: Option, - /// When set to true, fabric url scheme will be used - /// - /// i.e. https://{account_name}.dfs.fabric.microsoft.com - use_fabric_endpoint: ConfigValue, -} - -/// Configuration keys for [`MicrosoftAzureBuilder`] -/// -/// Configuration via keys can be done via [`MicrosoftAzureBuilder::with_config`] -/// -/// # Example -/// ``` -/// # use object_store::azure::{MicrosoftAzureBuilder, AzureConfigKey}; -/// let builder = MicrosoftAzureBuilder::new() -/// .with_config("azure_client_id".parse().unwrap(), "my-client-id") -/// .with_config(AzureConfigKey::AuthorityId, "my-tenant-id"); -/// ``` -#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Deserialize, Serialize)] -#[non_exhaustive] -pub enum AzureConfigKey { - /// The name of the azure storage account - /// - /// Supported keys: - /// - `azure_storage_account_name` - /// - `account_name` - AccountName, - - /// Master key for accessing storage account - /// - /// Supported keys: - /// - `azure_storage_account_key` - /// - `azure_storage_access_key` - /// - `azure_storage_master_key` - /// - `access_key` - /// - `account_key` - /// - `master_key` - AccessKey, - - /// Service principal client id for authorizing requests - /// - /// Supported keys: - /// - `azure_storage_client_id` - /// - `azure_client_id` - /// - `client_id` - ClientId, - - /// Service principal client secret for authorizing requests - /// - /// Supported keys: - /// - `azure_storage_client_secret` - /// - `azure_client_secret` - /// - `client_secret` - ClientSecret, - - /// Tenant id used in oauth flows - /// - /// Supported keys: - /// - `azure_storage_tenant_id` - /// - `azure_storage_authority_id` - /// - `azure_tenant_id` - /// - `azure_authority_id` - /// - `tenant_id` - /// - `authority_id` - AuthorityId, - - /// Shared access signature. - /// - /// The signature is expected to be percent-encoded, much like they are provided - /// in the azure storage explorer or azure portal. - /// - /// Supported keys: - /// - `azure_storage_sas_key` - /// - `azure_storage_sas_token` - /// - `sas_key` - /// - `sas_token` - SasKey, - - /// Bearer token - /// - /// Supported keys: - /// - `azure_storage_token` - /// - `bearer_token` - /// - `token` - Token, - - /// Use object store with azurite storage emulator - /// - /// Supported keys: - /// - `azure_storage_use_emulator` - /// - `object_store_use_emulator` - /// - `use_emulator` - UseEmulator, - - /// Override the endpoint used to communicate with blob storage - /// - /// Supported keys: - /// - `azure_storage_endpoint` - /// - `azure_endpoint` - /// - `endpoint` - Endpoint, - - /// Use object store with url scheme account.dfs.fabric.microsoft.com - /// - /// Supported keys: - /// - `azure_use_fabric_endpoint` - /// - `use_fabric_endpoint` - UseFabricEndpoint, - - /// Endpoint to request a imds managed identity token - /// - /// Supported keys: - /// - `azure_msi_endpoint` - /// - `azure_identity_endpoint` - /// - `identity_endpoint` - /// - `msi_endpoint` - MsiEndpoint, - - /// Object id for use with managed identity authentication - /// - /// Supported keys: - /// - `azure_object_id` - /// - `object_id` - ObjectId, - - /// Msi resource id for use with managed identity authentication - /// - /// Supported keys: - /// - `azure_msi_resource_id` - /// - `msi_resource_id` - MsiResourceId, - - /// File containing token for Azure AD workload identity federation - /// - /// Supported keys: - /// - `azure_federated_token_file` - /// - `federated_token_file` - FederatedTokenFile, - - /// Use azure cli for acquiring access token - /// - /// Supported keys: - /// - `azure_use_azure_cli` - /// - `use_azure_cli` - UseAzureCli, - - /// Container name - /// - /// Supported keys: - /// - `azure_container_name` - /// - `container_name` - ContainerName, - - /// Client options - Client(ClientConfigKey), -} - -impl AsRef for AzureConfigKey { - fn as_ref(&self) -> &str { - match self { - Self::AccountName => "azure_storage_account_name", - Self::AccessKey => "azure_storage_account_key", - Self::ClientId => "azure_storage_client_id", - Self::ClientSecret => "azure_storage_client_secret", - Self::AuthorityId => "azure_storage_tenant_id", - Self::SasKey => "azure_storage_sas_key", - Self::Token => "azure_storage_token", - Self::UseEmulator => "azure_storage_use_emulator", - Self::UseFabricEndpoint => "azure_use_fabric_endpoint", - Self::Endpoint => "azure_storage_endpoint", - Self::MsiEndpoint => "azure_msi_endpoint", - Self::ObjectId => "azure_object_id", - Self::MsiResourceId => "azure_msi_resource_id", - Self::FederatedTokenFile => "azure_federated_token_file", - Self::UseAzureCli => "azure_use_azure_cli", - Self::ContainerName => "azure_container_name", - Self::Client(key) => key.as_ref(), - } - } -} - -impl FromStr for AzureConfigKey { - type Err = super::Error; - - fn from_str(s: &str) -> Result { - match s { - "azure_storage_account_key" - | "azure_storage_access_key" - | "azure_storage_master_key" - | "master_key" - | "account_key" - | "access_key" => Ok(Self::AccessKey), - "azure_storage_account_name" | "account_name" => Ok(Self::AccountName), - "azure_storage_client_id" | "azure_client_id" | "client_id" => { - Ok(Self::ClientId) - } - "azure_storage_client_secret" | "azure_client_secret" | "client_secret" => { - Ok(Self::ClientSecret) - } - "azure_storage_tenant_id" - | "azure_storage_authority_id" - | "azure_tenant_id" - | "azure_authority_id" - | "tenant_id" - | "authority_id" => Ok(Self::AuthorityId), - "azure_storage_sas_key" - | "azure_storage_sas_token" - | "sas_key" - | "sas_token" => Ok(Self::SasKey), - "azure_storage_token" | "bearer_token" | "token" => Ok(Self::Token), - "azure_storage_use_emulator" | "use_emulator" => Ok(Self::UseEmulator), - "azure_storage_endpoint" | "azure_endpoint" | "endpoint" => { - Ok(Self::Endpoint) - } - "azure_msi_endpoint" - | "azure_identity_endpoint" - | "identity_endpoint" - | "msi_endpoint" => Ok(Self::MsiEndpoint), - "azure_object_id" | "object_id" => Ok(Self::ObjectId), - "azure_msi_resource_id" | "msi_resource_id" => Ok(Self::MsiResourceId), - "azure_federated_token_file" | "federated_token_file" => { - Ok(Self::FederatedTokenFile) - } - "azure_use_fabric_endpoint" | "use_fabric_endpoint" => { - Ok(Self::UseFabricEndpoint) - } - "azure_use_azure_cli" | "use_azure_cli" => Ok(Self::UseAzureCli), - "azure_container_name" | "container_name" => Ok(Self::ContainerName), - // Backwards compatibility - "azure_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), - _ => match s.parse() { - Ok(key) => Ok(Self::Client(key)), - Err(_) => Err(Error::UnknownConfigurationKey { key: s.into() }.into()), - }, - } - } -} - -impl Debug for MicrosoftAzureBuilder { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "MicrosoftAzureBuilder {{ account: {:?}, container_name: {:?} }}", - self.account_name, self.container_name - ) - } -} - -impl MicrosoftAzureBuilder { - /// Create a new [`MicrosoftAzureBuilder`] with default values. - pub fn new() -> Self { - Default::default() - } - - /// Create an instance of [`MicrosoftAzureBuilder`] with values pre-populated from environment variables. - /// - /// Variables extracted from environment: - /// * AZURE_STORAGE_ACCOUNT_NAME: storage account name - /// * AZURE_STORAGE_ACCOUNT_KEY: storage account master key - /// * AZURE_STORAGE_ACCESS_KEY: alias for AZURE_STORAGE_ACCOUNT_KEY - /// * AZURE_STORAGE_CLIENT_ID -> client id for service principal authorization - /// * AZURE_STORAGE_CLIENT_SECRET -> client secret for service principal authorization - /// * AZURE_STORAGE_TENANT_ID -> tenant id used in oauth flows - /// # Example - /// ``` - /// use object_store::azure::MicrosoftAzureBuilder; - /// - /// let azure = MicrosoftAzureBuilder::from_env() - /// .with_container_name("foo") - /// .build(); - /// ``` - pub fn from_env() -> Self { - let mut builder = Self::default(); - for (os_key, os_value) in std::env::vars_os() { - if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { - if key.starts_with("AZURE_") { - if let Ok(config_key) = key.to_ascii_lowercase().parse() { - builder = builder.with_config(config_key, value); - } - } - } - } - - if let Ok(text) = std::env::var(MSI_ENDPOINT_ENV_KEY) { - builder = builder.with_msi_endpoint(text); - } - - builder - } - - /// Parse available connection info form a well-known storage URL. - /// - /// The supported url schemes are: - /// - /// - `abfs[s]:///` (according to [fsspec](https://github.com/fsspec/adlfs)) - /// - `abfs[s]://@.dfs.core.windows.net/` - /// - `abfs[s]://@.dfs.fabric.microsoft.com/` - /// - `az:///` (according to [fsspec](https://github.com/fsspec/adlfs)) - /// - `adl:///` (according to [fsspec](https://github.com/fsspec/adlfs)) - /// - `azure:///` (custom) - /// - `https://.dfs.core.windows.net` - /// - `https://.blob.core.windows.net` - /// - `https://.dfs.fabric.microsoft.com` - /// - `https://.dfs.fabric.microsoft.com/` - /// - `https://.blob.fabric.microsoft.com` - /// - `https://.blob.fabric.microsoft.com/` - /// - /// Note: Settings derived from the URL will override any others set on this builder - /// - /// # Example - /// ``` - /// use object_store::azure::MicrosoftAzureBuilder; - /// - /// let azure = MicrosoftAzureBuilder::from_env() - /// .with_url("abfss://file_system@account.dfs.core.windows.net/") - /// .build(); - /// ``` - pub fn with_url(mut self, url: impl Into) -> Self { - self.url = Some(url.into()); - self - } - - /// Set an option on the builder via a key - value pair. - pub fn with_config(mut self, key: AzureConfigKey, value: impl Into) -> Self { - match key { - AzureConfigKey::AccessKey => self.access_key = Some(value.into()), - AzureConfigKey::AccountName => self.account_name = Some(value.into()), - AzureConfigKey::ClientId => self.client_id = Some(value.into()), - AzureConfigKey::ClientSecret => self.client_secret = Some(value.into()), - AzureConfigKey::AuthorityId => self.tenant_id = Some(value.into()), - AzureConfigKey::SasKey => self.sas_key = Some(value.into()), - AzureConfigKey::Token => self.bearer_token = Some(value.into()), - AzureConfigKey::MsiEndpoint => self.msi_endpoint = Some(value.into()), - AzureConfigKey::ObjectId => self.object_id = Some(value.into()), - AzureConfigKey::MsiResourceId => self.msi_resource_id = Some(value.into()), - AzureConfigKey::FederatedTokenFile => { - self.federated_token_file = Some(value.into()) - } - AzureConfigKey::UseAzureCli => self.use_azure_cli.parse(value), - AzureConfigKey::UseEmulator => self.use_emulator.parse(value), - AzureConfigKey::Endpoint => self.endpoint = Some(value.into()), - AzureConfigKey::UseFabricEndpoint => self.use_fabric_endpoint.parse(value), - AzureConfigKey::Client(key) => { - self.client_options = self.client_options.with_config(key, value) - } - AzureConfigKey::ContainerName => self.container_name = Some(value.into()), - }; - self - } - - /// Set an option on the builder via a key - value pair. - #[deprecated(note = "Use with_config")] - pub fn try_with_option( - self, - key: impl AsRef, - value: impl Into, - ) -> Result { - Ok(self.with_config(key.as_ref().parse()?, value)) - } - - /// Hydrate builder from key value pairs - #[deprecated(note = "Use with_config")] - #[allow(deprecated)] - pub fn try_with_options< - I: IntoIterator, impl Into)>, - >( - mut self, - options: I, - ) -> Result { - for (key, value) in options { - self = self.try_with_option(key, value)?; - } - Ok(self) - } - - /// Get config value via a [`AzureConfigKey`]. - /// - /// # Example - /// ``` - /// use object_store::azure::{MicrosoftAzureBuilder, AzureConfigKey}; - /// - /// let builder = MicrosoftAzureBuilder::from_env() - /// .with_account("foo"); - /// let account_name = builder.get_config_value(&AzureConfigKey::AccountName).unwrap_or_default(); - /// assert_eq!("foo", &account_name); - /// ``` - pub fn get_config_value(&self, key: &AzureConfigKey) -> Option { - match key { - AzureConfigKey::AccountName => self.account_name.clone(), - AzureConfigKey::AccessKey => self.access_key.clone(), - AzureConfigKey::ClientId => self.client_id.clone(), - AzureConfigKey::ClientSecret => self.client_secret.clone(), - AzureConfigKey::AuthorityId => self.tenant_id.clone(), - AzureConfigKey::SasKey => self.sas_key.clone(), - AzureConfigKey::Token => self.bearer_token.clone(), - AzureConfigKey::UseEmulator => Some(self.use_emulator.to_string()), - AzureConfigKey::UseFabricEndpoint => { - Some(self.use_fabric_endpoint.to_string()) - } - AzureConfigKey::Endpoint => self.endpoint.clone(), - AzureConfigKey::MsiEndpoint => self.msi_endpoint.clone(), - AzureConfigKey::ObjectId => self.object_id.clone(), - AzureConfigKey::MsiResourceId => self.msi_resource_id.clone(), - AzureConfigKey::FederatedTokenFile => self.federated_token_file.clone(), - AzureConfigKey::UseAzureCli => Some(self.use_azure_cli.to_string()), - AzureConfigKey::Client(key) => self.client_options.get_config_value(key), - AzureConfigKey::ContainerName => self.container_name.clone(), - } - } - - /// Sets properties on this builder based on a URL - /// - /// This is a separate member function to allow fallible computation to - /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] - fn parse_url(&mut self, url: &str) -> Result<()> { - let parsed = Url::parse(url).context(UnableToParseUrlSnafu { url })?; - let host = parsed.host_str().context(UrlNotRecognisedSnafu { url })?; - - let validate = |s: &str| match s.contains('.') { - true => Err(UrlNotRecognisedSnafu { url }.build()), - false => Ok(s.to_string()), - }; - - match parsed.scheme() { - "az" | "adl" | "azure" => self.container_name = Some(validate(host)?), - "abfs" | "abfss" => { - // abfs(s) might refer to the fsspec convention abfs:/// - // or the convention for the hadoop driver abfs[s]://@.dfs.core.windows.net/ - if parsed.username().is_empty() { - self.container_name = Some(validate(host)?); - } else if let Some(a) = host.strip_suffix(".dfs.core.windows.net") { - self.container_name = Some(validate(parsed.username())?); - self.account_name = Some(validate(a)?); - } else if let Some(a) = host.strip_suffix(".dfs.fabric.microsoft.com") { - self.container_name = Some(validate(parsed.username())?); - self.account_name = Some(validate(a)?); - self.use_fabric_endpoint = true.into(); - } else { - return Err(UrlNotRecognisedSnafu { url }.build().into()); - } - } - "https" => match host.split_once('.') { - Some((a, "dfs.core.windows.net")) - | Some((a, "blob.core.windows.net")) => { - self.account_name = Some(validate(a)?); - } - Some((a, "dfs.fabric.microsoft.com")) - | Some((a, "blob.fabric.microsoft.com")) => { - self.account_name = Some(validate(a)?); - // Attempt to infer the container name from the URL - // - https://onelake.dfs.fabric.microsoft.com///Files/test.csv - // - https://onelake.dfs.fabric.microsoft.com//.// - // - // See - if let Some(workspace) = parsed.path_segments().unwrap().next() { - if !workspace.is_empty() { - self.container_name = Some(workspace.to_string()) - } - } - self.use_fabric_endpoint = true.into(); - } - _ => return Err(UrlNotRecognisedSnafu { url }.build().into()), - }, - scheme => return Err(UnknownUrlSchemeSnafu { scheme }.build().into()), - } - Ok(()) - } - - /// Set the Azure Account (required) - pub fn with_account(mut self, account: impl Into) -> Self { - self.account_name = Some(account.into()); - self - } - - /// Set the Azure Container Name (required) - pub fn with_container_name(mut self, container_name: impl Into) -> Self { - self.container_name = Some(container_name.into()); - self - } - - /// Set the Azure Access Key (required - one of access key, bearer token, or client credentials) - pub fn with_access_key(mut self, access_key: impl Into) -> Self { - self.access_key = Some(access_key.into()); - self - } - - /// Set a static bearer token to be used for authorizing requests - pub fn with_bearer_token_authorization( - mut self, - bearer_token: impl Into, - ) -> Self { - self.bearer_token = Some(bearer_token.into()); - self - } - - /// Set a client secret used for client secret authorization - pub fn with_client_secret_authorization( - mut self, - client_id: impl Into, - client_secret: impl Into, - tenant_id: impl Into, - ) -> Self { - self.client_id = Some(client_id.into()); - self.client_secret = Some(client_secret.into()); - self.tenant_id = Some(tenant_id.into()); - self - } - - /// Sets the client id for use in client secret or k8s federated credential flow - pub fn with_client_id(mut self, client_id: impl Into) -> Self { - self.client_id = Some(client_id.into()); - self - } - - /// Sets the client secret for use in client secret flow - pub fn with_client_secret(mut self, client_secret: impl Into) -> Self { - self.client_secret = Some(client_secret.into()); - self - } - - /// Sets the tenant id for use in client secret or k8s federated credential flow - pub fn with_tenant_id(mut self, tenant_id: impl Into) -> Self { - self.tenant_id = Some(tenant_id.into()); - self - } - - /// Set query pairs appended to the url for shared access signature authorization - pub fn with_sas_authorization( - mut self, - query_pairs: impl Into>, - ) -> Self { - self.sas_query_pairs = Some(query_pairs.into()); - self - } - - /// Set the credential provider overriding any other options - pub fn with_credentials(mut self, credentials: AzureCredentialProvider) -> Self { - self.credentials = Some(credentials); - self - } - - /// Set if the Azure emulator should be used (defaults to false) - pub fn with_use_emulator(mut self, use_emulator: bool) -> Self { - self.use_emulator = use_emulator.into(); - self - } - - /// Override the endpoint used to communicate with blob storage - /// - /// Defaults to `https://{account}.blob.core.windows.net` - pub fn with_endpoint(mut self, endpoint: String) -> Self { - self.endpoint = Some(endpoint); - self - } - - /// Set if Microsoft Fabric url scheme should be used (defaults to false) - /// When disabled the url scheme used is `https://{account}.blob.core.windows.net` - /// When enabled the url scheme used is `https://{account}.dfs.fabric.microsoft.com` - /// - /// Note: [`Self::with_endpoint`] will take precedence over this option - pub fn with_use_fabric_endpoint(mut self, use_fabric_endpoint: bool) -> Self { - self.use_fabric_endpoint = use_fabric_endpoint.into(); - self - } - - /// Sets what protocol is allowed. If `allow_http` is : - /// * false (default): Only HTTPS are allowed - /// * true: HTTP and HTTPS are allowed - pub fn with_allow_http(mut self, allow_http: bool) -> Self { - self.client_options = self.client_options.with_allow_http(allow_http); - self - } - - /// Sets an alternative authority host for OAuth based authorization - /// common hosts for azure clouds are defined in [authority_hosts]. - /// Defaults to - pub fn with_authority_host(mut self, authority_host: impl Into) -> Self { - self.authority_host = Some(authority_host.into()); - self - } - - /// Set the retry configuration - pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { - self.retry_config = retry_config; - self - } - - /// Set the proxy_url to be used by the underlying client - pub fn with_proxy_url(mut self, proxy_url: impl Into) -> Self { - self.client_options = self.client_options.with_proxy_url(proxy_url); - self - } - - /// Set a trusted proxy CA certificate - pub fn with_proxy_ca_certificate( - mut self, - proxy_ca_certificate: impl Into, - ) -> Self { - self.client_options = self - .client_options - .with_proxy_ca_certificate(proxy_ca_certificate); - self - } - - /// Set a list of hosts to exclude from proxy connections - pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { - self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); - self - } - - /// Sets the client options, overriding any already set - pub fn with_client_options(mut self, options: ClientOptions) -> Self { - self.client_options = options; - self - } - - /// Sets the endpoint for acquiring managed identity token - pub fn with_msi_endpoint(mut self, msi_endpoint: impl Into) -> Self { - self.msi_endpoint = Some(msi_endpoint.into()); - self - } - - /// Sets a file path for acquiring azure federated identity token in k8s - /// - /// requires `client_id` and `tenant_id` to be set - pub fn with_federated_token_file( - mut self, - federated_token_file: impl Into, - ) -> Self { - self.federated_token_file = Some(federated_token_file.into()); - self - } - - /// Set if the Azure Cli should be used for acquiring access token - /// - pub fn with_use_azure_cli(mut self, use_azure_cli: bool) -> Self { - self.use_azure_cli = use_azure_cli.into(); - self - } - - /// Configure a connection to container with given name on Microsoft Azure - /// Blob store. - pub fn build(mut self) -> Result { - if let Some(url) = self.url.take() { - self.parse_url(&url)?; - } - - let container = self.container_name.ok_or(Error::MissingContainerName {})?; - - let static_creds = |credential: AzureCredential| -> AzureCredentialProvider { - Arc::new(StaticCredentialProvider::new(credential)) - }; - - let (is_emulator, storage_url, auth, account) = if self.use_emulator.get()? { - let account_name = self - .account_name - .unwrap_or_else(|| EMULATOR_ACCOUNT.to_string()); - // Allow overriding defaults. Values taken from - // from https://docs.rs/azure_storage/0.2.0/src/azure_storage/core/clients/storage_account_client.rs.html#129-141 - let url = url_from_env("AZURITE_BLOB_STORAGE_URL", "http://127.0.0.1:10000")?; - let account_key = self - .access_key - .unwrap_or_else(|| EMULATOR_ACCOUNT_KEY.to_string()); - - let credential = static_creds(AzureCredential::AccessKey(account_key)); - - self.client_options = self.client_options.with_allow_http(true); - (true, url, credential, account_name) - } else { - let account_name = self.account_name.ok_or(Error::MissingAccount {})?; - let account_url = match self.endpoint { - Some(account_url) => account_url, - None => match self.use_fabric_endpoint.get()? { - true => { - format!("https://{}.blob.fabric.microsoft.com", &account_name) - } - false => format!("https://{}.blob.core.windows.net", &account_name), - }, - }; - - let url = Url::parse(&account_url) - .context(UnableToParseUrlSnafu { url: account_url })?; - - let credential = if let Some(credential) = self.credentials { - credential - } else if let Some(bearer_token) = self.bearer_token { - static_creds(AzureCredential::BearerToken(bearer_token)) - } else if let Some(access_key) = self.access_key { - static_creds(AzureCredential::AccessKey(access_key)) - } else if let (Some(client_id), Some(tenant_id), Some(federated_token_file)) = - (&self.client_id, &self.tenant_id, self.federated_token_file) - { - let client_credential = credential::WorkloadIdentityOAuthProvider::new( - client_id, - federated_token_file, - tenant_id, - self.authority_host, - ); - Arc::new(TokenCredentialProvider::new( - client_credential, - self.client_options.client()?, - self.retry_config.clone(), - )) as _ - } else if let (Some(client_id), Some(client_secret), Some(tenant_id)) = - (&self.client_id, self.client_secret, &self.tenant_id) - { - let client_credential = credential::ClientSecretOAuthProvider::new( - client_id.clone(), - client_secret, - tenant_id, - self.authority_host, - ); - Arc::new(TokenCredentialProvider::new( - client_credential, - self.client_options.client()?, - self.retry_config.clone(), - )) as _ - } else if let Some(query_pairs) = self.sas_query_pairs { - static_creds(AzureCredential::SASToken(query_pairs)) - } else if let Some(sas) = self.sas_key { - static_creds(AzureCredential::SASToken(split_sas(&sas)?)) - } else if self.use_azure_cli.get()? { - Arc::new(credential::AzureCliCredential::new()) as _ - } else { - let msi_credential = credential::ImdsManagedIdentityProvider::new( - self.client_id, - self.object_id, - self.msi_resource_id, - self.msi_endpoint, - ); - Arc::new(TokenCredentialProvider::new( - msi_credential, - self.client_options.metadata_client()?, - self.retry_config.clone(), - )) as _ - }; - (false, url, credential, account_name) - }; - - let config = client::AzureConfig { - account, - is_emulator, - container, - retry_config: self.retry_config, - client_options: self.client_options, - service: storage_url, - credentials: auth, - }; - - let client = Arc::new(client::AzureClient::new(config)?); - - Ok(MicrosoftAzure { client }) - } -} - -/// Parses the contents of the environment variable `env_name` as a URL -/// if present, otherwise falls back to default_url -fn url_from_env(env_name: &str, default_url: &str) -> Result { - let url = match std::env::var(env_name) { - Ok(env_value) => { - Url::parse(&env_value).context(UnableToParseEmulatorUrlSnafu { - env_name, - env_value, - })? - } - Err(_) => Url::parse(default_url).expect("Failed to parse default URL"), - }; - Ok(url) -} - -fn split_sas(sas: &str) -> Result, Error> { - let sas = percent_decode_str(sas) - .decode_utf8() - .context(DecodeSasKeySnafu {})?; - let kv_str_pairs = sas - .trim_start_matches('?') - .split('&') - .filter(|s| !s.chars().all(char::is_whitespace)); - let mut pairs = Vec::new(); - for kv_pair_str in kv_str_pairs { - let (k, v) = kv_pair_str - .trim() - .split_once('=') - .ok_or(Error::MissingSasComponent {})?; - pairs.push((k.into(), v.into())) - } - Ok(pairs) -} - #[cfg(test)] mod tests { use super::*; @@ -1133,7 +205,6 @@ mod tests { copy_if_not_exists, get_opts, list_uses_directories_correctly, list_with_delimiter, put_get_delete_list_opts, rename_and_copy, stream_get, }; - use std::collections::HashMap; #[tokio::test] async fn azure_blob_test() { @@ -1149,118 +220,6 @@ mod tests { stream_get(&integration).await; } - #[test] - fn azure_blob_test_urls() { - let mut builder = MicrosoftAzureBuilder::new(); - builder - .parse_url("abfss://file_system@account.dfs.core.windows.net/") - .unwrap(); - assert_eq!(builder.account_name, Some("account".to_string())); - assert_eq!(builder.container_name, Some("file_system".to_string())); - assert!(!builder.use_fabric_endpoint.get().unwrap()); - - let mut builder = MicrosoftAzureBuilder::new(); - builder - .parse_url("abfss://file_system@account.dfs.fabric.microsoft.com/") - .unwrap(); - assert_eq!(builder.account_name, Some("account".to_string())); - assert_eq!(builder.container_name, Some("file_system".to_string())); - assert!(builder.use_fabric_endpoint.get().unwrap()); - - let mut builder = MicrosoftAzureBuilder::new(); - builder.parse_url("abfs://container/path").unwrap(); - assert_eq!(builder.container_name, Some("container".to_string())); - - let mut builder = MicrosoftAzureBuilder::new(); - builder.parse_url("az://container").unwrap(); - assert_eq!(builder.container_name, Some("container".to_string())); - - let mut builder = MicrosoftAzureBuilder::new(); - builder.parse_url("az://container/path").unwrap(); - assert_eq!(builder.container_name, Some("container".to_string())); - - let mut builder = MicrosoftAzureBuilder::new(); - builder - .parse_url("https://account.dfs.core.windows.net/") - .unwrap(); - assert_eq!(builder.account_name, Some("account".to_string())); - assert!(!builder.use_fabric_endpoint.get().unwrap()); - - let mut builder = MicrosoftAzureBuilder::new(); - builder - .parse_url("https://account.blob.core.windows.net/") - .unwrap(); - assert_eq!(builder.account_name, Some("account".to_string())); - assert!(!builder.use_fabric_endpoint.get().unwrap()); - - let mut builder = MicrosoftAzureBuilder::new(); - builder - .parse_url("https://account.dfs.fabric.microsoft.com/") - .unwrap(); - assert_eq!(builder.account_name, Some("account".to_string())); - assert_eq!(builder.container_name, None); - assert!(builder.use_fabric_endpoint.get().unwrap()); - - let mut builder = MicrosoftAzureBuilder::new(); - builder - .parse_url("https://account.dfs.fabric.microsoft.com/container") - .unwrap(); - assert_eq!(builder.account_name, Some("account".to_string())); - assert_eq!(builder.container_name.as_deref(), Some("container")); - assert!(builder.use_fabric_endpoint.get().unwrap()); - - let mut builder = MicrosoftAzureBuilder::new(); - builder - .parse_url("https://account.blob.fabric.microsoft.com/") - .unwrap(); - assert_eq!(builder.account_name, Some("account".to_string())); - assert_eq!(builder.container_name, None); - assert!(builder.use_fabric_endpoint.get().unwrap()); - - let mut builder = MicrosoftAzureBuilder::new(); - builder - .parse_url("https://account.blob.fabric.microsoft.com/container") - .unwrap(); - assert_eq!(builder.account_name, Some("account".to_string())); - assert_eq!(builder.container_name.as_deref(), Some("container")); - assert!(builder.use_fabric_endpoint.get().unwrap()); - - let err_cases = [ - "mailto://account.blob.core.windows.net/", - "az://blob.mydomain/", - "abfs://container.foo/path", - "abfss://file_system@account.foo.dfs.core.windows.net/", - "abfss://file_system.bar@account.dfs.core.windows.net/", - "https://blob.mydomain/", - "https://blob.foo.dfs.core.windows.net/", - ]; - let mut builder = MicrosoftAzureBuilder::new(); - for case in err_cases { - builder.parse_url(case).unwrap_err(); - } - } - - #[test] - fn azure_test_config_from_map() { - let azure_client_id = "object_store:fake_access_key_id"; - let azure_storage_account_name = "object_store:fake_secret_key"; - let azure_storage_token = "object_store:fake_default_region"; - let options = HashMap::from([ - ("azure_client_id", azure_client_id), - ("azure_storage_account_name", azure_storage_account_name), - ("azure_storage_token", azure_storage_token), - ]); - - let builder = options - .into_iter() - .fold(MicrosoftAzureBuilder::new(), |builder, (key, value)| { - builder.with_config(key.parse().unwrap(), value) - }); - assert_eq!(builder.client_id.unwrap(), azure_client_id); - assert_eq!(builder.account_name.unwrap(), azure_storage_account_name); - assert_eq!(builder.bearer_token.unwrap(), azure_storage_token); - } - #[test] fn azure_test_config_get_value() { let azure_client_id = "object_store:fake_access_key_id".to_string(); @@ -1286,22 +245,4 @@ mod tests { azure_storage_token ); } - - #[test] - fn azure_test_split_sas() { - let raw_sas = "?sv=2021-10-04&st=2023-01-04T17%3A48%3A57Z&se=2023-01-04T18%3A15%3A00Z&sr=c&sp=rcwl&sig=C7%2BZeEOWbrxPA3R0Cw%2Fw1EZz0%2B4KBvQexeKZKe%2BB6h0%3D"; - let expected = vec![ - ("sv".to_string(), "2021-10-04".to_string()), - ("st".to_string(), "2023-01-04T17:48:57Z".to_string()), - ("se".to_string(), "2023-01-04T18:15:00Z".to_string()), - ("sr".to_string(), "c".to_string()), - ("sp".to_string(), "rcwl".to_string()), - ( - "sig".to_string(), - "C7+ZeEOWbrxPA3R0Cw/w1EZz0+4KBvQexeKZKe+B6h0=".to_string(), - ), - ]; - let pairs = split_sas(raw_sas).unwrap(); - assert_eq!(expected, pairs); - } } From efd4d1900a9d2cadd9393ab8c8b4eac77f6b88b5 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:41:17 +0100 Subject: [PATCH 20/25] Add module links in docs root (#4955) --- object_store/src/lib.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 018f0f5e8dec..86313616be1b 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -64,19 +64,19 @@ //! #![cfg_attr( feature = "gcp", - doc = "* `gcp`: [Google Cloud Storage](https://cloud.google.com/storage/) support. See [`GoogleCloudStorageBuilder`](gcp::GoogleCloudStorageBuilder)" + doc = "* [`gcp`]: [Google Cloud Storage](https://cloud.google.com/storage/) support. See [`GoogleCloudStorageBuilder`](gcp::GoogleCloudStorageBuilder)" )] #![cfg_attr( feature = "aws", - doc = "* `aws`: [Amazon S3](https://aws.amazon.com/s3/). See [`AmazonS3Builder`](aws::AmazonS3Builder)" + doc = "* [`aws`]: [Amazon S3](https://aws.amazon.com/s3/). See [`AmazonS3Builder`](aws::AmazonS3Builder)" )] #![cfg_attr( feature = "azure", - doc = "* `azure`: [Azure Blob Storage](https://azure.microsoft.com/en-gb/services/storage/blobs/). See [`MicrosoftAzureBuilder`](azure::MicrosoftAzureBuilder)" + doc = "* [`azure`]: [Azure Blob Storage](https://azure.microsoft.com/en-gb/services/storage/blobs/). See [`MicrosoftAzureBuilder`](azure::MicrosoftAzureBuilder)" )] #![cfg_attr( feature = "http", - doc = "* `http`: [HTTP/WebDAV Storage](https://datatracker.ietf.org/doc/html/rfc2518). See [`HttpBuilder`](http::HttpBuilder)" + doc = "* [`http`]: [HTTP/WebDAV Storage](https://datatracker.ietf.org/doc/html/rfc2518). See [`HttpBuilder`](http::HttpBuilder)" )] //! //! # Adapters From f597d3a6874264ebd9cf28a0d07a7fae52df440b Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:45:43 +0100 Subject: [PATCH 21/25] Split gcp Module (#4956) * Split out GCP client * Split out builder * RAT --- object_store/src/gcp/builder.rs | 705 ++++++++++++++++++++ object_store/src/gcp/client.rs | 446 +++++++++++++ object_store/src/gcp/mod.rs | 1097 +------------------------------ 3 files changed, 1177 insertions(+), 1071 deletions(-) create mode 100644 object_store/src/gcp/builder.rs create mode 100644 object_store/src/gcp/client.rs diff --git a/object_store/src/gcp/builder.rs b/object_store/src/gcp/builder.rs new file mode 100644 index 000000000000..920ab8b2a9b5 --- /dev/null +++ b/object_store/src/gcp/builder.rs @@ -0,0 +1,705 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::TokenCredentialProvider; +use crate::gcp::client::{GoogleCloudStorageClient, GoogleCloudStorageConfig}; +use crate::gcp::credential::{ + ApplicationDefaultCredentials, InstanceCredentialProvider, ServiceAccountCredentials, + DEFAULT_GCS_BASE_URL, +}; +use crate::gcp::{ + credential, GcpCredential, GcpCredentialProvider, GoogleCloudStorage, STORE, +}; +use crate::{ + ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider, +}; +use serde::{Deserialize, Serialize}; +use snafu::{OptionExt, ResultExt, Snafu}; +use std::str::FromStr; +use std::sync::Arc; +use url::Url; + +#[derive(Debug, Snafu)] +enum Error { + #[snafu(display("Missing bucket name"))] + MissingBucketName {}, + + #[snafu(display( + "One of service account path or service account key may be provided." + ))] + ServiceAccountPathAndKeyProvided, + + #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] + UnableToParseUrl { + source: url::ParseError, + url: String, + }, + + #[snafu(display( + "Unknown url scheme cannot be parsed into storage location: {}", + scheme + ))] + UnknownUrlScheme { scheme: String }, + + #[snafu(display("URL did not match any known pattern for scheme: {}", url))] + UrlNotRecognised { url: String }, + + #[snafu(display("Configuration key: '{}' is not known.", key))] + UnknownConfigurationKey { key: String }, + + #[snafu(display("Unable to extract metadata from headers: {}", source))] + Metadata { + source: crate::client::header::Error, + }, + + #[snafu(display("GCP credential error: {}", source))] + Credential { source: credential::Error }, +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + match err { + Error::UnknownConfigurationKey { key } => { + Self::UnknownConfigurationKey { store: STORE, key } + } + _ => Self::Generic { + store: STORE, + source: Box::new(err), + }, + } + } +} + +/// Configure a connection to Google Cloud Storage using the specified +/// credentials. +/// +/// # Example +/// ``` +/// # let BUCKET_NAME = "foo"; +/// # let SERVICE_ACCOUNT_PATH = "/tmp/foo.json"; +/// # use object_store::gcp::GoogleCloudStorageBuilder; +/// let gcs = GoogleCloudStorageBuilder::new() +/// .with_service_account_path(SERVICE_ACCOUNT_PATH) +/// .with_bucket_name(BUCKET_NAME) +/// .build(); +/// ``` +#[derive(Debug, Clone)] +pub struct GoogleCloudStorageBuilder { + /// Bucket name + bucket_name: Option, + /// Url + url: Option, + /// Path to the service account file + service_account_path: Option, + /// The serialized service account key + service_account_key: Option, + /// Path to the application credentials file. + application_credentials_path: Option, + /// Retry config + retry_config: RetryConfig, + /// Client options + client_options: ClientOptions, + /// Credentials + credentials: Option, +} + +/// Configuration keys for [`GoogleCloudStorageBuilder`] +/// +/// Configuration via keys can be done via [`GoogleCloudStorageBuilder::with_config`] +/// +/// # Example +/// ``` +/// # use object_store::gcp::{GoogleCloudStorageBuilder, GoogleConfigKey}; +/// let builder = GoogleCloudStorageBuilder::new() +/// .with_config("google_service_account".parse().unwrap(), "my-service-account") +/// .with_config(GoogleConfigKey::Bucket, "my-bucket"); +/// ``` +#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Serialize, Deserialize)] +#[non_exhaustive] +pub enum GoogleConfigKey { + /// Path to the service account file + /// + /// Supported keys: + /// - `google_service_account` + /// - `service_account` + /// - `google_service_account_path` + /// - `service_account_path` + ServiceAccount, + + /// The serialized service account key. + /// + /// Supported keys: + /// - `google_service_account_key` + /// - `service_account_key` + ServiceAccountKey, + + /// Bucket name + /// + /// See [`GoogleCloudStorageBuilder::with_bucket_name`] for details. + /// + /// Supported keys: + /// - `google_bucket` + /// - `google_bucket_name` + /// - `bucket` + /// - `bucket_name` + Bucket, + + /// Application credentials path + /// + /// See [`GoogleCloudStorageBuilder::with_application_credentials`]. + ApplicationCredentials, + + /// Client options + Client(ClientConfigKey), +} + +impl AsRef for GoogleConfigKey { + fn as_ref(&self) -> &str { + match self { + Self::ServiceAccount => "google_service_account", + Self::ServiceAccountKey => "google_service_account_key", + Self::Bucket => "google_bucket", + Self::ApplicationCredentials => "google_application_credentials", + Self::Client(key) => key.as_ref(), + } + } +} + +impl FromStr for GoogleConfigKey { + type Err = crate::Error; + + fn from_str(s: &str) -> Result { + match s { + "google_service_account" + | "service_account" + | "google_service_account_path" + | "service_account_path" => Ok(Self::ServiceAccount), + "google_service_account_key" | "service_account_key" => { + Ok(Self::ServiceAccountKey) + } + "google_bucket" | "google_bucket_name" | "bucket" | "bucket_name" => { + Ok(Self::Bucket) + } + "google_application_credentials" => Ok(Self::ApplicationCredentials), + _ => match s.parse() { + Ok(key) => Ok(Self::Client(key)), + Err(_) => Err(Error::UnknownConfigurationKey { key: s.into() }.into()), + }, + } + } +} + +impl Default for GoogleCloudStorageBuilder { + fn default() -> Self { + Self { + bucket_name: None, + service_account_path: None, + service_account_key: None, + application_credentials_path: None, + retry_config: Default::default(), + client_options: ClientOptions::new().with_allow_http(true), + url: None, + credentials: None, + } + } +} + +impl GoogleCloudStorageBuilder { + /// Create a new [`GoogleCloudStorageBuilder`] with default values. + pub fn new() -> Self { + Default::default() + } + + /// Create an instance of [`GoogleCloudStorageBuilder`] with values pre-populated from environment variables. + /// + /// Variables extracted from environment: + /// * GOOGLE_SERVICE_ACCOUNT: location of service account file + /// * GOOGLE_SERVICE_ACCOUNT_PATH: (alias) location of service account file + /// * SERVICE_ACCOUNT: (alias) location of service account file + /// * GOOGLE_SERVICE_ACCOUNT_KEY: JSON serialized service account key + /// * GOOGLE_BUCKET: bucket name + /// * GOOGLE_BUCKET_NAME: (alias) bucket name + /// + /// # Example + /// ``` + /// use object_store::gcp::GoogleCloudStorageBuilder; + /// + /// let gcs = GoogleCloudStorageBuilder::from_env() + /// .with_bucket_name("foo") + /// .build(); + /// ``` + pub fn from_env() -> Self { + let mut builder = Self::default(); + + if let Ok(service_account_path) = std::env::var("SERVICE_ACCOUNT") { + builder.service_account_path = Some(service_account_path); + } + + for (os_key, os_value) in std::env::vars_os() { + if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { + if key.starts_with("GOOGLE_") { + if let Ok(config_key) = key.to_ascii_lowercase().parse() { + builder = builder.with_config(config_key, value); + } + } + } + } + + builder + } + + /// Parse available connection info form a well-known storage URL. + /// + /// The supported url schemes are: + /// + /// - `gs:///` + /// + /// Note: Settings derived from the URL will override any others set on this builder + /// + /// # Example + /// ``` + /// use object_store::gcp::GoogleCloudStorageBuilder; + /// + /// let gcs = GoogleCloudStorageBuilder::from_env() + /// .with_url("gs://bucket/path") + /// .build(); + /// ``` + pub fn with_url(mut self, url: impl Into) -> Self { + self.url = Some(url.into()); + self + } + + /// Set an option on the builder via a key - value pair. + pub fn with_config(mut self, key: GoogleConfigKey, value: impl Into) -> Self { + match key { + GoogleConfigKey::ServiceAccount => { + self.service_account_path = Some(value.into()) + } + GoogleConfigKey::ServiceAccountKey => { + self.service_account_key = Some(value.into()) + } + GoogleConfigKey::Bucket => self.bucket_name = Some(value.into()), + GoogleConfigKey::ApplicationCredentials => { + self.application_credentials_path = Some(value.into()) + } + GoogleConfigKey::Client(key) => { + self.client_options = self.client_options.with_config(key, value) + } + }; + self + } + + /// Set an option on the builder via a key - value pair. + #[deprecated(note = "Use with_config")] + pub fn try_with_option( + self, + key: impl AsRef, + value: impl Into, + ) -> Result { + Ok(self.with_config(key.as_ref().parse()?, value)) + } + + /// Hydrate builder from key value pairs + #[deprecated(note = "Use with_config")] + #[allow(deprecated)] + pub fn try_with_options< + I: IntoIterator, impl Into)>, + >( + mut self, + options: I, + ) -> Result { + for (key, value) in options { + self = self.try_with_option(key, value)?; + } + Ok(self) + } + + /// Get config value via a [`GoogleConfigKey`]. + /// + /// # Example + /// ``` + /// use object_store::gcp::{GoogleCloudStorageBuilder, GoogleConfigKey}; + /// + /// let builder = GoogleCloudStorageBuilder::from_env() + /// .with_service_account_key("foo"); + /// let service_account_key = builder.get_config_value(&GoogleConfigKey::ServiceAccountKey).unwrap_or_default(); + /// assert_eq!("foo", &service_account_key); + /// ``` + pub fn get_config_value(&self, key: &GoogleConfigKey) -> Option { + match key { + GoogleConfigKey::ServiceAccount => self.service_account_path.clone(), + GoogleConfigKey::ServiceAccountKey => self.service_account_key.clone(), + GoogleConfigKey::Bucket => self.bucket_name.clone(), + GoogleConfigKey::ApplicationCredentials => { + self.application_credentials_path.clone() + } + GoogleConfigKey::Client(key) => self.client_options.get_config_value(key), + } + } + + /// Sets properties on this builder based on a URL + /// + /// This is a separate member function to allow fallible computation to + /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] + fn parse_url(&mut self, url: &str) -> Result<()> { + let parsed = Url::parse(url).context(UnableToParseUrlSnafu { url })?; + let host = parsed.host_str().context(UrlNotRecognisedSnafu { url })?; + + let validate = |s: &str| match s.contains('.') { + true => Err(UrlNotRecognisedSnafu { url }.build()), + false => Ok(s.to_string()), + }; + + match parsed.scheme() { + "gs" => self.bucket_name = Some(validate(host)?), + scheme => return Err(UnknownUrlSchemeSnafu { scheme }.build().into()), + } + Ok(()) + } + + /// Set the bucket name (required) + pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { + self.bucket_name = Some(bucket_name.into()); + self + } + + /// Set the path to the service account file. + /// + /// This or [`GoogleCloudStorageBuilder::with_service_account_key`] must be + /// set. + /// + /// Example `"/tmp/gcs.json"`. + /// + /// Example contents of `gcs.json`: + /// + /// ```json + /// { + /// "gcs_base_url": "https://localhost:4443", + /// "disable_oauth": true, + /// "client_email": "", + /// "private_key": "" + /// } + /// ``` + pub fn with_service_account_path( + mut self, + service_account_path: impl Into, + ) -> Self { + self.service_account_path = Some(service_account_path.into()); + self + } + + /// Set the service account key. The service account must be in the JSON + /// format. + /// + /// This or [`GoogleCloudStorageBuilder::with_service_account_path`] must be + /// set. + pub fn with_service_account_key( + mut self, + service_account: impl Into, + ) -> Self { + self.service_account_key = Some(service_account.into()); + self + } + + /// Set the path to the application credentials file. + /// + /// + pub fn with_application_credentials( + mut self, + application_credentials_path: impl Into, + ) -> Self { + self.application_credentials_path = Some(application_credentials_path.into()); + self + } + + /// Set the credential provider overriding any other options + pub fn with_credentials(mut self, credentials: GcpCredentialProvider) -> Self { + self.credentials = Some(credentials); + self + } + + /// Set the retry configuration + pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { + self.retry_config = retry_config; + self + } + + /// Set the proxy_url to be used by the underlying client + pub fn with_proxy_url(mut self, proxy_url: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_url(proxy_url); + self + } + + /// Set a trusted proxy CA certificate + pub fn with_proxy_ca_certificate( + mut self, + proxy_ca_certificate: impl Into, + ) -> Self { + self.client_options = self + .client_options + .with_proxy_ca_certificate(proxy_ca_certificate); + self + } + + /// Set a list of hosts to exclude from proxy connections + pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); + self + } + + /// Sets the client options, overriding any already set + pub fn with_client_options(mut self, options: ClientOptions) -> Self { + self.client_options = options; + self + } + + /// Configure a connection to Google Cloud Storage, returning a + /// new [`GoogleCloudStorage`] and consuming `self` + pub fn build(mut self) -> Result { + if let Some(url) = self.url.take() { + self.parse_url(&url)?; + } + + let bucket_name = self.bucket_name.ok_or(Error::MissingBucketName {})?; + + // First try to initialize from the service account information. + let service_account_credentials = + match (self.service_account_path, self.service_account_key) { + (Some(path), None) => Some( + ServiceAccountCredentials::from_file(path) + .context(CredentialSnafu)?, + ), + (None, Some(key)) => Some( + ServiceAccountCredentials::from_key(&key).context(CredentialSnafu)?, + ), + (None, None) => None, + (Some(_), Some(_)) => { + return Err(Error::ServiceAccountPathAndKeyProvided.into()) + } + }; + + // Then try to initialize from the application credentials file, or the environment. + let application_default_credentials = ApplicationDefaultCredentials::read( + self.application_credentials_path.as_deref(), + )?; + + let disable_oauth = service_account_credentials + .as_ref() + .map(|c| c.disable_oauth) + .unwrap_or(false); + + let gcs_base_url: String = service_account_credentials + .as_ref() + .and_then(|c| c.gcs_base_url.clone()) + .unwrap_or_else(|| DEFAULT_GCS_BASE_URL.to_string()); + + let credentials = if let Some(credentials) = self.credentials { + credentials + } else if disable_oauth { + Arc::new(StaticCredentialProvider::new(GcpCredential { + bearer: "".to_string(), + })) as _ + } else if let Some(credentials) = service_account_credentials { + Arc::new(TokenCredentialProvider::new( + credentials.token_provider()?, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ + } else if let Some(credentials) = application_default_credentials { + match credentials { + ApplicationDefaultCredentials::AuthorizedUser(token) => { + Arc::new(TokenCredentialProvider::new( + token, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ + } + ApplicationDefaultCredentials::ServiceAccount(token) => { + Arc::new(TokenCredentialProvider::new( + token.token_provider()?, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ + } + } + } else { + Arc::new(TokenCredentialProvider::new( + InstanceCredentialProvider::default(), + self.client_options.metadata_client()?, + self.retry_config.clone(), + )) as _ + }; + + let config = GoogleCloudStorageConfig { + base_url: gcs_base_url, + credentials, + bucket_name, + retry_config: self.retry_config, + client_options: self.client_options, + }; + + Ok(GoogleCloudStorage { + client: Arc::new(GoogleCloudStorageClient::new(config)?), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + use std::io::Write; + use tempfile::NamedTempFile; + + const FAKE_KEY: &str = r#"{"private_key": "private_key", "private_key_id": "private_key_id", "client_email":"client_email", "disable_oauth":true}"#; + + #[test] + fn gcs_test_service_account_key_and_path() { + let mut tfile = NamedTempFile::new().unwrap(); + write!(tfile, "{FAKE_KEY}").unwrap(); + let _ = GoogleCloudStorageBuilder::new() + .with_service_account_key(FAKE_KEY) + .with_service_account_path(tfile.path().to_str().unwrap()) + .with_bucket_name("foo") + .build() + .unwrap_err(); + } + + #[test] + fn gcs_test_config_from_map() { + let google_service_account = "object_store:fake_service_account".to_string(); + let google_bucket_name = "object_store:fake_bucket".to_string(); + let options = HashMap::from([ + ("google_service_account", google_service_account.clone()), + ("google_bucket_name", google_bucket_name.clone()), + ]); + + let builder = options + .iter() + .fold(GoogleCloudStorageBuilder::new(), |builder, (key, value)| { + builder.with_config(key.parse().unwrap(), value) + }); + + assert_eq!( + builder.service_account_path.unwrap(), + google_service_account.as_str() + ); + assert_eq!(builder.bucket_name.unwrap(), google_bucket_name.as_str()); + } + + #[test] + fn gcs_test_config_aliases() { + // Service account path + for alias in [ + "google_service_account", + "service_account", + "google_service_account_path", + "service_account_path", + ] { + let builder = GoogleCloudStorageBuilder::new() + .with_config(alias.parse().unwrap(), "/fake/path.json"); + assert_eq!("/fake/path.json", builder.service_account_path.unwrap()); + } + + // Service account key + for alias in ["google_service_account_key", "service_account_key"] { + let builder = GoogleCloudStorageBuilder::new() + .with_config(alias.parse().unwrap(), FAKE_KEY); + assert_eq!(FAKE_KEY, builder.service_account_key.unwrap()); + } + + // Bucket name + for alias in [ + "google_bucket", + "google_bucket_name", + "bucket", + "bucket_name", + ] { + let builder = GoogleCloudStorageBuilder::new() + .with_config(alias.parse().unwrap(), "fake_bucket"); + assert_eq!("fake_bucket", builder.bucket_name.unwrap()); + } + } + + #[tokio::test] + async fn gcs_test_proxy_url() { + let mut tfile = NamedTempFile::new().unwrap(); + write!(tfile, "{FAKE_KEY}").unwrap(); + let service_account_path = tfile.path(); + let gcs = GoogleCloudStorageBuilder::new() + .with_service_account_path(service_account_path.to_str().unwrap()) + .with_bucket_name("foo") + .with_proxy_url("https://example.com") + .build(); + assert!(dbg!(gcs).is_ok()); + + let err = GoogleCloudStorageBuilder::new() + .with_service_account_path(service_account_path.to_str().unwrap()) + .with_bucket_name("foo") + .with_proxy_url("asdf://example.com") + .build() + .unwrap_err() + .to_string(); + + assert_eq!( + "Generic HTTP client error: builder error: unknown proxy scheme", + err + ); + } + + #[test] + fn gcs_test_urls() { + let mut builder = GoogleCloudStorageBuilder::new(); + builder.parse_url("gs://bucket/path").unwrap(); + assert_eq!(builder.bucket_name, Some("bucket".to_string())); + + let err_cases = ["mailto://bucket/path", "gs://bucket.mydomain/path"]; + let mut builder = GoogleCloudStorageBuilder::new(); + for case in err_cases { + builder.parse_url(case).unwrap_err(); + } + } + + #[test] + fn gcs_test_service_account_key_only() { + let _ = GoogleCloudStorageBuilder::new() + .with_service_account_key(FAKE_KEY) + .with_bucket_name("foo") + .build() + .unwrap(); + } + + #[test] + fn gcs_test_config_get_value() { + let google_service_account = "object_store:fake_service_account".to_string(); + let google_bucket_name = "object_store:fake_bucket".to_string(); + let builder = GoogleCloudStorageBuilder::new() + .with_config(GoogleConfigKey::ServiceAccount, &google_service_account) + .with_config(GoogleConfigKey::Bucket, &google_bucket_name); + + assert_eq!( + builder + .get_config_value(&GoogleConfigKey::ServiceAccount) + .unwrap(), + google_service_account + ); + assert_eq!( + builder.get_config_value(&GoogleConfigKey::Bucket).unwrap(), + google_bucket_name + ); + } +} diff --git a/object_store/src/gcp/client.rs b/object_store/src/gcp/client.rs new file mode 100644 index 000000000000..9141a9da8c5b --- /dev/null +++ b/object_store/src/gcp/client.rs @@ -0,0 +1,446 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::get::GetClient; +use crate::client::header::get_etag; +use crate::client::list::ListClient; +use crate::client::list_response::ListResponse; +use crate::client::retry::RetryExt; +use crate::client::GetOptionsExt; +use crate::gcp::{GcpCredential, GcpCredentialProvider, STORE}; +use crate::multipart::PartId; +use crate::path::{Path, DELIMITER}; +use crate::{ClientOptions, GetOptions, ListResult, MultipartId, Result, RetryConfig}; +use async_trait::async_trait; +use bytes::{Buf, Bytes}; +use percent_encoding::{percent_encode, utf8_percent_encode, NON_ALPHANUMERIC}; +use reqwest::{header, Client, Method, Response, StatusCode}; +use serde::Serialize; +use snafu::{ResultExt, Snafu}; +use std::sync::Arc; + +#[derive(Debug, Snafu)] +enum Error { + #[snafu(display("Error performing list request: {}", source))] + ListRequest { source: crate::client::retry::Error }, + + #[snafu(display("Error getting list response body: {}", source))] + ListResponseBody { source: reqwest::Error }, + + #[snafu(display("Got invalid list response: {}", source))] + InvalidListResponse { source: quick_xml::de::DeError }, + + #[snafu(display("Error performing get request {}: {}", path, source))] + GetRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error performing delete request {}: {}", path, source))] + DeleteRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error performing put request {}: {}", path, source))] + PutRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error getting put response body: {}", source))] + PutResponseBody { source: reqwest::Error }, + + #[snafu(display("Got invalid put response: {}", source))] + InvalidPutResponse { source: quick_xml::de::DeError }, + + #[snafu(display("Error performing post request {}: {}", path, source))] + PostRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Unable to extract metadata from headers: {}", source))] + Metadata { + source: crate::client::header::Error, + }, +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + match err { + Error::GetRequest { source, path } + | Error::DeleteRequest { source, path } + | Error::PutRequest { source, path } => source.error(STORE, path), + _ => Self::Generic { + store: STORE, + source: Box::new(err), + }, + } + } +} + +#[derive(Debug)] +pub struct GoogleCloudStorageConfig { + pub base_url: String, + + pub credentials: GcpCredentialProvider, + + pub bucket_name: String, + + pub retry_config: RetryConfig, + + pub client_options: ClientOptions, +} + +#[derive(Debug)] +pub struct GoogleCloudStorageClient { + config: GoogleCloudStorageConfig, + + client: Client, + + bucket_name_encoded: String, + + // TODO: Hook this up in tests + max_list_results: Option, +} + +impl GoogleCloudStorageClient { + pub fn new(config: GoogleCloudStorageConfig) -> Result { + let client = config.client_options.client()?; + let bucket_name_encoded = + percent_encode(config.bucket_name.as_bytes(), NON_ALPHANUMERIC).to_string(); + + Ok(Self { + config, + client, + bucket_name_encoded, + max_list_results: None, + }) + } + + pub fn config(&self) -> &GoogleCloudStorageConfig { + &self.config + } + + async fn get_credential(&self) -> Result> { + self.config.credentials.get_credential().await + } + + pub fn object_url(&self, path: &Path) -> String { + let encoded = utf8_percent_encode(path.as_ref(), NON_ALPHANUMERIC); + format!( + "{}/{}/{}", + self.config.base_url, self.bucket_name_encoded, encoded + ) + } + + /// Perform a put request + /// + /// Returns the new ETag + pub async fn put_request( + &self, + path: &Path, + payload: Bytes, + query: &T, + ) -> Result { + let credential = self.get_credential().await?; + let url = self.object_url(path); + + let content_type = self + .config + .client_options + .get_content_type(path) + .unwrap_or("application/octet-stream"); + + let response = self + .client + .request(Method::PUT, url) + .query(query) + .bearer_auth(&credential.bearer) + .header(header::CONTENT_TYPE, content_type) + .header(header::CONTENT_LENGTH, payload.len()) + .body(payload) + .send_retry(&self.config.retry_config) + .await + .context(PutRequestSnafu { + path: path.as_ref(), + })?; + + Ok(get_etag(response.headers()).context(MetadataSnafu)?) + } + + /// Initiate a multi-part upload + pub async fn multipart_initiate(&self, path: &Path) -> Result { + let credential = self.get_credential().await?; + let url = self.object_url(path); + + let content_type = self + .config + .client_options + .get_content_type(path) + .unwrap_or("application/octet-stream"); + + let response = self + .client + .request(Method::POST, &url) + .bearer_auth(&credential.bearer) + .header(header::CONTENT_TYPE, content_type) + .header(header::CONTENT_LENGTH, "0") + .query(&[("uploads", "")]) + .send_retry(&self.config.retry_config) + .await + .context(PutRequestSnafu { + path: path.as_ref(), + })?; + + let data = response.bytes().await.context(PutResponseBodySnafu)?; + let result: InitiateMultipartUploadResult = + quick_xml::de::from_reader(data.as_ref().reader()) + .context(InvalidPutResponseSnafu)?; + + Ok(result.upload_id) + } + + /// Cleanup unused parts + pub async fn multipart_cleanup( + &self, + path: &Path, + multipart_id: &MultipartId, + ) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.object_url(path); + + self.client + .request(Method::DELETE, &url) + .bearer_auth(&credential.bearer) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header(header::CONTENT_LENGTH, "0") + .query(&[("uploadId", multipart_id)]) + .send_retry(&self.config.retry_config) + .await + .context(PutRequestSnafu { + path: path.as_ref(), + })?; + + Ok(()) + } + + pub async fn multipart_complete( + &self, + path: &Path, + multipart_id: &MultipartId, + completed_parts: Vec, + ) -> Result<()> { + let upload_id = multipart_id.clone(); + let url = self.object_url(path); + + let parts = completed_parts + .into_iter() + .enumerate() + .map(|(part_number, part)| MultipartPart { + e_tag: part.content_id, + part_number: part_number + 1, + }) + .collect(); + + let credential = self.get_credential().await?; + let upload_info = CompleteMultipartUpload { parts }; + + let data = quick_xml::se::to_string(&upload_info) + .context(InvalidPutResponseSnafu)? + // We cannot disable the escaping that transforms "/" to ""e;" :( + // https://github.com/tafia/quick-xml/issues/362 + // https://github.com/tafia/quick-xml/issues/350 + .replace(""", "\""); + + self.client + .request(Method::POST, &url) + .bearer_auth(&credential.bearer) + .query(&[("uploadId", upload_id)]) + .body(data) + .send_retry(&self.config.retry_config) + .await + .context(PostRequestSnafu { + path: path.as_ref(), + })?; + + Ok(()) + } + + /// Perform a delete request + pub async fn delete_request(&self, path: &Path) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.object_url(path); + + let builder = self.client.request(Method::DELETE, url); + builder + .bearer_auth(&credential.bearer) + .send_retry(&self.config.retry_config) + .await + .context(DeleteRequestSnafu { + path: path.as_ref(), + })?; + + Ok(()) + } + + /// Perform a copy request + pub async fn copy_request( + &self, + from: &Path, + to: &Path, + if_not_exists: bool, + ) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.object_url(to); + + let from = utf8_percent_encode(from.as_ref(), NON_ALPHANUMERIC); + let source = format!("{}/{}", self.bucket_name_encoded, from); + + let mut builder = self + .client + .request(Method::PUT, url) + .header("x-goog-copy-source", source); + + if if_not_exists { + builder = builder.header("x-goog-if-generation-match", 0); + } + + builder + .bearer_auth(&credential.bearer) + // Needed if reqwest is compiled with native-tls instead of rustls-tls + // See https://github.com/apache/arrow-rs/pull/3921 + .header(header::CONTENT_LENGTH, 0) + .send_retry(&self.config.retry_config) + .await + .map_err(|err| match err.status() { + Some(StatusCode::PRECONDITION_FAILED) => crate::Error::AlreadyExists { + source: Box::new(err), + path: to.to_string(), + }, + _ => err.error(STORE, from.to_string()), + })?; + + Ok(()) + } +} + +#[async_trait] +impl GetClient for GoogleCloudStorageClient { + const STORE: &'static str = STORE; + + /// Perform a get request + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { + let credential = self.get_credential().await?; + let url = self.object_url(path); + + let method = match options.head { + true => Method::HEAD, + false => Method::GET, + }; + + let mut request = self.client.request(method, url).with_get_options(options); + + if !credential.bearer.is_empty() { + request = request.bearer_auth(&credential.bearer); + } + + let response = request + .send_retry(&self.config.retry_config) + .await + .context(GetRequestSnafu { + path: path.as_ref(), + })?; + + Ok(response) + } +} + +#[async_trait] +impl ListClient for GoogleCloudStorageClient { + /// Perform a list request + async fn list_request( + &self, + prefix: Option<&str>, + delimiter: bool, + page_token: Option<&str>, + offset: Option<&str>, + ) -> Result<(ListResult, Option)> { + assert!(offset.is_none()); // Not yet supported + + let credential = self.get_credential().await?; + let url = format!("{}/{}", self.config.base_url, self.bucket_name_encoded); + + let mut query = Vec::with_capacity(5); + query.push(("list-type", "2")); + if delimiter { + query.push(("delimiter", DELIMITER)) + } + + if let Some(prefix) = &prefix { + query.push(("prefix", prefix)) + } + + if let Some(page_token) = page_token { + query.push(("continuation-token", page_token)) + } + + if let Some(max_results) = &self.max_list_results { + query.push(("max-keys", max_results)) + } + + let response = self + .client + .request(Method::GET, url) + .query(&query) + .bearer_auth(&credential.bearer) + .send_retry(&self.config.retry_config) + .await + .context(ListRequestSnafu)? + .bytes() + .await + .context(ListResponseBodySnafu)?; + + let mut response: ListResponse = quick_xml::de::from_reader(response.reader()) + .context(InvalidListResponseSnafu)?; + + let token = response.next_continuation_token.take(); + Ok((response.try_into()?, token)) + } +} + +#[derive(serde::Deserialize, Debug)] +#[serde(rename_all = "PascalCase")] +struct InitiateMultipartUploadResult { + upload_id: String, +} + +#[derive(serde::Serialize, Debug)] +#[serde(rename_all = "PascalCase", rename(serialize = "Part"))] +struct MultipartPart { + #[serde(rename = "PartNumber")] + part_number: usize, + e_tag: String, +} + +#[derive(serde::Serialize, Debug)] +#[serde(rename_all = "PascalCase")] +struct CompleteMultipartUpload { + #[serde(rename = "Part", default)] + parts: Vec, +} diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index 97755c07c671..7c69d288740c 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -29,176 +29,34 @@ //! to abort the upload and drop those unneeded parts. In addition, you may wish to //! consider implementing automatic clean up of unused parts that are older than one //! week. -use std::str::FromStr; use std::sync::Arc; -use async_trait::async_trait; -use bytes::{Buf, Bytes}; -use futures::stream::BoxStream; -use percent_encoding::{percent_encode, utf8_percent_encode, NON_ALPHANUMERIC}; -use reqwest::{header, Client, Method, Response, StatusCode}; -use serde::{Deserialize, Serialize}; -use snafu::{OptionExt, ResultExt, Snafu}; -use tokio::io::AsyncWrite; -use url::Url; - -use crate::client::get::{GetClient, GetClientExt}; -use crate::client::list::{ListClient, ListClientExt}; -use crate::client::list_response::ListResponse; -use crate::client::retry::RetryExt; -use crate::client::{ - ClientConfigKey, CredentialProvider, GetOptionsExt, StaticCredentialProvider, - TokenCredentialProvider, -}; +use crate::client::CredentialProvider; use crate::{ multipart::{PartId, PutPart, WriteMultiPart}, - path::{Path, DELIMITER}, - ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, - ObjectStore, PutResult, Result, RetryConfig, + path::Path, + GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutResult, + Result, }; +use async_trait::async_trait; +use bytes::Bytes; +use client::GoogleCloudStorageClient; +use futures::stream::BoxStream; +use tokio::io::AsyncWrite; -use credential::{InstanceCredentialProvider, ServiceAccountCredentials}; +use crate::client::get::GetClientExt; +use crate::client::list::ListClientExt; +pub use builder::{GoogleCloudStorageBuilder, GoogleConfigKey}; +pub use credential::GcpCredential; +mod builder; +mod client; mod credential; const STORE: &str = "GCS"; /// [`CredentialProvider`] for [`GoogleCloudStorage`] pub type GcpCredentialProvider = Arc>; -use crate::client::header::get_etag; -use crate::gcp::credential::{ApplicationDefaultCredentials, DEFAULT_GCS_BASE_URL}; -pub use credential::GcpCredential; - -#[derive(Debug, Snafu)] -enum Error { - #[snafu(display("Got invalid XML response for {} {}: {}", method, url, source))] - InvalidXMLResponse { - source: quick_xml::de::DeError, - method: String, - url: String, - data: Bytes, - }, - - #[snafu(display("Error performing list request: {}", source))] - ListRequest { source: crate::client::retry::Error }, - - #[snafu(display("Error getting list response body: {}", source))] - ListResponseBody { source: reqwest::Error }, - - #[snafu(display("Got invalid list response: {}", source))] - InvalidListResponse { source: quick_xml::de::DeError }, - - #[snafu(display("Error performing get request {}: {}", path, source))] - GetRequest { - source: crate::client::retry::Error, - path: String, - }, - - #[snafu(display("Error getting get response body {}: {}", path, source))] - GetResponseBody { - source: reqwest::Error, - path: String, - }, - - #[snafu(display("Error performing delete request {}: {}", path, source))] - DeleteRequest { - source: crate::client::retry::Error, - path: String, - }, - - #[snafu(display("Error performing put request {}: {}", path, source))] - PutRequest { - source: crate::client::retry::Error, - path: String, - }, - - #[snafu(display("Error getting put response body: {}", source))] - PutResponseBody { source: reqwest::Error }, - - #[snafu(display("Got invalid put response: {}", source))] - InvalidPutResponse { source: quick_xml::de::DeError }, - - #[snafu(display("Error performing post request {}: {}", path, source))] - PostRequest { - source: crate::client::retry::Error, - path: String, - }, - - #[snafu(display("Error decoding object size: {}", source))] - InvalidSize { source: std::num::ParseIntError }, - - #[snafu(display("Missing bucket name"))] - MissingBucketName {}, - - #[snafu(display( - "One of service account path or service account key may be provided." - ))] - ServiceAccountPathAndKeyProvided, - - #[snafu(display("GCP credential error: {}", source))] - Credential { source: credential::Error }, - - #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] - UnableToParseUrl { - source: url::ParseError, - url: String, - }, - - #[snafu(display( - "Unknown url scheme cannot be parsed into storage location: {}", - scheme - ))] - UnknownUrlScheme { scheme: String }, - - #[snafu(display("URL did not match any known pattern for scheme: {}", url))] - UrlNotRecognised { url: String }, - - #[snafu(display("Configuration key: '{}' is not known.", key))] - UnknownConfigurationKey { key: String }, - - #[snafu(display("Unable to extract metadata from headers: {}", source))] - Metadata { - source: crate::client::header::Error, - }, -} - -impl From for super::Error { - fn from(err: Error) -> Self { - match err { - Error::GetRequest { source, path } - | Error::DeleteRequest { source, path } - | Error::PutRequest { source, path } => source.error(STORE, path), - Error::UnknownConfigurationKey { key } => { - Self::UnknownConfigurationKey { store: STORE, key } - } - _ => Self::Generic { - store: STORE, - source: Box::new(err), - }, - } - } -} - -#[derive(serde::Deserialize, Debug)] -#[serde(rename_all = "PascalCase")] -struct InitiateMultipartUploadResult { - upload_id: String, -} - -#[derive(serde::Serialize, Debug)] -#[serde(rename_all = "PascalCase", rename(serialize = "Part"))] -struct MultipartPart { - #[serde(rename = "PartNumber")] - part_number: usize, - e_tag: String, -} - -#[derive(serde::Serialize, Debug)] -#[serde(rename_all = "PascalCase")] -struct CompleteMultipartUpload { - #[serde(rename = "Part", default)] - parts: Vec, -} /// Interface for [Google Cloud Storage](https://cloud.google.com/storage/). #[derive(Debug)] @@ -208,271 +66,18 @@ pub struct GoogleCloudStorage { impl std::fmt::Display for GoogleCloudStorage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "GoogleCloudStorage({})", self.client.bucket_name) + write!( + f, + "GoogleCloudStorage({})", + self.client.config().bucket_name + ) } } impl GoogleCloudStorage { /// Returns the [`GcpCredentialProvider`] used by [`GoogleCloudStorage`] pub fn credentials(&self) -> &GcpCredentialProvider { - &self.client.credentials - } -} - -#[derive(Debug)] -struct GoogleCloudStorageClient { - client: Client, - base_url: String, - - credentials: GcpCredentialProvider, - - bucket_name: String, - bucket_name_encoded: String, - - retry_config: RetryConfig, - client_options: ClientOptions, - - // TODO: Hook this up in tests - max_list_results: Option, -} - -impl GoogleCloudStorageClient { - async fn get_credential(&self) -> Result> { - self.credentials.get_credential().await - } - - fn object_url(&self, path: &Path) -> String { - let encoded = utf8_percent_encode(path.as_ref(), NON_ALPHANUMERIC); - format!("{}/{}/{}", self.base_url, self.bucket_name_encoded, encoded) - } - - /// Perform a put request - /// - /// Returns the new ETag - async fn put_request( - &self, - path: &Path, - payload: Bytes, - query: &T, - ) -> Result { - let credential = self.get_credential().await?; - let url = self.object_url(path); - - let content_type = self - .client_options - .get_content_type(path) - .unwrap_or("application/octet-stream"); - - let response = self - .client - .request(Method::PUT, url) - .query(query) - .bearer_auth(&credential.bearer) - .header(header::CONTENT_TYPE, content_type) - .header(header::CONTENT_LENGTH, payload.len()) - .body(payload) - .send_retry(&self.retry_config) - .await - .context(PutRequestSnafu { - path: path.as_ref(), - })?; - - Ok(get_etag(response.headers()).context(MetadataSnafu)?) - } - - /// Initiate a multi-part upload - async fn multipart_initiate(&self, path: &Path) -> Result { - let credential = self.get_credential().await?; - let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded, path); - - let content_type = self - .client_options - .get_content_type(path) - .unwrap_or("application/octet-stream"); - - let response = self - .client - .request(Method::POST, &url) - .bearer_auth(&credential.bearer) - .header(header::CONTENT_TYPE, content_type) - .header(header::CONTENT_LENGTH, "0") - .query(&[("uploads", "")]) - .send_retry(&self.retry_config) - .await - .context(PutRequestSnafu { - path: path.as_ref(), - })?; - - let data = response.bytes().await.context(PutResponseBodySnafu)?; - let result: InitiateMultipartUploadResult = - quick_xml::de::from_reader(data.as_ref().reader()) - .context(InvalidPutResponseSnafu)?; - - Ok(result.upload_id) - } - - /// Cleanup unused parts - async fn multipart_cleanup( - &self, - path: &str, - multipart_id: &MultipartId, - ) -> Result<()> { - let credential = self.get_credential().await?; - let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded, path); - - self.client - .request(Method::DELETE, &url) - .bearer_auth(&credential.bearer) - .header(header::CONTENT_TYPE, "application/octet-stream") - .header(header::CONTENT_LENGTH, "0") - .query(&[("uploadId", multipart_id)]) - .send_retry(&self.retry_config) - .await - .context(PutRequestSnafu { path })?; - - Ok(()) - } - - /// Perform a delete request - async fn delete_request(&self, path: &Path) -> Result<()> { - let credential = self.get_credential().await?; - let url = self.object_url(path); - - let builder = self.client.request(Method::DELETE, url); - builder - .bearer_auth(&credential.bearer) - .send_retry(&self.retry_config) - .await - .context(DeleteRequestSnafu { - path: path.as_ref(), - })?; - - Ok(()) - } - - /// Perform a copy request - async fn copy_request( - &self, - from: &Path, - to: &Path, - if_not_exists: bool, - ) -> Result<()> { - let credential = self.get_credential().await?; - let url = self.object_url(to); - - let from = utf8_percent_encode(from.as_ref(), NON_ALPHANUMERIC); - let source = format!("{}/{}", self.bucket_name_encoded, from); - - let mut builder = self - .client - .request(Method::PUT, url) - .header("x-goog-copy-source", source); - - if if_not_exists { - builder = builder.header("x-goog-if-generation-match", 0); - } - - builder - .bearer_auth(&credential.bearer) - // Needed if reqwest is compiled with native-tls instead of rustls-tls - // See https://github.com/apache/arrow-rs/pull/3921 - .header(header::CONTENT_LENGTH, 0) - .send_retry(&self.retry_config) - .await - .map_err(|err| match err.status() { - Some(StatusCode::PRECONDITION_FAILED) => crate::Error::AlreadyExists { - source: Box::new(err), - path: to.to_string(), - }, - _ => err.error(STORE, from.to_string()), - })?; - - Ok(()) - } -} - -#[async_trait] -impl GetClient for GoogleCloudStorageClient { - const STORE: &'static str = STORE; - - /// Perform a get request - async fn get_request(&self, path: &Path, options: GetOptions) -> Result { - let credential = self.get_credential().await?; - let url = self.object_url(path); - - let method = match options.head { - true => Method::HEAD, - false => Method::GET, - }; - - let mut request = self.client.request(method, url).with_get_options(options); - - if !credential.bearer.is_empty() { - request = request.bearer_auth(&credential.bearer); - } - - let response = - request - .send_retry(&self.retry_config) - .await - .context(GetRequestSnafu { - path: path.as_ref(), - })?; - - Ok(response) - } -} - -#[async_trait] -impl ListClient for GoogleCloudStorageClient { - /// Perform a list request - async fn list_request( - &self, - prefix: Option<&str>, - delimiter: bool, - page_token: Option<&str>, - offset: Option<&str>, - ) -> Result<(ListResult, Option)> { - assert!(offset.is_none()); // Not yet supported - - let credential = self.get_credential().await?; - let url = format!("{}/{}", self.base_url, self.bucket_name_encoded); - - let mut query = Vec::with_capacity(5); - query.push(("list-type", "2")); - if delimiter { - query.push(("delimiter", DELIMITER)) - } - - if let Some(prefix) = &prefix { - query.push(("prefix", prefix)) - } - - if let Some(page_token) = page_token { - query.push(("continuation-token", page_token)) - } - - if let Some(max_results) = &self.max_list_results { - query.push(("max-keys", max_results)) - } - - let response = self - .client - .request(Method::GET, url) - .query(&query) - .bearer_auth(&credential.bearer) - .send_retry(&self.retry_config) - .await - .context(ListRequestSnafu)? - .bytes() - .await - .context(ListResponseBodySnafu)?; - - let mut response: ListResponse = quick_xml::de::from_reader(response.reader()) - .context(InvalidListResponseSnafu)?; - - let token = response.next_continuation_token.take(); - Ok((response.try_into()?, token)) + &self.client.config().credentials } } @@ -504,41 +109,9 @@ impl PutPart for GCSMultipartUpload { /// Complete a multipart upload async fn complete(&self, completed_parts: Vec) -> Result<()> { - let upload_id = self.multipart_id.clone(); - let url = self.client.object_url(&self.path); - - let parts = completed_parts - .into_iter() - .enumerate() - .map(|(part_number, part)| MultipartPart { - e_tag: part.content_id, - part_number: part_number + 1, - }) - .collect(); - - let credential = self.client.get_credential().await?; - let upload_info = CompleteMultipartUpload { parts }; - - let data = quick_xml::se::to_string(&upload_info) - .context(InvalidPutResponseSnafu)? - // We cannot disable the escaping that transforms "/" to ""e;" :( - // https://github.com/tafia/quick-xml/issues/362 - // https://github.com/tafia/quick-xml/issues/350 - .replace(""", "\""); - self.client - .client - .request(Method::POST, &url) - .bearer_auth(&credential.bearer) - .query(&[("uploadId", upload_id)]) - .body(data) - .send_retry(&self.client.retry_config) + .multipart_complete(&self.path, &self.multipart_id, completed_parts) .await - .context(PostRequestSnafu { - path: self.path.as_ref(), - })?; - - Ok(()) } } @@ -570,7 +143,7 @@ impl ObjectStore for GoogleCloudStorage { multipart_id: &MultipartId, ) -> Result<()> { self.client - .multipart_cleanup(location.as_ref(), multipart_id) + .multipart_cleanup(location, multipart_id) .await?; Ok(()) @@ -601,498 +174,16 @@ impl ObjectStore for GoogleCloudStorage { } } -/// Configure a connection to Google Cloud Storage using the specified -/// credentials. -/// -/// # Example -/// ``` -/// # let BUCKET_NAME = "foo"; -/// # let SERVICE_ACCOUNT_PATH = "/tmp/foo.json"; -/// # use object_store::gcp::GoogleCloudStorageBuilder; -/// let gcs = GoogleCloudStorageBuilder::new() -/// .with_service_account_path(SERVICE_ACCOUNT_PATH) -/// .with_bucket_name(BUCKET_NAME) -/// .build(); -/// ``` -#[derive(Debug, Clone)] -pub struct GoogleCloudStorageBuilder { - /// Bucket name - bucket_name: Option, - /// Url - url: Option, - /// Path to the service account file - service_account_path: Option, - /// The serialized service account key - service_account_key: Option, - /// Path to the application credentials file. - application_credentials_path: Option, - /// Retry config - retry_config: RetryConfig, - /// Client options - client_options: ClientOptions, - /// Credentials - credentials: Option, -} - -/// Configuration keys for [`GoogleCloudStorageBuilder`] -/// -/// Configuration via keys can be done via [`GoogleCloudStorageBuilder::with_config`] -/// -/// # Example -/// ``` -/// # use object_store::gcp::{GoogleCloudStorageBuilder, GoogleConfigKey}; -/// let builder = GoogleCloudStorageBuilder::new() -/// .with_config("google_service_account".parse().unwrap(), "my-service-account") -/// .with_config(GoogleConfigKey::Bucket, "my-bucket"); -/// ``` -#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Serialize, Deserialize)] -#[non_exhaustive] -pub enum GoogleConfigKey { - /// Path to the service account file - /// - /// Supported keys: - /// - `google_service_account` - /// - `service_account` - /// - `google_service_account_path` - /// - `service_account_path` - ServiceAccount, - - /// The serialized service account key. - /// - /// Supported keys: - /// - `google_service_account_key` - /// - `service_account_key` - ServiceAccountKey, - - /// Bucket name - /// - /// See [`GoogleCloudStorageBuilder::with_bucket_name`] for details. - /// - /// Supported keys: - /// - `google_bucket` - /// - `google_bucket_name` - /// - `bucket` - /// - `bucket_name` - Bucket, - - /// Application credentials path - /// - /// See [`GoogleCloudStorageBuilder::with_application_credentials`]. - ApplicationCredentials, - - /// Client options - Client(ClientConfigKey), -} - -impl AsRef for GoogleConfigKey { - fn as_ref(&self) -> &str { - match self { - Self::ServiceAccount => "google_service_account", - Self::ServiceAccountKey => "google_service_account_key", - Self::Bucket => "google_bucket", - Self::ApplicationCredentials => "google_application_credentials", - Self::Client(key) => key.as_ref(), - } - } -} - -impl FromStr for GoogleConfigKey { - type Err = super::Error; - - fn from_str(s: &str) -> Result { - match s { - "google_service_account" - | "service_account" - | "google_service_account_path" - | "service_account_path" => Ok(Self::ServiceAccount), - "google_service_account_key" | "service_account_key" => { - Ok(Self::ServiceAccountKey) - } - "google_bucket" | "google_bucket_name" | "bucket" | "bucket_name" => { - Ok(Self::Bucket) - } - "google_application_credentials" => Ok(Self::ApplicationCredentials), - _ => match s.parse() { - Ok(key) => Ok(Self::Client(key)), - Err(_) => Err(Error::UnknownConfigurationKey { key: s.into() }.into()), - }, - } - } -} - -impl Default for GoogleCloudStorageBuilder { - fn default() -> Self { - Self { - bucket_name: None, - service_account_path: None, - service_account_key: None, - application_credentials_path: None, - retry_config: Default::default(), - client_options: ClientOptions::new().with_allow_http(true), - url: None, - credentials: None, - } - } -} - -impl GoogleCloudStorageBuilder { - /// Create a new [`GoogleCloudStorageBuilder`] with default values. - pub fn new() -> Self { - Default::default() - } - - /// Create an instance of [`GoogleCloudStorageBuilder`] with values pre-populated from environment variables. - /// - /// Variables extracted from environment: - /// * GOOGLE_SERVICE_ACCOUNT: location of service account file - /// * GOOGLE_SERVICE_ACCOUNT_PATH: (alias) location of service account file - /// * SERVICE_ACCOUNT: (alias) location of service account file - /// * GOOGLE_SERVICE_ACCOUNT_KEY: JSON serialized service account key - /// * GOOGLE_BUCKET: bucket name - /// * GOOGLE_BUCKET_NAME: (alias) bucket name - /// - /// # Example - /// ``` - /// use object_store::gcp::GoogleCloudStorageBuilder; - /// - /// let gcs = GoogleCloudStorageBuilder::from_env() - /// .with_bucket_name("foo") - /// .build(); - /// ``` - pub fn from_env() -> Self { - let mut builder = Self::default(); - - if let Ok(service_account_path) = std::env::var("SERVICE_ACCOUNT") { - builder.service_account_path = Some(service_account_path); - } - - for (os_key, os_value) in std::env::vars_os() { - if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { - if key.starts_with("GOOGLE_") { - if let Ok(config_key) = key.to_ascii_lowercase().parse() { - builder = builder.with_config(config_key, value); - } - } - } - } - - builder - } - - /// Parse available connection info form a well-known storage URL. - /// - /// The supported url schemes are: - /// - /// - `gs:///` - /// - /// Note: Settings derived from the URL will override any others set on this builder - /// - /// # Example - /// ``` - /// use object_store::gcp::GoogleCloudStorageBuilder; - /// - /// let gcs = GoogleCloudStorageBuilder::from_env() - /// .with_url("gs://bucket/path") - /// .build(); - /// ``` - pub fn with_url(mut self, url: impl Into) -> Self { - self.url = Some(url.into()); - self - } - - /// Set an option on the builder via a key - value pair. - pub fn with_config(mut self, key: GoogleConfigKey, value: impl Into) -> Self { - match key { - GoogleConfigKey::ServiceAccount => { - self.service_account_path = Some(value.into()) - } - GoogleConfigKey::ServiceAccountKey => { - self.service_account_key = Some(value.into()) - } - GoogleConfigKey::Bucket => self.bucket_name = Some(value.into()), - GoogleConfigKey::ApplicationCredentials => { - self.application_credentials_path = Some(value.into()) - } - GoogleConfigKey::Client(key) => { - self.client_options = self.client_options.with_config(key, value) - } - }; - self - } - - /// Set an option on the builder via a key - value pair. - #[deprecated(note = "Use with_config")] - pub fn try_with_option( - self, - key: impl AsRef, - value: impl Into, - ) -> Result { - Ok(self.with_config(key.as_ref().parse()?, value)) - } - - /// Hydrate builder from key value pairs - #[deprecated(note = "Use with_config")] - #[allow(deprecated)] - pub fn try_with_options< - I: IntoIterator, impl Into)>, - >( - mut self, - options: I, - ) -> Result { - for (key, value) in options { - self = self.try_with_option(key, value)?; - } - Ok(self) - } - - /// Get config value via a [`GoogleConfigKey`]. - /// - /// # Example - /// ``` - /// use object_store::gcp::{GoogleCloudStorageBuilder, GoogleConfigKey}; - /// - /// let builder = GoogleCloudStorageBuilder::from_env() - /// .with_service_account_key("foo"); - /// let service_account_key = builder.get_config_value(&GoogleConfigKey::ServiceAccountKey).unwrap_or_default(); - /// assert_eq!("foo", &service_account_key); - /// ``` - pub fn get_config_value(&self, key: &GoogleConfigKey) -> Option { - match key { - GoogleConfigKey::ServiceAccount => self.service_account_path.clone(), - GoogleConfigKey::ServiceAccountKey => self.service_account_key.clone(), - GoogleConfigKey::Bucket => self.bucket_name.clone(), - GoogleConfigKey::ApplicationCredentials => { - self.application_credentials_path.clone() - } - GoogleConfigKey::Client(key) => self.client_options.get_config_value(key), - } - } - - /// Sets properties on this builder based on a URL - /// - /// This is a separate member function to allow fallible computation to - /// be deferred until [`Self::build`] which in turn allows deriving [`Clone`] - fn parse_url(&mut self, url: &str) -> Result<()> { - let parsed = Url::parse(url).context(UnableToParseUrlSnafu { url })?; - let host = parsed.host_str().context(UrlNotRecognisedSnafu { url })?; - - let validate = |s: &str| match s.contains('.') { - true => Err(UrlNotRecognisedSnafu { url }.build()), - false => Ok(s.to_string()), - }; - - match parsed.scheme() { - "gs" => self.bucket_name = Some(validate(host)?), - scheme => return Err(UnknownUrlSchemeSnafu { scheme }.build().into()), - } - Ok(()) - } - - /// Set the bucket name (required) - pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { - self.bucket_name = Some(bucket_name.into()); - self - } - - /// Set the path to the service account file. - /// - /// This or [`GoogleCloudStorageBuilder::with_service_account_key`] must be - /// set. - /// - /// Example `"/tmp/gcs.json"`. - /// - /// Example contents of `gcs.json`: - /// - /// ```json - /// { - /// "gcs_base_url": "https://localhost:4443", - /// "disable_oauth": true, - /// "client_email": "", - /// "private_key": "" - /// } - /// ``` - pub fn with_service_account_path( - mut self, - service_account_path: impl Into, - ) -> Self { - self.service_account_path = Some(service_account_path.into()); - self - } - - /// Set the service account key. The service account must be in the JSON - /// format. - /// - /// This or [`GoogleCloudStorageBuilder::with_service_account_path`] must be - /// set. - pub fn with_service_account_key( - mut self, - service_account: impl Into, - ) -> Self { - self.service_account_key = Some(service_account.into()); - self - } - - /// Set the path to the application credentials file. - /// - /// - pub fn with_application_credentials( - mut self, - application_credentials_path: impl Into, - ) -> Self { - self.application_credentials_path = Some(application_credentials_path.into()); - self - } - - /// Set the credential provider overriding any other options - pub fn with_credentials(mut self, credentials: GcpCredentialProvider) -> Self { - self.credentials = Some(credentials); - self - } - - /// Set the retry configuration - pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { - self.retry_config = retry_config; - self - } - - /// Set the proxy_url to be used by the underlying client - pub fn with_proxy_url(mut self, proxy_url: impl Into) -> Self { - self.client_options = self.client_options.with_proxy_url(proxy_url); - self - } - - /// Set a trusted proxy CA certificate - pub fn with_proxy_ca_certificate( - mut self, - proxy_ca_certificate: impl Into, - ) -> Self { - self.client_options = self - .client_options - .with_proxy_ca_certificate(proxy_ca_certificate); - self - } - - /// Set a list of hosts to exclude from proxy connections - pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { - self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); - self - } - - /// Sets the client options, overriding any already set - pub fn with_client_options(mut self, options: ClientOptions) -> Self { - self.client_options = options; - self - } - - /// Configure a connection to Google Cloud Storage, returning a - /// new [`GoogleCloudStorage`] and consuming `self` - pub fn build(mut self) -> Result { - if let Some(url) = self.url.take() { - self.parse_url(&url)?; - } - - let bucket_name = self.bucket_name.ok_or(Error::MissingBucketName {})?; - - let client = self.client_options.client()?; - - // First try to initialize from the service account information. - let service_account_credentials = - match (self.service_account_path, self.service_account_key) { - (Some(path), None) => Some( - ServiceAccountCredentials::from_file(path) - .context(CredentialSnafu)?, - ), - (None, Some(key)) => Some( - ServiceAccountCredentials::from_key(&key).context(CredentialSnafu)?, - ), - (None, None) => None, - (Some(_), Some(_)) => { - return Err(Error::ServiceAccountPathAndKeyProvided.into()) - } - }; - - // Then try to initialize from the application credentials file, or the environment. - let application_default_credentials = ApplicationDefaultCredentials::read( - self.application_credentials_path.as_deref(), - )?; - - let disable_oauth = service_account_credentials - .as_ref() - .map(|c| c.disable_oauth) - .unwrap_or(false); - - let gcs_base_url: String = service_account_credentials - .as_ref() - .and_then(|c| c.gcs_base_url.clone()) - .unwrap_or_else(|| DEFAULT_GCS_BASE_URL.to_string()); - - let credentials = if let Some(credentials) = self.credentials { - credentials - } else if disable_oauth { - Arc::new(StaticCredentialProvider::new(GcpCredential { - bearer: "".to_string(), - })) as _ - } else if let Some(credentials) = service_account_credentials { - Arc::new(TokenCredentialProvider::new( - credentials.token_provider()?, - self.client_options.client()?, - self.retry_config.clone(), - )) as _ - } else if let Some(credentials) = application_default_credentials { - match credentials { - ApplicationDefaultCredentials::AuthorizedUser(token) => { - Arc::new(TokenCredentialProvider::new( - token, - self.client_options.client()?, - self.retry_config.clone(), - )) as _ - } - ApplicationDefaultCredentials::ServiceAccount(token) => { - Arc::new(TokenCredentialProvider::new( - token.token_provider()?, - self.client_options.client()?, - self.retry_config.clone(), - )) as _ - } - } - } else { - Arc::new(TokenCredentialProvider::new( - InstanceCredentialProvider::default(), - self.client_options.metadata_client()?, - self.retry_config.clone(), - )) as _ - }; - - let encoded_bucket_name = - percent_encode(bucket_name.as_bytes(), NON_ALPHANUMERIC).to_string(); - - Ok(GoogleCloudStorage { - client: Arc::new(GoogleCloudStorageClient { - client, - base_url: gcs_base_url, - credentials, - bucket_name, - bucket_name_encoded: encoded_bucket_name, - retry_config: self.retry_config, - client_options: self.client_options, - max_list_results: None, - }), - }) - } -} - #[cfg(test)] mod test { + use bytes::Bytes; - use std::collections::HashMap; - use std::io::Write; - use tempfile::NamedTempFile; + use credential::DEFAULT_GCS_BASE_URL; use crate::tests::*; use super::*; - const FAKE_KEY: &str = r#"{"private_key": "private_key", "private_key_id": "private_key_id", "client_email":"client_email", "disable_oauth":true}"#; const NON_EXISTENT_NAME: &str = "nonexistentname"; #[tokio::test] @@ -1104,7 +195,7 @@ mod test { list_uses_directories_correctly(&integration).await; list_with_delimiter(&integration).await; rename_and_copy(&integration).await; - if integration.client.base_url == DEFAULT_GCS_BASE_URL { + if integration.client.config().base_url == DEFAULT_GCS_BASE_URL { // Fake GCS server doesn't currently honor ifGenerationMatch // https://github.com/fsouza/fake-gcs-server/issues/994 copy_if_not_exists(&integration).await; @@ -1198,140 +289,4 @@ mod test { err ) } - - #[tokio::test] - async fn gcs_test_proxy_url() { - let mut tfile = NamedTempFile::new().unwrap(); - write!(tfile, "{FAKE_KEY}").unwrap(); - let service_account_path = tfile.path(); - let gcs = GoogleCloudStorageBuilder::new() - .with_service_account_path(service_account_path.to_str().unwrap()) - .with_bucket_name("foo") - .with_proxy_url("https://example.com") - .build(); - assert!(dbg!(gcs).is_ok()); - - let err = GoogleCloudStorageBuilder::new() - .with_service_account_path(service_account_path.to_str().unwrap()) - .with_bucket_name("foo") - .with_proxy_url("asdf://example.com") - .build() - .unwrap_err() - .to_string(); - - assert_eq!( - "Generic HTTP client error: builder error: unknown proxy scheme", - err - ); - } - - #[test] - fn gcs_test_urls() { - let mut builder = GoogleCloudStorageBuilder::new(); - builder.parse_url("gs://bucket/path").unwrap(); - assert_eq!(builder.bucket_name, Some("bucket".to_string())); - - let err_cases = ["mailto://bucket/path", "gs://bucket.mydomain/path"]; - let mut builder = GoogleCloudStorageBuilder::new(); - for case in err_cases { - builder.parse_url(case).unwrap_err(); - } - } - - #[test] - fn gcs_test_service_account_key_only() { - let _ = GoogleCloudStorageBuilder::new() - .with_service_account_key(FAKE_KEY) - .with_bucket_name("foo") - .build() - .unwrap(); - } - - #[test] - fn gcs_test_service_account_key_and_path() { - let mut tfile = NamedTempFile::new().unwrap(); - write!(tfile, "{FAKE_KEY}").unwrap(); - let _ = GoogleCloudStorageBuilder::new() - .with_service_account_key(FAKE_KEY) - .with_service_account_path(tfile.path().to_str().unwrap()) - .with_bucket_name("foo") - .build() - .unwrap_err(); - } - - #[test] - fn gcs_test_config_from_map() { - let google_service_account = "object_store:fake_service_account".to_string(); - let google_bucket_name = "object_store:fake_bucket".to_string(); - let options = HashMap::from([ - ("google_service_account", google_service_account.clone()), - ("google_bucket_name", google_bucket_name.clone()), - ]); - - let builder = options - .iter() - .fold(GoogleCloudStorageBuilder::new(), |builder, (key, value)| { - builder.with_config(key.parse().unwrap(), value) - }); - - assert_eq!( - builder.service_account_path.unwrap(), - google_service_account.as_str() - ); - assert_eq!(builder.bucket_name.unwrap(), google_bucket_name.as_str()); - } - - #[test] - fn gcs_test_config_get_value() { - let google_service_account = "object_store:fake_service_account".to_string(); - let google_bucket_name = "object_store:fake_bucket".to_string(); - let builder = GoogleCloudStorageBuilder::new() - .with_config(GoogleConfigKey::ServiceAccount, &google_service_account) - .with_config(GoogleConfigKey::Bucket, &google_bucket_name); - - assert_eq!( - builder - .get_config_value(&GoogleConfigKey::ServiceAccount) - .unwrap(), - google_service_account - ); - assert_eq!( - builder.get_config_value(&GoogleConfigKey::Bucket).unwrap(), - google_bucket_name - ); - } - - #[test] - fn gcs_test_config_aliases() { - // Service account path - for alias in [ - "google_service_account", - "service_account", - "google_service_account_path", - "service_account_path", - ] { - let builder = GoogleCloudStorageBuilder::new() - .with_config(alias.parse().unwrap(), "/fake/path.json"); - assert_eq!("/fake/path.json", builder.service_account_path.unwrap()); - } - - // Service account key - for alias in ["google_service_account_key", "service_account_key"] { - let builder = GoogleCloudStorageBuilder::new() - .with_config(alias.parse().unwrap(), FAKE_KEY); - assert_eq!(FAKE_KEY, builder.service_account_key.unwrap()); - } - - // Bucket name - for alias in [ - "google_bucket", - "google_bucket_name", - "bucket", - "bucket_name", - ] { - let builder = GoogleCloudStorageBuilder::new() - .with_config(alias.parse().unwrap(), "fake_bucket"); - assert_eq!("fake_bucket", builder.bucket_name.unwrap()); - } - } } From 7e134f4d277c0b62c27529fc15a4739de3ad0afd Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Thu, 19 Oct 2023 17:19:40 +0100 Subject: [PATCH 22/25] Use rustfmt default line width (#4960) * Use rustfmt default line width * Further format --- arrow-arith/src/aggregate.rs | 61 +- arrow-arith/src/arithmetic.rs | 49 +- arrow-arith/src/arity.rs | 22 +- arrow-arith/src/bitwise.rs | 21 +- arrow-arith/src/boolean.rs | 39 +- arrow-arith/src/numeric.rs | 35 +- arrow-arith/src/temporal.rs | 108 +- arrow-array/src/arithmetic.rs | 5 +- arrow-array/src/array/binary_array.rs | 42 +- arrow-array/src/array/boolean_array.rs | 13 +- arrow-array/src/array/byte_array.rs | 18 +- arrow-array/src/array/dictionary_array.rs | 57 +- .../src/array/fixed_size_binary_array.rs | 36 +- .../src/array/fixed_size_list_array.rs | 54 +- arrow-array/src/array/list_array.rs | 80 +- arrow-array/src/array/map_array.rs | 42 +- arrow-array/src/array/mod.rs | 93 +- arrow-array/src/array/primitive_array.rs | 144 +- arrow-array/src/array/run_array.rs | 49 +- arrow-array/src/array/string_array.rs | 46 +- arrow-array/src/array/struct_array.rs | 30 +- arrow-array/src/array/union_array.rs | 16 +- arrow-array/src/builder/boolean_builder.rs | 9 +- arrow-array/src/builder/buffer_builder.rs | 10 +- .../src/builder/fixed_size_binary_builder.rs | 21 +- .../src/builder/generic_byte_run_builder.rs | 23 +- .../src/builder/generic_bytes_builder.rs | 11 +- .../generic_bytes_dictionary_builder.rs | 61 +- arrow-array/src/builder/map_builder.rs | 13 +- arrow-array/src/builder/primitive_builder.rs | 25 +- .../builder/primitive_dictionary_builder.rs | 18 +- arrow-array/src/builder/struct_builder.rs | 35 +- arrow-array/src/builder/union_builder.rs | 13 +- arrow-array/src/cast.rs | 23 +- arrow-array/src/delta.rs | 10 +- arrow-array/src/iterator.rs | 10 +- arrow-array/src/lib.rs | 3 +- arrow-array/src/numeric.rs | 56 +- arrow-array/src/record_batch.rs | 70 +- arrow-array/src/run_iterator.rs | 18 +- arrow-array/src/temporal_conversions.rs | 13 +- arrow-array/src/timezone.rs | 14 +- arrow-array/src/types.rs | 32 +- arrow-avro/src/reader/header.rs | 4 +- arrow-avro/src/reader/mod.rs | 4 +- arrow-avro/src/schema.rs | 4 +- arrow-buffer/src/bigint/div.rs | 16 +- arrow-buffer/src/bigint/mod.rs | 31 +- arrow-buffer/src/buffer/boolean.rs | 27 +- arrow-buffer/src/buffer/immutable.rs | 8 +- arrow-buffer/src/buffer/mutable.rs | 11 +- arrow-buffer/src/buffer/null.rs | 5 +- arrow-buffer/src/buffer/offset.rs | 3 +- arrow-buffer/src/buffer/ops.rs | 6 +- arrow-buffer/src/buffer/run.rs | 6 +- arrow-buffer/src/buffer/scalar.rs | 12 +- arrow-buffer/src/builder/boolean.rs | 6 +- arrow-buffer/src/bytes.rs | 6 +- arrow-buffer/src/util/bit_chunk_iterator.rs | 35 +- arrow-buffer/src/util/bit_iterator.rs | 4 +- arrow-buffer/src/util/bit_mask.rs | 38 +- arrow-cast/src/cast.rs | 1170 ++++++----------- arrow-cast/src/display.rs | 62 +- arrow-cast/src/parse.rs | 215 ++- arrow-cast/src/pretty.rs | 50 +- arrow-csv/src/reader/mod.rs | 143 +- arrow-csv/src/reader/records.rs | 19 +- arrow-csv/src/writer.rs | 45 +- arrow-data/src/data.rs | 147 +-- arrow-data/src/decimal.rs | 617 +++++---- arrow-data/src/equal/boolean.rs | 11 +- arrow-data/src/equal/fixed_binary.rs | 22 +- arrow-data/src/equal/mod.rs | 52 +- arrow-data/src/equal/primitive.rs | 22 +- arrow-data/src/equal/union.rs | 5 +- arrow-data/src/equal/utils.rs | 6 +- arrow-data/src/transform/list.rs | 15 +- arrow-data/src/transform/mod.rs | 60 +- arrow-data/src/transform/primitive.rs | 5 +- arrow-data/src/transform/utils.rs | 4 +- arrow-data/src/transform/variable_size.rs | 15 +- arrow-flight/examples/flight_sql_server.rs | 40 +- arrow-flight/examples/server.rs | 6 +- arrow-flight/src/client.rs | 19 +- arrow-flight/src/decode.rs | 33 +- arrow-flight/src/encode.rs | 128 +- arrow-flight/src/lib.rs | 5 +- arrow-flight/src/sql/client.rs | 77 +- arrow-flight/src/sql/metadata/db_schemas.rs | 6 +- arrow-flight/src/sql/metadata/sql_info.rs | 19 +- arrow-flight/src/sql/metadata/tables.rs | 8 +- arrow-flight/src/sql/metadata/xdbc_info.rs | 10 +- arrow-flight/src/sql/mod.rs | 5 +- arrow-flight/src/sql/server.rs | 117 +- arrow-flight/src/trailers.rs | 9 +- arrow-flight/src/utils.rs | 23 +- arrow-flight/tests/client.rs | 14 +- arrow-flight/tests/common/server.rs | 26 +- arrow-flight/tests/common/trailers_layer.rs | 4 +- arrow-flight/tests/encode_decode.rs | 41 +- arrow-flight/tests/flight_sql_client_cli.rs | 17 +- arrow-integration-test/src/datatype.rs | 20 +- arrow-integration-test/src/field.rs | 106 +- arrow-integration-test/src/lib.rs | 141 +- arrow-integration-test/src/schema.rs | 27 +- .../src/bin/arrow-json-integration-test.rs | 11 +- .../src/bin/flight-test-integration-client.rs | 3 +- .../auth_basic_proto.rs | 10 +- .../integration_test.rs | 38 +- .../src/flight_client_scenarios/middleware.rs | 3 +- .../auth_basic_proto.rs | 28 +- .../integration_test.rs | 34 +- .../src/flight_server_scenarios/middleware.rs | 9 +- arrow-integration-testing/src/lib.rs | 4 +- arrow-integration-testing/tests/ipc_reader.rs | 10 +- arrow-integration-testing/tests/ipc_writer.rs | 37 +- arrow-ipc/src/compression.rs | 26 +- arrow-ipc/src/convert.rs | 84 +- arrow-ipc/src/gen/File.rs | 47 +- arrow-ipc/src/gen/Message.rs | 148 ++- arrow-ipc/src/gen/Schema.rs | 339 ++--- arrow-ipc/src/gen/SparseTensor.rs | 244 ++-- arrow-ipc/src/gen/Tensor.rs | 188 ++- arrow-ipc/src/reader.rs | 151 +-- arrow-ipc/src/writer.rs | 222 ++-- arrow-json/src/reader/list_array.rs | 5 +- arrow-json/src/reader/map_array.rs | 5 +- arrow-json/src/reader/mod.rs | 28 +- arrow-json/src/reader/primitive_array.rs | 11 +- arrow-json/src/reader/schema.rs | 38 +- arrow-json/src/reader/serializer.rs | 18 +- arrow-json/src/reader/string_array.rs | 3 +- arrow-json/src/reader/struct_array.rs | 18 +- arrow-json/src/reader/tape.rs | 26 +- arrow-json/src/reader/timestamp_array.rs | 22 +- arrow-json/src/writer.rs | 96 +- arrow-ord/src/cmp.rs | 40 +- arrow-ord/src/comparison.rs | 621 +++------ arrow-ord/src/ord.rs | 10 +- arrow-ord/src/partition.rs | 4 +- arrow-ord/src/rank.rs | 10 +- arrow-ord/src/sort.rs | 138 +- arrow-row/src/lib.rs | 87 +- arrow-row/src/list.rs | 3 +- arrow-row/src/variable.rs | 4 +- arrow-schema/src/datatype.rs | 76 +- arrow-schema/src/ffi.rs | 28 +- arrow-schema/src/field.rs | 22 +- arrow-schema/src/schema.rs | 38 +- arrow-select/src/concat.rs | 172 +-- arrow-select/src/dictionary.rs | 38 +- arrow-select/src/filter.rs | 79 +- arrow-select/src/interleave.rs | 36 +- arrow-select/src/nullif.rs | 20 +- arrow-select/src/take.rs | 147 +-- arrow-string/src/concat_elements.rs | 18 +- arrow-string/src/length.rs | 16 +- arrow-string/src/like.rs | 27 +- arrow-string/src/predicate.rs | 19 +- arrow-string/src/regexp.rs | 47 +- arrow-string/src/substring.rs | 24 +- arrow/benches/array_data_validate.rs | 3 +- arrow/benches/array_from_vec.rs | 4 +- arrow/benches/bitwise_kernel.rs | 12 +- arrow/benches/buffer_bit_ops.rs | 12 +- arrow/benches/buffer_create.rs | 13 +- arrow/benches/builder.rs | 5 +- arrow/benches/csv_reader.rs | 15 +- arrow/benches/csv_writer.rs | 6 +- arrow/benches/decimal_validate.rs | 4 +- arrow/benches/filter_kernels.rs | 3 +- arrow/benches/interleave_kernels.rs | 3 +- arrow/benches/lexsort.rs | 8 +- arrow/benches/primitive_run_accessor.rs | 7 +- arrow/benches/primitive_run_take.rs | 4 +- arrow/benches/row_format.rs | 24 +- arrow/benches/sort_kernel.rs | 6 +- arrow/benches/string_run_builder.rs | 4 +- arrow/benches/string_run_iterator.rs | 4 +- arrow/benches/take_kernels.rs | 4 +- arrow/examples/builders.rs | 9 +- arrow/examples/dynamic_types.rs | 3 +- arrow/src/array/ffi.rs | 19 +- arrow/src/compute/kernels.rs | 4 +- arrow/src/datatypes/mod.rs | 4 +- arrow/src/ffi.rs | 60 +- arrow/src/ffi_stream.rs | 22 +- arrow/src/lib.rs | 3 +- arrow/src/pyarrow.rs | 16 +- arrow/src/tensor.rs | 30 +- arrow/src/util/bench_util.rs | 22 +- arrow/src/util/data_gen.rs | 39 +- arrow/tests/arithmetic.rs | 4 +- arrow/tests/array_cast.rs | 76 +- arrow/tests/array_equal.rs | 114 +- arrow/tests/array_transform.rs | 105 +- arrow/tests/array_validation.rs | 78 +- arrow/tests/csv.rs | 6 +- object_store/src/aws/builder.rs | 67 +- object_store/src/aws/client.rs | 56 +- object_store/src/aws/credential.rs | 39 +- object_store/src/aws/mod.rs | 34 +- object_store/src/aws/resolve.rs | 5 +- object_store/src/azure/builder.rs | 85 +- object_store/src/azure/client.rs | 32 +- object_store/src/azure/credential.rs | 60 +- object_store/src/azure/mod.rs | 16 +- object_store/src/buffered.rs | 48 +- object_store/src/chunked.rs | 9 +- object_store/src/client/backoff.rs | 12 +- object_store/src/client/get.rs | 13 +- object_store/src/client/mock_server.rs | 3 +- object_store/src/client/mod.rs | 37 +- object_store/src/client/retry.rs | 7 +- object_store/src/delimited.rs | 3 +- object_store/src/gcp/builder.rs | 87 +- object_store/src/gcp/client.rs | 20 +- object_store/src/gcp/credential.rs | 12 +- object_store/src/gcp/mod.rs | 9 +- object_store/src/http/client.rs | 15 +- object_store/src/http/mod.rs | 10 +- object_store/src/lib.rs | 70 +- object_store/src/limit.rs | 26 +- object_store/src/local.rs | 97 +- object_store/src/memory.rs | 20 +- object_store/src/parse.rs | 10 +- object_store/src/path/mod.rs | 30 +- object_store/src/prefix.rs | 20 +- object_store/src/signer.rs | 7 +- object_store/src/throttle.rs | 52 +- object_store/src/util.rs | 19 +- object_store/tests/get_range_file.rs | 11 +- parquet/benches/arrow_reader.rs | 157 +-- parquet/benches/arrow_writer.rs | 5 +- parquet/benches/compression.rs | 9 +- parquet/examples/read_with_rowgroup.rs | 17 +- parquet/src/arrow/arrow_reader/mod.rs | 276 ++-- parquet/src/arrow/arrow_reader/selection.rs | 26 +- parquet/src/arrow/arrow_writer/byte_array.rs | 37 +- parquet/src/arrow/arrow_writer/levels.rs | 130 +- parquet/src/arrow/arrow_writer/mod.rs | 152 +-- parquet/src/arrow/async_reader/metadata.rs | 20 +- parquet/src/arrow/async_reader/mod.rs | 144 +- parquet/src/arrow/async_reader/store.rs | 9 +- parquet/src/arrow/async_writer/mod.rs | 19 +- parquet/src/arrow/buffer/bit_util.rs | 3 +- parquet/src/arrow/buffer/dictionary_buffer.rs | 29 +- parquet/src/arrow/buffer/offset_buffer.rs | 10 +- parquet/src/arrow/decoder/delta_byte_array.rs | 10 +- parquet/src/arrow/decoder/dictionary_index.rs | 15 +- parquet/src/arrow/mod.rs | 10 +- parquet/src/arrow/record_reader/buffer.rs | 7 +- .../arrow/record_reader/definition_levels.rs | 35 +- parquet/src/arrow/record_reader/mod.rs | 17 +- parquet/src/basic.rs | 74 +- parquet/src/bin/parquet-fromcsv.rs | 44 +- parquet/src/bin/parquet-index.rs | 4 +- parquet/src/bin/parquet-layout.rs | 5 +- parquet/src/bin/parquet-read.rs | 3 +- parquet/src/bin/parquet-rewrite.rs | 46 +- parquet/src/bin/parquet-rowcount.rs | 3 +- parquet/src/bin/parquet-show-bloom-filter.rs | 4 +- parquet/src/bloom_filter/mod.rs | 25 +- parquet/src/column/reader.rs | 92 +- parquet/src/column/reader/decoder.rs | 45 +- parquet/src/column/writer/encoder.rs | 7 +- parquet/src/column/writer/mod.rs | 215 +-- parquet/src/data_type.rs | 84 +- parquet/src/file/footer.rs | 3 +- parquet/src/file/metadata.rs | 30 +- parquet/src/file/page_encoding_stats.rs | 4 +- parquet/src/file/page_index/index_reader.rs | 18 +- parquet/src/file/properties.rs | 29 +- parquet/src/file/reader.rs | 43 +- parquet/src/file/serialized_reader.rs | 69 +- parquet/src/file/writer.rs | 105 +- parquet/src/record/api.rs | 43 +- parquet/src/record/mod.rs | 3 +- parquet/src/record/reader.rs | 93 +- parquet/src/record/triplet.rs | 62 +- parquet/src/schema/parser.rs | 406 +++--- parquet/src/schema/printer.rs | 66 +- parquet/src/schema/types.rs | 38 +- parquet/src/schema/visitor.rs | 24 +- parquet/src/thrift.rs | 9 +- parquet/tests/arrow_writer_layout.rs | 13 +- parquet_derive/src/parquet_field.rs | 30 +- parquet_derive_test/src/lib.rs | 3 +- rustfmt.toml | 6 - 289 files changed, 4941 insertions(+), 8730 deletions(-) diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index 04417c666c85..0dabaa50f5f6 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -207,15 +207,15 @@ where } let iter = ArrayIter::new(array); - let sum = - iter.into_iter() - .try_fold(T::default_value(), |accumulator, value| { - if let Some(value) = value { - accumulator.add_checked(value) - } else { - Ok(accumulator) - } - })?; + let sum = iter + .into_iter() + .try_fold(T::default_value(), |accumulator, value| { + if let Some(value) = value { + accumulator.add_checked(value) + } else { + Ok(accumulator) + } + })?; Ok(Some(sum)) } @@ -230,11 +230,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeType, { - min_max_array_helper::( - array, - |a, b| (is_nan(*a) & !is_nan(*b)) || a > b, - min, - ) + min_max_array_helper::(array, |a, b| (is_nan(*a) & !is_nan(*b)) || a > b, min) } /// Returns the max of values in the array of `ArrowNumericType` type, or dictionary @@ -244,11 +240,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeType, { - min_max_array_helper::( - array, - |a, b| (!is_nan(*a) & is_nan(*b)) || a < b, - max, - ) + min_max_array_helper::(array, |a, b| (!is_nan(*a) & is_nan(*b)) || a < b, max) } fn min_max_array_helper, F, M>( @@ -501,10 +493,7 @@ mod simd { fn init_accumulator_chunk() -> Self::SimdAccumulator; /// Updates the accumulator with the values of one chunk - fn accumulate_chunk_non_null( - accumulator: &mut Self::SimdAccumulator, - chunk: T::Simd, - ); + fn accumulate_chunk_non_null(accumulator: &mut Self::SimdAccumulator, chunk: T::Simd); /// Updates the accumulator with the values of one chunk according to the given vector mask fn accumulate_chunk_nullable( @@ -602,10 +591,7 @@ mod simd { (T::init(T::default_value()), T::mask_init(false)) } - fn accumulate_chunk_non_null( - accumulator: &mut Self::SimdAccumulator, - chunk: T::Simd, - ) { + fn accumulate_chunk_non_null(accumulator: &mut Self::SimdAccumulator, chunk: T::Simd) { let acc_is_nan = !T::eq(accumulator.0, accumulator.0); let is_lt = acc_is_nan | T::lt(chunk, accumulator.0); let first_or_lt = !accumulator.1 | is_lt; @@ -627,10 +613,7 @@ mod simd { accumulator.1 |= vecmask; } - fn accumulate_scalar( - accumulator: &mut Self::ScalarAccumulator, - value: T::Native, - ) { + fn accumulate_scalar(accumulator: &mut Self::ScalarAccumulator, value: T::Native) { if !accumulator.1 { accumulator.0 = value; } else { @@ -690,10 +673,7 @@ mod simd { (T::init(T::default_value()), T::mask_init(false)) } - fn accumulate_chunk_non_null( - accumulator: &mut Self::SimdAccumulator, - chunk: T::Simd, - ) { + fn accumulate_chunk_non_null(accumulator: &mut Self::SimdAccumulator, chunk: T::Simd) { let chunk_is_nan = !T::eq(chunk, chunk); let is_gt = chunk_is_nan | T::gt(chunk, accumulator.0); let first_or_gt = !accumulator.1 | is_gt; @@ -715,10 +695,7 @@ mod simd { accumulator.1 |= vecmask; } - fn accumulate_scalar( - accumulator: &mut Self::ScalarAccumulator, - value: T::Native, - ) { + fn accumulate_scalar(accumulator: &mut Self::ScalarAccumulator, value: T::Native) { if !accumulator.1 { accumulator.0 = value; } else { @@ -1009,8 +986,7 @@ mod tests { #[test] fn test_primitive_array_bool_or_with_nulls() { - let a = - BooleanArray::from(vec![None, Some(false), Some(false), None, Some(false)]); + let a = BooleanArray::from(vec![None, Some(false), Some(false), None, Some(false)]); assert!(!bool_or(&a).unwrap()); } @@ -1297,8 +1273,7 @@ mod tests { assert_eq!(Some(false), min_boolean(&a)); assert_eq!(Some(true), max_boolean(&a)); - let a = - BooleanArray::from(vec![Some(false), Some(true), None, Some(false), None]); + let a = BooleanArray::from(vec![Some(false), Some(true), None, Some(false), None]); assert_eq!(Some(false), min_boolean(&a)); assert_eq!(Some(true), max_boolean(&a)); } diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs index 8635ce0ddd80..124614d77f97 100644 --- a/arrow-arith/src/arithmetic.rs +++ b/arrow-arith/src/arithmetic.rs @@ -48,8 +48,7 @@ fn get_fixed_point_info( ))); } - let divisor = - i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32); + let divisor = i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32); Ok((precision, product_scale, divisor)) } @@ -78,8 +77,7 @@ pub fn multiply_fixed_point_dyn( let left = left.as_any().downcast_ref::().unwrap(); let right = right.as_any().downcast_ref::().unwrap(); - multiply_fixed_point(left, right, required_scale) - .map(|a| Arc::new(a) as ArrayRef) + multiply_fixed_point(left, right, required_scale).map(|a| Arc::new(a) as ArrayRef) } (_, _) => Err(ArrowError::CastError(format!( "Unsupported data type {}, {}", @@ -113,10 +111,8 @@ pub fn multiply_fixed_point_checked( )?; if required_scale == product_scale { - return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| { - a.mul_checked(b) - })? - .with_precision_and_scale(precision, required_scale); + return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| a.mul_checked(b))? + .with_precision_and_scale(precision, required_scale); } try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| { @@ -213,17 +209,16 @@ mod tests { .unwrap(); let err = mul(&a, &b).unwrap_err(); - assert!(err.to_string().contains( - "Overflow happened on: 123456789000000000000000000 * 10000000000000000000" - )); + assert!(err + .to_string() + .contains("Overflow happened on: 123456789000000000000000000 * 10000000000000000000")); // Allow precision loss. let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); // [1234567890] - let expected = - Decimal128Array::from(vec![12345678900000000000000000000000000000]) - .with_precision_and_scale(38, 28) - .unwrap(); + let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000]) + .with_precision_and_scale(38, 28) + .unwrap(); assert_eq!(&expected, &result); assert_eq!( @@ -233,13 +228,9 @@ mod tests { // Rounding case // [0.000000000000000001, 123456789.555555555555555555, 1.555555555555555555] - let a = Decimal128Array::from(vec![ - 1, - 123456789555555555555555555, - 1555555555555555555, - ]) - .with_precision_and_scale(38, 18) - .unwrap(); + let a = Decimal128Array::from(vec![1, 123456789555555555555555555, 1555555555555555555]) + .with_precision_and_scale(38, 18) + .unwrap(); // [1.555555555555555555, 11.222222222222222222, 0.000000000000000001] let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1]) @@ -311,10 +302,9 @@ mod tests { )); let result = multiply_fixed_point(&a, &b, 28).unwrap(); - let expected = - Decimal128Array::from(vec![62946009661555981610246871926660136960]) - .with_precision_and_scale(38, 28) - .unwrap(); + let expected = Decimal128Array::from(vec![62946009661555981610246871926660136960]) + .with_precision_and_scale(38, 28) + .unwrap(); assert_eq!(&expected, &result); } @@ -338,10 +328,9 @@ mod tests { // Avoid overflow by reducing the scale. let result = multiply_fixed_point(&a, &b, 28).unwrap(); // [1234567890] - let expected = - Decimal128Array::from(vec![12345678900000000000000000000000000000]) - .with_precision_and_scale(38, 28) - .unwrap(); + let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000]) + .with_precision_and_scale(38, 28) + .unwrap(); assert_eq!(&expected, &result); assert_eq!( diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index f3118d104536..ff8b82a5d943 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -49,10 +49,7 @@ where } /// See [`PrimitiveArray::try_unary`] -pub fn try_unary( - array: &PrimitiveArray, - op: F, -) -> Result, ArrowError> +pub fn try_unary(array: &PrimitiveArray, op: F) -> Result, ArrowError> where I: ArrowPrimitiveType, O: ArrowPrimitiveType, @@ -86,10 +83,7 @@ where } /// A helper function that applies a fallible unary function to a dictionary array with primitive value type. -fn try_unary_dict( - array: &DictionaryArray, - op: F, -) -> Result +fn try_unary_dict(array: &DictionaryArray, op: F) -> Result where K: ArrowDictionaryKeyType + ArrowNumericType, T: ArrowPrimitiveType, @@ -299,8 +293,7 @@ where try_binary_no_nulls(len, a, b, op) } else { let nulls = - NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()) - .unwrap(); + NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()).unwrap(); let mut buffer = BufferBuilder::::new(len); buffer.append_n_zeroed(len); @@ -308,8 +301,7 @@ where nulls.try_for_each_valid_idx(|idx| { unsafe { - *slice.get_unchecked_mut(idx) = - op(a.value_unchecked(idx), b.value_unchecked(idx))? + *slice.get_unchecked_mut(idx) = op(a.value_unchecked(idx), b.value_unchecked(idx))? }; Ok::<_, ArrowError>(()) })?; @@ -360,8 +352,7 @@ where try_binary_no_nulls_mut(len, a, b, op) } else { let nulls = - NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()) - .unwrap(); + NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()).unwrap(); let mut builder = a.into_builder()?; @@ -440,8 +431,7 @@ mod tests { #[test] #[allow(deprecated)] fn test_unary_f64_slice() { - let input = - Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]); + let input = Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]); let input_slice = input.slice(1, 4); let result = unary(&input_slice, |n| n.round()); assert_eq!( diff --git a/arrow-arith/src/bitwise.rs b/arrow-arith/src/bitwise.rs index a5dec4638703..c7885952f8ba 100644 --- a/arrow-arith/src/bitwise.rs +++ b/arrow-arith/src/bitwise.rs @@ -212,10 +212,8 @@ mod tests { #[test] fn test_bitwise_shift_left() { let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4), Some(8)]); - let right = - UInt64Array::from(vec![Some(5), Some(10), Some(8), Some(12), Some(u64::MAX)]); - let expected = - UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(0)]); + let right = UInt64Array::from(vec![Some(5), Some(10), Some(8), Some(12), Some(u64::MAX)]); + let expected = UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(0)]); let result = bitwise_shift_left(&left, &right).unwrap(); assert_eq!(expected, result); } @@ -224,18 +222,15 @@ mod tests { fn test_bitwise_shift_left_scalar() { let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4), Some(8)]); let scalar = 2; - let expected = - UInt64Array::from(vec![Some(4), Some(8), None, Some(16), Some(32)]); + let expected = UInt64Array::from(vec![Some(4), Some(8), None, Some(16), Some(32)]); let result = bitwise_shift_left_scalar(&left, scalar).unwrap(); assert_eq!(expected, result); } #[test] fn test_bitwise_shift_right() { - let left = - UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(3)]); - let right = - UInt64Array::from(vec![Some(5), Some(10), Some(8), Some(12), Some(65)]); + let left = UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(3)]); + let right = UInt64Array::from(vec![Some(5), Some(10), Some(8), Some(12), Some(65)]); let expected = UInt64Array::from(vec![Some(1), Some(2), None, Some(4), Some(1)]); let result = bitwise_shift_right(&left, &right).unwrap(); assert_eq!(expected, result); @@ -243,11 +238,9 @@ mod tests { #[test] fn test_bitwise_shift_right_scalar() { - let left = - UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(3)]); + let left = UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(3)]); let scalar = 2; - let expected = - UInt64Array::from(vec![Some(8), Some(512), None, Some(4096), Some(0)]); + let expected = UInt64Array::from(vec![Some(8), Some(512), None, Some(4096), Some(0)]); let result = bitwise_shift_right_scalar(&left, scalar).unwrap(); assert_eq!(expected, result); } diff --git a/arrow-arith/src/boolean.rs b/arrow-arith/src/boolean.rs index 46e5998208f1..269a36d66c2b 100644 --- a/arrow-arith/src/boolean.rs +++ b/arrow-arith/src/boolean.rs @@ -57,10 +57,7 @@ use arrow_schema::ArrowError; /// # Fails /// /// If the operands have different lengths -pub fn and_kleene( - left: &BooleanArray, - right: &BooleanArray, -) -> Result { +pub fn and_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { if left.len() != right.len() { return Err(ArrowError::ComputeError( "Cannot perform bitwise operation on arrays of different length".to_string(), @@ -155,10 +152,7 @@ pub fn and_kleene( /// # Fails /// /// If the operands have different lengths -pub fn or_kleene( - left: &BooleanArray, - right: &BooleanArray, -) -> Result { +pub fn or_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { if left.len() != right.len() { return Err(ArrowError::ComputeError( "Cannot perform bitwise operation on arrays of different length".to_string(), @@ -257,10 +251,7 @@ where /// let and_ab = and(&a, &b).unwrap(); /// assert_eq!(and_ab, BooleanArray::from(vec![Some(false), Some(true), None])); /// ``` -pub fn and( - left: &BooleanArray, - right: &BooleanArray, -) -> Result { +pub fn and(left: &BooleanArray, right: &BooleanArray) -> Result { binary_boolean_kernel(left, right, |a, b| a & b) } @@ -581,8 +572,7 @@ mod tests { let a = a.as_any().downcast_ref::().unwrap(); let c = not(a).unwrap(); - let expected = - BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]); + let expected = BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]); assert_eq!(c, expected); } @@ -631,12 +621,10 @@ mod tests { #[test] fn test_bool_array_and_sliced_same_offset() { let a = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, false, true, - true, + false, false, false, false, false, false, false, false, false, false, true, true, ]); let b = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, true, false, - true, + false, false, false, false, false, false, false, false, false, true, false, true, ]); let a = a.slice(8, 4); @@ -654,12 +642,10 @@ mod tests { #[test] fn test_bool_array_and_sliced_same_offset_mod8() { let a = BooleanArray::from(vec![ - false, false, true, true, false, false, false, false, false, false, false, - false, + false, false, true, true, false, false, false, false, false, false, false, false, ]); let b = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, true, false, - true, + false, false, false, false, false, false, false, false, false, true, false, true, ]); let a = a.slice(0, 4); @@ -677,8 +663,7 @@ mod tests { #[test] fn test_bool_array_and_sliced_offset1() { let a = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, false, true, - true, + false, false, false, false, false, false, false, false, false, false, true, true, ]); let b = BooleanArray::from(vec![false, true, false, true]); @@ -696,8 +681,7 @@ mod tests { fn test_bool_array_and_sliced_offset2() { let a = BooleanArray::from(vec![false, false, true, true]); let b = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, true, false, - true, + false, false, false, false, false, false, false, false, false, true, false, true, ]); let b = b.slice(8, 4); @@ -730,8 +714,7 @@ mod tests { let c = and(a, b).unwrap(); - let expected = - BooleanArray::from(vec![Some(false), Some(false), None, Some(true)]); + let expected = BooleanArray::from(vec![Some(false), Some(false), None, Some(true)]); assert_eq!(expected, c); } diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs index c47731ed5125..b2c87bba5143 100644 --- a/arrow-arith/src/numeric.rs +++ b/arrow-arith/src/numeric.rs @@ -144,13 +144,13 @@ pub fn neg(array: &dyn Array) -> Result { let a = array .as_primitive::() .try_unary::<_, IntervalMonthDayNanoType, ArrowError>(|x| { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(x); - Ok(IntervalMonthDayNanoType::make_value( - months.neg_checked()?, - days.neg_checked()?, - nanos.neg_checked()?, - )) - })?; + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(x); + Ok(IntervalMonthDayNanoType::make_value( + months.neg_checked()?, + days.neg_checked()?, + nanos.neg_checked()?, + )) + })?; Ok(Arc::new(a)) } t => Err(ArrowError::InvalidArgumentError(format!( @@ -201,11 +201,7 @@ impl Op { } /// Dispatch the given `op` to the appropriate specialized kernel -fn arithmetic_op( - op: Op, - lhs: &dyn Datum, - rhs: &dyn Datum, -) -> Result { +fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result { use DataType::*; use IntervalUnit::*; use TimeUnit::*; @@ -675,8 +671,7 @@ fn date_op( (Date64, Op::Sub | Op::SubWrapping, Date64) => { let l = l.as_primitive::(); let r = r.as_primitive::(); - let result = - try_op_ref!(DurationMillisecondType, l, l_s, r, r_s, l.sub_checked(r)); + let result = try_op_ref!(DurationMillisecondType, l, l_s, r, r_s, l.sub_checked(r)); return Ok(result); } _ => {} @@ -800,8 +795,7 @@ fn decimal_op( let mul_pow = result_scale - s1 + s2; // p1 - s1 + s2 + result_scale - let result_precision = - (mul_pow.saturating_add(*p1 as i8) as u8).min(T::MAX_PRECISION); + let result_precision = (mul_pow.saturating_add(*p1 as i8) as u8).min(T::MAX_PRECISION); let (l_mul, r_mul) = match mul_pow.cmp(&0) { Ordering::Greater => ( @@ -1158,7 +1152,10 @@ mod tests { .with_precision_and_scale(3, -1) .unwrap(); let err = add(&a, &b).unwrap_err().to_string(); - assert_eq!(err, "Compute error: Overflow happened on: 10 * 100000000000000000000000000000000000000"); + assert_eq!( + err, + "Compute error: Overflow happened on: 10 * 100000000000000000000000000000000000000" + ); let b = Decimal128Array::from(vec![0]) .with_precision_and_scale(1, 1) @@ -1199,9 +1196,7 @@ mod tests { "1960-01-30T04:23:20Z", ] .into_iter() - .map(|x| { - T::make_value(DateTime::parse_from_rfc3339(x).unwrap().naive_utc()).unwrap() - }) + .map(|x| T::make_value(DateTime::parse_from_rfc3339(x).unwrap().naive_utc()).unwrap()) .collect(); let a = PrimitiveArray::::new(values, None); diff --git a/arrow-arith/src/temporal.rs b/arrow-arith/src/temporal.rs index 7855b6fc6e46..a9c3de5401c1 100644 --- a/arrow-arith/src/temporal.rs +++ b/arrow-arith/src/temporal.rs @@ -23,9 +23,7 @@ use chrono::{DateTime, Datelike, NaiveDateTime, NaiveTime, Offset, Timelike}; use arrow_array::builder::*; use arrow_array::iterator::ArrayIter; -use arrow_array::temporal_conversions::{ - as_datetime, as_datetime_with_timezone, as_time, -}; +use arrow_array::temporal_conversions::{as_datetime, as_datetime_with_timezone, as_time}; use arrow_array::timezone::Tz; use arrow_array::types::*; use arrow_array::*; @@ -209,12 +207,9 @@ where } DataType::Timestamp(_, Some(tz)) => { let iter = ArrayIter::new(array); - extract_component_from_datetime_array::<&PrimitiveArray, T, _>( - iter, - b, - tz, - |t| t.hour() as i32, - ) + extract_component_from_datetime_array::<&PrimitiveArray, T, _>(iter, b, tz, |t| { + t.hour() as i32 + }) } _ => return_compute_error_with!("hour does not support", array.data_type()), } @@ -289,9 +284,7 @@ pub fn num_days_from_monday_dyn(array: &dyn Array) -> Result( - array: &PrimitiveArray, -) -> Result +pub fn num_days_from_monday(array: &PrimitiveArray) -> Result where T: ArrowTemporalType + ArrowNumericType, i64: From, @@ -318,9 +311,7 @@ pub fn num_days_from_sunday_dyn(array: &dyn Array) -> Result( - array: &PrimitiveArray, -) -> Result +pub fn num_days_from_sunday(array: &PrimitiveArray) -> Result where T: ArrowTemporalType + ArrowNumericType, i64: From, @@ -449,11 +440,7 @@ pub fn millisecond_dyn(array: &dyn Array) -> Result { } /// Extracts the time fraction of a given temporal array as an array of integers -fn time_fraction_dyn( - array: &dyn Array, - name: &str, - op: F, -) -> Result +fn time_fraction_dyn(array: &dyn Array, name: &str, op: F) -> Result where F: Fn(NaiveDateTime) -> i32, { @@ -498,14 +485,9 @@ where } DataType::Timestamp(_, Some(tz)) => { let iter = ArrayIter::new(array); - extract_component_from_datetime_array::<_, T, _>(iter, b, tz, |t| { - op(t.naive_local()) - }) + extract_component_from_datetime_array::<_, T, _>(iter, b, tz, |t| op(t.naive_local())) } - _ => return_compute_error_with!( - format!("{name} does not support"), - array.data_type() - ), + _ => return_compute_error_with!(format!("{name} does not support"), array.data_type()), } } @@ -559,8 +541,7 @@ mod tests { #[test] fn test_temporal_array_time64_micro_hour() { - let a: PrimitiveArray = - vec![37800000000, 86339000000].into(); + let a: PrimitiveArray = vec![37800000000, 86339000000].into(); let b = hour(&a).unwrap(); assert_eq!(10, b.value(0)); @@ -623,12 +604,10 @@ mod tests { #[test] fn test_temporal_array_timestamp_quarter_with_timezone() { // 24 * 60 * 60 = 86400 - let a = TimestampSecondArray::from(vec![86400 * 90]) - .with_timezone("+00:00".to_string()); + let a = TimestampSecondArray::from(vec![86400 * 90]).with_timezone("+00:00".to_string()); let b = quarter(&a).unwrap(); assert_eq!(2, b.value(0)); - let a = TimestampSecondArray::from(vec![86400 * 90]) - .with_timezone("-10:00".to_string()); + let a = TimestampSecondArray::from(vec![86400 * 90]).with_timezone("-10:00".to_string()); let b = quarter(&a).unwrap(); assert_eq!(1, b.value(0)); } @@ -659,12 +638,10 @@ mod tests { #[test] fn test_temporal_array_timestamp_month_with_timezone() { // 24 * 60 * 60 = 86400 - let a = TimestampSecondArray::from(vec![86400 * 31]) - .with_timezone("+00:00".to_string()); + let a = TimestampSecondArray::from(vec![86400 * 31]).with_timezone("+00:00".to_string()); let b = month(&a).unwrap(); assert_eq!(2, b.value(0)); - let a = TimestampSecondArray::from(vec![86400 * 31]) - .with_timezone("-10:00".to_string()); + let a = TimestampSecondArray::from(vec![86400 * 31]).with_timezone("-10:00".to_string()); let b = month(&a).unwrap(); assert_eq!(1, b.value(0)); } @@ -672,12 +649,10 @@ mod tests { #[test] fn test_temporal_array_timestamp_day_with_timezone() { // 24 * 60 * 60 = 86400 - let a = - TimestampSecondArray::from(vec![86400]).with_timezone("+00:00".to_string()); + let a = TimestampSecondArray::from(vec![86400]).with_timezone("+00:00".to_string()); let b = day(&a).unwrap(); assert_eq!(2, b.value(0)); - let a = - TimestampSecondArray::from(vec![86400]).with_timezone("-10:00".to_string()); + let a = TimestampSecondArray::from(vec![86400]).with_timezone("-10:00".to_string()); let b = day(&a).unwrap(); assert_eq!(1, b.value(0)); } @@ -857,8 +832,7 @@ mod tests { #[test] fn test_temporal_array_timestamp_second_with_timezone() { - let a = - TimestampSecondArray::from(vec![10, 20]).with_timezone("+00:00".to_string()); + let a = TimestampSecondArray::from(vec![10, 20]).with_timezone("+00:00".to_string()); let b = second(&a).unwrap(); assert_eq!(10, b.value(0)); assert_eq!(20, b.value(1)); @@ -866,8 +840,7 @@ mod tests { #[test] fn test_temporal_array_timestamp_minute_with_timezone() { - let a = - TimestampSecondArray::from(vec![0, 60]).with_timezone("+00:50".to_string()); + let a = TimestampSecondArray::from(vec![0, 60]).with_timezone("+00:50".to_string()); let b = minute(&a).unwrap(); assert_eq!(50, b.value(0)); assert_eq!(51, b.value(1)); @@ -875,48 +848,42 @@ mod tests { #[test] fn test_temporal_array_timestamp_minute_with_negative_timezone() { - let a = - TimestampSecondArray::from(vec![60 * 55]).with_timezone("-00:50".to_string()); + let a = TimestampSecondArray::from(vec![60 * 55]).with_timezone("-00:50".to_string()); let b = minute(&a).unwrap(); assert_eq!(5, b.value(0)); } #[test] fn test_temporal_array_timestamp_hour_with_timezone() { - let a = TimestampSecondArray::from(vec![60 * 60 * 10]) - .with_timezone("+01:00".to_string()); + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+01:00".to_string()); let b = hour(&a).unwrap(); assert_eq!(11, b.value(0)); } #[test] fn test_temporal_array_timestamp_hour_with_timezone_without_colon() { - let a = TimestampSecondArray::from(vec![60 * 60 * 10]) - .with_timezone("+0100".to_string()); + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+0100".to_string()); let b = hour(&a).unwrap(); assert_eq!(11, b.value(0)); } #[test] fn test_temporal_array_timestamp_hour_with_timezone_without_minutes() { - let a = TimestampSecondArray::from(vec![60 * 60 * 10]) - .with_timezone("+01".to_string()); + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+01".to_string()); let b = hour(&a).unwrap(); assert_eq!(11, b.value(0)); } #[test] fn test_temporal_array_timestamp_hour_with_timezone_without_initial_sign() { - let a = TimestampSecondArray::from(vec![60 * 60 * 10]) - .with_timezone("0100".to_string()); + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("0100".to_string()); let err = hour(&a).unwrap_err().to_string(); assert!(err.contains("Invalid timezone"), "{}", err); } #[test] fn test_temporal_array_timestamp_hour_with_timezone_with_only_colon() { - let a = TimestampSecondArray::from(vec![60 * 60 * 10]) - .with_timezone("01:00".to_string()); + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("01:00".to_string()); let err = hour(&a).unwrap_err().to_string(); assert!(err.contains("Invalid timezone"), "{}", err); } @@ -960,10 +927,8 @@ mod tests { let b = hour_dyn(&dict).unwrap(); - let expected_dict = DictionaryArray::new( - keys.clone(), - Arc::new(Int32Array::from(vec![11, 21, 7])), - ); + let expected_dict = + DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![11, 21, 7]))); let expected = Arc::new(expected_dict) as ArrayRef; assert_eq!(&expected, &b); @@ -987,8 +952,7 @@ mod tests { assert_eq!(&expected, &b); assert_eq!(&expected, &b_old); - let b = - time_fraction_dyn(&dict, "nanosecond", |t| t.nanosecond() as i32).unwrap(); + let b = time_fraction_dyn(&dict, "nanosecond", |t| t.nanosecond() as i32).unwrap(); let expected_dict = DictionaryArray::new(keys, Arc::new(Int32Array::from(vec![0, 0, 0, 0, 0]))); @@ -998,8 +962,7 @@ mod tests { #[test] fn test_year_dictionary_array() { - let a: PrimitiveArray = - vec![Some(1514764800000), Some(1550636625000)].into(); + let a: PrimitiveArray = vec![Some(1514764800000), Some(1550636625000)].into(); let keys = Int8Array::from_iter_values([0_i8, 1, 1, 0]); let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); @@ -1018,24 +981,20 @@ mod tests { fn test_quarter_month_dictionary_array() { //1514764800000 -> 2018-01-01 //1566275025000 -> 2019-08-20 - let a: PrimitiveArray = - vec![Some(1514764800000), Some(1566275025000)].into(); + let a: PrimitiveArray = vec![Some(1514764800000), Some(1566275025000)].into(); let keys = Int8Array::from_iter_values([0_i8, 1, 1, 0]); let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); let b = quarter_dyn(&dict).unwrap(); - let expected = DictionaryArray::new( - keys.clone(), - Arc::new(Int32Array::from(vec![1, 3, 3, 1])), - ); + let expected = + DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![1, 3, 3, 1]))); assert_eq!(b.as_ref(), &expected); let b = month_dyn(&dict).unwrap(); - let expected = - DictionaryArray::new(keys, Arc::new(Int32Array::from(vec![1, 8, 8, 1]))); + let expected = DictionaryArray::new(keys, Arc::new(Int32Array::from(vec![1, 8, 8, 1]))); assert_eq!(b.as_ref(), &expected); } @@ -1043,8 +1002,7 @@ mod tests { fn test_num_days_from_monday_sunday_day_doy_week_dictionary_array() { //1514764800000 -> 2018-01-01 (Monday) //1550636625000 -> 2019-02-20 (Wednesday) - let a: PrimitiveArray = - vec![Some(1514764800000), Some(1550636625000)].into(); + let a: PrimitiveArray = vec![Some(1514764800000), Some(1550636625000)].into(); let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), Some(0), None]); let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); diff --git a/arrow-array/src/arithmetic.rs b/arrow-array/src/arithmetic.rs index b0ecef70ee19..c9be39d44144 100644 --- a/arrow-array/src/arithmetic.rs +++ b/arrow-array/src/arithmetic.rs @@ -229,10 +229,7 @@ macro_rules! native_type_op { #[inline] fn pow_checked(self, exp: u32) -> Result { self.checked_pow(exp).ok_or_else(|| { - ArrowError::ComputeError(format!( - "Overflow happened on: {:?} ^ {exp:?}", - self - )) + ArrowError::ComputeError(format!("Overflow happened on: {:?} ^ {exp:?}", self)) }) } diff --git a/arrow-array/src/array/binary_array.rs b/arrow-array/src/array/binary_array.rs index 75880bec30ce..6b18cbc2d9f7 100644 --- a/arrow-array/src/array/binary_array.rs +++ b/arrow-array/src/array/binary_array.rs @@ -16,9 +16,7 @@ // under the License. use crate::types::{ByteArrayType, GenericBinaryType}; -use crate::{ - Array, GenericByteArray, GenericListArray, GenericStringArray, OffsetSizeTrait, -}; +use crate::{Array, GenericByteArray, GenericListArray, GenericStringArray, OffsetSizeTrait}; use arrow_data::ArrayData; use arrow_schema::DataType; @@ -102,9 +100,7 @@ impl GenericBinaryArray { } } -impl From>> - for GenericBinaryArray -{ +impl From>> for GenericBinaryArray { fn from(v: Vec>) -> Self { Self::from_opt_vec(v) } @@ -376,9 +372,11 @@ mod tests { .unwrap(); let binary_array1 = GenericBinaryArray::::from(array_data1); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( - Field::new("item", DataType::UInt8, false), - )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + false, + ))); let array_data2 = ArrayData::builder(data_type) .len(3) @@ -423,9 +421,11 @@ mod tests { let offsets = [0, 5, 8, 15].map(|n| O::from_usize(n).unwrap()); let null_buffer = Buffer::from_slice_ref([0b101]); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( - Field::new("item", DataType::UInt8, false), - )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + false, + ))); // [None, Some(b"Parquet")] let array_data = ArrayData::builder(data_type) @@ -456,9 +456,7 @@ mod tests { _test_generic_binary_array_from_list_array_with_offset::(); } - fn _test_generic_binary_array_from_list_array_with_child_nulls_failed< - O: OffsetSizeTrait, - >() { + fn _test_generic_binary_array_from_list_array_with_child_nulls_failed() { let values = b"HelloArrow"; let child_data = ArrayData::builder(DataType::UInt8) .len(10) @@ -468,9 +466,11 @@ mod tests { .unwrap(); let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap()); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( - Field::new("item", DataType::UInt8, true), - )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + true, + ))); // [None, Some(b"Parquet")] let array_data = ArrayData::builder(data_type) @@ -558,8 +558,7 @@ mod tests { .unwrap(); let offsets: [i32; 4] = [0, 5, 5, 12]; - let data_type = - DataType::List(Arc::new(Field::new("item", DataType::UInt32, false))); + let data_type = DataType::List(Arc::new(Field::new("item", DataType::UInt32, false))); let array_data = ArrayData::builder(data_type) .len(3) .add_buffer(Buffer::from_slice_ref(offsets)) @@ -575,8 +574,7 @@ mod tests { expected = "Trying to access an element at index 4 from a BinaryArray of length 3" )] fn test_binary_array_get_value_index_out_of_bound() { - let values: [u8; 12] = - [104, 101, 108, 108, 111, 112, 97, 114, 113, 117, 101, 116]; + let values: [u8; 12] = [104, 101, 108, 108, 111, 112, 97, 114, 113, 117, 101, 116]; let offsets: [i32; 4] = [0, 5, 5, 12]; let array_data = ArrayData::builder(DataType::Binary) .len(3) diff --git a/arrow-array/src/array/boolean_array.rs b/arrow-array/src/array/boolean_array.rs index 4d19babe3e4b..a778dc92ea35 100644 --- a/arrow-array/src/array/boolean_array.rs +++ b/arrow-array/src/array/boolean_array.rs @@ -238,11 +238,7 @@ impl BooleanArray { /// /// This function panics if left and right are not the same length /// - pub fn from_binary( - left: T, - right: S, - mut op: F, - ) -> Self + pub fn from_binary(left: T, right: S, mut op: F) -> Self where F: FnMut(T::Item, S::Item) -> bool, { @@ -362,8 +358,7 @@ impl From for BooleanArray { 1, "BooleanArray data should contain a single buffer only (values buffer)" ); - let values = - BooleanBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()); + let values = BooleanBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()); Self { values, @@ -591,9 +586,7 @@ mod tests { } #[test] - #[should_panic( - expected = "BooleanArray expected ArrayData with type Boolean got Int32" - )] + #[should_panic(expected = "BooleanArray expected ArrayData with type Boolean got Int32")] fn test_from_array_data_validation() { let _ = BooleanArray::from(ArrayData::new_empty(&DataType::Int32)); } diff --git a/arrow-array/src/array/byte_array.rs b/arrow-array/src/array/byte_array.rs index 37d8de931e99..db825bbea97d 100644 --- a/arrow-array/src/array/byte_array.rs +++ b/arrow-array/src/array/byte_array.rs @@ -197,8 +197,7 @@ impl GenericByteArray { let (_, data_len) = iter.size_hint(); let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. - let mut offsets = - MutableBuffer::new((data_len + 1) * std::mem::size_of::()); + let mut offsets = MutableBuffer::new((data_len + 1) * std::mem::size_of::()); offsets.push(T::Offset::usize_as(0)); let mut values = MutableBuffer::new(0); @@ -335,8 +334,7 @@ impl GenericByteArray { /// offset and data buffers are not shared by others. pub fn into_builder(self) -> Result, Self> { let len = self.len(); - let value_len = - T::Offset::as_usize(self.value_offsets()[len] - self.value_offsets()[0]); + let value_len = T::Offset::as_usize(self.value_offsets()[len] - self.value_offsets()[0]); let data = self.into_data(); let null_bit_buffer = data.nulls().map(|b| b.inner().sliced()); @@ -578,17 +576,14 @@ mod tests { let nulls = NullBuffer::new_null(3); let err = - StringArray::try_new(offsets.clone(), data.clone(), Some(nulls.clone())) - .unwrap_err(); + StringArray::try_new(offsets.clone(), data.clone(), Some(nulls.clone())).unwrap_err(); assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for StringArray, expected 2 got 3"); - let err = - BinaryArray::try_new(offsets.clone(), data.clone(), Some(nulls)).unwrap_err(); + let err = BinaryArray::try_new(offsets.clone(), data.clone(), Some(nulls)).unwrap_err(); assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for BinaryArray, expected 2 got 3"); let non_utf8_data = Buffer::from_slice_ref(b"he\xFFloworld"); - let err = StringArray::try_new(offsets.clone(), non_utf8_data.clone(), None) - .unwrap_err(); + let err = StringArray::try_new(offsets.clone(), non_utf8_data.clone(), None).unwrap_err(); assert_eq!(err.to_string(), "Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2"); BinaryArray::new(offsets, non_utf8_data, None); @@ -611,8 +606,7 @@ mod tests { BinaryArray::new(offsets, non_ascii_data.clone(), None); let offsets = OffsetBuffer::new(vec![0, 3, 10].into()); - let err = StringArray::try_new(offsets.clone(), non_ascii_data.clone(), None) - .unwrap_err(); + let err = StringArray::try_new(offsets.clone(), non_ascii_data.clone(), None).unwrap_err(); assert_eq!( err.to_string(), "Invalid argument error: Split UTF-8 codepoint at offset 3" diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index 0cb00878929c..1f4d83b1c5d0 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -286,10 +286,7 @@ impl DictionaryArray { /// # Errors /// /// Returns an error if any `keys[i] >= values.len() || keys[i] < 0` - pub fn try_new( - keys: PrimitiveArray, - values: ArrayRef, - ) -> Result { + pub fn try_new(keys: PrimitiveArray, values: ArrayRef) -> Result { let data_type = DataType::Dictionary( Box::new(keys.data_type().clone()), Box::new(values.data_type().clone()), @@ -298,9 +295,11 @@ impl DictionaryArray { let zero = K::Native::usize_as(0); let values_len = values.len(); - if let Some((idx, v)) = keys.values().iter().enumerate().find(|(idx, v)| { - (v.is_lt(zero) || v.as_usize() >= values_len) && keys.is_valid(*idx) - }) { + if let Some((idx, v)) = + keys.values().iter().enumerate().find(|(idx, v)| { + (v.is_lt(zero) || v.as_usize() >= values_len) && keys.is_valid(*idx) + }) + { return Err(ArrowError::InvalidArgumentError(format!( "Invalid dictionary key {v:?} at index {idx}, expected 0 <= key < {values_len}", ))); @@ -349,8 +348,7 @@ impl DictionaryArray { /// /// Panics if `values` is not a [`StringArray`]. pub fn lookup_key(&self, value: &str) -> Option { - let rd_buf: &StringArray = - self.values.as_any().downcast_ref::().unwrap(); + let rd_buf: &StringArray = self.values.as_any().downcast_ref::().unwrap(); (0..rd_buf.len()) .position(|i| rd_buf.value(i) == value) @@ -463,10 +461,8 @@ impl DictionaryArray { /// pub fn with_values(&self, values: ArrayRef) -> Self { assert!(values.len() >= self.values.len()); - let data_type = DataType::Dictionary( - Box::new(K::DATA_TYPE), - Box::new(values.data_type().clone()), - ); + let data_type = + DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(values.data_type().clone())); Self { data_type, keys: self.keys.clone(), @@ -477,9 +473,7 @@ impl DictionaryArray { /// Returns `PrimitiveDictionaryBuilder` of this dictionary array for mutating /// its keys and values if the underlying data buffer is not shared by others. - pub fn into_primitive_dict_builder( - self, - ) -> Result, Self> + pub fn into_primitive_dict_builder(self) -> Result, Self> where V: ArrowPrimitiveType, { @@ -540,8 +534,7 @@ impl DictionaryArray { V: ArrowPrimitiveType, F: Fn(V::Native) -> V::Native, { - let mut builder: PrimitiveDictionaryBuilder = - self.into_primitive_dict_builder()?; + let mut builder: PrimitiveDictionaryBuilder = self.into_primitive_dict_builder()?; builder .values_slice_mut() .iter_mut() @@ -806,9 +799,7 @@ impl<'a, K: ArrowDictionaryKeyType, V> Clone for TypedDictionaryArray<'a, K, V> impl<'a, K: ArrowDictionaryKeyType, V> Copy for TypedDictionaryArray<'a, K, V> {} -impl<'a, K: ArrowDictionaryKeyType, V> std::fmt::Debug - for TypedDictionaryArray<'a, K, V> -{ +impl<'a, K: ArrowDictionaryKeyType, V> std::fmt::Debug for TypedDictionaryArray<'a, K, V> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { writeln!(f, "TypedDictionaryArray({:?})", self.dictionary) } @@ -1040,8 +1031,7 @@ mod tests { // Construct a dictionary array from the above two let key_type = DataType::Int16; let value_type = DataType::Int8; - let dict_data_type = - DataType::Dictionary(Box::new(key_type), Box::new(value_type)); + let dict_data_type = DataType::Dictionary(Box::new(key_type), Box::new(value_type)); let dict_data = ArrayData::builder(dict_data_type.clone()) .len(3) .add_buffer(keys.clone()) @@ -1079,8 +1069,7 @@ mod tests { #[test] fn test_dictionary_array_fmt_debug() { - let mut builder = - PrimitiveDictionaryBuilder::::with_capacity(3, 2); + let mut builder = PrimitiveDictionaryBuilder::::with_capacity(3, 2); builder.append(12345678).unwrap(); builder.append_null(); builder.append(22345678).unwrap(); @@ -1090,8 +1079,7 @@ mod tests { format!("{array:?}") ); - let mut builder = - PrimitiveDictionaryBuilder::::with_capacity(20, 2); + let mut builder = PrimitiveDictionaryBuilder::::with_capacity(20, 2); for _ in 0..20 { builder.append(1).unwrap(); } @@ -1267,9 +1255,7 @@ mod tests { } #[test] - #[should_panic( - expected = "Invalid dictionary key 3 at index 1, expected 0 <= key < 2" - )] + #[should_panic(expected = "Invalid dictionary key 3 at index 1, expected 0 <= key < 2")] fn test_try_new_index_too_large() { let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); // dictionary only has 2 values, so offset 3 is out of bounds @@ -1278,9 +1264,7 @@ mod tests { } #[test] - #[should_panic( - expected = "Invalid dictionary key -100 at index 0, expected 0 <= key < 2" - )] + #[should_panic(expected = "Invalid dictionary key -100 at index 0, expected 0 <= key < 2")] fn test_try_new_index_too_small() { let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); let keys: Int32Array = [Some(-100)].into_iter().collect(); @@ -1288,9 +1272,7 @@ mod tests { } #[test] - #[should_panic( - expected = "DictionaryArray's data type must match, expected Int64 got Int32" - )] + #[should_panic(expected = "DictionaryArray's data type must match, expected Int64 got Int32")] fn test_from_array_data_validation() { let a = DictionaryArray::::from_iter(["32"]); let _ = DictionaryArray::::from(a.into_data()); @@ -1335,8 +1317,7 @@ mod tests { let boxed: ArrayRef = Arc::new(dict_array); - let col: DictionaryArray = - DictionaryArray::::from(boxed.to_data()); + let col: DictionaryArray = DictionaryArray::::from(boxed.to_data()); let err = col.into_primitive_dict_builder::(); let returned = err.unwrap_err(); diff --git a/arrow-array/src/array/fixed_size_binary_array.rs b/arrow-array/src/array/fixed_size_binary_array.rs index f0b04c203ceb..d89bbd5ad084 100644 --- a/arrow-array/src/array/fixed_size_binary_array.rs +++ b/arrow-array/src/array/fixed_size_binary_array.rs @@ -81,10 +81,7 @@ impl FixedSizeBinaryArray { ) -> Result { let data_type = DataType::FixedSizeBinary(size); let s = size.to_usize().ok_or_else(|| { - ArrowError::InvalidArgumentError(format!( - "Size cannot be negative, got {}", - size - )) + ArrowError::InvalidArgumentError(format!("Size cannot be negative, got {}", size)) })?; let len = values.len() / s; @@ -333,10 +330,7 @@ impl FixedSizeBinaryArray { /// # Errors /// /// Returns error if argument has length zero, or sizes of nested slices don't match. - pub fn try_from_sparse_iter_with_size( - mut iter: T, - size: i32, - ) -> Result + pub fn try_from_sparse_iter_with_size(mut iter: T, size: i32) -> Result where T: Iterator>, U: AsRef<[u8]>, @@ -812,8 +806,7 @@ mod tests { let none_option: Option<[u8; 32]> = None; let input_arg = vec![none_option, none_option, none_option]; #[allow(deprecated)] - let arr = - FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); + let arr = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); assert_eq!(0, arr.value_length()); assert_eq!(3, arr.len()) } @@ -828,16 +821,12 @@ mod tests { Some(vec![13, 14]), ]; #[allow(deprecated)] - let arr = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.iter().cloned()) - .unwrap(); + let arr = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.iter().cloned()).unwrap(); assert_eq!(2, arr.value_length()); assert_eq!(5, arr.len()); - let arr = FixedSizeBinaryArray::try_from_sparse_iter_with_size( - input_arg.into_iter(), - 2, - ) - .unwrap(); + let arr = + FixedSizeBinaryArray::try_from_sparse_iter_with_size(input_arg.into_iter(), 2).unwrap(); assert_eq!(2, arr.value_length()); assert_eq!(5, arr.len()); } @@ -846,11 +835,8 @@ mod tests { fn test_fixed_size_binary_array_from_sparse_iter_with_size_all_none() { let input_arg = vec![None, None, None, None, None] as Vec>>; - let arr = FixedSizeBinaryArray::try_from_sparse_iter_with_size( - input_arg.into_iter(), - 16, - ) - .unwrap(); + let arr = FixedSizeBinaryArray::try_from_sparse_iter_with_size(input_arg.into_iter(), 16) + .unwrap(); assert_eq!(16, arr.value_length()); assert_eq!(5, arr.len()) } @@ -917,8 +903,7 @@ mod tests { fn fixed_size_binary_array_all_null() { let data = vec![None] as Vec>; let array = - FixedSizeBinaryArray::try_from_sparse_iter_with_size(data.into_iter(), 0) - .unwrap(); + FixedSizeBinaryArray::try_from_sparse_iter_with_size(data.into_iter(), 0).unwrap(); array .into_data() .validate_full() @@ -928,8 +913,7 @@ mod tests { #[test] // Test for https://github.com/apache/arrow-rs/issues/1390 fn fixed_size_binary_array_all_null_in_batch_with_schema() { - let schema = - Schema::new(vec![Field::new("a", DataType::FixedSizeBinary(2), true)]); + let schema = Schema::new(vec![Field::new("a", DataType::FixedSizeBinary(2), true)]); let none_option: Option<[u8; 2]> = None; let item = FixedSizeBinaryArray::try_from_sparse_iter_with_size( diff --git a/arrow-array/src/array/fixed_size_list_array.rs b/arrow-array/src/array/fixed_size_list_array.rs index db3ccbe0617b..f8f01516e3d4 100644 --- a/arrow-array/src/array/fixed_size_list_array.rs +++ b/arrow-array/src/array/fixed_size_list_array.rs @@ -130,12 +130,7 @@ impl FixedSizeListArray { /// # Panics /// /// Panics if [`Self::try_new`] returns an error - pub fn new( - field: FieldRef, - size: i32, - values: ArrayRef, - nulls: Option, - ) -> Self { + pub fn new(field: FieldRef, size: i32, values: ArrayRef, nulls: Option) -> Self { Self::try_new(field, size, values, nulls).unwrap() } @@ -154,10 +149,7 @@ impl FixedSizeListArray { nulls: Option, ) -> Result { let s = size.to_usize().ok_or_else(|| { - ArrowError::InvalidArgumentError(format!( - "Size cannot be negative, got {}", - size - )) + ArrowError::InvalidArgumentError(format!("Size cannot be negative, got {}", size)) })?; let len = values.len() / s.max(1); @@ -350,9 +342,8 @@ impl From for FixedSizeListArray { }; let size = value_length as usize; - let values = make_array( - data.child_data()[0].slice(data.offset() * size, data.len() * size), - ); + let values = + make_array(data.child_data()[0].slice(data.offset() * size, data.len() * size)); Self { data_type: data.data_type().clone(), values, @@ -483,10 +474,8 @@ mod tests { .unwrap(); // Construct a list array from the above two - let list_data_type = DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Int32, false)), - 3, - ); + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3); let list_data = ArrayData::builder(list_data_type.clone()) .len(3) .add_child_data(value_data.clone()) @@ -538,10 +527,8 @@ mod tests { .unwrap(); // Construct a list array from the above two - let list_data_type = DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Int32, false)), - 3, - ); + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) @@ -569,10 +556,8 @@ mod tests { bit_util::set_bit(&mut null_bits, 4); // Construct a fixed size list array from the above two - let list_data_type = DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Int32, false)), - 2, - ); + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 2); let list_data = ArrayData::builder(list_data_type) .len(5) .add_child_data(value_data.clone()) @@ -611,9 +596,7 @@ mod tests { } #[test] - #[should_panic( - expected = "the offset of the new Buffer cannot exceed the existing length" - )] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] fn test_fixed_size_list_array_index_out_of_bound() { // Construct a value array let value_data = ArrayData::builder(DataType::Int32) @@ -631,10 +614,8 @@ mod tests { bit_util::set_bit(&mut null_bits, 4); // Construct a fixed size list array from the above two - let list_data_type = DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Int32, false)), - 2, - ); + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 2); let list_data = ArrayData::builder(list_data_type) .len(5) .add_child_data(value_data) @@ -668,8 +649,7 @@ mod tests { let list = FixedSizeListArray::new(field.clone(), 4, values.clone(), None); assert_eq!(list.len(), 1); - let err = FixedSizeListArray::try_new(field.clone(), -1, values.clone(), None) - .unwrap_err(); + let err = FixedSizeListArray::try_new(field.clone(), -1, values.clone(), None).unwrap_err(); assert_eq!( err.to_string(), "Invalid argument error: Size cannot be negative, got -1" @@ -679,13 +659,11 @@ mod tests { assert_eq!(list.len(), 6); let nulls = NullBuffer::new_null(2); - let err = FixedSizeListArray::try_new(field, 2, values.clone(), Some(nulls)) - .unwrap_err(); + let err = FixedSizeListArray::try_new(field, 2, values.clone(), Some(nulls)).unwrap_err(); assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for FixedSizeListArray, expected 3 got 2"); let field = Arc::new(Field::new("item", DataType::Int32, false)); - let err = FixedSizeListArray::try_new(field.clone(), 2, values.clone(), None) - .unwrap_err(); + let err = FixedSizeListArray::try_new(field.clone(), 2, values.clone(), None).unwrap_err(); assert_eq!(err.to_string(), "Invalid argument error: Found unmasked nulls for non-nullable FixedSizeListArray field \"item\""); // Valid as nulls in child masked by parent diff --git a/arrow-array/src/array/list_array.rs b/arrow-array/src/array/list_array.rs index e36d0ac4434f..9758c112a1ef 100644 --- a/arrow-array/src/array/list_array.rs +++ b/arrow-array/src/array/list_array.rs @@ -372,9 +372,8 @@ impl GenericListArray { impl From for GenericListArray { fn from(data: ArrayData) -> Self { - Self::try_new_from_array_data(data).expect( - "Expected infallible creation of GenericListArray from ArrayDataRef failed", - ) + Self::try_new_from_array_data(data) + .expect("Expected infallible creation of GenericListArray from ArrayDataRef failed") } } @@ -391,17 +390,14 @@ impl From> for ArrayDa } } -impl From - for GenericListArray -{ +impl From for GenericListArray { fn from(value: FixedSizeListArray) -> Self { let (field, size) = match value.data_type() { DataType::FixedSizeList(f, size) => (f, *size as usize), _ => unreachable!(), }; - let offsets = - OffsetBuffer::from_lengths(std::iter::repeat(size).take(value.len())); + let offsets = OffsetBuffer::from_lengths(std::iter::repeat(size).take(value.len())); Self { data_type: Self::DATA_TYPE_CONSTRUCTOR(field.clone()), @@ -415,9 +411,10 @@ impl From impl GenericListArray { fn try_new_from_array_data(data: ArrayData) -> Result { if data.buffers().len() != 1 { - return Err(ArrowError::InvalidArgumentError( - format!("ListArray data should contain a single buffer only (value offsets), had {}", - data.buffers().len()))); + return Err(ArrowError::InvalidArgumentError(format!( + "ListArray data should contain a single buffer only (value offsets), had {}", + data.buffers().len() + ))); } if data.child_data().len() != 1 { @@ -593,8 +590,7 @@ mod tests { let value_offsets = Buffer::from([]); // Construct a list array from the above two - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(0) .add_buffer(value_offsets) @@ -620,8 +616,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); // Construct a list array from the above two - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type.clone()) .len(3) .add_buffer(value_offsets.clone()) @@ -807,8 +802,7 @@ mod tests { bit_util::set_bit(&mut null_bits, 8); // Construct a list array from the above two - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(9) .add_buffer(value_offsets) @@ -839,8 +833,7 @@ mod tests { } // Check offset and length for each non-null value. - let sliced_list_array = - sliced_array.as_any().downcast_ref::().unwrap(); + let sliced_list_array = sliced_array.as_any().downcast_ref::().unwrap(); assert_eq!(2, sliced_list_array.value_offsets()[2]); assert_eq!(2, sliced_list_array.value_length(2)); assert_eq!(4, sliced_list_array.value_offsets()[3]); @@ -951,9 +944,7 @@ mod tests { list_array.value(10); } #[test] - #[should_panic( - expected = "ListArray data should contain a single buffer only (value offsets)" - )] + #[should_panic(expected = "ListArray data should contain a single buffer only (value offsets)")] // Different error messages, so skip for now // https://github.com/apache/arrow-rs/issues/1545 #[cfg(not(feature = "force_validate"))] @@ -964,8 +955,7 @@ mod tests { .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) .build_unchecked() }; - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) @@ -976,16 +966,13 @@ mod tests { } #[test] - #[should_panic( - expected = "ListArray should contain a single child array (values array)" - )] + #[should_panic(expected = "ListArray should contain a single child array (values array)")] // Different error messages, so skip for now // https://github.com/apache/arrow-rs/issues/1545 #[cfg(not(feature = "force_validate"))] fn test_list_array_invalid_child_array_len() { let value_offsets = Buffer::from_slice_ref([0, 2, 5, 7]); - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) @@ -996,9 +983,7 @@ mod tests { } #[test] - #[should_panic( - expected = "[Large]ListArray's datatype must be [Large]ListArray(). It is List" - )] + #[should_panic(expected = "[Large]ListArray's datatype must be [Large]ListArray(). It is List")] fn test_from_array_data_validation() { let mut builder = ListBuilder::new(Int32Builder::new()); builder.values().append_value(1); @@ -1017,8 +1002,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref([2, 2, 5, 7]); - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -1033,9 +1017,7 @@ mod tests { } #[test] - #[should_panic( - expected = "Memory pointer is not aligned with the specified scalar type" - )] + #[should_panic(expected = "Memory pointer is not aligned with the specified scalar type")] // Different error messages, so skip for now // https://github.com/apache/arrow-rs/issues/1545 #[cfg(not(feature = "force_validate"))] @@ -1051,9 +1033,7 @@ mod tests { } #[test] - #[should_panic( - expected = "Memory pointer is not aligned with the specified scalar type" - )] + #[should_panic(expected = "Memory pointer is not aligned with the specified scalar type")] // Different error messages, so skip for now // https://github.com/apache/arrow-rs/issues/1545 #[cfg(not(feature = "force_validate"))] @@ -1068,8 +1048,7 @@ mod tests { .build_unchecked() }; - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .add_buffer(buf2) @@ -1187,9 +1166,8 @@ mod tests { let nulls = NullBuffer::new_null(3); let offsets = OffsetBuffer::new(vec![0, 1, 2, 4, 5].into()); - let err = - LargeListArray::try_new(field, offsets.clone(), values.clone(), Some(nulls)) - .unwrap_err(); + let err = LargeListArray::try_new(field, offsets.clone(), values.clone(), Some(nulls)) + .unwrap_err(); assert_eq!( err.to_string(), @@ -1197,9 +1175,8 @@ mod tests { ); let field = Arc::new(Field::new("element", DataType::Int64, false)); - let err = - LargeListArray::try_new(field.clone(), offsets.clone(), values.clone(), None) - .unwrap_err(); + let err = LargeListArray::try_new(field.clone(), offsets.clone(), values.clone(), None) + .unwrap_err(); assert_eq!( err.to_string(), @@ -1210,8 +1187,8 @@ mod tests { let values = Int64Array::new(vec![0; 7].into(), Some(nulls)); let values = Arc::new(values); - let err = LargeListArray::try_new(field, offsets.clone(), values.clone(), None) - .unwrap_err(); + let err = + LargeListArray::try_new(field, offsets.clone(), values.clone(), None).unwrap_err(); assert_eq!( err.to_string(), @@ -1222,8 +1199,7 @@ mod tests { LargeListArray::new(field.clone(), offsets.clone(), values, None); let values = Int64Array::new(vec![0; 2].into(), None); - let err = - LargeListArray::try_new(field, offsets, Arc::new(values), None).unwrap_err(); + let err = LargeListArray::try_new(field, offsets, Arc::new(values), None).unwrap_err(); assert_eq!( err.to_string(), diff --git a/arrow-array/src/array/map_array.rs b/arrow-array/src/array/map_array.rs index 77a7b9d4d547..bde7fdd5a953 100644 --- a/arrow-array/src/array/map_array.rs +++ b/arrow-array/src/array/map_array.rs @@ -17,9 +17,7 @@ use crate::array::{get_offsets, print_long_array}; use crate::iterator::MapArrayIter; -use crate::{ - make_array, Array, ArrayAccessor, ArrayRef, ListArray, StringArray, StructArray, -}; +use crate::{make_array, Array, ArrayAccessor, ArrayRef, ListArray, StringArray, StructArray}; use arrow_buffer::{ArrowNativeType, Buffer, NullBuffer, OffsetBuffer, ToByteSlice}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field, FieldRef}; @@ -264,9 +262,10 @@ impl MapArray { } if data.buffers().len() != 1 { - return Err(ArrowError::InvalidArgumentError( - format!("MapArray data should contain a single buffer only (value offsets), had {}", - data.len()))); + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray data should contain a single buffer only (value offsets), had {}", + data.len() + ))); } if data.child_data().len() != 1 { @@ -281,9 +280,9 @@ impl MapArray { if let DataType::Struct(fields) = entries.data_type() { if fields.len() != 2 { return Err(ArrowError::InvalidArgumentError(format!( - "MapArray should contain a struct array with 2 fields, have {} fields", - fields.len() - ))); + "MapArray should contain a struct array with 2 fields, have {} fields", + fields.len() + ))); } } else { return Err(ArrowError::InvalidArgumentError(format!( @@ -576,8 +575,7 @@ mod tests { assert_eq!(2, map_array.value_length(1)); let key_array = Arc::new(Int32Array::from(vec![3, 4, 5])) as ArrayRef; - let value_array = - Arc::new(UInt32Array::from(vec![None, Some(40), None])) as ArrayRef; + let value_array = Arc::new(UInt32Array::from(vec![None, Some(40), None])) as ArrayRef; let struct_array = StructArray::from(vec![(keys_field, key_array), (values_field, value_array)]); assert_eq!( @@ -669,9 +667,7 @@ mod tests { } #[test] - #[should_panic( - expected = "MapArray expected ArrayData with DataType::Map got Dictionary" - )] + #[should_panic(expected = "MapArray expected ArrayData with DataType::Map got Dictionary")] fn test_from_array_data_validation() { // A DictionaryArray has similar buffer layout to a MapArray // but the meaning of the values differs @@ -692,12 +688,9 @@ mod tests { // [[a, b, c], [d, e, f], [g, h]] let entry_offsets = [0, 3, 6, 8]; - let map_array = MapArray::new_from_strings( - keys.clone().into_iter(), - &values_data, - &entry_offsets, - ) - .unwrap(); + let map_array = + MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets) + .unwrap(); assert_eq!( &values_data, @@ -768,9 +761,8 @@ mod tests { "Invalid argument error: Incorrect length of null buffer for MapArray, expected 4 got 3" ); - let err = - MapArray::try_new(field, offsets.clone(), entries.slice(0, 2), None, false) - .unwrap_err(); + let err = MapArray::try_new(field, offsets.clone(), entries.slice(0, 2), None, false) + .unwrap_err(); assert_eq!( err.to_string(), @@ -783,9 +775,7 @@ mod tests { .to_string(); assert!( - err.starts_with( - "Invalid argument error: MapArray expected data type Int64 got Struct" - ), + err.starts_with("Invalid argument error: MapArray expected data type Int64 got Struct"), "{err}" ); diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index 9b66826f7584..f19406c1610b 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -536,9 +536,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::Float64 => Arc::new(Float64Array::from(data)) as ArrayRef, DataType::Date32 => Arc::new(Date32Array::from(data)) as ArrayRef, DataType::Date64 => Arc::new(Date64Array::from(data)) as ArrayRef, - DataType::Time32(TimeUnit::Second) => { - Arc::new(Time32SecondArray::from(data)) as ArrayRef - } + DataType::Time32(TimeUnit::Second) => Arc::new(Time32SecondArray::from(data)) as ArrayRef, DataType::Time32(TimeUnit::Millisecond) => { Arc::new(Time32MillisecondArray::from(data)) as ArrayRef } @@ -583,9 +581,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { } DataType::Binary => Arc::new(BinaryArray::from(data)) as ArrayRef, DataType::LargeBinary => Arc::new(LargeBinaryArray::from(data)) as ArrayRef, - DataType::FixedSizeBinary(_) => { - Arc::new(FixedSizeBinaryArray::from(data)) as ArrayRef - } + DataType::FixedSizeBinary(_) => Arc::new(FixedSizeBinaryArray::from(data)) as ArrayRef, DataType::Utf8 => Arc::new(StringArray::from(data)) as ArrayRef, DataType::LargeUtf8 => Arc::new(LargeStringArray::from(data)) as ArrayRef, DataType::List(_) => Arc::new(ListArray::from(data)) as ArrayRef, @@ -593,50 +589,24 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef, DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef, DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef, - DataType::FixedSizeList(_, _) => { - Arc::new(FixedSizeListArray::from(data)) as ArrayRef - } + DataType::FixedSizeList(_, _) => Arc::new(FixedSizeListArray::from(data)) as ArrayRef, DataType::Dictionary(ref key_type, _) => match key_type.as_ref() { - DataType::Int8 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::Int16 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::Int32 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::Int64 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::UInt8 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::UInt16 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::UInt32 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::UInt64 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } + DataType::Int8 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::Int16 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::Int32 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::Int64 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt8 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt16 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt32 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt64 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, dt => panic!("Unexpected dictionary key type {dt:?}"), }, - DataType::RunEndEncoded(ref run_ends_type, _) => { - match run_ends_type.data_type() { - DataType::Int16 => { - Arc::new(RunArray::::from(data)) as ArrayRef - } - DataType::Int32 => { - Arc::new(RunArray::::from(data)) as ArrayRef - } - DataType::Int64 => { - Arc::new(RunArray::::from(data)) as ArrayRef - } - dt => panic!("Unexpected data type for run_ends array {dt:?}"), - } - } + DataType::RunEndEncoded(ref run_ends_type, _) => match run_ends_type.data_type() { + DataType::Int16 => Arc::new(RunArray::::from(data)) as ArrayRef, + DataType::Int32 => Arc::new(RunArray::::from(data)) as ArrayRef, + DataType::Int64 => Arc::new(RunArray::::from(data)) as ArrayRef, + dt => panic!("Unexpected data type for run_ends array {dt:?}"), + }, DataType::Null => Arc::new(NullArray::from(data)) as ArrayRef, DataType::Decimal128(_, _) => Arc::new(Decimal128Array::from(data)) as ArrayRef, DataType::Decimal256(_, _) => Arc::new(Decimal256Array::from(data)) as ArrayRef, @@ -687,11 +657,8 @@ unsafe fn get_offsets(data: &ArrayData) -> OffsetBuffer { match data.is_empty() && data.buffers()[0].is_empty() { true => OffsetBuffer::new_empty(), false => { - let buffer = ScalarBuffer::new( - data.buffers()[0].clone(), - data.offset(), - data.len() + 1, - ); + let buffer = + ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len() + 1); // Safety: // ArrayData is valid unsafe { OffsetBuffer::new_unchecked(buffer) } @@ -700,11 +667,7 @@ unsafe fn get_offsets(data: &ArrayData) -> OffsetBuffer { } /// Helper function for printing potentially long arrays. -fn print_long_array( - array: &A, - f: &mut std::fmt::Formatter, - print_item: F, -) -> std::fmt::Result +fn print_long_array(array: &A, f: &mut std::fmt::Formatter, print_item: F) -> std::fmt::Result where A: Array, F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result, @@ -767,8 +730,7 @@ mod tests { #[test] fn test_empty_list_primitive() { - let data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); + let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let array = new_empty_array(&data_type); let a = array.as_any().downcast_ref::().unwrap(); assert_eq!(a.len(), 0); @@ -799,8 +761,7 @@ mod tests { fn test_null_struct() { // It is possible to create a null struct containing a non-nullable child // see https://github.com/apache/arrow-rs/pull/3244 for details - let struct_type = - DataType::Struct(vec![Field::new("data", DataType::Int64, false)].into()); + let struct_type = DataType::Struct(vec![Field::new("data", DataType::Int64, false)].into()); let array = new_null_array(&struct_type, 9); let a = array.as_any().downcast_ref::().unwrap(); @@ -827,8 +788,7 @@ mod tests { #[test] fn test_null_list_primitive() { - let data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); let array = new_null_array(&data_type, 9); let a = array.as_any().downcast_ref::().unwrap(); assert_eq!(a.len(), 9); @@ -862,8 +822,8 @@ mod tests { #[test] fn test_null_dictionary() { - let values = vec![None, None, None, None, None, None, None, None, None] - as Vec>; + let values = + vec![None, None, None, None, None, None, None, None, None] as Vec>; let array: DictionaryArray = values.into_iter().collect(); let array = Arc::new(array) as ArrayRef; @@ -965,8 +925,7 @@ mod tests { #[test] fn test_memory_size_primitive() { let arr = PrimitiveArray::::from_iter_values(0..128); - let empty = - PrimitiveArray::::from(ArrayData::new_empty(arr.data_type())); + let empty = PrimitiveArray::::from(ArrayData::new_empty(arr.data_type())); // subtract empty array to avoid magic numbers for the size of additional fields assert_eq!( diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 4c07e81468aa..1112acacfcd9 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -917,8 +917,8 @@ impl PrimitiveArray { let null_bit_buffer = data.nulls().map(|b| b.inner().sliced()); let element_len = std::mem::size_of::(); - let buffer = data.buffers()[0] - .slice_with_length(data.offset() * element_len, len * element_len); + let buffer = + data.buffers()[0].slice_with_length(data.offset() * element_len, len * element_len); drop(data); @@ -1116,10 +1116,9 @@ impl std::fmt::Debug for PrimitiveArray { }, // if the time zone is invalid, shows NaiveDateTime with an error message Err(_) => match as_datetime::(v) { - Some(datetime) => write!( - f, - "{datetime:?} (Unknown Time Zone '{tz_string}')" - ), + Some(datetime) => { + write!(f, "{datetime:?} (Unknown Time Zone '{tz_string}')") + } None => write!(f, "null"), }, } @@ -1191,25 +1190,19 @@ def_from_for_primitive!(Float64Type, f64); def_from_for_primitive!(Decimal128Type, i128); def_from_for_primitive!(Decimal256Type, i256); -impl From::Native>> - for NativeAdapter -{ +impl From::Native>> for NativeAdapter { fn from(value: Option<::Native>) -> Self { NativeAdapter { native: value } } } -impl From<&Option<::Native>> - for NativeAdapter -{ +impl From<&Option<::Native>> for NativeAdapter { fn from(value: &Option<::Native>) -> Self { NativeAdapter { native: *value } } } -impl>> FromIterator - for PrimitiveArray -{ +impl>> FromIterator for PrimitiveArray { fn from_iter>(iter: I) -> Self { let iter = iter.into_iter(); let (lower, _) = iter.size_hint(); @@ -1265,15 +1258,8 @@ impl PrimitiveArray { let (null, buffer) = trusted_len_unzip(iterator); - let data = ArrayData::new_unchecked( - T::DATA_TYPE, - len, - None, - Some(null), - 0, - vec![buffer], - vec![], - ); + let data = + ArrayData::new_unchecked(T::DATA_TYPE, len, None, Some(null), 0, vec![buffer], vec![]); PrimitiveArray::from(data) } } @@ -1294,9 +1280,7 @@ macro_rules! def_numeric_from_vec { } // Constructs a primitive array from a vector. Should only be used for testing. - impl From::Native>>> - for PrimitiveArray<$ty> - { + impl From::Native>>> for PrimitiveArray<$ty> { fn from(data: Vec::Native>>) -> Self { PrimitiveArray::from_iter(data.iter()) } @@ -1392,8 +1376,7 @@ impl From for PrimitiveArray { "PrimitiveArray data should contain a single buffer only (values buffer)" ); - let values = - ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()); + let values = ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()); Self { data_type: data.data_type().clone(), values, @@ -1407,11 +1390,7 @@ impl PrimitiveArray { /// specified precision and scale. /// /// See [`validate_decimal_precision_and_scale`] - pub fn with_precision_and_scale( - self, - precision: u8, - scale: i8, - ) -> Result { + pub fn with_precision_and_scale(self, precision: u8, scale: i8) -> Result { validate_decimal_precision_and_scale::(precision, scale)?; Ok(Self { data_type: T::TYPE_CONSTRUCTOR(precision, scale), @@ -1575,8 +1554,7 @@ mod tests { // 1: 00:00:00.001 // 37800005: 10:30:00.005 // 86399210: 23:59:59.210 - let arr: PrimitiveArray = - vec![1, 37_800_005, 86_399_210].into(); + let arr: PrimitiveArray = vec![1, 37_800_005, 86_399_210].into(); assert_eq!(3, arr.len()); assert_eq!(0, arr.offset()); assert_eq!(0, arr.null_count()); @@ -1858,11 +1836,7 @@ mod tests { #[test] fn test_timestamp_fmt_debug() { let arr: PrimitiveArray = - TimestampMillisecondArray::from(vec![ - 1546214400000, - 1546214400000, - -1546214400000, - ]); + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]); assert_eq!( "PrimitiveArray\n[\n 2018-12-31T00:00:00,\n 2018-12-31T00:00:00,\n 1921-01-02T00:00:00,\n]", format!("{arr:?}") @@ -1872,12 +1846,8 @@ mod tests { #[test] fn test_timestamp_utc_fmt_debug() { let arr: PrimitiveArray = - TimestampMillisecondArray::from(vec![ - 1546214400000, - 1546214400000, - -1546214400000, - ]) - .with_timezone_utc(); + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone_utc(); assert_eq!( "PrimitiveArray\n[\n 2018-12-31T00:00:00+00:00,\n 2018-12-31T00:00:00+00:00,\n 1921-01-02T00:00:00+00:00,\n]", format!("{arr:?}") @@ -1888,12 +1858,8 @@ mod tests { #[cfg(feature = "chrono-tz")] fn test_timestamp_with_named_tz_fmt_debug() { let arr: PrimitiveArray = - TimestampMillisecondArray::from(vec![ - 1546214400000, - 1546214400000, - -1546214400000, - ]) - .with_timezone("Asia/Taipei".to_string()); + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("Asia/Taipei".to_string()); assert_eq!( "PrimitiveArray\n[\n 2018-12-31T08:00:00+08:00,\n 2018-12-31T08:00:00+08:00,\n 1921-01-02T08:00:00+08:00,\n]", format!("{:?}", arr) @@ -1904,12 +1870,8 @@ mod tests { #[cfg(not(feature = "chrono-tz"))] fn test_timestamp_with_named_tz_fmt_debug() { let arr: PrimitiveArray = - TimestampMillisecondArray::from(vec![ - 1546214400000, - 1546214400000, - -1546214400000, - ]) - .with_timezone("Asia/Taipei".to_string()); + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("Asia/Taipei".to_string()); println!("{arr:?}"); @@ -1922,12 +1884,8 @@ mod tests { #[test] fn test_timestamp_with_fixed_offset_tz_fmt_debug() { let arr: PrimitiveArray = - TimestampMillisecondArray::from(vec![ - 1546214400000, - 1546214400000, - -1546214400000, - ]) - .with_timezone("+08:00".to_string()); + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("+08:00".to_string()); assert_eq!( "PrimitiveArray\n[\n 2018-12-31T08:00:00+08:00,\n 2018-12-31T08:00:00+08:00,\n 1921-01-02T08:00:00+08:00,\n]", format!("{arr:?}") @@ -1937,12 +1895,8 @@ mod tests { #[test] fn test_timestamp_with_incorrect_tz_fmt_debug() { let arr: PrimitiveArray = - TimestampMillisecondArray::from(vec![ - 1546214400000, - 1546214400000, - -1546214400000, - ]) - .with_timezone("xxx".to_string()); + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("xxx".to_string()); assert_eq!( "PrimitiveArray\n[\n 2018-12-31T00:00:00 (Unknown Time Zone 'xxx'),\n 2018-12-31T00:00:00 (Unknown Time Zone 'xxx'),\n 1921-01-02T00:00:00 (Unknown Time Zone 'xxx'),\n]", format!("{arr:?}") @@ -1952,14 +1906,13 @@ mod tests { #[test] #[cfg(feature = "chrono-tz")] fn test_timestamp_with_tz_with_daylight_saving_fmt_debug() { - let arr: PrimitiveArray = - TimestampMillisecondArray::from(vec![ - 1647161999000, - 1647162000000, - 1667717999000, - 1667718000000, - ]) - .with_timezone("America/Denver".to_string()); + let arr: PrimitiveArray = TimestampMillisecondArray::from(vec![ + 1647161999000, + 1647162000000, + 1667717999000, + 1667718000000, + ]) + .with_timezone("America/Denver".to_string()); assert_eq!( "PrimitiveArray\n[\n 2022-03-13T01:59:59-07:00,\n 2022-03-13T03:00:00-06:00,\n 2022-11-06T00:59:59-06:00,\n 2022-11-06T01:00:00-06:00,\n]", format!("{:?}", arr) @@ -1997,8 +1950,7 @@ mod tests { #[test] fn test_timestamp_micros_out_of_range() { // replicate the issue from https://github.com/apache/arrow-datafusion/issues/3832 - let arr: PrimitiveArray = - vec![9065525203050843594].into(); + let arr: PrimitiveArray = vec![9065525203050843594].into(); assert_eq!( "PrimitiveArray\n[\n null,\n]", format!("{arr:?}") @@ -2143,8 +2095,7 @@ mod tests { #[test] fn test_decimal256() { - let values: Vec<_> = - vec![i256::ZERO, i256::ONE, i256::MINUS_ONE, i256::MIN, i256::MAX]; + let values: Vec<_> = vec![i256::ZERO, i256::ONE, i256::MINUS_ONE, i256::MIN, i256::MAX]; let array: PrimitiveArray = PrimitiveArray::from_iter(values.iter().copied()); @@ -2166,8 +2117,8 @@ mod tests { // let val_8887: [u8; 16] = [192, 219, 180, 17, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; // let val_neg_8887: [u8; 16] = [64, 36, 75, 238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255]; let values: [u8; 32] = [ - 192, 219, 180, 17, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64, 36, 75, 238, 253, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 192, 219, 180, 17, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64, 36, 75, 238, 253, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, ]; let array_data = ArrayData::builder(DataType::Decimal128(38, 6)) .len(2) @@ -2232,8 +2183,7 @@ mod tests { #[test] fn test_decimal_from_iter() { - let array: Decimal128Array = - vec![Some(-100), None, Some(101)].into_iter().collect(); + let array: Decimal128Array = vec![Some(-100), None, Some(101)].into_iter().collect(); assert_eq!(array.len(), 3); assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); assert_eq!(-100_i128, array.value(0)); @@ -2343,8 +2293,7 @@ mod tests { #[test] fn test_decimal_array_set_null_if_overflow_with_precision() { - let array = - Decimal128Array::from(vec![Some(123456), Some(123), None, Some(123456)]); + let array = Decimal128Array::from(vec![Some(123456), Some(123), None, Some(123456)]); let result = array.null_if_overflow_precision(5); let expected = Decimal128Array::from(vec![None, Some(123), None, None]); assert_eq!(result, expected); @@ -2361,8 +2310,7 @@ mod tests { let decimal2 = i256::from_i128(56789); builder.append_value(decimal2); - let array: Decimal256Array = - builder.finish().with_precision_and_scale(76, 6).unwrap(); + let array: Decimal256Array = builder.finish().with_precision_and_scale(76, 6).unwrap(); let collected: Vec<_> = array.iter().collect(); assert_eq!(vec![Some(decimal1), None, Some(decimal2)], collected); @@ -2387,8 +2335,7 @@ mod tests { #[test] fn test_from_iter_decimal128array() { - let mut array: Decimal128Array = - vec![Some(-100), None, Some(101)].into_iter().collect(); + let mut array: Decimal128Array = vec![Some(-100), None, Some(101)].into_iter().collect(); array = array.with_precision_and_scale(38, 10).unwrap(); assert_eq!(array.len(), 3); assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); @@ -2404,13 +2351,11 @@ mod tests { let array = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7]); let r = array.unary_opt::<_, Int32Type>(|x| (x % 2 != 0).then_some(x)); - let expected = - Int32Array::from(vec![Some(1), None, Some(3), None, Some(5), None, Some(7)]); + let expected = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5), None, Some(7)]); assert_eq!(r, expected); let r = expected.unary_opt::<_, Int32Type>(|x| (x % 3 != 0).then_some(x)); - let expected = - Int32Array::from(vec![Some(1), None, None, None, Some(5), None, Some(7)]); + let expected = Int32Array::from(vec![Some(1), None, None, None, Some(5), None, Some(7)]); assert_eq!(r, expected); } @@ -2513,9 +2458,8 @@ mod tests { Int32Array::new(vec![1, 2, 3, 4].into(), None); Int32Array::new(vec![1, 2, 3, 4].into(), Some(NullBuffer::new_null(4))); - let err = - Int32Array::try_new(vec![1, 2, 3, 4].into(), Some(NullBuffer::new_null(3))) - .unwrap_err(); + let err = Int32Array::try_new(vec![1, 2, 3, 4].into(), Some(NullBuffer::new_null(3))) + .unwrap_err(); assert_eq!( err.to_string(), diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index ba6986c28463..4877f9f850a3 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -91,10 +91,7 @@ impl RunArray { /// Attempts to create RunArray using given run_ends (index where a run ends) /// and the values (value of the run). Returns an error if the given data is not compatible /// with RunEndEncoded specification. - pub fn try_new( - run_ends: &PrimitiveArray, - values: &dyn Array, - ) -> Result { + pub fn try_new(run_ends: &PrimitiveArray, values: &dyn Array) -> Result { let run_ends_type = run_ends.data_type().clone(); let values_type = values.data_type().clone(); let ree_array_type = DataType::RunEndEncoded( @@ -182,10 +179,7 @@ impl RunArray { /// scaled well for larger inputs. /// See for more details. #[inline] - pub fn get_physical_indices( - &self, - logical_indices: &[I], - ) -> Result, ArrowError> + pub fn get_physical_indices(&self, logical_indices: &[I]) -> Result, ArrowError> where I: ArrowNativeType, { @@ -211,8 +205,7 @@ impl RunArray { }); // Return early if all the logical indices cannot be converted to physical indices. - let largest_logical_index = - logical_indices[*ordered_indices.last().unwrap()].as_usize(); + let largest_logical_index = logical_indices[*ordered_indices.last().unwrap()].as_usize(); if largest_logical_index >= len { return Err(ArrowError::InvalidArgumentError(format!( "Cannot convert all logical indices to physical indices. The logical index cannot be converted is {largest_logical_index}.", @@ -225,8 +218,7 @@ impl RunArray { let mut physical_indices = vec![0; indices_len]; let mut ordered_index = 0_usize; - for (physical_index, run_end) in - self.run_ends.values().iter().enumerate().skip(skip_value) + for (physical_index, run_end) in self.run_ends.values().iter().enumerate().skip(skip_value) { // Get the run end index (relative to offset) of current physical index let run_end_value = run_end.as_usize() - offset; @@ -234,8 +226,7 @@ impl RunArray { // All the `logical_indices` that are less than current run end index // belongs to current physical index. while ordered_index < indices_len - && logical_indices[ordered_indices[ordered_index]].as_usize() - < run_end_value + && logical_indices[ordered_indices[ordered_index]].as_usize() < run_end_value { physical_indices[ordered_indices[ordered_index]] = physical_index; ordered_index += 1; @@ -245,8 +236,7 @@ impl RunArray { // If there are input values >= run_ends.last_value then we'll not be able to convert // all logical indices to physical indices. if ordered_index < logical_indices.len() { - let logical_index = - logical_indices[ordered_indices[ordered_index]].as_usize(); + let logical_index = logical_indices[ordered_indices[ordered_index]].as_usize(); return Err(ArrowError::InvalidArgumentError(format!( "Cannot convert all logical indices to physical indices. The logical index cannot be converted is {logical_index}.", ))); @@ -704,8 +694,7 @@ mod tests { seed.shuffle(&mut rng); } // repeat the items between 1 and 8 times. Cap the length for smaller sized arrays - let num = - max_run_length.min(rand::thread_rng().gen_range(1..=max_run_length)); + let num = max_run_length.min(rand::thread_rng().gen_range(1..=max_run_length)); for _ in 0..num { result.push(seed[ix]); } @@ -749,19 +738,16 @@ mod tests { #[test] fn test_run_array() { // Construct a value array - let value_data = PrimitiveArray::::from_iter_values([ - 10_i8, 11, 12, 13, 14, 15, 16, 17, - ]); + let value_data = + PrimitiveArray::::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); // Construct a run_ends array: let run_ends_values = [4_i16, 6, 7, 9, 13, 18, 20, 22]; - let run_ends_data = PrimitiveArray::::from_iter_values( - run_ends_values.iter().copied(), - ); + let run_ends_data = + PrimitiveArray::::from_iter_values(run_ends_values.iter().copied()); // Construct a run ends encoded array from the above two - let ree_array = - RunArray::::try_new(&run_ends_data, &value_data).unwrap(); + let ree_array = RunArray::::try_new(&run_ends_data, &value_data).unwrap(); assert_eq!(ree_array.len(), 22); assert_eq!(ree_array.null_count(), 0); @@ -872,8 +858,7 @@ mod tests { let values: StringArray = [Some("foo"), Some("bar"), None, Some("baz")] .into_iter() .collect(); - let run_ends: Int32Array = - [Some(1), Some(2), Some(3), Some(4)].into_iter().collect(); + let run_ends: Int32Array = [Some(1), Some(2), Some(3), Some(4)].into_iter().collect(); let array = RunArray::::try_new(&run_ends, &values).unwrap(); assert_eq!(array.values().data_type(), &DataType::Utf8); @@ -924,7 +909,10 @@ mod tests { let run_ends: Int32Array = [Some(1), None, Some(3)].into_iter().collect(); let actual = RunArray::::try_new(&run_ends, &values); - let expected = ArrowError::InvalidArgumentError("Found null values in run_ends array. The run_ends array should not have null values.".to_string()); + let expected = ArrowError::InvalidArgumentError( + "Found null values in run_ends array. The run_ends array should not have null values." + .to_string(), + ); assert_eq!(expected.to_string(), actual.err().unwrap().to_string()); } @@ -1003,8 +991,7 @@ mod tests { let mut rng = thread_rng(); logical_indices.shuffle(&mut rng); - let physical_indices = - run_array.get_physical_indices(&logical_indices).unwrap(); + let physical_indices = run_array.get_physical_indices(&logical_indices).unwrap(); assert_eq!(logical_indices.len(), physical_indices.len()); diff --git a/arrow-array/src/array/string_array.rs b/arrow-array/src/array/string_array.rs index cac4651f4496..9d266e0ca4b8 100644 --- a/arrow-array/src/array/string_array.rs +++ b/arrow-array/src/array/string_array.rs @@ -59,9 +59,7 @@ impl GenericStringArray { /// Fallibly creates a [`GenericStringArray`] from a [`GenericBinaryArray`] returning /// an error if [`GenericBinaryArray`] contains invalid UTF-8 data - pub fn try_from_binary( - v: GenericBinaryArray, - ) -> Result { + pub fn try_from_binary(v: GenericBinaryArray) -> Result { let (offsets, values, nulls) = v.into_parts(); Self::try_new(offsets, values, nulls) } @@ -83,9 +81,7 @@ impl From> } } -impl From>> - for GenericStringArray -{ +impl From>> for GenericStringArray { fn from(v: Vec>) -> Self { v.into_iter().collect() } @@ -97,9 +93,7 @@ impl From> for GenericStringArray From>> - for GenericStringArray -{ +impl From>> for GenericStringArray { fn from(v: Vec>) -> Self { v.into_iter().collect() } @@ -438,13 +432,11 @@ mod tests { let expected: LargeStringArray = data.clone().into_iter().map(Some).collect(); // Iterator reports too many items - let arr = - LargeStringArray::from_iter_values(BadIterator::new(3, 10, data.clone())); + let arr = LargeStringArray::from_iter_values(BadIterator::new(3, 10, data.clone())); assert_eq!(expected, arr); // Iterator reports too few items - let arr = - LargeStringArray::from_iter_values(BadIterator::new(3, 1, data.clone())); + let arr = LargeStringArray::from_iter_values(BadIterator::new(3, 1, data.clone())); assert_eq!(expected, arr); } @@ -460,9 +452,11 @@ mod tests { let offsets = [0, 5, 8, 15].map(|n| O::from_usize(n).unwrap()); let null_buffer = Buffer::from_slice_ref([0b101]); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( - Field::new("item", DataType::UInt8, false), - )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + false, + ))); // [None, Some("Parquet")] let array_data = ArrayData::builder(data_type) @@ -493,9 +487,7 @@ mod tests { _test_generic_string_array_from_list_array::(); } - fn _test_generic_string_array_from_list_array_with_child_nulls_failed< - O: OffsetSizeTrait, - >() { + fn _test_generic_string_array_from_list_array_with_child_nulls_failed() { let values = b"HelloArrow"; let child_data = ArrayData::builder(DataType::UInt8) .len(10) @@ -508,9 +500,11 @@ mod tests { // It is possible to create a null struct containing a non-nullable child // see https://github.com/apache/arrow-rs/pull/3244 for details - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( - Field::new("item", DataType::UInt8, true), - )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + true, + ))); // [None, Some(b"Parquet")] let array_data = ArrayData::builder(data_type) @@ -544,9 +538,11 @@ mod tests { .unwrap(); let offsets = [0, 2, 3].map(|n| O::from_usize(n).unwrap()); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( - Field::new("item", DataType::UInt16, false), - )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt16, + false, + ))); let array_data = ArrayData::builder(data_type) .len(2) diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index 0e586ed1ef96..699da28cf7a3 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -462,9 +462,7 @@ impl Index<&str> for StructArray { mod tests { use super::*; - use crate::{ - BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray, - }; + use crate::{BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray}; use arrow_buffer::ToByteSlice; use std::sync::Arc; @@ -540,12 +538,10 @@ mod tests { None, Some("mark"), ])); - let ints: ArrayRef = - Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])); + let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])); let arr = - StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) - .unwrap(); + StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]).unwrap(); let struct_data = arr.into_data(); assert_eq!(4, struct_data.len()); @@ -578,13 +574,11 @@ mod tests { None, // 3 elements, not 4 ])); - let ints: ArrayRef = - Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])); + let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])); - let err = - StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) - .unwrap_err() - .to_string(); + let err = StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) + .unwrap_err() + .to_string(); assert_eq!( err, @@ -599,8 +593,7 @@ mod tests { fn test_struct_array_from_mismatched_types_single() { drop(StructArray::from(vec![( Arc::new(Field::new("b", DataType::Int16, false)), - Arc::new(BooleanArray::from(vec![false, false, true, true])) - as Arc, + Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc, )])); } @@ -612,8 +605,7 @@ mod tests { drop(StructArray::from(vec![ ( Arc::new(Field::new("b", DataType::Int16, false)), - Arc::new(BooleanArray::from(vec![false, false, true, true])) - as Arc, + Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc, ), ( Arc::new(Field::new("c", DataType::Utf8, false)), @@ -733,9 +725,7 @@ mod tests { } #[test] - #[should_panic( - expected = "Found unmasked nulls for non-nullable StructArray field \\\"c\\\"" - )] + #[should_panic(expected = "Found unmasked nulls for non-nullable StructArray field \\\"c\\\"")] fn test_struct_array_from_mismatched_nullability() { drop(StructArray::from(vec![( Arc::new(Field::new("c", DataType::Int32, false)), diff --git a/arrow-array/src/array/union_array.rs b/arrow-array/src/array/union_array.rs index 74a5f1efa767..94ac0bc879e4 100644 --- a/arrow-array/src/array/union_array.rs +++ b/arrow-array/src/array/union_array.rs @@ -179,8 +179,7 @@ impl UnionArray { if let Some(b) = &value_offsets { if ((type_ids.len()) * 4) != b.len() { return Err(ArrowError::InvalidArgumentError( - "Type Ids and Offsets represent a different number of array slots." - .to_string(), + "Type Ids and Offsets represent a different number of array slots.".to_string(), )); } } @@ -216,9 +215,8 @@ impl UnionArray { // Unsafe Justification: arguments were validated above (and // re-revalidated as part of data().validate() below) - let new_self = unsafe { - Self::new_unchecked(field_type_ids, type_ids, value_offsets, child_arrays) - }; + let new_self = + unsafe { Self::new_unchecked(field_type_ids, type_ids, value_offsets, child_arrays) }; new_self.to_data().validate()?; Ok(new_self) @@ -1059,7 +1057,13 @@ mod tests { let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1.0).unwrap(); let err = builder.append::("a", 1).unwrap_err().to_string(); - assert!(err.contains("Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"), "{}", err); + assert!( + err.contains( + "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32" + ), + "{}", + err + ); } #[test] diff --git a/arrow-array/src/builder/boolean_builder.rs b/arrow-array/src/builder/boolean_builder.rs index 5f0013269677..7e59d940a50e 100644 --- a/arrow-array/src/builder/boolean_builder.rs +++ b/arrow-array/src/builder/boolean_builder.rs @@ -127,11 +127,7 @@ impl BooleanBuilder { /// /// Returns an error if the slices are of different lengths #[inline] - pub fn append_values( - &mut self, - values: &[bool], - is_valid: &[bool], - ) -> Result<(), ArrowError> { + pub fn append_values(&mut self, values: &[bool], is_valid: &[bool]) -> Result<(), ArrowError> { if values.len() != is_valid.len() { Err(ArrowError::InvalidArgumentError( "Value and validity lengths must be equal".to_string(), @@ -250,8 +246,7 @@ mod tests { #[test] fn test_boolean_array_builder_append_slice() { - let arr1 = - BooleanArray::from(vec![Some(true), Some(false), None, None, Some(false)]); + let arr1 = BooleanArray::from(vec![Some(true), Some(false), None, None, Some(false)]); let mut builder = BooleanArray::builder(0); builder.append_slice(&[true, false]); diff --git a/arrow-array/src/builder/buffer_builder.rs b/arrow-array/src/builder/buffer_builder.rs index 01e4c1d4e217..2b66a8187fa9 100644 --- a/arrow-array/src/builder/buffer_builder.rs +++ b/arrow-array/src/builder/buffer_builder.rs @@ -45,11 +45,9 @@ pub type Float32BufferBuilder = BufferBuilder; pub type Float64BufferBuilder = BufferBuilder; /// Buffer builder for 128-bit decimal type. -pub type Decimal128BufferBuilder = - BufferBuilder<::Native>; +pub type Decimal128BufferBuilder = BufferBuilder<::Native>; /// Buffer builder for 256-bit decimal type. -pub type Decimal256BufferBuilder = - BufferBuilder<::Native>; +pub type Decimal256BufferBuilder = BufferBuilder<::Native>; /// Buffer builder for timestamp type of second unit. pub type TimestampSecondBufferBuilder = @@ -107,9 +105,7 @@ pub type DurationNanosecondBufferBuilder = #[cfg(test)] mod tests { - use crate::builder::{ - ArrayBuilder, Int32BufferBuilder, Int8Builder, UInt8BufferBuilder, - }; + use crate::builder::{ArrayBuilder, Int32BufferBuilder, Int8Builder, UInt8BufferBuilder}; use crate::Array; #[test] diff --git a/arrow-array/src/builder/fixed_size_binary_builder.rs b/arrow-array/src/builder/fixed_size_binary_builder.rs index 180150e988f3..0a50eb8a50e9 100644 --- a/arrow-array/src/builder/fixed_size_binary_builder.rs +++ b/arrow-array/src/builder/fixed_size_binary_builder.rs @@ -75,7 +75,8 @@ impl FixedSizeBinaryBuilder { pub fn append_value(&mut self, value: impl AsRef<[u8]>) -> Result<(), ArrowError> { if self.value_length != value.as_ref().len() as i32 { Err(ArrowError::InvalidArgumentError( - "Byte slice does not have the same length as FixedSizeBinaryBuilder value lengths".to_string() + "Byte slice does not have the same length as FixedSizeBinaryBuilder value lengths" + .to_string(), )) } else { self.values_builder.append_slice(value.as_ref()); @@ -95,11 +96,10 @@ impl FixedSizeBinaryBuilder { /// Builds the [`FixedSizeBinaryArray`] and reset this builder. pub fn finish(&mut self) -> FixedSizeBinaryArray { let array_length = self.len(); - let array_data_builder = - ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) - .add_buffer(self.values_builder.finish()) - .nulls(self.null_buffer_builder.finish()) - .len(array_length); + let array_data_builder = ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) + .add_buffer(self.values_builder.finish()) + .nulls(self.null_buffer_builder.finish()) + .len(array_length); let array_data = unsafe { array_data_builder.build_unchecked() }; FixedSizeBinaryArray::from(array_data) } @@ -108,11 +108,10 @@ impl FixedSizeBinaryBuilder { pub fn finish_cloned(&self) -> FixedSizeBinaryArray { let array_length = self.len(); let values_buffer = Buffer::from_slice_ref(self.values_builder.as_slice()); - let array_data_builder = - ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) - .add_buffer(values_buffer) - .nulls(self.null_buffer_builder.finish_cloned()) - .len(array_length); + let array_data_builder = ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) + .add_buffer(values_buffer) + .nulls(self.null_buffer_builder.finish_cloned()) + .len(array_length); let array_data = unsafe { array_data_builder.build_unchecked() }; FixedSizeBinaryArray::from(array_data) } diff --git a/arrow-array/src/builder/generic_byte_run_builder.rs b/arrow-array/src/builder/generic_byte_run_builder.rs index 41165208de55..3cde76c4a039 100644 --- a/arrow-array/src/builder/generic_byte_run_builder.rs +++ b/arrow-array/src/builder/generic_byte_run_builder.rs @@ -19,10 +19,7 @@ use crate::types::bytes::ByteArrayNativeType; use std::{any::Any, sync::Arc}; use crate::{ - types::{ - BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, RunEndIndexType, - Utf8Type, - }, + types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, RunEndIndexType, Utf8Type}, ArrayRef, ArrowPrimitiveType, RunArray, }; @@ -112,10 +109,7 @@ where pub fn with_capacity(capacity: usize, data_capacity: usize) -> Self { Self { run_ends_builder: PrimitiveBuilder::with_capacity(capacity), - values_builder: GenericByteBuilder::::with_capacity( - capacity, - data_capacity, - ), + values_builder: GenericByteBuilder::::with_capacity(capacity, data_capacity), current_value: Vec::new(), has_current_value: false, current_run_end_index: 0, @@ -282,12 +276,13 @@ where } fn run_end_index_as_native(&self) -> R::Native { - R::Native::from_usize(self.current_run_end_index) - .unwrap_or_else(|| panic!( + R::Native::from_usize(self.current_run_end_index).unwrap_or_else(|| { + panic!( "Cannot convert the value {} from `usize` to native form of arrow datatype {}", self.current_run_end_index, R::DATA_TYPE - )) + ) + }) } } @@ -413,8 +408,7 @@ mod tests { // Values are polymorphic and so require a downcast. let av = array.values(); - let ava: &GenericByteArray = - av.as_any().downcast_ref::>().unwrap(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); assert_eq!(*ava.value(0), *values[0]); assert!(ava.is_null(1)); @@ -459,8 +453,7 @@ mod tests { // Values are polymorphic and so require a downcast. let av = array.values(); - let ava: &GenericByteArray = - av.as_any().downcast_ref::>().unwrap(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); assert_eq!(ava.value(0), values[0]); assert!(ava.is_null(1)); diff --git a/arrow-array/src/builder/generic_bytes_builder.rs b/arrow-array/src/builder/generic_bytes_builder.rs index d84be8c2fca6..2c7ee7a3e448 100644 --- a/arrow-array/src/builder/generic_bytes_builder.rs +++ b/arrow-array/src/builder/generic_bytes_builder.rs @@ -68,12 +68,8 @@ impl GenericByteBuilder { let value_builder = BufferBuilder::::new_from_buffer(value_buffer); let null_buffer_builder = null_buffer - .map(|buffer| { - NullBufferBuilder::new_from_buffer(buffer, offsets_builder.len() - 1) - }) - .unwrap_or_else(|| { - NullBufferBuilder::new_with_len(offsets_builder.len() - 1) - }); + .map(|buffer| NullBufferBuilder::new_from_buffer(buffer, offsets_builder.len() - 1)) + .unwrap_or_else(|| NullBufferBuilder::new_with_len(offsets_builder.len() - 1)); Self { offsets_builder, @@ -84,8 +80,7 @@ impl GenericByteBuilder { #[inline] fn next_offset(&self) -> T::Offset { - T::Offset::from_usize(self.value_builder.len()) - .expect("byte array offset overflow") + T::Offset::from_usize(self.value_builder.len()).expect("byte array offset overflow") } /// Appends a value into the builder. diff --git a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs index 282f423fa6d1..b0c722ae7cda 100644 --- a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs +++ b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs @@ -16,9 +16,7 @@ // under the License. use crate::builder::{ArrayBuilder, GenericByteBuilder, PrimitiveBuilder}; -use crate::types::{ - ArrowDictionaryKeyType, ByteArrayType, GenericBinaryType, GenericStringType, -}; +use crate::types::{ArrowDictionaryKeyType, ByteArrayType, GenericBinaryType, GenericStringType}; use crate::{Array, ArrayRef, DictionaryArray, GenericByteArray}; use arrow_buffer::ArrowNativeType; use arrow_schema::{ArrowError, DataType}; @@ -91,10 +89,7 @@ where state: Default::default(), dedup: Default::default(), keys_builder: PrimitiveBuilder::with_capacity(keys_capacity), - values_builder: GenericByteBuilder::::with_capacity( - value_capacity, - data_capacity, - ), + values_builder: GenericByteBuilder::::with_capacity(value_capacity, data_capacity), } } @@ -131,8 +126,7 @@ where let mut dedup = HashMap::with_capacity_and_hasher(dict_len, ()); let values_len = dictionary_values.value_data().len(); - let mut values_builder = - GenericByteBuilder::::with_capacity(dict_len, values_len); + let mut values_builder = GenericByteBuilder::::with_capacity(dict_len, values_len); K::Native::from_usize(dictionary_values.len()) .ok_or(ArrowError::DictionaryKeyOverflowError)?; @@ -214,10 +208,7 @@ where /// value is appended to the values array. /// /// Returns an error if the new index would overflow the key type. - pub fn append( - &mut self, - value: impl AsRef, - ) -> Result { + pub fn append(&mut self, value: impl AsRef) -> Result { let value_native: &T::Native = value.as_ref(); let value_bytes: &[u8] = value_native.as_ref(); @@ -240,8 +231,7 @@ where state.hash_one(get_bytes(storage, *idx)) }); - K::Native::from_usize(idx) - .ok_or(ArrowError::DictionaryKeyOverflowError)? + K::Native::from_usize(idx).ok_or(ArrowError::DictionaryKeyOverflowError)? } }; self.keys_builder.append_value(key); @@ -283,8 +273,7 @@ where let values = self.values_builder.finish(); let keys = self.keys_builder.finish(); - let data_type = - DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); let builder = keys .into_data() @@ -300,8 +289,7 @@ where let values = self.values_builder.finish_cloned(); let keys = self.keys_builder.finish_cloned(); - let data_type = - DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); let builder = keys .into_data() @@ -367,12 +355,10 @@ fn get_bytes(values: &GenericByteBuilder, idx: usize) -> &[ /// assert_eq!(ava.value(1), "def"); /// /// ``` -pub type StringDictionaryBuilder = - GenericByteDictionaryBuilder>; +pub type StringDictionaryBuilder = GenericByteDictionaryBuilder>; /// Builder for [`DictionaryArray`] of [`LargeStringArray`](crate::array::LargeStringArray) -pub type LargeStringDictionaryBuilder = - GenericByteDictionaryBuilder>; +pub type LargeStringDictionaryBuilder = GenericByteDictionaryBuilder>; /// Builder for [`DictionaryArray`] of [`BinaryArray`](crate::array::BinaryArray) /// @@ -407,12 +393,10 @@ pub type LargeStringDictionaryBuilder = /// assert_eq!(ava.value(1), b"def"); /// /// ``` -pub type BinaryDictionaryBuilder = - GenericByteDictionaryBuilder>; +pub type BinaryDictionaryBuilder = GenericByteDictionaryBuilder>; /// Builder for [`DictionaryArray`] of [`LargeBinaryArray`](crate::array::LargeBinaryArray) -pub type LargeBinaryDictionaryBuilder = - GenericByteDictionaryBuilder>; +pub type LargeBinaryDictionaryBuilder = GenericByteDictionaryBuilder>; #[cfg(test)] mod tests { @@ -444,8 +428,7 @@ mod tests { // Values are polymorphic and so require a downcast. let av = array.values(); - let ava: &GenericByteArray = - av.as_any().downcast_ref::>().unwrap(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); assert_eq!(*ava.value(0), *values[0]); assert_eq!(*ava.value(1), *values[1]); @@ -483,8 +466,7 @@ mod tests { // Values are polymorphic and so require a downcast. let av = array.values(); - let ava: &GenericByteArray = - av.as_any().downcast_ref::>().unwrap(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); assert_eq!(ava.value(0), values[0]); assert_eq!(ava.value(1), values[1]); @@ -542,11 +524,8 @@ mod tests { ::Native: AsRef<::Native>, { let mut builder = - GenericByteDictionaryBuilder::::new_with_dictionary( - 6, - &dictionary, - ) - .unwrap(); + GenericByteDictionaryBuilder::::new_with_dictionary(6, &dictionary) + .unwrap(); builder.append(values[0]).unwrap(); builder.append_null(); builder.append(values[1]).unwrap(); @@ -562,8 +541,7 @@ mod tests { // Values are polymorphic and so require a downcast. let av = array.values(); - let ava: &GenericByteArray = - av.as_any().downcast_ref::>().unwrap(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); assert!(!ava.is_valid(0)); assert_eq!(ava.value(1), values[1]); @@ -597,11 +575,8 @@ mod tests { ::Native: AsRef<::Native>, { let mut builder = - GenericByteDictionaryBuilder::::new_with_dictionary( - 4, - &dictionary, - ) - .unwrap(); + GenericByteDictionaryBuilder::::new_with_dictionary(4, &dictionary) + .unwrap(); builder.append(values[0]).unwrap(); builder.append_null(); builder.append(values[1]).unwrap(); diff --git a/arrow-array/src/builder/map_builder.rs b/arrow-array/src/builder/map_builder.rs index 4e3ec4a7944d..3a5244ed81a0 100644 --- a/arrow-array/src/builder/map_builder.rs +++ b/arrow-array/src/builder/map_builder.rs @@ -86,11 +86,7 @@ impl Default for MapFieldNames { impl MapBuilder { /// Creates a new `MapBuilder` - pub fn new( - field_names: Option, - key_builder: K, - value_builder: V, - ) -> Self { + pub fn new(field_names: Option, key_builder: K, value_builder: V) -> Self { let capacity = key_builder.len(); Self::with_capacity(field_names, key_builder, value_builder, capacity) } @@ -243,12 +239,9 @@ mod tests { use super::*; #[test] - #[should_panic( - expected = "Keys array must have no null values, found 1 null value(s)" - )] + #[should_panic(expected = "Keys array must have no null values, found 1 null value(s)")] fn test_map_builder_with_null_keys_panics() { - let mut builder = - MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); builder.keys().append_null(); builder.values().append_value(42); builder.append(true).unwrap(); diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs index b23d6bba36c4..0aad2dbfce0e 100644 --- a/arrow-array/src/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -161,9 +161,7 @@ impl PrimitiveBuilder { let values_builder = BufferBuilder::::new_from_buffer(values_buffer); let null_buffer_builder = null_buffer - .map(|buffer| { - NullBufferBuilder::new_from_buffer(buffer, values_builder.len()) - }) + .map(|buffer| NullBufferBuilder::new_from_buffer(buffer, values_builder.len())) .unwrap_or_else(|| NullBufferBuilder::new_with_len(values_builder.len())); Self { @@ -256,10 +254,7 @@ impl PrimitiveBuilder { /// This requires the iterator be a trusted length. This could instead require /// the iterator implement `TrustedLen` once that is stabilized. #[inline] - pub unsafe fn append_trusted_len_iter( - &mut self, - iter: impl IntoIterator, - ) { + pub unsafe fn append_trusted_len_iter(&mut self, iter: impl IntoIterator) { let iter = iter.into_iter(); let len = iter .size_hint() @@ -328,11 +323,7 @@ impl PrimitiveBuilder { impl PrimitiveBuilder

{ /// Sets the precision and scale - pub fn with_precision_and_scale( - self, - precision: u8, - scale: i8, - ) -> Result { + pub fn with_precision_and_scale(self, precision: u8, scale: i8) -> Result { validate_decimal_precision_and_scale::

(precision, scale)?; Ok(Self { data_type: P::TYPE_CONSTRUCTOR(precision, scale), @@ -592,25 +583,21 @@ mod tests { #[test] fn test_primitive_array_builder_with_data_type() { - let mut builder = - Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2)); + let mut builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2)); builder.append_value(1); let array = builder.finish(); assert_eq!(array.precision(), 1); assert_eq!(array.scale(), 2); let data_type = DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())); - let mut builder = - TimestampNanosecondBuilder::new().with_data_type(data_type.clone()); + let mut builder = TimestampNanosecondBuilder::new().with_data_type(data_type.clone()); builder.append_value(1); let array = builder.finish(); assert_eq!(array.data_type(), &data_type); } #[test] - #[should_panic( - expected = "incompatible data type for builder, expected Int32 got Int64" - )] + #[should_panic(expected = "incompatible data type for builder, expected Int32 got Int64")] fn test_invalid_with_data_type() { Int32Builder::new().with_data_type(DataType::Int64); } diff --git a/arrow-array/src/builder/primitive_dictionary_builder.rs b/arrow-array/src/builder/primitive_dictionary_builder.rs index 7323ee57627d..a47b2d30d4f3 100644 --- a/arrow-array/src/builder/primitive_dictionary_builder.rs +++ b/arrow-array/src/builder/primitive_dictionary_builder.rs @@ -221,8 +221,7 @@ where let key = self.values_builder.len(); self.values_builder.append_value(value); vacant.insert(key); - K::Native::from_usize(key) - .ok_or(ArrowError::DictionaryKeyOverflowError)? + K::Native::from_usize(key).ok_or(ArrowError::DictionaryKeyOverflowError)? } Entry::Occupied(o) => K::Native::usize_as(*o.get()), }; @@ -266,10 +265,8 @@ where let values = self.values_builder.finish(); let keys = self.keys_builder.finish(); - let data_type = DataType::Dictionary( - Box::new(K::DATA_TYPE), - Box::new(values.data_type().clone()), - ); + let data_type = + DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(values.data_type().clone())); let builder = keys .into_data() @@ -285,8 +282,7 @@ where let values = self.values_builder.finish_cloned(); let keys = self.keys_builder.finish_cloned(); - let data_type = - DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(V::DATA_TYPE)); + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(V::DATA_TYPE)); let builder = keys .into_data() @@ -331,8 +327,7 @@ mod tests { #[test] fn test_primitive_dictionary_builder() { - let mut builder = - PrimitiveDictionaryBuilder::::with_capacity(3, 2); + let mut builder = PrimitiveDictionaryBuilder::::with_capacity(3, 2); builder.append(12345678).unwrap(); builder.append_null(); builder.append(22345678).unwrap(); @@ -384,8 +379,7 @@ mod tests { #[test] fn test_primitive_dictionary_with_builders() { let keys_builder = PrimitiveBuilder::::new(); - let values_builder = - Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2)); + let values_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2)); let mut builder = PrimitiveDictionaryBuilder::::new_from_empty_builders( keys_builder, diff --git a/arrow-array/src/builder/struct_builder.rs b/arrow-array/src/builder/struct_builder.rs index 7aa91dacaa8c..0f40b8a487ae 100644 --- a/arrow-array/src/builder/struct_builder.rs +++ b/arrow-array/src/builder/struct_builder.rs @@ -106,24 +106,18 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box Box::new(Float32Builder::with_capacity(capacity)), DataType::Float64 => Box::new(Float64Builder::with_capacity(capacity)), DataType::Binary => Box::new(BinaryBuilder::with_capacity(capacity, 1024)), - DataType::LargeBinary => { - Box::new(LargeBinaryBuilder::with_capacity(capacity, 1024)) - } + DataType::LargeBinary => Box::new(LargeBinaryBuilder::with_capacity(capacity, 1024)), DataType::FixedSizeBinary(len) => { Box::new(FixedSizeBinaryBuilder::with_capacity(capacity, *len)) } DataType::Decimal128(p, s) => Box::new( - Decimal128Builder::with_capacity(capacity) - .with_data_type(DataType::Decimal128(*p, *s)), + Decimal128Builder::with_capacity(capacity).with_data_type(DataType::Decimal128(*p, *s)), ), DataType::Decimal256(p, s) => Box::new( - Decimal256Builder::with_capacity(capacity) - .with_data_type(DataType::Decimal256(*p, *s)), + Decimal256Builder::with_capacity(capacity).with_data_type(DataType::Decimal256(*p, *s)), ), DataType::Utf8 => Box::new(StringBuilder::with_capacity(capacity, 1024)), - DataType::LargeUtf8 => { - Box::new(LargeStringBuilder::with_capacity(capacity, 1024)) - } + DataType::LargeUtf8 => Box::new(LargeStringBuilder::with_capacity(capacity, 1024)), DataType::Date32 => Box::new(Date32Builder::with_capacity(capacity)), DataType::Date64 => Box::new(Date64Builder::with_capacity(capacity)), DataType::Time32(TimeUnit::Second) => { @@ -175,19 +169,14 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box { Box::new(DurationNanosecondBuilder::with_capacity(capacity)) } - DataType::Struct(fields) => { - Box::new(StructBuilder::from_fields(fields.clone(), capacity)) - } + DataType::Struct(fields) => Box::new(StructBuilder::from_fields(fields.clone(), capacity)), t => panic!("Data type {t:?} is not currently supported"), } } impl StructBuilder { /// Creates a new `StructBuilder` - pub fn new( - fields: impl Into, - field_builders: Vec>, - ) -> Self { + pub fn new(fields: impl Into, field_builders: Vec>) -> Self { Self { field_builders, fields: fields.into(), @@ -234,10 +223,7 @@ impl StructBuilder { pub fn finish(&mut self) -> StructArray { self.validate_content(); if self.fields.is_empty() { - return StructArray::new_empty_fields( - self.len(), - self.null_buffer_builder.finish(), - ); + return StructArray::new_empty_fields(self.len(), self.null_buffer_builder.finish()); } let arrays = self.field_builders.iter_mut().map(|f| f.finish()).collect(); @@ -524,8 +510,7 @@ mod tests { expected = "Data type List(Field { name: \"item\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) is not currently supported" )] fn test_struct_array_builder_from_schema_unsupported_type() { - let list_type = - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))); + let list_type = DataType::List(Arc::new(Field::new("item", DataType::Int64, true))); let fields = vec![ Field::new("f1", DataType::Int16, false), Field::new("f2", list_type, false), @@ -571,9 +556,7 @@ mod tests { } #[test] - #[should_panic( - expected = "Number of fields is not equal to the number of field_builders." - )] + #[should_panic(expected = "Number of fields is not equal to the number of field_builders.")] fn test_struct_array_builder_unequal_field_field_builders() { let int_builder = Int32Builder::with_capacity(10); diff --git a/arrow-array/src/builder/union_builder.rs b/arrow-array/src/builder/union_builder.rs index f74afb2aa9aa..4f88c9d41b9a 100644 --- a/arrow-array/src/builder/union_builder.rs +++ b/arrow-array/src/builder/union_builder.rs @@ -65,11 +65,7 @@ impl FieldDataValues for BufferBuilder { impl FieldData { /// Creates a new `FieldData`. - fn new( - type_id: i8, - data_type: DataType, - capacity: usize, - ) -> Self { + fn new(type_id: i8, data_type: DataType, capacity: usize) -> Self { Self { type_id, data_type, @@ -222,7 +218,12 @@ impl UnionBuilder { let mut field_data = match self.fields.remove(&type_name) { Some(data) => { if data.data_type != T::DATA_TYPE { - return Err(ArrowError::InvalidArgumentError(format!("Attempt to write col \"{}\" with type {} doesn't match existing type {}", type_name, T::DATA_TYPE, data.data_type))); + return Err(ArrowError::InvalidArgumentError(format!( + "Attempt to write col \"{}\" with type {} doesn't match existing type {}", + type_name, + T::DATA_TYPE, + data.data_type + ))); } data } diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index b6cda44e8973..2e21f3e7e640 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -578,9 +578,7 @@ macro_rules! downcast_run_array { /// Force downcast of an [`Array`], such as an [`ArrayRef`] to /// [`GenericListArray`], panicking on failure. -pub fn as_generic_list_array( - arr: &dyn Array, -) -> &GenericListArray { +pub fn as_generic_list_array(arr: &dyn Array) -> &GenericListArray { arr.as_any() .downcast_ref::>() .expect("Unable to downcast to list array") @@ -612,9 +610,7 @@ pub fn as_large_list_array(arr: &dyn Array) -> &LargeListArray { /// Force downcast of an [`Array`], such as an [`ArrayRef`] to /// [`GenericBinaryArray`], panicking on failure. #[inline] -pub fn as_generic_binary_array( - arr: &dyn Array, -) -> &GenericBinaryArray { +pub fn as_generic_binary_array(arr: &dyn Array) -> &GenericBinaryArray { arr.as_any() .downcast_ref::>() .expect("Unable to downcast to binary array") @@ -826,8 +822,7 @@ pub trait AsArray: private::Sealed { } /// Downcast this to a [`DictionaryArray`] returning `None` if not possible - fn as_dictionary_opt(&self) - -> Option<&DictionaryArray>; + fn as_dictionary_opt(&self) -> Option<&DictionaryArray>; /// Downcast this to a [`DictionaryArray`] panicking if not possible fn as_dictionary(&self) -> &DictionaryArray { @@ -877,9 +872,7 @@ impl AsArray for dyn Array + '_ { self.as_any().downcast_ref() } - fn as_dictionary_opt( - &self, - ) -> Option<&DictionaryArray> { + fn as_dictionary_opt(&self) -> Option<&DictionaryArray> { self.as_any().downcast_ref() } @@ -926,9 +919,7 @@ impl AsArray for ArrayRef { self.as_any().downcast_ref() } - fn as_dictionary_opt( - &self, - ) -> Option<&DictionaryArray> { + fn as_dictionary_opt(&self) -> Option<&DictionaryArray> { self.as_ref().as_dictionary_opt() } @@ -972,9 +963,7 @@ mod tests { #[test] fn test_decimal256array() { - let a = Decimal256Array::from_iter_values( - [1, 2, 4, 5].into_iter().map(i256::from_i128), - ); + let a = Decimal256Array::from_iter_values([1, 2, 4, 5].into_iter().map(i256::from_i128)); assert!(!as_primitive_array::(&a).is_empty()); } } diff --git a/arrow-array/src/delta.rs b/arrow-array/src/delta.rs index bf9ee5ca685f..d9aa4aa6de5d 100644 --- a/arrow-array/src/delta.rs +++ b/arrow-array/src/delta.rs @@ -55,10 +55,7 @@ pub(crate) fn add_months_datetime( /// Add the given number of days to the given datetime. /// /// Returns `None` when it will result in overflow. -pub(crate) fn add_days_datetime( - dt: DateTime, - days: i32, -) -> Option> { +pub(crate) fn add_days_datetime(dt: DateTime, days: i32) -> Option> { match days.cmp(&0) { Ordering::Equal => Some(dt), Ordering::Greater => dt.checked_add_days(Days::new(days as u64)), @@ -83,10 +80,7 @@ pub(crate) fn sub_months_datetime( /// Substract the given number of days to the given datetime. /// /// Returns `None` when it will result in overflow. -pub(crate) fn sub_days_datetime( - dt: DateTime, - days: i32, -) -> Option> { +pub(crate) fn sub_days_datetime(dt: DateTime, days: i32) -> Option> { match days.cmp(&0) { Ordering::Equal => Some(dt), Ordering::Greater => dt.checked_sub_days(Days::new(days as u64)), diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index a198332ca5b5..3f9cc0d525c1 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -18,8 +18,8 @@ //! Idiomatic iterators for [`Array`](crate::Array) use crate::array::{ - ArrayAccessor, BooleanArray, FixedSizeBinaryArray, GenericBinaryArray, - GenericListArray, GenericStringArray, PrimitiveArray, + ArrayAccessor, BooleanArray, FixedSizeBinaryArray, GenericBinaryArray, GenericListArray, + GenericStringArray, PrimitiveArray, }; use crate::{FixedSizeListArray, MapArray}; use arrow_buffer::NullBuffer; @@ -187,8 +187,7 @@ mod tests { #[test] fn test_string_array_iter_round_trip() { - let array = - StringArray::from(vec![Some("a"), None, Some("aaa"), None, Some("aaaaa")]); + let array = StringArray::from(vec![Some("a"), None, Some("aaa"), None, Some("aaaaa")]); let array = Arc::new(array) as ArrayRef; let array = array.as_any().downcast_ref::().unwrap(); @@ -211,8 +210,7 @@ mod tests { // check if DoubleEndedIterator is implemented let result: StringArray = array.iter().rev().collect(); - let rev_array = - StringArray::from(vec![Some("aaaaa"), None, Some("aaa"), None, Some("a")]); + let rev_array = StringArray::from(vec![Some("aaaaa"), None, Some("aaa"), None, Some("a")]); assert_eq!(result, rev_array); // check if ExactSizeIterator is implemented let _ = array.iter().rposition(|opt_b| opt_b == Some("a")); diff --git a/arrow-array/src/lib.rs b/arrow-array/src/lib.rs index afb7ec5e6e44..ef98c5efefb0 100644 --- a/arrow-array/src/lib.rs +++ b/arrow-array/src/lib.rs @@ -182,8 +182,7 @@ pub use array::*; mod record_batch; pub use record_batch::{ - RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, - RecordBatchWriter, + RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, RecordBatchWriter, }; mod arithmetic; diff --git a/arrow-array/src/numeric.rs b/arrow-array/src/numeric.rs index afc0e2c33010..ad7b3eca1dbc 100644 --- a/arrow-array/src/numeric.rs +++ b/arrow-array/src/numeric.rs @@ -179,8 +179,8 @@ macro_rules! make_numeric_type { 16 => { // same general logic as for 8 lanes, extended to 16 bits let vecidx = i32x16::new( - 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, - 8192, 16384, 32768, + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, + 32768, ); let vecmask = i32x16::splat((mask & 0xFFFF) as i32); @@ -194,21 +194,19 @@ macro_rules! make_numeric_type { let tmp = &mut [0_i16; 32]; let vecidx = i32x16::new( - 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, - 8192, 16384, 32768, + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, + 32768, ); let vecmask = i32x16::splat((mask & 0xFFFF) as i32); let vecmask = (vecidx & vecmask).eq(vecidx); - i16x16::from_cast(vecmask) - .write_to_slice_unaligned(&mut tmp[0..16]); + i16x16::from_cast(vecmask).write_to_slice_unaligned(&mut tmp[0..16]); let vecmask = i32x16::splat(((mask >> 16) & 0xFFFF) as i32); let vecmask = (vecidx & vecmask).eq(vecidx); - i16x16::from_cast(vecmask) - .write_to_slice_unaligned(&mut tmp[16..32]); + i16x16::from_cast(vecmask).write_to_slice_unaligned(&mut tmp[16..32]); unsafe { std::mem::transmute(i16x32::from_slice_unaligned(tmp)) } } @@ -218,33 +216,29 @@ macro_rules! make_numeric_type { let tmp = &mut [0_i8; 64]; let vecidx = i32x16::new( - 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, - 8192, 16384, 32768, + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, + 32768, ); let vecmask = i32x16::splat((mask & 0xFFFF) as i32); let vecmask = (vecidx & vecmask).eq(vecidx); - i8x16::from_cast(vecmask) - .write_to_slice_unaligned(&mut tmp[0..16]); + i8x16::from_cast(vecmask).write_to_slice_unaligned(&mut tmp[0..16]); let vecmask = i32x16::splat(((mask >> 16) & 0xFFFF) as i32); let vecmask = (vecidx & vecmask).eq(vecidx); - i8x16::from_cast(vecmask) - .write_to_slice_unaligned(&mut tmp[16..32]); + i8x16::from_cast(vecmask).write_to_slice_unaligned(&mut tmp[16..32]); let vecmask = i32x16::splat(((mask >> 32) & 0xFFFF) as i32); let vecmask = (vecidx & vecmask).eq(vecidx); - i8x16::from_cast(vecmask) - .write_to_slice_unaligned(&mut tmp[32..48]); + i8x16::from_cast(vecmask).write_to_slice_unaligned(&mut tmp[32..48]); let vecmask = i32x16::splat(((mask >> 48) & 0xFFFF) as i32); let vecmask = (vecidx & vecmask).eq(vecidx); - i8x16::from_cast(vecmask) - .write_to_slice_unaligned(&mut tmp[48..64]); + i8x16::from_cast(vecmask).write_to_slice_unaligned(&mut tmp[48..64]); unsafe { std::mem::transmute(i8x64::from_slice_unaligned(tmp)) } } @@ -269,11 +263,7 @@ macro_rules! make_numeric_type { /// Selects elements of `a` and `b` using `mask` #[inline] - fn mask_select( - mask: Self::SimdMask, - a: Self::Simd, - b: Self::Simd, - ) -> Self::Simd { + fn mask_select(mask: Self::SimdMask, a: Self::Simd, b: Self::Simd) -> Self::Simd { mask.select(a, b) } @@ -327,10 +317,7 @@ macro_rules! make_numeric_type { } #[inline] - fn unary_op Self::Simd>( - a: Self::Simd, - op: F, - ) -> Self::Simd { + fn unary_op Self::Simd>(a: Self::Simd, op: F) -> Self::Simd { op(a) } } @@ -581,8 +568,7 @@ mod tests { let mask = 0b1101; let actual = IntervalMonthDayNanoType::mask_from_u64(mask); let expected = expected_mask!(i128, mask); - let expected = - m128x4::from_cast(i128x4::from_slice_unaligned(expected.as_slice())); + let expected = m128x4::from_cast(i128x4::from_slice_unaligned(expected.as_slice())); assert_eq!(expected, actual); } @@ -612,8 +598,7 @@ mod tests { let mask = 0b10101010_10101010; let actual = Float32Type::mask_from_u64(mask); let expected = expected_mask!(i32, mask); - let expected = - m32x16::from_cast(i32x16::from_slice_unaligned(expected.as_slice())); + let expected = m32x16::from_cast(i32x16::from_slice_unaligned(expected.as_slice())); assert_eq!(expected, actual); } @@ -623,8 +608,7 @@ mod tests { let mask = 0b01010101_01010101; let actual = Int32Type::mask_from_u64(mask); let expected = expected_mask!(i32, mask); - let expected = - m32x16::from_cast(i32x16::from_slice_unaligned(expected.as_slice())); + let expected = m32x16::from_cast(i32x16::from_slice_unaligned(expected.as_slice())); assert_eq!(expected, actual); } @@ -635,16 +619,14 @@ mod tests { let actual = UInt16Type::mask_from_u64(mask); let expected = expected_mask!(i16, mask); dbg!(&expected); - let expected = - m16x32::from_cast(i16x32::from_slice_unaligned(expected.as_slice())); + let expected = m16x32::from_cast(i16x32::from_slice_unaligned(expected.as_slice())); assert_eq!(expected, actual); } #[test] fn test_mask_i8() { - let mask = - 0b01010101_01010101_10101010_10101010_01010101_01010101_10101010_10101010; + let mask = 0b01010101_01010101_10101010_10101010_01010101_01010101_10101010_10101010; let actual = Int8Type::mask_from_u64(mask); let expected = expected_mask!(i8, mask); let expected = m8x64::from_cast(i8x64::from_slice_unaligned(expected.as_slice())); diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index 27804447fba6..1f3e1df847a8 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -107,10 +107,7 @@ impl RecordBatch { /// vec![Arc::new(id_array)] /// ).unwrap(); /// ``` - pub fn try_new( - schema: SchemaRef, - columns: Vec, - ) -> Result { + pub fn try_new(schema: SchemaRef, columns: Vec) -> Result { let options = RecordBatchOptions::new(); Self::try_new_impl(schema, columns, &options) } @@ -179,9 +176,7 @@ impl RecordBatch { // check that all columns have the same row count if columns.iter().any(|c| c.len() != row_count) { let err = match options.row_count { - Some(_) => { - "all columns in a record batch must have the specified row count" - } + Some(_) => "all columns in a record batch must have the specified row count", None => "all columns in a record batch must have the same length", }; return Err(ArrowError::InvalidArgumentError(err.to_string())); @@ -190,9 +185,7 @@ impl RecordBatch { // function for comparing column type and field type // return true if 2 types are not matched let type_not_match = if options.match_field_names { - |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { - col_type != field_type - } + |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type } else { |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { !col_type.equals_datatype(field_type) @@ -484,7 +477,11 @@ impl From for RecordBatch { fn from(value: StructArray) -> Self { let row_count = value.len(); let (fields, columns, nulls) = value.into_parts(); - assert_eq!(nulls.map(|n| n.null_count()).unwrap_or_default(), 0, "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"); + assert_eq!( + nulls.map(|n| n.null_count()).unwrap_or_default(), + 0, + "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation" + ); RecordBatch { schema: Arc::new(Schema::new(fields)), @@ -588,9 +585,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{ - BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, - }; + use crate::{BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray}; use arrow_buffer::{Buffer, ToByteSlice}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::Fields; @@ -606,8 +601,7 @@ mod tests { let b = StringArray::from(vec!["a", "b", "c", "d", "e"]); let record_batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); check_batch(record_batch, 5) } @@ -622,8 +616,7 @@ mod tests { let b = StringArray::from(vec!["a", "b", "c", "d", "e"]); let record_batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); assert_eq!(record_batch.get_array_memory_size(), 364); } @@ -649,8 +642,7 @@ mod tests { let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]); let record_batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); let offset = 2; let length = 5; @@ -699,8 +691,8 @@ mod tests { ])); let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); - let record_batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]) - .expect("valid conversion"); + let record_batch = + RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion"); let expected_schema = Schema::new(vec![ Field::new("a", DataType::Int32, true), @@ -716,11 +708,9 @@ mod tests { let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); // Note there are no nulls in a or b, but we specify that b is nullable - let record_batch = RecordBatch::try_from_iter_with_nullable(vec![ - ("a", a, false), - ("b", b, true), - ]) - .expect("valid conversion"); + let record_batch = + RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)]) + .expect("valid conversion"); let expected_schema = Schema::new(vec![ Field::new("a", DataType::Int32, false), @@ -792,8 +782,7 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let b = Int32Array::from(vec![1, 2, 3, 4, 5]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]); assert!(batch.is_err()); } @@ -863,11 +852,8 @@ mod tests { Field::new("id", DataType::Int32, false), Field::new("val", DataType::Int32, false), ]); - let record_batch = RecordBatch::try_new( - Arc::new(schema1), - vec![id_arr.clone(), val_arr.clone()], - ) - .unwrap(); + let record_batch = + RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap(); assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref()); assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref()); @@ -1005,15 +991,12 @@ mod tests { let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); - let record_batch = RecordBatch::try_from_iter(vec![ - ("a", a.clone()), - ("b", b.clone()), - ("c", c.clone()), - ]) - .expect("valid conversion"); + let record_batch = + RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())]) + .expect("valid conversion"); - let expected = RecordBatch::try_from_iter(vec![("a", a), ("c", c)]) - .expect("valid conversion"); + let expected = + RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion"); assert_eq!(expected, record_batch.project(&[0, 2]).unwrap()); } @@ -1049,8 +1032,7 @@ mod tests { let options = RecordBatchOptions::new().with_row_count(Some(10)); - let ok = - RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); + let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); assert_eq!(ok.num_rows(), 10); let a = ok.slice(2, 5); diff --git a/arrow-array/src/run_iterator.rs b/arrow-array/src/run_iterator.rs index 489aabf4756a..7a98fccb73b5 100644 --- a/arrow-array/src/run_iterator.rs +++ b/arrow-array/src/run_iterator.rs @@ -86,8 +86,7 @@ where // If current logical index is greater than current run end index then increment // the physical index. let run_ends = self.array.run_ends().values(); - if self.current_front_logical >= run_ends[self.current_front_physical].as_usize() - { + if self.current_front_logical >= run_ends[self.current_front_physical].as_usize() { // As the run_ends is expected to be strictly increasing, there // should be at least one logical entry in one physical entry. Because of this // reason the next value can be accessed by incrementing physical index once. @@ -136,8 +135,7 @@ where let run_ends = self.array.run_ends().values(); if self.current_back_physical > 0 - && self.current_back_logical - < run_ends[self.current_back_physical - 1].as_usize() + && self.current_back_logical < run_ends[self.current_back_physical - 1].as_usize() { // As the run_ends is expected to be strictly increasing, there // should be at least one logical entry in one physical entry. Because of this @@ -211,8 +209,7 @@ mod tests { seed.shuffle(&mut rng); } // repeat the items between 1 and 8 times. Cap the length for smaller sized arrays - let num = - max_run_length.min(rand::thread_rng().gen_range(1..=max_run_length)); + let num = max_run_length.min(rand::thread_rng().gen_range(1..=max_run_length)); for _ in 0..num { result.push(seed[ix]); } @@ -285,8 +282,7 @@ mod tests { for logical_len in logical_lengths { let input_array = build_input_array(logical_len); - let mut run_array_builder = - PrimitiveRunBuilder::::new(); + let mut run_array_builder = PrimitiveRunBuilder::::new(); run_array_builder.extend(input_array.iter().copied()); let run_array = run_array_builder.finish(); let typed_array = run_array.downcast::().unwrap(); @@ -327,8 +323,7 @@ mod tests { }) .collect(); - let result_asref: Vec> = - result.iter().map(|f| f.as_deref()).collect(); + let result_asref: Vec> = result.iter().map(|f| f.as_deref()).collect(); let expected_vec = vec![ Some("abb"), @@ -364,8 +359,7 @@ mod tests { // Iterate on sliced typed run array let actual: Vec> = sliced_typed_run_array.into_iter().collect(); - let expected: Vec> = - input_array.iter().take(slice_len).copied().collect(); + let expected: Vec> = input_array.iter().take(slice_len).copied().collect(); assert_eq!(expected, actual); // test for offset = total_len - slice_len, length = slice_len diff --git a/arrow-array/src/temporal_conversions.rs b/arrow-array/src/temporal_conversions.rs index f1f3f36d3c61..e0edcc9bc182 100644 --- a/arrow-array/src/temporal_conversions.rs +++ b/arrow-array/src/temporal_conversions.rs @@ -20,9 +20,7 @@ use crate::timezone::Tz; use crate::ArrowPrimitiveType; use arrow_schema::{DataType, TimeUnit}; -use chrono::{ - DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc, -}; +use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc}; /// Number of seconds in a day pub const SECONDS_IN_DAY: i64 = 86_400; @@ -221,10 +219,7 @@ pub fn as_datetime(v: i64) -> Option { } /// Converts an [`ArrowPrimitiveType`] to [`DateTime`] -pub fn as_datetime_with_timezone( - v: i64, - tz: Tz, -) -> Option> { +pub fn as_datetime_with_timezone(v: i64, tz: Tz) -> Option> { let naive = as_datetime::(v)?; Some(Utc.from_utc_datetime(&naive).with_timezone(&tz)) } @@ -274,8 +269,8 @@ pub fn as_duration(v: i64) -> Option { #[cfg(test)] mod tests { use crate::temporal_conversions::{ - date64_to_datetime, split_second, timestamp_ms_to_datetime, - timestamp_ns_to_datetime, timestamp_us_to_datetime, NANOSECONDS, + date64_to_datetime, split_second, timestamp_ms_to_datetime, timestamp_ns_to_datetime, + timestamp_us_to_datetime, NANOSECONDS, }; use chrono::NaiveDateTime; diff --git a/arrow-array/src/timezone.rs b/arrow-array/src/timezone.rs index f56189c46512..dc91886f34c5 100644 --- a/arrow-array/src/timezone.rs +++ b/arrow-array/src/timezone.rs @@ -38,8 +38,8 @@ fn parse_fixed_offset(tz: &str) -> Option { if values.iter().any(|x| *x > 9) { return None; } - let secs = (values[0] * 10 + values[1]) as i32 * 60 * 60 - + (values[2] * 10 + values[3]) as i32 * 60; + let secs = + (values[0] * 10 + values[1]) as i32 * 60 * 60 + (values[2] * 10 + values[3]) as i32 * 60; match bytes[0] { b'+' => FixedOffset::east_opt(secs), @@ -122,10 +122,7 @@ mod private { }) } - fn offset_from_local_datetime( - &self, - local: &NaiveDateTime, - ) -> LocalResult { + fn offset_from_local_datetime(&self, local: &NaiveDateTime) -> LocalResult { tz!(self, tz, { tz.offset_from_local_datetime(local).map(|x| TzOffset { tz: *self, @@ -285,10 +282,7 @@ mod private { self.0.offset_from_local_date(local).map(TzOffset) } - fn offset_from_local_datetime( - &self, - local: &NaiveDateTime, - ) -> LocalResult { + fn offset_from_local_datetime(&self, local: &NaiveDateTime) -> LocalResult { self.0.offset_from_local_datetime(local).map(TzOffset) } diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index 7988fe9f6690..16d0e822d052 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -18,8 +18,7 @@ //! Zero-sized types used to parameterize generic array implementations use crate::delta::{ - add_days_datetime, add_months_datetime, shift_months, sub_days_datetime, - sub_months_datetime, + add_days_datetime, add_months_datetime, shift_months, sub_days_datetime, sub_months_datetime, }; use crate::temporal_conversions::as_datetime_with_timezone; use crate::timezone::Tz; @@ -27,9 +26,8 @@ use crate::{ArrowNativeTypeOp, OffsetSizeTrait}; use arrow_buffer::{i256, Buffer, OffsetBuffer}; use arrow_data::decimal::{validate_decimal256_precision, validate_decimal_precision}; use arrow_schema::{ - ArrowError, DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, - DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, - DECIMAL_DEFAULT_SCALE, + ArrowError, DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL_DEFAULT_SCALE, }; use chrono::{Duration, NaiveDate, NaiveDateTime}; use half::f16; @@ -875,9 +873,7 @@ impl IntervalDayTimeType { /// /// * `i` - The IntervalDayTimeType to convert #[inline] - pub fn to_parts( - i: ::Native, - ) -> (i32, i32) { + pub fn to_parts(i: ::Native) -> (i32, i32) { let days = (i >> 32) as i32; let ms = i as i32; (days, ms) @@ -1221,10 +1217,7 @@ pub trait DecimalType: fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String; /// Validates that `value` contains no more than `precision` decimal digits - fn validate_decimal_precision( - value: Self::Native, - precision: u8, - ) -> Result<(), ArrowError>; + fn validate_decimal_precision(value: Self::Native, precision: u8) -> Result<(), ArrowError>; } /// Validate that `precision` and `scale` are valid for `T` @@ -1400,10 +1393,7 @@ pub trait ByteArrayType: 'static + Send + Sync + bytes::ByteArrayTypeSealed { const DATA_TYPE: DataType; /// Verifies that every consecutive pair of `offsets` denotes a valid slice of `values` - fn validate( - offsets: &OffsetBuffer, - values: &Buffer, - ) -> Result<(), ArrowError>; + fn validate(offsets: &OffsetBuffer, values: &Buffer) -> Result<(), ArrowError>; } /// [`ByteArrayType`] for string arrays @@ -1422,10 +1412,7 @@ impl ByteArrayType for GenericStringType { DataType::Utf8 }; - fn validate( - offsets: &OffsetBuffer, - values: &Buffer, - ) -> Result<(), ArrowError> { + fn validate(offsets: &OffsetBuffer, values: &Buffer) -> Result<(), ArrowError> { // Verify that the slice as a whole is valid UTF-8 let validated = std::str::from_utf8(values).map_err(|e| { ArrowError::InvalidArgumentError(format!("Encountered non UTF-8 data: {e}")) @@ -1471,10 +1458,7 @@ impl ByteArrayType for GenericBinaryType { DataType::Binary }; - fn validate( - offsets: &OffsetBuffer, - values: &Buffer, - ) -> Result<(), ArrowError> { + fn validate(offsets: &OffsetBuffer, values: &Buffer) -> Result<(), ArrowError> { // offsets are guaranteed to be monotonically increasing and non-empty let max_offset = offsets.last().unwrap().as_usize(); if values.len() < max_offset { diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index 2d443175a7aa..00e85b39be73 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -133,9 +133,7 @@ impl HeaderDecoder { let remaining = &MAGIC[MAGIC.len() - self.bytes_remaining..]; let to_decode = buf.len().min(remaining.len()); if !buf.starts_with(&remaining[..to_decode]) { - return Err(ArrowError::ParseError( - "Incorrect avro magic".to_string(), - )); + return Err(ArrowError::ParseError("Incorrect avro magic".to_string())); } self.bytes_remaining -= to_decode; buf = &buf[to_decode..]; diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 91e2dbf9835b..7769bbbc4998 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -50,9 +50,7 @@ fn read_header(mut reader: R) -> Result { } /// Return an iterator of [`Block`] from the provided [`BufRead`] -fn read_blocks( - mut reader: R, -) -> impl Iterator> { +fn read_blocks(mut reader: R) -> impl Iterator> { let mut decoder = BlockDecoder::default(); let mut try_next = move || { diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 839ba65bd5fc..17b82cf861b7 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -335,9 +335,7 @@ mod tests { Field { name: "value", doc: None, - r#type: Schema::TypeName(TypeName::Primitive( - PrimitiveType::Long - )), + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), default: None, }, Field { diff --git a/arrow-buffer/src/bigint/div.rs b/arrow-buffer/src/bigint/div.rs index ba530ffcc6c8..e1b2ed4f8aa5 100644 --- a/arrow-buffer/src/bigint/div.rs +++ b/arrow-buffer/src/bigint/div.rs @@ -26,10 +26,7 @@ /// # Panics /// /// Panics if divisor is zero -pub fn div_rem( - numerator: &[u64; N], - divisor: &[u64; N], -) -> ([u64; N], [u64; N]) { +pub fn div_rem(numerator: &[u64; N], divisor: &[u64; N]) -> ([u64; N], [u64; N]) { let numerator_bits = bits(numerator); let divisor_bits = bits(divisor); assert_ne!(divisor_bits, 0, "division by zero"); @@ -61,10 +58,7 @@ fn bits(arr: &[u64]) -> usize { } /// Division of numerator by a u64 divisor -fn div_rem_small( - numerator: &[u64; N], - divisor: u64, -) -> ([u64; N], [u64; N]) { +fn div_rem_small(numerator: &[u64; N], divisor: u64) -> ([u64; N], [u64; N]) { let mut rem = 0u64; let mut numerator = *numerator; numerator.iter_mut().rev().for_each(|d| { @@ -227,11 +221,7 @@ fn sub_assign(a: &mut [u64], b: &[u64]) -> bool { } /// Converts an overflowing binary operation on scalars to one on slices -fn binop_slice( - a: &mut [u64], - b: &[u64], - binop: impl Fn(u64, u64) -> (u64, bool) + Copy, -) -> bool { +fn binop_slice(a: &mut [u64], b: &[u64], binop: impl Fn(u64, u64) -> (u64, bool) + Copy) -> bool { let mut c = false; a.iter_mut().zip(b.iter()).for_each(|(x, y)| { let (res1, overflow1) = y.overflowing_add(u64::from(c)); diff --git a/arrow-buffer/src/bigint/mod.rs b/arrow-buffer/src/bigint/mod.rs index d064663bf63a..afbb3a31df12 100644 --- a/arrow-buffer/src/bigint/mod.rs +++ b/arrow-buffer/src/bigint/mod.rs @@ -310,9 +310,7 @@ impl i256 { (Self::from_le_bytes(bytes), false) } Ordering::Equal => (Self::from_le_bytes(v_bytes.try_into().unwrap()), false), - Ordering::Greater => { - (Self::from_le_bytes(v_bytes[..32].try_into().unwrap()), true) - } + Ordering::Greater => (Self::from_le_bytes(v_bytes[..32].try_into().unwrap()), true), } } @@ -357,8 +355,7 @@ impl i256 { #[inline] pub fn checked_add(self, other: Self) -> Option { let r = self.wrapping_add(other); - ((other.is_negative() && r < self) || (!other.is_negative() && r >= self)) - .then_some(r) + ((other.is_negative() && r < self) || (!other.is_negative() && r >= self)).then_some(r) } /// Performs wrapping subtraction @@ -373,8 +370,7 @@ impl i256 { #[inline] pub fn checked_sub(self, other: Self) -> Option { let r = self.wrapping_sub(other); - ((other.is_negative() && r > self) || (!other.is_negative() && r <= self)) - .then_some(r) + ((other.is_negative() && r > self) || (!other.is_negative() && r <= self)).then_some(r) } /// Performs wrapping multiplication @@ -591,9 +587,7 @@ impl i256 { /// Temporary workaround due to lack of stable const array slicing /// See -const fn split_array( - vals: [u8; N], -) -> ([u8; M], [u8; M]) { +const fn split_array(vals: [u8; N]) -> ([u8; M], [u8; M]) { let mut a = [0; M]; let mut b = [0; M]; let mut i = 0; @@ -915,8 +909,7 @@ mod tests { // Addition let actual = il.wrapping_add(ir); - let (expected, overflow) = - i256::from_bigint_with_overflow(bl.clone() + br.clone()); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone() + br.clone()); assert_eq!(actual, expected); let checked = il.checked_add(ir); @@ -927,8 +920,7 @@ mod tests { // Subtraction let actual = il.wrapping_sub(ir); - let (expected, overflow) = - i256::from_bigint_with_overflow(bl.clone() - br.clone()); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone() - br.clone()); assert_eq!(actual.to_string(), expected.to_string()); let checked = il.checked_sub(ir); @@ -939,8 +931,7 @@ mod tests { // Multiplication let actual = il.wrapping_mul(ir); - let (expected, overflow) = - i256::from_bigint_with_overflow(bl.clone() * br.clone()); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone() * br.clone()); assert_eq!(actual.to_string(), expected.to_string()); let checked = il.checked_mul(ir); @@ -996,8 +987,7 @@ mod tests { // Exponentiation for exp in vec![0, 1, 2, 3, 8, 100].into_iter() { let actual = il.wrapping_pow(exp); - let (expected, overflow) = - i256::from_bigint_with_overflow(bl.clone().pow(exp)); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone().pow(exp)); assert_eq!(actual.to_string(), expected.to_string()); let checked = il.checked_pow(exp); @@ -1212,7 +1202,10 @@ mod tests { ("000000000000000000000000000000000000000", Some(i256::ZERO)), ("0000000000000000000000000000000000000000-11", None), ("11-1111111111111111111111111111111111111", None), - ("115792089237316195423570985008687907853269984665640564039457584007913129639936", None) + ( + "115792089237316195423570985008687907853269984665640564039457584007913129639936", + None, + ), ]; for (case, expected) in cases { assert_eq!(i256::from_string(case), expected) diff --git a/arrow-buffer/src/buffer/boolean.rs b/arrow-buffer/src/buffer/boolean.rs index 577c716e4bea..c651edcad92e 100644 --- a/arrow-buffer/src/buffer/boolean.rs +++ b/arrow-buffer/src/buffer/boolean.rs @@ -223,13 +223,7 @@ impl BitAnd<&BooleanBuffer> for &BooleanBuffer { fn bitand(self, rhs: &BooleanBuffer) -> Self::Output { assert_eq!(self.len, rhs.len); BooleanBuffer { - buffer: buffer_bin_and( - &self.buffer, - self.offset, - &rhs.buffer, - rhs.offset, - self.len, - ), + buffer: buffer_bin_and(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), offset: 0, len: self.len, } @@ -242,13 +236,7 @@ impl BitOr<&BooleanBuffer> for &BooleanBuffer { fn bitor(self, rhs: &BooleanBuffer) -> Self::Output { assert_eq!(self.len, rhs.len); BooleanBuffer { - buffer: buffer_bin_or( - &self.buffer, - self.offset, - &rhs.buffer, - rhs.offset, - self.len, - ), + buffer: buffer_bin_or(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), offset: 0, len: self.len, } @@ -261,13 +249,7 @@ impl BitXor<&BooleanBuffer> for &BooleanBuffer { fn bitxor(self, rhs: &BooleanBuffer) -> Self::Output { assert_eq!(self.len, rhs.len); BooleanBuffer { - buffer: buffer_bin_xor( - &self.buffer, - self.offset, - &rhs.buffer, - rhs.offset, - self.len, - ), + buffer: buffer_bin_xor(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), offset: 0, len: self.len, } @@ -428,8 +410,7 @@ mod tests { let buf = Buffer::from(&[0, 1, 1, 0, 0]); let boolean_buf = &BooleanBuffer::new(buf, offset, len); - let expected = - BooleanBuffer::new(Buffer::from(&[255, 254, 254, 255, 255]), offset, len); + let expected = BooleanBuffer::new(Buffer::from(&[255, 254, 254, 255, 255]), offset, len); assert_eq!(!boolean_buf, expected); } } diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index bda6dfc5cdee..05530eed9b08 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -523,9 +523,7 @@ mod tests { } #[test] - #[should_panic( - expected = "the offset of the new Buffer cannot exceed the existing length" - )] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] fn test_slice_offset_out_of_bound() { let buf = Buffer::from(&[2, 4, 6, 8, 10]); buf.slice(6); @@ -688,9 +686,7 @@ mod tests { } #[test] - #[should_panic( - expected = "the offset of the new Buffer cannot exceed the existing length" - )] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] fn slice_overflow() { let buffer = Buffer::from(MutableBuffer::from_len_zeroed(12)); buffer.slice_with_length(2, usize::MAX); diff --git a/arrow-buffer/src/buffer/mutable.rs b/arrow-buffer/src/buffer/mutable.rs index 2c56f9a5b270..69c986cc1056 100644 --- a/arrow-buffer/src/buffer/mutable.rs +++ b/arrow-buffer/src/buffer/mutable.rs @@ -334,9 +334,7 @@ impl MutableBuffer { #[inline] pub(super) fn into_buffer(self) -> Buffer { - let bytes = unsafe { - Bytes::new(self.data, self.len, Deallocation::Standard(self.layout)) - }; + let bytes = unsafe { Bytes::new(self.data, self.len, Deallocation::Standard(self.layout)) }; std::mem::forget(self); Buffer::from_bytes(bytes) } @@ -351,8 +349,7 @@ impl MutableBuffer { // SAFETY // ArrowNativeType is trivially transmutable, is sealed to prevent potentially incorrect // implementation outside this crate, and this method checks alignment - let (prefix, offsets, suffix) = - unsafe { self.as_slice_mut().align_to_mut::() }; + let (prefix, offsets, suffix) = unsafe { self.as_slice_mut().align_to_mut::() }; assert!(prefix.is_empty() && suffix.is_empty()); offsets } @@ -604,9 +601,7 @@ impl MutableBuffer { // we can't specialize `extend` for `TrustedLen` like `Vec` does. // 2. `from_trusted_len_iter_bool` is faster. #[inline] - pub unsafe fn from_trusted_len_iter_bool>( - mut iterator: I, - ) -> Self { + pub unsafe fn from_trusted_len_iter_bool>(mut iterator: I) -> Self { let (_, upper) = iterator.size_hint(); let len = upper.expect("from_trusted_len_iter requires an upper limit"); diff --git a/arrow-buffer/src/buffer/null.rs b/arrow-buffer/src/buffer/null.rs index e0c7d9ef8f49..c79aef398059 100644 --- a/arrow-buffer/src/buffer/null.rs +++ b/arrow-buffer/src/buffer/null.rs @@ -71,10 +71,7 @@ impl NullBuffer { /// This is commonly used by binary operations where the result is NULL if either /// of the input values is NULL. Handling the null mask separately in this way /// can yield significant performance improvements over an iterator approach - pub fn union( - lhs: Option<&NullBuffer>, - rhs: Option<&NullBuffer>, - ) -> Option { + pub fn union(lhs: Option<&NullBuffer>, rhs: Option<&NullBuffer>) -> Option { match (lhs, rhs) { (Some(lhs), Some(rhs)) => Some(Self::new(lhs.inner() & rhs.inner())), (Some(n), None) | (None, Some(n)) => Some(n.clone()), diff --git a/arrow-buffer/src/buffer/offset.rs b/arrow-buffer/src/buffer/offset.rs index a6f2f7f6cfae..652d30c3b0ab 100644 --- a/arrow-buffer/src/buffer/offset.rs +++ b/arrow-buffer/src/buffer/offset.rs @@ -219,8 +219,7 @@ mod tests { assert_eq!(buffer.as_ref(), &[0, 2, 8, 11, 18, 20]); let half_max = i32::MAX / 2; - let buffer = - OffsetBuffer::::from_lengths([half_max as usize, half_max as usize]); + let buffer = OffsetBuffer::::from_lengths([half_max as usize, half_max as usize]); assert_eq!(buffer.as_ref(), &[0, half_max, half_max * 2]); } diff --git a/arrow-buffer/src/buffer/ops.rs b/arrow-buffer/src/buffer/ops.rs index eccff6280dd8..ca00e41bea21 100644 --- a/arrow-buffer/src/buffer/ops.rs +++ b/arrow-buffer/src/buffer/ops.rs @@ -184,10 +184,6 @@ pub fn buffer_bin_xor( /// Apply a bitwise not to one input and return the result as a Buffer. /// The input is treated as a bitmap, meaning that offset and length are specified in number of bits. -pub fn buffer_unary_not( - left: &Buffer, - offset_in_bits: usize, - len_in_bits: usize, -) -> Buffer { +pub fn buffer_unary_not(left: &Buffer, offset_in_bits: usize, len_in_bits: usize) -> Buffer { bitwise_unary_op_helper(left, offset_in_bits, len_in_bits, |a| !a) } diff --git a/arrow-buffer/src/buffer/run.rs b/arrow-buffer/src/buffer/run.rs index 29c0f3dfd949..3dbbe344a025 100644 --- a/arrow-buffer/src/buffer/run.rs +++ b/arrow-buffer/src/buffer/run.rs @@ -110,11 +110,7 @@ where /// /// - `buffer` must contain strictly increasing values greater than zero /// - The last value of `buffer` must be greater than or equal to `offset + len` - pub unsafe fn new_unchecked( - run_ends: ScalarBuffer, - offset: usize, - len: usize, - ) -> Self { + pub unsafe fn new_unchecked(run_ends: ScalarBuffer, offset: usize, len: usize) -> Self { Self { run_ends, offset, diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs index 276e635e825c..3b75d5384046 100644 --- a/arrow-buffer/src/buffer/scalar.rs +++ b/arrow-buffer/src/buffer/scalar.rs @@ -221,9 +221,7 @@ mod tests { } #[test] - #[should_panic( - expected = "Memory pointer is not aligned with the specified scalar type" - )] + #[should_panic(expected = "Memory pointer is not aligned with the specified scalar type")] fn test_unaligned() { let expected = [0_i32, 1, 2]; let buffer = Buffer::from_iter(expected.iter().cloned()); @@ -232,18 +230,14 @@ mod tests { } #[test] - #[should_panic( - expected = "the offset of the new Buffer cannot exceed the existing length" - )] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] fn test_length_out_of_bounds() { let buffer = Buffer::from_iter([0_i32, 1, 2]); ScalarBuffer::::new(buffer, 1, 3); } #[test] - #[should_panic( - expected = "the offset of the new Buffer cannot exceed the existing length" - )] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] fn test_offset_out_of_bounds() { let buffer = Buffer::from_iter([0_i32, 1, 2]); ScalarBuffer::::new(buffer, 4, 0); diff --git a/arrow-buffer/src/builder/boolean.rs b/arrow-buffer/src/builder/boolean.rs index f0e7f0f13670..ca178ae5ce4e 100644 --- a/arrow-buffer/src/builder/boolean.rs +++ b/arrow-buffer/src/builder/boolean.rs @@ -154,14 +154,12 @@ impl BooleanBufferBuilder { if cur_remainder != 0 { // Pad last byte with 1s - *self.buffer.as_slice_mut().last_mut().unwrap() |= - !((1 << cur_remainder) - 1) + *self.buffer.as_slice_mut().last_mut().unwrap() |= !((1 << cur_remainder) - 1) } self.buffer.resize(new_len_bytes, 0xFF); if new_remainder != 0 { // Clear remaining bits - *self.buffer.as_slice_mut().last_mut().unwrap() &= - (1 << new_remainder) - 1 + *self.buffer.as_slice_mut().last_mut().unwrap() &= (1 << new_remainder) - 1 } self.len = new_len; } diff --git a/arrow-buffer/src/bytes.rs b/arrow-buffer/src/bytes.rs index 8f5019d5a4cc..81860b604868 100644 --- a/arrow-buffer/src/bytes.rs +++ b/arrow-buffer/src/bytes.rs @@ -60,11 +60,7 @@ impl Bytes { /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` /// bytes. If the `ptr` and `capacity` come from a `Buffer`, then this is guaranteed. #[inline] - pub(crate) unsafe fn new( - ptr: NonNull, - len: usize, - deallocation: Deallocation, - ) -> Bytes { + pub(crate) unsafe fn new(ptr: NonNull, len: usize, deallocation: Deallocation) -> Bytes { Bytes { ptr, len, diff --git a/arrow-buffer/src/util/bit_chunk_iterator.rs b/arrow-buffer/src/util/bit_chunk_iterator.rs index 6830acae94a1..9e4fb8268dff 100644 --- a/arrow-buffer/src/util/bit_chunk_iterator.rs +++ b/arrow-buffer/src/util/bit_chunk_iterator.rs @@ -60,8 +60,7 @@ impl<'a> UnalignedBitChunk<'a> { // If less than 8 bytes, read into prefix if buffer.len() <= 8 { - let (suffix_mask, trailing_padding) = - compute_suffix_mask(len, offset_padding); + let (suffix_mask, trailing_padding) = compute_suffix_mask(len, offset_padding); let prefix = read_u64(buffer) & suffix_mask & prefix_mask; return Self { @@ -75,8 +74,7 @@ impl<'a> UnalignedBitChunk<'a> { // If less than 16 bytes, read into prefix and suffix if buffer.len() <= 16 { - let (suffix_mask, trailing_padding) = - compute_suffix_mask(len, offset_padding); + let (suffix_mask, trailing_padding) = compute_suffix_mask(len, offset_padding); let prefix = read_u64(&buffer[..8]) & prefix_mask; let suffix = read_u64(&buffer[8..]) & suffix_mask; @@ -167,10 +165,7 @@ impl<'a> UnalignedBitChunk<'a> { } pub type UnalignedBitChunkIterator<'a> = std::iter::Chain< - std::iter::Chain< - std::option::IntoIter, - std::iter::Cloned>, - >, + std::iter::Chain, std::iter::Cloned>>, std::option::IntoIter, >; @@ -338,9 +333,8 @@ impl Iterator for BitChunkIterator<'_> { } else { // the constructor ensures that bit_offset is in 0..8 // that means we need to read at most one additional byte to fill in the high bits - let next = unsafe { - std::ptr::read_unaligned(raw_data.add(index + 1) as *const u8) as u64 - }; + let next = + unsafe { std::ptr::read_unaligned(raw_data.add(index + 1) as *const u8) as u64 }; (current >> bit_offset) | (next << (64 - bit_offset)) }; @@ -387,8 +381,8 @@ mod tests { #[test] fn test_iter_unaligned() { let input: &[u8] = &[ - 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, - 0b00100000, 0b01000000, 0b11111111, + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, ]; let buffer: Buffer = Buffer::from(input); @@ -408,8 +402,8 @@ mod tests { #[test] fn test_iter_unaligned_remainder_1_byte() { let input: &[u8] = &[ - 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, - 0b00100000, 0b01000000, 0b11111111, + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, ]; let buffer: Buffer = Buffer::from(input); @@ -442,8 +436,8 @@ mod tests { #[test] fn test_iter_unaligned_remainder_bits_large() { let input: &[u8] = &[ - 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, 0b00000000, - 0b11111111, 0b00000000, 0b11111111, + 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, + 0b00000000, 0b11111111, ]; let buffer: Buffer = Buffer::from(input); @@ -637,11 +631,8 @@ mod tests { let max_truncate = 128.min(mask_len - offset); let truncate = rng.gen::().checked_rem(max_truncate).unwrap_or(0); - let unaligned = UnalignedBitChunk::new( - buffer.as_slice(), - offset, - mask_len - offset - truncate, - ); + let unaligned = + UnalignedBitChunk::new(buffer.as_slice(), offset, mask_len - offset - truncate); let bool_slice = &bools[offset..mask_len - truncate]; diff --git a/arrow-buffer/src/util/bit_iterator.rs b/arrow-buffer/src/util/bit_iterator.rs index 4e24ccdabec0..df40a8fbaccb 100644 --- a/arrow-buffer/src/util/bit_iterator.rs +++ b/arrow-buffer/src/util/bit_iterator.rs @@ -276,8 +276,8 @@ mod tests { assert_eq!( actual, &[ - false, true, false, false, true, false, true, false, false, false, false, - false, true, false + false, true, false, false, true, false, true, false, false, false, false, false, + true, false ] ); diff --git a/arrow-buffer/src/util/bit_mask.rs b/arrow-buffer/src/util/bit_mask.rs index 2af24b782632..8f81cb7d0469 100644 --- a/arrow-buffer/src/util/bit_mask.rs +++ b/arrow-buffer/src/util/bit_mask.rs @@ -42,8 +42,7 @@ pub fn set_bits( let chunks = BitChunks::new(data, offset_read + bits_to_align, len - bits_to_align); chunks.iter().for_each(|chunk| { null_count += chunk.count_zeros(); - write_data[write_byte_index..write_byte_index + 8] - .copy_from_slice(&chunk.to_le_bytes()); + write_data[write_byte_index..write_byte_index + 8].copy_from_slice(&chunk.to_le_bytes()); write_byte_index += 8; }); @@ -70,8 +69,8 @@ mod tests { fn test_set_bits_aligned() { let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, ]; let destination_offset = 8; @@ -80,8 +79,8 @@ mod tests { let len = 64; let expected_data: &[u8] = &[ - 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, 0, + 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, 0, ]; let expected_null_count = 24; let result = set_bits( @@ -100,8 +99,8 @@ mod tests { fn test_set_bits_unaligned_destination_start() { let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, ]; let destination_offset = 3; @@ -110,8 +109,8 @@ mod tests { let len = 64; let expected_data: &[u8] = &[ - 0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, - 0b00111110, 0b00101111, 0b00000101, 0b00000000, + 0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, 0b00111110, + 0b00101111, 0b00000101, 0b00000000, ]; let expected_null_count = 24; let result = set_bits( @@ -130,8 +129,8 @@ mod tests { fn test_set_bits_unaligned_destination_end() { let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, ]; let destination_offset = 8; @@ -140,8 +139,8 @@ mod tests { let len = 62; let expected_data: &[u8] = &[ - 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b00100101, 0, + 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b00100101, 0, ]; let expected_null_count = 23; let result = set_bits( @@ -160,9 +159,9 @@ mod tests { fn test_set_bits_unaligned() { let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, 0b10100101, + 0b10011001, 0b11011011, 0b11101011, 0b11000011, ]; let destination_offset = 3; @@ -171,9 +170,8 @@ mod tests { let len = 95; let expected_data: &[u8] = &[ - 0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, - 0b01111001, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, - 0b00000001, + 0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b01111001, + 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b00000001, ]; let expected_null_count = 35; let result = set_bits( diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 54c500f1ac41..97307f076f34 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -46,9 +46,7 @@ use crate::parse::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, string_to_datetime, Parser, }; -use arrow_array::{ - builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *, -}; +use arrow_array::{builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *}; use arrow_buffer::{i256, ArrowNativeType, Buffer, OffsetBuffer}; use arrow_data::ArrayData; use arrow_schema::*; @@ -365,9 +363,10 @@ where if cast_options.safe { array .unary_opt::<_, Decimal128Type>(|v| { - (mul * v.as_()).round().to_i128().filter(|v| { - Decimal128Type::validate_decimal_precision(*v, precision).is_ok() - }) + (mul * v.as_()) + .round() + .to_i128() + .filter(|v| Decimal128Type::validate_decimal_precision(*v, precision).is_ok()) }) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) @@ -387,8 +386,7 @@ where )) }) .and_then(|v| { - Decimal128Type::validate_decimal_precision(v, precision) - .map(|_| v) + Decimal128Type::validate_decimal_precision(v, precision).map(|_| v) }) })? .with_precision_and_scale(precision, scale) @@ -410,9 +408,8 @@ where if cast_options.safe { array .unary_opt::<_, Decimal256Type>(|v| { - i256::from_f64((v.as_() * mul).round()).filter(|v| { - Decimal256Type::validate_decimal_precision(*v, precision).is_ok() - }) + i256::from_f64((v.as_() * mul).round()) + .filter(|v| Decimal256Type::validate_decimal_precision(*v, precision).is_ok()) }) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) @@ -430,8 +427,7 @@ where )) }) .and_then(|v| { - Decimal256Type::validate_decimal_precision(v, precision) - .map(|_| v) + Decimal256Type::validate_decimal_precision(v, precision).map(|_| v) }) })? .with_precision_and_scale(precision, scale) @@ -493,7 +489,10 @@ fn cast_month_day_nano_to_duration>( .map(|v| { v.map(|v| match v >> 64 { 0 => Ok((v as i64) / scale), - _ => Err(ArrowError::ComputeError("Cannot convert interval containing non-zero months or days to duration".to_string())) + _ => Err(ArrowError::ComputeError( + "Cannot convert interval containing non-zero months or days to duration" + .to_string(), + )), }) .transpose() }) @@ -559,10 +558,7 @@ fn cast_duration_to_interval>( } /// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`] -fn cast_reinterpret_arrays< - I: ArrowPrimitiveType, - O: ArrowPrimitiveType, ->( +fn cast_reinterpret_arrays>( array: &dyn Array, ) -> Result { Ok(Arc::new(array.as_primitive::().reinterpret_cast::())) @@ -613,14 +609,13 @@ where } else { let v = array.value(i).div_checked(div)?; - let value = - ::from::(v).ok_or_else(|| { - ArrowError::CastError(format!( - "value of {:?} is out of range {}", - v, - T::DATA_TYPE - )) - })?; + let value = ::from::(v).ok_or_else(|| { + ArrowError::CastError(format!( + "value of {:?} is out of range {}", + v, + T::DATA_TYPE + )) + })?; value_builder.append_value(value); } @@ -780,9 +775,7 @@ pub fn cast_with_options( "Casting from type {from_type:?} to dictionary type {to_type:?} not supported", ))), }, - (List(_), List(ref to)) => { - cast_list_inner::(array, to, to_type, cast_options) - } + (List(_), List(ref to)) => cast_list_inner::(array, to, to_type, cast_options), (LargeList(_), LargeList(ref to)) => { cast_list_inner::(array, to, to_type, cast_options) } @@ -919,16 +912,12 @@ pub fn cast_with_options( *scale, cast_options, ), - Float32 => { - cast_decimal_to_float::(array, |x| { - (x as f64 / 10_f64.powi(*scale as i32)) as f32 - }) - } - Float64 => { - cast_decimal_to_float::(array, |x| { - x as f64 / 10_f64.powi(*scale as i32) - }) - } + Float32 => cast_decimal_to_float::(array, |x| { + (x as f64 / 10_f64.powi(*scale as i32)) as f32 + }), + Float64 => cast_decimal_to_float::(array, |x| { + x as f64 / 10_f64.powi(*scale as i32) + }), Utf8 => value_to_string::(array, cast_options), LargeUtf8 => value_to_string::(array, cast_options), Null => Ok(new_null_array(to_type, array.len())), @@ -988,16 +977,12 @@ pub fn cast_with_options( *scale, cast_options, ), - Float32 => { - cast_decimal_to_float::(array, |x| { - (x.to_f64().unwrap() / 10_f64.powi(*scale as i32)) as f32 - }) - } - Float64 => { - cast_decimal_to_float::(array, |x| { - x.to_f64().unwrap() / 10_f64.powi(*scale as i32) - }) - } + Float32 => cast_decimal_to_float::(array, |x| { + (x.to_f64().unwrap() / 10_f64.powi(*scale as i32)) as f32 + }), + Float64 => cast_decimal_to_float::(array, |x| { + x.to_f64().unwrap() / 10_f64.powi(*scale as i32) + }), Utf8 => value_to_string::(array, cast_options), LargeUtf8 => value_to_string::(array, cast_options), Null => Ok(new_null_array(to_type, array.len())), @@ -1239,25 +1224,35 @@ pub fn cast_with_options( Float64 => parse_string::(array, cast_options), Date32 => parse_string::(array, cast_options), Date64 => parse_string::(array, cast_options), - Binary => Ok(Arc::new(BinaryArray::from(array.as_string::().clone()))), + Binary => Ok(Arc::new(BinaryArray::from( + array.as_string::().clone(), + ))), LargeBinary => { let binary = BinaryArray::from(array.as_string::().clone()); cast_byte_container::(&binary) } LargeUtf8 => cast_byte_container::(array), Time32(TimeUnit::Second) => parse_string::(array, cast_options), - Time32(TimeUnit::Millisecond) => parse_string::(array, cast_options), - Time64(TimeUnit::Microsecond) => parse_string::(array, cast_options), - Time64(TimeUnit::Nanosecond) => parse_string::(array, cast_options), - Timestamp(TimeUnit::Second, to_tz) => { - cast_string_to_timestamp::(array, to_tz, cast_options) + Time32(TimeUnit::Millisecond) => { + parse_string::(array, cast_options) } - Timestamp(TimeUnit::Millisecond, to_tz) => { - cast_string_to_timestamp::(array, to_tz, cast_options) + Time64(TimeUnit::Microsecond) => { + parse_string::(array, cast_options) } - Timestamp(TimeUnit::Microsecond, to_tz) => { - cast_string_to_timestamp::(array, to_tz, cast_options) + Time64(TimeUnit::Nanosecond) => { + parse_string::(array, cast_options) } + Timestamp(TimeUnit::Second, to_tz) => { + cast_string_to_timestamp::(array, to_tz, cast_options) + } + Timestamp(TimeUnit::Millisecond, to_tz) => cast_string_to_timestamp::< + i32, + TimestampMillisecondType, + >(array, to_tz, cast_options), + Timestamp(TimeUnit::Microsecond, to_tz) => cast_string_to_timestamp::< + i32, + TimestampMicrosecondType, + >(array, to_tz, cast_options), Timestamp(TimeUnit::Nanosecond, to_tz) => { cast_string_to_timestamp::(array, to_tz, cast_options) } @@ -1289,26 +1284,33 @@ pub fn cast_with_options( Date64 => parse_string::(array, cast_options), Utf8 => cast_byte_container::(array), Binary => { - let large_binary = - LargeBinaryArray::from(array.as_string::().clone()); + let large_binary = LargeBinaryArray::from(array.as_string::().clone()); cast_byte_container::(&large_binary) } LargeBinary => Ok(Arc::new(LargeBinaryArray::from( array.as_string::().clone(), ))), Time32(TimeUnit::Second) => parse_string::(array, cast_options), - Time32(TimeUnit::Millisecond) => parse_string::(array, cast_options), - Time64(TimeUnit::Microsecond) => parse_string::(array, cast_options), - Time64(TimeUnit::Nanosecond) => parse_string::(array, cast_options), - Timestamp(TimeUnit::Second, to_tz) => { - cast_string_to_timestamp::(array, to_tz, cast_options) + Time32(TimeUnit::Millisecond) => { + parse_string::(array, cast_options) } - Timestamp(TimeUnit::Millisecond, to_tz) => { - cast_string_to_timestamp::(array, to_tz, cast_options) + Time64(TimeUnit::Microsecond) => { + parse_string::(array, cast_options) } - Timestamp(TimeUnit::Microsecond, to_tz) => { - cast_string_to_timestamp::(array, to_tz, cast_options) + Time64(TimeUnit::Nanosecond) => { + parse_string::(array, cast_options) } + Timestamp(TimeUnit::Second, to_tz) => { + cast_string_to_timestamp::(array, to_tz, cast_options) + } + Timestamp(TimeUnit::Millisecond, to_tz) => cast_string_to_timestamp::< + i64, + TimestampMillisecondType, + >(array, to_tz, cast_options), + Timestamp(TimeUnit::Microsecond, to_tz) => cast_string_to_timestamp::< + i64, + TimestampMicrosecondType, + >(array, to_tz, cast_options), Timestamp(TimeUnit::Nanosecond, to_tz) => { cast_string_to_timestamp::(array, to_tz, cast_options) } @@ -1331,9 +1333,7 @@ pub fn cast_with_options( let array = cast_binary_to_string::(array, cast_options)?; cast_byte_container::(array.as_ref()) } - LargeBinary => { - cast_byte_container::(array) - } + LargeBinary => cast_byte_container::(array), FixedSizeBinary(size) => { cast_binary_to_fixed_size_binary::(array, *size, cast_options) } @@ -1357,278 +1357,117 @@ pub fn cast_with_options( }, (FixedSizeBinary(size), _) => match to_type { Binary => cast_fixed_size_binary_to_binary::(array, *size), - LargeBinary => - cast_fixed_size_binary_to_binary::(array, *size), + LargeBinary => cast_fixed_size_binary_to_binary::(array, *size), _ => Err(ArrowError::CastError(format!( "Casting from {from_type:?} to {to_type:?} not supported", ))), }, - (from_type, LargeUtf8) if from_type.is_primitive() => value_to_string::(array, cast_options), - (from_type, Utf8) if from_type.is_primitive() => value_to_string::(array, cast_options), - // start numeric casts - (UInt8, UInt16) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt8, UInt32) => { - cast_numeric_arrays::(array, cast_options) + (from_type, LargeUtf8) if from_type.is_primitive() => { + value_to_string::(array, cast_options) } - (UInt8, UInt64) => { - cast_numeric_arrays::(array, cast_options) + (from_type, Utf8) if from_type.is_primitive() => { + value_to_string::(array, cast_options) } + // start numeric casts + (UInt8, UInt16) => cast_numeric_arrays::(array, cast_options), + (UInt8, UInt32) => cast_numeric_arrays::(array, cast_options), + (UInt8, UInt64) => cast_numeric_arrays::(array, cast_options), (UInt8, Int8) => cast_numeric_arrays::(array, cast_options), - (UInt8, Int16) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt8, Int32) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt8, Int64) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt8, Float32) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt8, Float64) => { - cast_numeric_arrays::(array, cast_options) - } - - (UInt16, UInt8) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt16, UInt32) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt16, UInt64) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt16, Int8) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt16, Int16) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt16, Int32) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt16, Int64) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt16, Float32) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt16, Float64) => { - cast_numeric_arrays::(array, cast_options) - } - - (UInt32, UInt8) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt32, UInt16) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt32, UInt64) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt32, Int8) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt32, Int16) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt32, Int32) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt32, Int64) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt32, Float32) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt32, Float64) => { - cast_numeric_arrays::(array, cast_options) - } - - (UInt64, UInt8) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt64, UInt16) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt64, UInt32) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt64, Int8) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt64, Int16) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt64, Int32) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt64, Int64) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt64, Float32) => { - cast_numeric_arrays::(array, cast_options) - } - (UInt64, Float64) => { - cast_numeric_arrays::(array, cast_options) - } + (UInt8, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt8, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt8, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt8, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt8, Float64) => cast_numeric_arrays::(array, cast_options), + + (UInt16, UInt8) => cast_numeric_arrays::(array, cast_options), + (UInt16, UInt32) => cast_numeric_arrays::(array, cast_options), + (UInt16, UInt64) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int8) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt16, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt16, Float64) => cast_numeric_arrays::(array, cast_options), + + (UInt32, UInt8) => cast_numeric_arrays::(array, cast_options), + (UInt32, UInt16) => cast_numeric_arrays::(array, cast_options), + (UInt32, UInt64) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int8) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt32, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt32, Float64) => cast_numeric_arrays::(array, cast_options), + + (UInt64, UInt8) => cast_numeric_arrays::(array, cast_options), + (UInt64, UInt16) => cast_numeric_arrays::(array, cast_options), + (UInt64, UInt32) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int8) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt64, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt64, Float64) => cast_numeric_arrays::(array, cast_options), (Int8, UInt8) => cast_numeric_arrays::(array, cast_options), - (Int8, UInt16) => { - cast_numeric_arrays::(array, cast_options) - } - (Int8, UInt32) => { - cast_numeric_arrays::(array, cast_options) - } - (Int8, UInt64) => { - cast_numeric_arrays::(array, cast_options) - } + (Int8, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int8, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int8, UInt64) => cast_numeric_arrays::(array, cast_options), (Int8, Int16) => cast_numeric_arrays::(array, cast_options), (Int8, Int32) => cast_numeric_arrays::(array, cast_options), (Int8, Int64) => cast_numeric_arrays::(array, cast_options), - (Int8, Float32) => { - cast_numeric_arrays::(array, cast_options) - } - (Int8, Float64) => { - cast_numeric_arrays::(array, cast_options) - } + (Int8, Float32) => cast_numeric_arrays::(array, cast_options), + (Int8, Float64) => cast_numeric_arrays::(array, cast_options), - (Int16, UInt8) => { - cast_numeric_arrays::(array, cast_options) - } - (Int16, UInt16) => { - cast_numeric_arrays::(array, cast_options) - } - (Int16, UInt32) => { - cast_numeric_arrays::(array, cast_options) - } - (Int16, UInt64) => { - cast_numeric_arrays::(array, cast_options) - } + (Int16, UInt8) => cast_numeric_arrays::(array, cast_options), + (Int16, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int16, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int16, UInt64) => cast_numeric_arrays::(array, cast_options), (Int16, Int8) => cast_numeric_arrays::(array, cast_options), - (Int16, Int32) => { - cast_numeric_arrays::(array, cast_options) - } - (Int16, Int64) => { - cast_numeric_arrays::(array, cast_options) - } - (Int16, Float32) => { - cast_numeric_arrays::(array, cast_options) - } - (Int16, Float64) => { - cast_numeric_arrays::(array, cast_options) - } - - (Int32, UInt8) => { - cast_numeric_arrays::(array, cast_options) - } - (Int32, UInt16) => { - cast_numeric_arrays::(array, cast_options) - } - (Int32, UInt32) => { - cast_numeric_arrays::(array, cast_options) - } - (Int32, UInt64) => { - cast_numeric_arrays::(array, cast_options) - } + (Int16, Int32) => cast_numeric_arrays::(array, cast_options), + (Int16, Int64) => cast_numeric_arrays::(array, cast_options), + (Int16, Float32) => cast_numeric_arrays::(array, cast_options), + (Int16, Float64) => cast_numeric_arrays::(array, cast_options), + + (Int32, UInt8) => cast_numeric_arrays::(array, cast_options), + (Int32, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int32, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int32, UInt64) => cast_numeric_arrays::(array, cast_options), (Int32, Int8) => cast_numeric_arrays::(array, cast_options), - (Int32, Int16) => { - cast_numeric_arrays::(array, cast_options) - } - (Int32, Int64) => { - cast_numeric_arrays::(array, cast_options) - } - (Int32, Float32) => { - cast_numeric_arrays::(array, cast_options) - } - (Int32, Float64) => { - cast_numeric_arrays::(array, cast_options) - } - - (Int64, UInt8) => { - cast_numeric_arrays::(array, cast_options) - } - (Int64, UInt16) => { - cast_numeric_arrays::(array, cast_options) - } - (Int64, UInt32) => { - cast_numeric_arrays::(array, cast_options) - } - (Int64, UInt64) => { - cast_numeric_arrays::(array, cast_options) - } + (Int32, Int16) => cast_numeric_arrays::(array, cast_options), + (Int32, Int64) => cast_numeric_arrays::(array, cast_options), + (Int32, Float32) => cast_numeric_arrays::(array, cast_options), + (Int32, Float64) => cast_numeric_arrays::(array, cast_options), + + (Int64, UInt8) => cast_numeric_arrays::(array, cast_options), + (Int64, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int64, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int64, UInt64) => cast_numeric_arrays::(array, cast_options), (Int64, Int8) => cast_numeric_arrays::(array, cast_options), - (Int64, Int16) => { - cast_numeric_arrays::(array, cast_options) - } - (Int64, Int32) => { - cast_numeric_arrays::(array, cast_options) - } - (Int64, Float32) => { - cast_numeric_arrays::(array, cast_options) - } - (Int64, Float64) => { - cast_numeric_arrays::(array, cast_options) - } - - (Float32, UInt8) => { - cast_numeric_arrays::(array, cast_options) - } - (Float32, UInt16) => { - cast_numeric_arrays::(array, cast_options) - } - (Float32, UInt32) => { - cast_numeric_arrays::(array, cast_options) - } - (Float32, UInt64) => { - cast_numeric_arrays::(array, cast_options) - } - (Float32, Int8) => { - cast_numeric_arrays::(array, cast_options) - } - (Float32, Int16) => { - cast_numeric_arrays::(array, cast_options) - } - (Float32, Int32) => { - cast_numeric_arrays::(array, cast_options) - } - (Float32, Int64) => { - cast_numeric_arrays::(array, cast_options) - } - (Float32, Float64) => { - cast_numeric_arrays::(array, cast_options) - } - - (Float64, UInt8) => { - cast_numeric_arrays::(array, cast_options) - } - (Float64, UInt16) => { - cast_numeric_arrays::(array, cast_options) - } - (Float64, UInt32) => { - cast_numeric_arrays::(array, cast_options) - } - (Float64, UInt64) => { - cast_numeric_arrays::(array, cast_options) - } - (Float64, Int8) => { - cast_numeric_arrays::(array, cast_options) - } - (Float64, Int16) => { - cast_numeric_arrays::(array, cast_options) - } - (Float64, Int32) => { - cast_numeric_arrays::(array, cast_options) - } - (Float64, Int64) => { - cast_numeric_arrays::(array, cast_options) - } - (Float64, Float32) => { - cast_numeric_arrays::(array, cast_options) - } + (Int64, Int16) => cast_numeric_arrays::(array, cast_options), + (Int64, Int32) => cast_numeric_arrays::(array, cast_options), + (Int64, Float32) => cast_numeric_arrays::(array, cast_options), + (Int64, Float64) => cast_numeric_arrays::(array, cast_options), + + (Float32, UInt8) => cast_numeric_arrays::(array, cast_options), + (Float32, UInt16) => cast_numeric_arrays::(array, cast_options), + (Float32, UInt32) => cast_numeric_arrays::(array, cast_options), + (Float32, UInt64) => cast_numeric_arrays::(array, cast_options), + (Float32, Int8) => cast_numeric_arrays::(array, cast_options), + (Float32, Int16) => cast_numeric_arrays::(array, cast_options), + (Float32, Int32) => cast_numeric_arrays::(array, cast_options), + (Float32, Int64) => cast_numeric_arrays::(array, cast_options), + (Float32, Float64) => cast_numeric_arrays::(array, cast_options), + + (Float64, UInt8) => cast_numeric_arrays::(array, cast_options), + (Float64, UInt16) => cast_numeric_arrays::(array, cast_options), + (Float64, UInt32) => cast_numeric_arrays::(array, cast_options), + (Float64, UInt64) => cast_numeric_arrays::(array, cast_options), + (Float64, Int8) => cast_numeric_arrays::(array, cast_options), + (Float64, Int16) => cast_numeric_arrays::(array, cast_options), + (Float64, Int32) => cast_numeric_arrays::(array, cast_options), + (Float64, Int64) => cast_numeric_arrays::(array, cast_options), + (Float64, Float32) => cast_numeric_arrays::(array, cast_options), // end numeric casts // temporal casts @@ -1684,71 +1523,77 @@ pub fn cast_with_options( cast_reinterpret_arrays::(array) } (Date32, Date64) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Date64Type>(|x| x as i64 * MILLISECONDS_IN_DAY), )), (Date64, Date32) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Date32Type>(|x| (x / MILLISECONDS_IN_DAY) as i32), )), (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Time32MillisecondType>(|x| x * MILLISECONDS as i32), )), (Time32(TimeUnit::Second), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Time64MicrosecondType>(|x| x as i64 * MICROSECONDS), )), (Time32(TimeUnit::Second), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Time64NanosecondType>(|x| x as i64 * NANOSECONDS), )), (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Time32SecondType>(|x| x / MILLISECONDS as i32), )), (Time32(TimeUnit::Millisecond), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( - array.as_primitive::() - .unary::<_, Time64MicrosecondType>(|x| { - x as i64 * (MICROSECONDS / MILLISECONDS) - }), + array + .as_primitive::() + .unary::<_, Time64MicrosecondType>(|x| x as i64 * (MICROSECONDS / MILLISECONDS)), )), (Time32(TimeUnit::Millisecond), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( - array.as_primitive::() - .unary::<_, Time64NanosecondType>(|x| { - x as i64 * (MICROSECONDS / NANOSECONDS) - }), + array + .as_primitive::() + .unary::<_, Time64NanosecondType>(|x| x as i64 * (MICROSECONDS / NANOSECONDS)), )), (Time64(TimeUnit::Microsecond), Time32(TimeUnit::Second)) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Time32SecondType>(|x| (x / MICROSECONDS) as i32), )), (Time64(TimeUnit::Microsecond), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( - array.as_primitive::() - .unary::<_, Time32MillisecondType>(|x| { - (x / (MICROSECONDS / MILLISECONDS)) as i32 - }), + array + .as_primitive::() + .unary::<_, Time32MillisecondType>(|x| (x / (MICROSECONDS / MILLISECONDS)) as i32), )), (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Time64NanosecondType>(|x| x * (NANOSECONDS / MICROSECONDS)), )), (Time64(TimeUnit::Nanosecond), Time32(TimeUnit::Second)) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Time32SecondType>(|x| (x / NANOSECONDS) as i32), )), (Time64(TimeUnit::Nanosecond), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( - array.as_primitive::() - .unary::<_, Time32MillisecondType>(|x| { - (x / (NANOSECONDS / MILLISECONDS)) as i32 - }), + array + .as_primitive::() + .unary::<_, Time32MillisecondType>(|x| (x / (NANOSECONDS / MILLISECONDS)) as i32), )), (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Time64MicrosecondType>(|x| x / (NANOSECONDS / MICROSECONDS)), )), @@ -1803,39 +1648,29 @@ pub fn cast_with_options( (None, Some(to_tz)) => { let to_tz: Tz = to_tz.parse()?; match to_unit { - TimeUnit::Second => { - adjust_timestamp_to_timezone::( - converted, - &to_tz, - cast_options, - )? - } - TimeUnit::Millisecond => { - adjust_timestamp_to_timezone::( - converted, - &to_tz, - cast_options, - )? - } - TimeUnit::Microsecond => { - adjust_timestamp_to_timezone::( - converted, - &to_tz, - cast_options, - )? - } - TimeUnit::Nanosecond => { - adjust_timestamp_to_timezone::( - converted, - &to_tz, - cast_options, - )? - } + TimeUnit::Second => adjust_timestamp_to_timezone::( + converted, + &to_tz, + cast_options, + )?, + TimeUnit::Millisecond => adjust_timestamp_to_timezone::< + TimestampMillisecondType, + >( + converted, &to_tz, cast_options + )?, + TimeUnit::Microsecond => adjust_timestamp_to_timezone::< + TimestampMicrosecondType, + >( + converted, &to_tz, cast_options + )?, + TimeUnit::Nanosecond => adjust_timestamp_to_timezone::< + TimestampNanosecondType, + >( + converted, &to_tz, cast_options + )?, } } - _ => { - converted - } + _ => converted, }; Ok(make_timestamp_array( &adjusted, @@ -1854,45 +1689,43 @@ pub fn cast_with_options( if time_array.is_null(i) { b.append_null(); } else { - b.append_value(num::integer::div_floor::(time_array.value(i), from_size) as i32); + b.append_value( + num::integer::div_floor::(time_array.value(i), from_size) as i32, + ); } } Ok(Arc::new(b.finish()) as ArrayRef) } - (Timestamp(TimeUnit::Second, _), Date64) => Ok(Arc::new( - match cast_options.safe { - true => { - // change error to None - array.as_primitive::() - .unary_opt::<_, Date64Type>(|x| { - x.checked_mul(MILLISECONDS) - }) - } - false => { - array.as_primitive::().try_unary::<_, Date64Type, _>( - |x| { - x.mul_checked(MILLISECONDS) - }, - )? - } - }, - )), + (Timestamp(TimeUnit::Second, _), Date64) => Ok(Arc::new(match cast_options.safe { + true => { + // change error to None + array + .as_primitive::() + .unary_opt::<_, Date64Type>(|x| x.checked_mul(MILLISECONDS)) + } + false => array + .as_primitive::() + .try_unary::<_, Date64Type, _>(|x| x.mul_checked(MILLISECONDS))?, + })), (Timestamp(TimeUnit::Millisecond, _), Date64) => { cast_reinterpret_arrays::(array) } (Timestamp(TimeUnit::Microsecond, _), Date64) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Date64Type>(|x| x / (MICROSECONDS / MILLISECONDS)), )), (Timestamp(TimeUnit::Nanosecond, _), Date64) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, Date64Type>(|x| x / (NANOSECONDS / MILLISECONDS)), )), (Timestamp(TimeUnit::Second, tz), Time64(TimeUnit::Microsecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { Ok(time_to_time64us(as_time_res_with_timezone::< TimestampSecondType, @@ -1903,7 +1736,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Second, tz), Time64(TimeUnit::Nanosecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { Ok(time_to_time64ns(as_time_res_with_timezone::< TimestampSecondType, @@ -1914,7 +1748,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Millisecond, tz), Time64(TimeUnit::Microsecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { Ok(time_to_time64us(as_time_res_with_timezone::< TimestampMillisecondType, @@ -1925,7 +1760,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Millisecond, tz), Time64(TimeUnit::Nanosecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { Ok(time_to_time64ns(as_time_res_with_timezone::< TimestampMillisecondType, @@ -1936,7 +1772,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Microsecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { Ok(time_to_time64us(as_time_res_with_timezone::< TimestampMicrosecondType, @@ -1947,7 +1784,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Nanosecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { Ok(time_to_time64ns(as_time_res_with_timezone::< TimestampMicrosecondType, @@ -1958,7 +1796,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Microsecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { Ok(time_to_time64us(as_time_res_with_timezone::< TimestampNanosecondType, @@ -1969,7 +1808,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Nanosecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { Ok(time_to_time64ns(as_time_res_with_timezone::< TimestampNanosecondType, @@ -1980,7 +1820,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Second, tz), Time32(TimeUnit::Second)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time32SecondType, ArrowError>(|x| { Ok(time_to_time32s(as_time_res_with_timezone::< TimestampSecondType, @@ -1991,7 +1832,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Second, tz), Time32(TimeUnit::Millisecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { Ok(time_to_time32ms(as_time_res_with_timezone::< TimestampSecondType, @@ -2002,7 +1844,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Millisecond, tz), Time32(TimeUnit::Second)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time32SecondType, ArrowError>(|x| { Ok(time_to_time32s(as_time_res_with_timezone::< TimestampMillisecondType, @@ -2013,7 +1856,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Millisecond, tz), Time32(TimeUnit::Millisecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { Ok(time_to_time32ms(as_time_res_with_timezone::< TimestampMillisecondType, @@ -2024,7 +1868,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Microsecond, tz), Time32(TimeUnit::Second)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time32SecondType, ArrowError>(|x| { Ok(time_to_time32s(as_time_res_with_timezone::< TimestampMicrosecondType, @@ -2035,7 +1880,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Microsecond, tz), Time32(TimeUnit::Millisecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { Ok(time_to_time32ms(as_time_res_with_timezone::< TimestampMicrosecondType, @@ -2046,7 +1892,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Nanosecond, tz), Time32(TimeUnit::Second)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time32SecondType, ArrowError>(|x| { Ok(time_to_time32s(as_time_res_with_timezone::< TimestampNanosecondType, @@ -2057,7 +1904,8 @@ pub fn cast_with_options( (Timestamp(TimeUnit::Nanosecond, tz), Time32(TimeUnit::Millisecond)) => { let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { Ok(time_to_time32ms(as_time_res_with_timezone::< TimestampNanosecondType, @@ -2067,38 +1915,41 @@ pub fn cast_with_options( } (Date64, Timestamp(TimeUnit::Second, None)) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, TimestampSecondType>(|x| x / MILLISECONDS), )), (Date64, Timestamp(TimeUnit::Millisecond, None)) => { cast_reinterpret_arrays::(array) } (Date64, Timestamp(TimeUnit::Microsecond, None)) => Ok(Arc::new( - array.as_primitive::().unary::<_, TimestampMicrosecondType>( - |x| x * (MICROSECONDS / MILLISECONDS), - ), + array + .as_primitive::() + .unary::<_, TimestampMicrosecondType>(|x| x * (MICROSECONDS / MILLISECONDS)), )), (Date64, Timestamp(TimeUnit::Nanosecond, None)) => Ok(Arc::new( - array.as_primitive::().unary::<_, TimestampNanosecondType>( - |x| x * (NANOSECONDS / MILLISECONDS), - ), + array + .as_primitive::() + .unary::<_, TimestampNanosecondType>(|x| x * (NANOSECONDS / MILLISECONDS)), )), (Date32, Timestamp(TimeUnit::Second, None)) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, TimestampSecondType>(|x| (x as i64) * SECONDS_IN_DAY), )), (Date32, Timestamp(TimeUnit::Millisecond, None)) => Ok(Arc::new( - array.as_primitive::().unary::<_, TimestampMillisecondType>( - |x| (x as i64) * MILLISECONDS_IN_DAY, - ), + array + .as_primitive::() + .unary::<_, TimestampMillisecondType>(|x| (x as i64) * MILLISECONDS_IN_DAY), )), (Date32, Timestamp(TimeUnit::Microsecond, None)) => Ok(Arc::new( - array.as_primitive::().unary::<_, TimestampMicrosecondType>( - |x| (x as i64) * MICROSECONDS_IN_DAY, - ), + array + .as_primitive::() + .unary::<_, TimestampMicrosecondType>(|x| (x as i64) * MICROSECONDS_IN_DAY), )), (Date32, Timestamp(TimeUnit::Nanosecond, None)) => Ok(Arc::new( - array.as_primitive::() + array + .as_primitive::() .unary::<_, TimestampNanosecondType>(|x| (x as i64) * NANOSECONDS_IN_DAY), )), (Int64, Duration(TimeUnit::Second)) => { @@ -2416,9 +2267,7 @@ where // Natural cast between numeric types // If the value of T can't be casted to R, will throw error -fn try_numeric_cast( - from: &PrimitiveArray, -) -> Result, ArrowError> +fn try_numeric_cast(from: &PrimitiveArray) -> Result, ArrowError> where T: ArrowPrimitiveType, R: ArrowPrimitiveType, @@ -2519,11 +2368,7 @@ fn cast_string_to_timestamp( Ok(Arc::new(out.with_timezone_opt(to_tz.clone()))) } -fn cast_string_to_timestamp_impl< - O: OffsetSizeTrait, - T: ArrowTimestampType, - Tz: TimeZone, ->( +fn cast_string_to_timestamp_impl( array: &GenericStringArray, tz: &Tz, cast_options: &CastOptions, @@ -2680,9 +2525,7 @@ fn adjust_timestamp_to_timezone( } else { array.try_unary::<_, Int64Type, _>(|o| { adjust(o).ok_or_else(|| { - ArrowError::CastError( - "Cannot cast timezone to different timezone".to_string(), - ) + ArrowError::CastError("Cannot cast timezone to different timezone".to_string()) }) })? }; @@ -2706,11 +2549,10 @@ where .iter() .map(|value| match value { Some(value) => match value.to_ascii_lowercase().trim() { - "t" | "tr" | "tru" | "true" | "y" | "ye" | "yes" | "on" | "1" => { - Ok(Some(true)) + "t" | "tr" | "tru" | "true" | "y" | "ye" | "yes" | "on" | "1" => Ok(Some(true)), + "f" | "fa" | "fal" | "fals" | "false" | "n" | "no" | "of" | "off" | "0" => { + Ok(Some(false)) } - "f" | "fa" | "fal" | "fals" | "false" | "n" | "no" | "of" | "off" - | "0" => Ok(Some(false)), invalid_value => match cast_options.safe { true => Ok(None), false => Err(ArrowError::CastError(format!( @@ -2748,13 +2590,10 @@ where // Adjust decimal based on scale let number_decimals = if decimals.len() > scale { let decimal_number = i256::from_string(decimals).ok_or_else(|| { - ArrowError::InvalidArgumentError(format!( - "Cannot parse decimal format: {value_str}" - )) + ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}")) })?; - let div = - i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?; + let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?; let half = div.div_wrapping(i256::from_i128(2)); let half_neg = half.neg_wrapping(); @@ -2776,9 +2615,7 @@ where "Cannot parse decimal format: {value_str}" )) }) - .map(|v| { - v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)) - })? + .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))? } else { i256::ZERO }; @@ -2800,11 +2637,7 @@ where })?; T::Native::from_decimal(value).ok_or_else(|| { - ArrowError::InvalidArgumentError(format!( - "Cannot convert {} to {}", - value_str, - T::PREFIX - )) + ArrowError::InvalidArgumentError(format!("Cannot convert {} to {}", value_str, T::PREFIX)) }) } @@ -2848,9 +2681,7 @@ where T::DATA_TYPE, )) }) - .and_then(|v| { - T::validate_decimal_precision(v, precision).map(|_| v) - }) + .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v)) }) .transpose() }) @@ -2907,8 +2738,7 @@ fn cast_numeric_to_bool(from: &dyn Array) -> Result where FROM: ArrowPrimitiveType, { - numeric_to_bool_cast::(from.as_primitive::()) - .map(|to| Arc::new(to) as ArrayRef) + numeric_to_bool_cast::(from.as_primitive::()).map(|to| Arc::new(to) as ArrayRef) } fn numeric_to_bool_cast(from: &PrimitiveArray) -> Result @@ -2947,10 +2777,7 @@ where ))) } -fn bool_to_numeric_cast( - from: &BooleanArray, - _cast_options: &CastOptions, -) -> PrimitiveArray +fn bool_to_numeric_cast(from: &BooleanArray, _cast_options: &CastOptions) -> PrimitiveArray where T: ArrowPrimitiveType, T::Native: num::NumCast, @@ -2998,8 +2825,7 @@ fn dictionary_cast( Arc::new(PrimitiveArray::::from(dict_array.keys().to_data())); let values_array = dict_array.values(); let cast_keys = cast_with_options(&keys_array, to_index_type, cast_options)?; - let cast_values = - cast_with_options(values_array, to_value_type, cast_options)?; + let cast_values = cast_with_options(values_array, to_value_type, cast_options)?; // Failure to cast keys (because they don't fit in the // target type) results in NULL values; @@ -3071,66 +2897,24 @@ fn cast_to_dictionary( use DataType::*; match *dict_value_type { - Int8 => pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - ), - Int16 => pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - ), - Int32 => pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - ), - Int64 => pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - ), - UInt8 => pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - ), - UInt16 => pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - ), - UInt32 => pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - ), - UInt64 => pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - ), - Decimal128(_, _) => pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - ), - Decimal256(_, _) => pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - ), - Utf8 => pack_byte_to_dictionary::>(array, cast_options), - LargeUtf8 => { - pack_byte_to_dictionary::>(array, cast_options) - } - Binary => { - pack_byte_to_dictionary::>(array, cast_options) - } - LargeBinary => { - pack_byte_to_dictionary::>(array, cast_options) + Int8 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt8 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Decimal128(_, _) => { + pack_numeric_to_dictionary::(array, dict_value_type, cast_options) + } + Decimal256(_, _) => { + pack_numeric_to_dictionary::(array, dict_value_type, cast_options) } + Utf8 => pack_byte_to_dictionary::>(array, cast_options), + LargeUtf8 => pack_byte_to_dictionary::>(array, cast_options), + Binary => pack_byte_to_dictionary::>(array, cast_options), + LargeBinary => pack_byte_to_dictionary::>(array, cast_options), _ => Err(ArrowError::CastError(format!( "Unsupported output type for dictionary packing: {dict_value_type:?}" ))), @@ -3152,8 +2936,7 @@ where let cast_values = cast_with_options(array, dict_value_type, cast_options)?; let values = cast_values.as_primitive::(); - let mut b = - PrimitiveDictionaryBuilder::::with_capacity(values.len(), values.len()); + let mut b = PrimitiveDictionaryBuilder::::with_capacity(values.len(), values.len()); // copy each element one at a time for i in 0..values.len() { @@ -3181,8 +2964,7 @@ where .as_any() .downcast_ref::>() .unwrap(); - let mut b = - GenericByteDictionaryBuilder::::with_capacity(values.len(), 1024, 1024); + let mut b = GenericByteDictionaryBuilder::::with_capacity(values.len(), 1024, 1024); // copy each element one at a time for i in 0..values.len() { @@ -3216,8 +2998,7 @@ fn cast_list_inner( ) -> Result { let data = array.to_data(); let underlying_array = make_array(data.child_data()[0].clone()); - let cast_array = - cast_with_options(underlying_array.as_ref(), to.data_type(), cast_options)?; + let cast_array = cast_with_options(underlying_array.as_ref(), to.data_type(), cast_options)?; let builder = data .into_builder() .data_type(to_type.clone()) @@ -3246,10 +3027,8 @@ fn cast_binary_to_string( Err(e) => match cast_options.safe { true => { // Fallback to slow method to convert invalid sequences to nulls - let mut builder = GenericStringBuilder::::with_capacity( - array.len(), - array.value_data().len(), - ); + let mut builder = + GenericStringBuilder::::with_capacity(array.len(), array.value_data().len()); let iter = array .iter() @@ -3344,8 +3123,8 @@ where offsets .iter() .try_for_each::<_, Result<_, ArrowError>>(|offset| { - let offset = <::Offset as NumCast>::from(*offset) - .ok_or_else(|| { + let offset = + <::Offset as NumCast>::from(*offset).ok_or_else(|| { ArrowError::ComputeError(format!( "{}{} array too large to cast to {}{} array", FROM::Offset::PREFIX, @@ -3374,9 +3153,7 @@ where Ok(Arc::new(GenericByteArray::::from(array_data))) } -fn cast_fixed_size_list_to_list( - array: &dyn Array, -) -> Result +fn cast_fixed_size_list_to_list(array: &dyn Array) -> Result where OffsetSize: OffsetSizeTrait, { @@ -3457,8 +3234,8 @@ mod tests { macro_rules! generate_cast_test_case { ($INPUT_ARRAY: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, $OUTPUT_VALUES: expr) => { - let output = $OUTPUT_TYPE_ARRAY::from($OUTPUT_VALUES) - .with_data_type($OUTPUT_TYPE.clone()); + let output = + $OUTPUT_TYPE_ARRAY::from($OUTPUT_VALUES).with_data_type($OUTPUT_TYPE.clone()); // assert cast type let input_array_type = $INPUT_ARRAY.data_type(); @@ -3471,8 +3248,7 @@ mod tests { safe: false, format_options: FormatOptions::default(), }; - let result = - cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap(); + let result = cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap(); assert_eq!($OUTPUT_TYPE, result.data_type()); assert_eq!(result.as_ref(), &output); }; @@ -3806,8 +3582,7 @@ mod tests { #[test] fn test_cast_decimal_to_numeric() { - let value_array: Vec> = - vec![Some(125), Some(225), Some(325), None, Some(525)]; + let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; let array = create_decimal_array(value_array, 38, 2).unwrap(); // u8 generate_cast_test_case!( @@ -4619,8 +4394,7 @@ mod tests { #[test] fn test_cast_i32_to_list_f64_nullable_sliced() { - let array = - Int32Array::from(vec![Some(5), None, Some(7), Some(8), None, Some(10)]); + let array = Int32Array::from(vec![Some(5), None, Some(7), Some(8), None, Some(10)]); let array = array.slice(2, 4); let b = cast( &array, @@ -4670,9 +4444,8 @@ mod tests { Ok(_) => panic!("expected error"), Err(e) => { assert!( - e.to_string().contains( - "Cast error: Cannot cast string 'seven' to value of Int32 type", - ), + e.to_string() + .contains("Cast error: Cannot cast string 'seven' to value of Int32 type",), "Error: {e}" ) } @@ -4683,8 +4456,7 @@ mod tests { fn test_cast_utf8_to_bool() { let strings = StringArray::from(vec!["true", "false", "invalid", " Y ", ""]); let casted = cast(&strings, &DataType::Boolean).unwrap(); - let expected = - BooleanArray::from(vec![Some(true), Some(false), None, Some(true), None]); + let expected = BooleanArray::from(vec![Some(true), Some(false), None, Some(true), None]); assert_eq!(*as_boolean_array(&casted), expected); } @@ -4702,9 +4474,9 @@ mod tests { match casted { Ok(_) => panic!("expected error"), Err(e) => { - assert!(e.to_string().contains( - "Cast error: Cannot cast value 'invalid' to value of Boolean type" - )) + assert!(e + .to_string() + .contains("Cast error: Cannot cast value 'invalid' to value of Boolean type")) } } } @@ -4750,9 +4522,7 @@ mod tests { } #[test] - #[should_panic( - expected = "Casting from Int32 to Timestamp(Microsecond, None) not supported" - )] + #[should_panic(expected = "Casting from Int32 to Timestamp(Microsecond, None) not supported")] fn test_cast_int32_to_timestamp() { let array = Int32Array::from(vec![Some(2), Some(10), None]); cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); @@ -4760,15 +4530,13 @@ mod tests { #[test] fn test_cast_list_i32_to_list_u16() { - let value_data = - Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 100000000]).into_data(); + let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 100000000]).into_data(); let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); // Construct a list array from the above two // [[0,0,0], [-1, -2, -1], [2, 100000000]] - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -4812,19 +4580,15 @@ mod tests { } #[test] - #[should_panic( - expected = "Casting from Int32 to Timestamp(Microsecond, None) not supported" - )] + #[should_panic(expected = "Casting from Int32 to Timestamp(Microsecond, None) not supported")] fn test_cast_list_i32_to_list_timestamp() { // Construct a value array - let value_data = - Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 8, 100000000]).into_data(); + let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 8, 100000000]).into_data(); let value_offsets = Buffer::from_slice_ref([0, 3, 6, 9]); // Construct a list array from the above two - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -4969,7 +4733,10 @@ mod tests { format_options: FormatOptions::default(), }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); - assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid date' to value of Date32 type"); + assert_eq!( + err.to_string(), + "Cast error: Cannot cast string 'Not a valid date' to value of Date32 type" + ); } } @@ -5126,14 +4893,16 @@ mod tests { format_options: FormatOptions::default(), }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); - assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid date' to value of Date64 type"); + assert_eq!( + err.to_string(), + "Cast error: Cannot cast string 'Not a valid date' to value of Date64 type" + ); } } macro_rules! test_safe_string_to_interval { ($data_vec:expr, $interval_unit:expr, $array_ty:ty, $expect_vec:expr) => { - let source_string_array = - Arc::new(StringArray::from($data_vec.clone())) as ArrayRef; + let source_string_array = Arc::new(StringArray::from($data_vec.clone())) as ArrayRef; let options = CastOptions { safe: true, @@ -5427,12 +5196,9 @@ mod tests { #[test] fn test_cast_timestamp_to_date32() { - let array = TimestampMillisecondArray::from(vec![ - Some(864000000005), - Some(1545696000001), - None, - ]) - .with_timezone("UTC".to_string()); + let array = + TimestampMillisecondArray::from(vec![Some(864000000005), Some(1545696000001), None]) + .with_timezone("UTC".to_string()); let b = cast(&array, &DataType::Date32).unwrap(); let c = b.as_primitive::(); assert_eq!(10000, c.value(0)); @@ -5442,19 +5208,15 @@ mod tests { #[test] fn test_cast_timestamp_to_date64() { - let array = TimestampMillisecondArray::from(vec![ - Some(864000000005), - Some(1545696000001), - None, - ]); + let array = + TimestampMillisecondArray::from(vec![Some(864000000005), Some(1545696000001), None]); let b = cast(&array, &DataType::Date64).unwrap(); let c = b.as_primitive::(); assert_eq!(864000000005, c.value(0)); assert_eq!(1545696000001, c.value(1)); assert!(c.is_null(2)); - let array = - TimestampSecondArray::from(vec![Some(864000000005), Some(1545696000001)]); + let array = TimestampSecondArray::from(vec![Some(864000000005), Some(1545696000001)]); let b = cast(&array, &DataType::Date64).unwrap(); let c = b.as_primitive::(); assert_eq!(864000000005000, c.value(0)); @@ -5506,9 +5268,8 @@ mod tests { assert!(c.is_null(2)); // test timestamp microseconds - let a = - TimestampMicrosecondArray::from(vec![Some(86405000000), Some(1000000), None]) - .with_timezone("+01:00".to_string()); + let a = TimestampMicrosecondArray::from(vec![Some(86405000000), Some(1000000), None]) + .with_timezone("+01:00".to_string()); let array = Arc::new(a) as ArrayRef; let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)).unwrap(); let c = b.as_primitive::(); @@ -5522,12 +5283,8 @@ mod tests { assert!(c.is_null(2)); // test timestamp nanoseconds - let a = TimestampNanosecondArray::from(vec![ - Some(86405000000000), - Some(1000000000), - None, - ]) - .with_timezone("+01:00".to_string()); + let a = TimestampNanosecondArray::from(vec![Some(86405000000000), Some(1000000000), None]) + .with_timezone("+01:00".to_string()); let array = Arc::new(a) as ArrayRef; let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)).unwrap(); let c = b.as_primitive::(); @@ -5541,8 +5298,8 @@ mod tests { assert!(c.is_null(2)); // test overflow - let a = TimestampSecondArray::from(vec![Some(i64::MAX)]) - .with_timezone("+01:00".to_string()); + let a = + TimestampSecondArray::from(vec![Some(i64::MAX)]).with_timezone("+01:00".to_string()); let array = Arc::new(a) as ArrayRef; let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)); assert!(b.is_err()); @@ -5585,9 +5342,8 @@ mod tests { assert!(c.is_null(2)); // test timestamp microseconds - let a = - TimestampMicrosecondArray::from(vec![Some(86405000000), Some(1000000), None]) - .with_timezone("+01:00".to_string()); + let a = TimestampMicrosecondArray::from(vec![Some(86405000000), Some(1000000), None]) + .with_timezone("+01:00".to_string()); let array = Arc::new(a) as ArrayRef; let b = cast(&array, &DataType::Time32(TimeUnit::Second)).unwrap(); let c = b.as_primitive::(); @@ -5601,12 +5357,8 @@ mod tests { assert!(c.is_null(2)); // test timestamp nanoseconds - let a = TimestampNanosecondArray::from(vec![ - Some(86405000000000), - Some(1000000000), - None, - ]) - .with_timezone("+01:00".to_string()); + let a = TimestampNanosecondArray::from(vec![Some(86405000000000), Some(1000000000), None]) + .with_timezone("+01:00".to_string()); let array = Arc::new(a) as ArrayRef; let b = cast(&array, &DataType::Time32(TimeUnit::Second)).unwrap(); let c = b.as_primitive::(); @@ -5620,8 +5372,8 @@ mod tests { assert!(c.is_null(2)); // test overflow - let a = TimestampSecondArray::from(vec![Some(i64::MAX)]) - .with_timezone("+01:00".to_string()); + let a = + TimestampSecondArray::from(vec![Some(i64::MAX)]).with_timezone("+01:00".to_string()); let array = Arc::new(a) as ArrayRef; let b = cast(&array, &DataType::Time32(TimeUnit::Second)); assert!(b.is_err()); @@ -5708,8 +5460,7 @@ mod tests { #[test] fn test_cast_date64_to_timestamp() { - let array = - Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); let b = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap(); let c = b.as_primitive::(); assert_eq!(864000000, c.value(0)); @@ -5719,8 +5470,7 @@ mod tests { #[test] fn test_cast_date64_to_timestamp_ms() { - let array = - Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); let b = cast(&array, &DataType::Timestamp(TimeUnit::Millisecond, None)).unwrap(); let c = b .as_any() @@ -5733,8 +5483,7 @@ mod tests { #[test] fn test_cast_date64_to_timestamp_us() { - let array = - Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); let b = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); let c = b .as_any() @@ -5747,8 +5496,7 @@ mod tests { #[test] fn test_cast_date64_to_timestamp_ns() { - let array = - Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); let b = cast(&array, &DataType::Timestamp(TimeUnit::Nanosecond, None)).unwrap(); let c = b .as_any() @@ -5761,12 +5509,9 @@ mod tests { #[test] fn test_cast_timestamp_to_i64() { - let array = TimestampMillisecondArray::from(vec![ - Some(864000000005), - Some(1545696000001), - None, - ]) - .with_timezone("UTC".to_string()); + let array = + TimestampMillisecondArray::from(vec![Some(864000000005), Some(1545696000001), None]) + .with_timezone("UTC".to_string()); let b = cast(&array, &DataType::Int64).unwrap(); let c = b.as_primitive::(); assert_eq!(&DataType::Int64, c.data_type()); @@ -5798,11 +5543,8 @@ mod tests { #[test] fn test_cast_timestamp_to_strings() { // "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None - let array = TimestampMillisecondArray::from(vec![ - Some(864000003005), - Some(1545696002001), - None, - ]); + let array = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]); let out = cast(&array, &DataType::Utf8).unwrap(); let out = out .as_any() @@ -5846,13 +5588,9 @@ mod tests { .with_timestamp_tz_format(Some(ts_format)), }; // "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None - let array_without_tz = TimestampMillisecondArray::from(vec![ - Some(864000003005), - Some(1545696002001), - None, - ]); - let out = - cast_with_options(&array_without_tz, &DataType::Utf8, &cast_options).unwrap(); + let array_without_tz = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]); + let out = cast_with_options(&array_without_tz, &DataType::Utf8, &cast_options).unwrap(); let out = out .as_any() .downcast_ref::() @@ -5868,8 +5606,7 @@ mod tests { ] ); let out = - cast_with_options(&array_without_tz, &DataType::LargeUtf8, &cast_options) - .unwrap(); + cast_with_options(&array_without_tz, &DataType::LargeUtf8, &cast_options).unwrap(); let out = out .as_any() .downcast_ref::() @@ -5885,14 +5622,10 @@ mod tests { ] ); - let array_with_tz = TimestampMillisecondArray::from(vec![ - Some(864000003005), - Some(1545696002001), - None, - ]) - .with_timezone(tz.to_string()); - let out = - cast_with_options(&array_with_tz, &DataType::Utf8, &cast_options).unwrap(); + let array_with_tz = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]) + .with_timezone(tz.to_string()); + let out = cast_with_options(&array_with_tz, &DataType::Utf8, &cast_options).unwrap(); let out = out .as_any() .downcast_ref::() @@ -5907,8 +5640,7 @@ mod tests { None ] ); - let out = cast_with_options(&array_with_tz, &DataType::LargeUtf8, &cast_options) - .unwrap(); + let out = cast_with_options(&array_with_tz, &DataType::LargeUtf8, &cast_options).unwrap(); let out = out .as_any() .downcast_ref::() @@ -5927,11 +5659,8 @@ mod tests { #[test] fn test_cast_between_timestamps() { - let array = TimestampMillisecondArray::from(vec![ - Some(864000003005), - Some(1545696002001), - None, - ]); + let array = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]); let b = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap(); let c = b.as_primitive::(); assert_eq!(864000003, c.value(0)); @@ -6335,8 +6064,7 @@ mod tests { ]; let u64_array: ArrayRef = Arc::new(UInt64Array::from(u64_values)); - let f64_expected = - vec![0.0, 255.0, 65535.0, 4294967295.0, 18446744073709552000.0]; + let f64_expected = vec![0.0, 255.0, 65535.0, 4294967295.0, 18446744073709552000.0]; assert_eq!( f64_expected, get_cast_values::(&u64_array, &DataType::Float64) @@ -6345,8 +6073,7 @@ mod tests { .collect::>() ); - let f32_expected = - vec![0.0, 255.0, 65535.0, 4294967300.0, 18446744000000000000.0]; + let f32_expected = vec![0.0, 255.0, 65535.0, 4294967300.0, 18446744000000000000.0]; assert_eq!( f32_expected, get_cast_values::(&u64_array, &DataType::Float32) @@ -6379,8 +6106,7 @@ mod tests { get_cast_values::(&u64_array, &DataType::Int8) ); - let u64_expected = - vec!["0", "255", "65535", "4294967295", "18446744073709551615"]; + let u64_expected = vec!["0", "255", "65535", "4294967295", "18446744073709551615"]; assert_eq!( u64_expected, get_cast_values::(&u64_array, &DataType::UInt64) @@ -6811,15 +6537,13 @@ mod tests { get_cast_values::(&i32_array, &DataType::Int8) ); - let u64_expected = - vec!["null", "null", "null", "0", "127", "32767", "2147483647"]; + let u64_expected = vec!["null", "null", "null", "0", "127", "32767", "2147483647"]; assert_eq!( u64_expected, get_cast_values::(&i32_array, &DataType::UInt64) ); - let u32_expected = - vec!["null", "null", "null", "0", "127", "32767", "2147483647"]; + let u32_expected = vec!["null", "null", "null", "0", "127", "32767", "2147483647"]; assert_eq!( u32_expected, get_cast_values::(&i32_array, &DataType::UInt32) @@ -6855,8 +6579,7 @@ mod tests { #[test] fn test_cast_from_int16() { - let i16_values: Vec = - vec![i16::MIN, i8::MIN as i16, 0, i8::MAX as i16, i16::MAX]; + let i16_values: Vec = vec![i16::MIN, i8::MIN as i16, 0, i8::MAX as i16, i16::MAX]; let i16_array: ArrayRef = Arc::new(Int16Array::from(i16_values)); let f64_expected = vec!["-32768.0", "-128.0", "0.0", "127.0", "32767.0"]; @@ -7197,8 +6920,7 @@ mod tests { fn test_cast_string_array_to_dict() { use DataType::*; - let array = Arc::new(StringArray::from(vec![Some("one"), None, Some("three")])) - as ArrayRef; + let array = Arc::new(StringArray::from(vec![Some("one"), None, Some("three")])) as ArrayRef; let expected = vec!["one", "null", "three"]; @@ -7297,16 +7019,12 @@ mod tests { cast_from_null_to_other(&data_type); // Cast null from and to list - let data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); cast_from_null_to_other(&data_type); - let data_type = - DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + let data_type = DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); cast_from_null_to_other(&data_type); - let data_type = DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Int32, true)), - 4, - ); + let data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 4); cast_from_null_to_other(&data_type); // Cast null from and to dictionary @@ -7317,8 +7035,7 @@ mod tests { cast_from_null_to_other(&data_type); // Cast null from and to struct - let data_type = - DataType::Struct(vec![Field::new("data", DataType::Int64, false)].into()); + let data_type = DataType::Struct(vec![Field::new("data", DataType::Int64, false)].into()); cast_from_null_to_other(&data_type); } @@ -7511,8 +7228,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); // Construct a list array from the above two - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -7554,10 +7270,8 @@ mod tests { .build() .unwrap(); - let list_data_type = DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Int32, true)), - 4, - ); + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 4); let list_data = ArrayData::builder(list_data_type) .len(2) .add_child_data(value_data) @@ -7574,10 +7288,8 @@ mod tests { .build() .unwrap(); - let list_data_type = DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Int64, true)), - 4, - ); + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 4); let list_data = ArrayData::builder(list_data_type) .len(2) .add_child_data(value_data) @@ -7618,8 +7330,7 @@ mod tests { let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); let value_data = str_array.into_data(); - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -7958,12 +7669,7 @@ mod tests { let array = vec![Some(123)]; let input_decimal_array = create_decimal_array(array, 10, -1).unwrap(); let array = Arc::new(input_decimal_array) as ArrayRef; - generate_cast_test_case!( - &array, - Decimal128Array, - &output_type, - vec![Some(12_i128),] - ); + generate_cast_test_case!(&array, Decimal128Array, &output_type, vec![Some(12_i128),]); let casted_array = cast(&array, &output_type).unwrap(); let decimal_arr = casted_array.as_primitive::(); @@ -7973,12 +7679,7 @@ mod tests { let array = vec![Some(125)]; let input_decimal_array = create_decimal_array(array, 10, -1).unwrap(); let array = Arc::new(input_decimal_array) as ArrayRef; - generate_cast_test_case!( - &array, - Decimal128Array, - &output_type, - vec![Some(13_i128),] - ); + generate_cast_test_case!(&array, Decimal128Array, &output_type, vec![Some(13_i128),]); let casted_array = cast(&array, &output_type).unwrap(); let decimal_arr = casted_array.as_primitive::(); @@ -8220,9 +7921,9 @@ mod tests { let str_array = StringArray::from(vec![". 0.123"]); let array = Arc::new(str_array) as ArrayRef; let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); - assert!(casted_err.to_string().contains( - "Cannot cast string '. 0.123' to value of Decimal128(38, 10) type" - )); + assert!(casted_err + .to_string() + .contains("Cannot cast string '. 0.123' to value of Decimal128(38, 10) type")); } fn test_cast_string_to_decimal128_overflow(overflow_array: ArrayRef) { @@ -8499,9 +8200,8 @@ mod tests { let tz = tz.as_ref().parse().unwrap(); - let as_tz = |v: i64| { - as_datetime_with_timezone::(v, tz).unwrap() - }; + let as_tz = + |v: i64| as_datetime_with_timezone::(v, tz).unwrap(); let as_utc = |v: &i64| as_tz(*v).naive_utc().to_string(); let as_local = |v: &i64| as_tz(*v).naive_local().to_string(); @@ -8611,8 +8311,7 @@ mod tests { None, ]; - let array256: Vec> = - array128.iter().map(|v| v.map(i256::from_i128)).collect(); + let array256: Vec> = array128.iter().map(|v| v.map(i256::from_i128)).collect(); test_decimal_to_string::( DataType::Utf8, @@ -8701,11 +8400,9 @@ mod tests { fn test_cast_from_duration_to_interval() { // from duration second to interval month day nano let array = vec![1234567]; - let casted_array = cast_from_duration_to_interval::( - array, - &CastOptions::default(), - ) - .unwrap(); + let casted_array = + cast_from_duration_to_interval::(array, &CastOptions::default()) + .unwrap(); assert_eq!( casted_array.data_type(), &DataType::Interval(IntervalUnit::MonthDayNano) @@ -8824,10 +8521,7 @@ mod tests { .as_any() .downcast_ref::>() .ok_or_else(|| { - ArrowError::ComputeError(format!( - "Failed to downcast to {}", - T::DATA_TYPE - )) + ArrowError::ComputeError(format!("Failed to downcast to {}", T::DATA_TYPE)) }) .cloned() } @@ -8865,8 +8559,7 @@ mod tests { cast_from_interval_to_duration(&array, &nullable).unwrap(); assert!(!casted_array.is_valid(0)); - let res = - cast_from_interval_to_duration::(&array, &fallible); + let res = cast_from_interval_to_duration::(&array, &fallible); assert!(res.is_err()); // from interval month day nano to duration microsecond @@ -8877,8 +8570,7 @@ mod tests { let array = vec![i128::MAX].into(); let casted_array = - cast_from_interval_to_duration::(&array, &nullable) - .unwrap(); + cast_from_interval_to_duration::(&array, &nullable).unwrap(); assert!(!casted_array.is_valid(0)); let casted_array = @@ -8909,8 +8601,7 @@ mod tests { ] .into(); let casted_array = - cast_from_interval_to_duration::(&array, &nullable) - .unwrap(); + cast_from_interval_to_duration::(&array, &nullable).unwrap(); assert!(!casted_array.is_valid(0)); assert!(!casted_array.is_valid(1)); assert!(!casted_array.is_valid(2)); @@ -8979,11 +8670,9 @@ mod tests { fn test_cast_from_interval_day_time_to_interval_month_day_nano() { // from interval day time to interval month day nano let array = vec![123]; - let casted_array = cast_from_interval_day_time_to_interval_month_day_nano( - array, - &CastOptions::default(), - ) - .unwrap(); + let casted_array = + cast_from_interval_day_time_to_interval_month_day_nano(array, &CastOptions::default()) + .unwrap(); assert_eq!( casted_array.data_type(), &DataType::Interval(IntervalUnit::MonthDayNano) @@ -9017,8 +8706,7 @@ mod tests { .map(|ts| ts / 1_000_000) .collect::>(); - let array = - TimestampMillisecondArray::from(ts_array).with_timezone("UTC".to_string()); + let array = TimestampMillisecondArray::from(ts_array).with_timezone("UTC".to_string()); let casted_array = cast(&array, &DataType::Date32).unwrap(); let date_array = casted_array.as_primitive::(); let casted_array = cast(&date_array, &DataType::Utf8).unwrap(); diff --git a/arrow-cast/src/display.rs b/arrow-cast/src/display.rs index 246135e114bc..28c29c94bbdb 100644 --- a/arrow-cast/src/display.rs +++ b/arrow-cast/src/display.rs @@ -129,10 +129,7 @@ impl<'a> FormatOptions<'a> { } /// Overrides the format used for [`DataType::Timestamp`] columns with a timezone - pub const fn with_timestamp_tz_format( - self, - timestamp_tz_format: Option<&'a str>, - ) -> Self { + pub const fn with_timestamp_tz_format(self, timestamp_tz_format: Option<&'a str>) -> Self { Self { timestamp_tz_format, ..self @@ -173,9 +170,7 @@ impl<'a> ValueFormatter<'a> { match self.formatter.format.write(self.idx, s) { Ok(_) => Ok(()), Err(FormatError::Arrow(e)) => Err(e), - Err(FormatError::Format(_)) => { - Err(ArrowError::CastError("Format error".to_string())) - } + Err(FormatError::Format(_)) => Err(ArrowError::CastError("Format error".to_string())), } } @@ -260,10 +255,7 @@ impl<'a> ArrayFormatter<'a> { /// Returns an [`ArrayFormatter`] that can be used to format `array` /// /// This returns an error if an array of the given data type cannot be formatted - pub fn try_new( - array: &'a dyn Array, - options: &FormatOptions<'a>, - ) -> Result { + pub fn try_new(array: &'a dyn Array, options: &FormatOptions<'a>) -> Result { Ok(Self { format: make_formatter(array, options)?, safe: options.safe, @@ -472,9 +464,7 @@ fn write_timestamp( let date = Utc.from_utc_datetime(&naive).with_timezone(&tz); match format { Some(s) => write!(f, "{}", date.format(s))?, - None => { - write!(f, "{}", date.to_rfc3339_opts(SecondsFormat::AutoSi, true))? - } + None => write!(f, "{}", date.to_rfc3339_opts(SecondsFormat::AutoSi, true))?, } } None => match format { @@ -526,19 +516,11 @@ macro_rules! temporal_display { impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { type State = TimeFormat<'a>; - fn prepare( - &self, - options: &FormatOptions<'a>, - ) -> Result { + fn prepare(&self, options: &FormatOptions<'a>) -> Result { Ok(options.$format) } - fn write( - &self, - fmt: &Self::State, - idx: usize, - f: &mut dyn Write, - ) -> FormatResult { + fn write(&self, fmt: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { let value = self.value(idx); let naive = $convert(value as _).ok_or_else(|| { ArrowError::CastError(format!( @@ -575,19 +557,11 @@ macro_rules! duration_display { impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { type State = DurationFormat; - fn prepare( - &self, - options: &FormatOptions<'a>, - ) -> Result { + fn prepare(&self, options: &FormatOptions<'a>) -> Result { Ok(options.duration_format) } - fn write( - &self, - fmt: &Self::State, - idx: usize, - f: &mut dyn Write, - ) -> FormatResult { + fn write(&self, fmt: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { let v = self.value(idx); match fmt { DurationFormat::ISO8601 => write!(f, "{}", $convert(v))?, @@ -704,8 +678,7 @@ impl<'a> DisplayIndex for &'a PrimitiveArray { fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { let value: u128 = self.value(idx) as u128; - let months_part: i32 = - ((value & 0xFFFFFFFF000000000000000000000000) >> 96) as i32; + let months_part: i32 = ((value & 0xFFFFFFFF000000000000000000000000) >> 96) as i32; let days_part: i32 = ((value & 0xFFFFFFFF0000000000000000) >> 64) as i32; let nanoseconds_part: i64 = (value & 0xFFFFFFFFFFFFFFFF) as i64; @@ -937,10 +910,7 @@ impl<'a> DisplayIndexState<'a> for &'a UnionArray { /// suitable for converting large arrays or record batches. /// /// Please see [`ArrayFormatter`] for a more performant interface -pub fn array_value_to_string( - column: &dyn Array, - row: usize, -) -> Result { +pub fn array_value_to_string(column: &dyn Array, row: usize) -> Result { let options = FormatOptions::default().with_display_error(true); let formatter = ArrayFormatter::try_new(column, &options)?; Ok(formatter.value(row).to_string()) @@ -986,12 +956,9 @@ mod tests { // [[a, b, c], [d, e, f], [g, h]] let entry_offsets = [0, 3, 6, 8]; - let map_array = MapArray::new_from_strings( - keys.clone().into_iter(), - &values_data, - &entry_offsets, - ) - .unwrap(); + let map_array = + MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets) + .unwrap(); assert_eq!( "{d: 30, e: 40, f: 50}", array_value_to_string(&map_array, 1).unwrap() @@ -1006,8 +973,7 @@ mod tests { #[test] fn test_array_value_to_string_duration() { let iso_fmt = FormatOptions::new(); - let pretty_fmt = - FormatOptions::new().with_duration_format(DurationFormat::Pretty); + let pretty_fmt = FormatOptions::new().with_duration_format(DurationFormat::Pretty); let array = DurationNanosecondArray::from(vec![ 1, diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs index 3806f0adc5d6..f01b2b4c0d63 100644 --- a/arrow-cast/src/parse.rs +++ b/arrow-cast/src/parse.rs @@ -64,10 +64,7 @@ impl TimestampParser { /// Parses a date of the form `1997-01-31` fn date(&self) -> Option { - if self.mask & 0b1111111111 != 0b1101101111 - || !self.test(4, b'-') - || !self.test(7, b'-') - { + if self.mask & 0b1111111111 != 0b1101101111 || !self.test(4, b'-') || !self.test(7, b'-') { return None; } @@ -173,13 +170,9 @@ impl TimestampParser { /// * "2023-01-01 04:05:06.789 PST", /// /// [IANA timezones]: https://www.iana.org/time-zones -pub fn string_to_datetime( - timezone: &T, - s: &str, -) -> Result, ArrowError> { - let err = |ctx: &str| { - ArrowError::ParseError(format!("Error parsing timestamp from '{s}': {ctx}")) - }; +pub fn string_to_datetime(timezone: &T, s: &str) -> Result, ArrowError> { + let err = + |ctx: &str| ArrowError::ParseError(format!("Error parsing timestamp from '{s}': {ctx}")); let bytes = s.as_bytes(); if bytes.len() < 10 { @@ -300,9 +293,8 @@ fn to_timestamp_nanos(dt: NaiveDateTime) -> Result { /// This function does not support parsing strings with a timezone /// or offset specified, as it considers only time since midnight. pub fn string_to_time_nanoseconds(s: &str) -> Result { - let nt = string_to_time(s).ok_or_else(|| { - ArrowError::ParseError(format!("Failed to parse \'{s}\' as time")) - })?; + let nt = string_to_time(s) + .ok_or_else(|| ArrowError::ParseError(format!("Failed to parse \'{s}\' as time")))?; Ok(nt.num_seconds_from_midnight() as i64 * 1_000_000_000 + nt.nanosecond() as i64) } @@ -313,12 +305,8 @@ fn string_to_time(s: &str) -> Option { } let (am, bytes) = match bytes.get(bytes.len() - 3..) { - Some(b" AM" | b" am" | b" Am" | b" aM") => { - (Some(true), &bytes[..bytes.len() - 3]) - } - Some(b" PM" | b" pm" | b" pM" | b" Pm") => { - (Some(false), &bytes[..bytes.len() - 3]) - } + Some(b" AM" | b" am" | b" Am" | b" aM") => (Some(true), &bytes[..bytes.len() - 3]), + Some(b" PM" | b" pm" | b" pM" | b" Pm") => (Some(false), &bytes[..bytes.len() - 3]), _ => (None, bytes), }; @@ -501,10 +489,7 @@ impl Parser for Time64NanosecondType { fn parse_formatted(string: &str, format: &str) -> Option { let nt = NaiveTime::parse_from_str(string, format).ok()?; - Some( - nt.num_seconds_from_midnight() as i64 * 1_000_000_000 - + nt.nanosecond() as i64, - ) + Some(nt.num_seconds_from_midnight() as i64 * 1_000_000_000 + nt.nanosecond() as i64) } } @@ -519,10 +504,7 @@ impl Parser for Time64MicrosecondType { fn parse_formatted(string: &str, format: &str) -> Option { let nt = NaiveTime::parse_from_str(string, format).ok()?; - Some( - nt.num_seconds_from_midnight() as i64 * 1_000_000 - + nt.nanosecond() as i64 / 1_000, - ) + Some(nt.num_seconds_from_midnight() as i64 * 1_000_000 + nt.nanosecond() as i64 / 1_000) } } @@ -537,10 +519,7 @@ impl Parser for Time32MillisecondType { fn parse_formatted(string: &str, format: &str) -> Option { let nt = NaiveTime::parse_from_str(string, format).ok()?; - Some( - nt.num_seconds_from_midnight() as i32 * 1_000 - + nt.nanosecond() as i32 / 1_000_000, - ) + Some(nt.num_seconds_from_midnight() as i32 * 1_000 + nt.nanosecond() as i32 / 1_000_000) } } @@ -555,10 +534,7 @@ impl Parser for Time32SecondType { fn parse_formatted(string: &str, format: &str) -> Option { let nt = NaiveTime::parse_from_str(string, format).ok()?; - Some( - nt.num_seconds_from_midnight() as i32 - + nt.nanosecond() as i32 / 1_000_000_000, - ) + Some(nt.num_seconds_from_midnight() as i32 + nt.nanosecond() as i32 / 1_000_000_000) } } @@ -615,10 +591,8 @@ fn parse_date(string: &str) -> Option { _ => return None, }; - let year = digits[0] as u16 * 1000 - + digits[1] as u16 * 100 - + digits[2] as u16 * 10 - + digits[3] as u16; + let year = + digits[0] as u16 * 1000 + digits[1] as u16 * 100 + digits[2] as u16 * 10 + digits[3] as u16; NaiveDate::from_ymd_opt(year as _, month as _, day as _) } @@ -728,8 +702,7 @@ pub fn parse_decimal( fractionals += 1; digits += 1; result = result.mul_wrapping(base); - result = - result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); + result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); } // Fail on "." @@ -771,9 +744,11 @@ pub fn parse_interval_year_month( let config = IntervalParseConfig::new(IntervalUnit::Year); let interval = Interval::parse(value, &config)?; - let months = interval.to_year_months().map_err(|_| ArrowError::CastError(format!( + let months = interval.to_year_months().map_err(|_| { + ArrowError::CastError(format!( "Cannot cast {value} to IntervalYearMonth. Only year and month fields are allowed." - )))?; + )) + })?; Ok(IntervalYearMonthType::make_value(0, months)) } @@ -888,21 +863,16 @@ impl FromStr for IntervalAmount { Ok(0) } else { integer.parse::().map_err(|_| { - ArrowError::ParseError(format!( - "Failed to parse {s} as interval amount" - )) + ArrowError::ParseError(format!("Failed to parse {s} as interval amount")) }) }?; let frac_unscaled = frac.parse::().map_err(|_| { - ArrowError::ParseError(format!( - "Failed to parse {s} as interval amount" - )) + ArrowError::ParseError(format!("Failed to parse {s} as interval amount")) })?; // scale fractional part by interval precision - let frac = - frac_unscaled * 10_i64.pow(INTERVAL_PRECISION - frac.len() as u32); + let frac = frac_unscaled * 10_i64.pow(INTERVAL_PRECISION - frac.len() as u32); // propagate the sign of the integer part to the fractional part let frac = if integer < 0 || explicit_neg { @@ -915,9 +885,9 @@ impl FromStr for IntervalAmount { Ok(result) } - Some((_, frac)) if frac.starts_with('-') => Err(ArrowError::ParseError( - format!("Failed to parse {s} as interval amount"), - )), + Some((_, frac)) if frac.starts_with('-') => Err(ArrowError::ParseError(format!( + "Failed to parse {s} as interval amount" + ))), Some((_, frac)) if frac.len() > INTERVAL_PRECISION as usize => { Err(ArrowError::ParseError(format!( "{s} exceeds the precision available for interval amount" @@ -925,9 +895,7 @@ impl FromStr for IntervalAmount { } Some(_) | None => { let integer = s.parse::().map_err(|_| { - ArrowError::ParseError(format!( - "Failed to parse {s} as interval amount" - )) + ArrowError::ParseError(format!("Failed to parse {s} as interval amount")) })?; let result = Self { integer, frac: 0 }; @@ -1005,25 +973,20 @@ impl Interval { /// e.g. INTERVAL '0.5 MONTH' = 15 days, INTERVAL '1.5 MONTH' = 1 month 15 days /// e.g. INTERVAL '0.5 DAY' = 12 hours, INTERVAL '1.5 DAY' = 1 day 12 hours /// [Postgres reference](https://www.postgresql.org/docs/15/datatype-datetime.html#DATATYPE-INTERVAL-INPUT:~:text=Field%20values%20can,fractional%20on%20output.) - fn add( - &self, - amount: IntervalAmount, - unit: IntervalUnit, - ) -> Result { + fn add(&self, amount: IntervalAmount, unit: IntervalUnit) -> Result { let result = match unit { IntervalUnit::Century => { let months_int = amount.integer.mul_checked(100)?.mul_checked(12)?; let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION - 2); - let months = - months_int - .add_checked(month_frac)? - .try_into() - .map_err(|_| { - ArrowError::ParseError(format!( - "Unable to represent {} centuries as months in a signed 32-bit integer", - &amount.integer - )) - })?; + let months = months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} centuries as months in a signed 32-bit integer", + &amount.integer + )) + })?; Self::new(self.months.add_checked(months)?, self.days, self.nanos) } @@ -1031,32 +994,30 @@ impl Interval { let months_int = amount.integer.mul_checked(10)?.mul_checked(12)?; let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION - 1); - let months = - months_int - .add_checked(month_frac)? - .try_into() - .map_err(|_| { - ArrowError::ParseError(format!( - "Unable to represent {} decades as months in a signed 32-bit integer", - &amount.integer - )) - })?; + let months = months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} decades as months in a signed 32-bit integer", + &amount.integer + )) + })?; Self::new(self.months.add_checked(months)?, self.days, self.nanos) } IntervalUnit::Year => { let months_int = amount.integer.mul_checked(12)?; let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION); - let months = - months_int - .add_checked(month_frac)? - .try_into() - .map_err(|_| { - ArrowError::ParseError(format!( - "Unable to represent {} years as months in a signed 32-bit integer", - &amount.integer - )) - })?; + let months = months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} years as months in a signed 32-bit integer", + &amount.integer + )) + })?; Self::new(self.months.add_checked(months)?, self.days, self.nanos) } @@ -1090,8 +1051,7 @@ impl Interval { )) })?; - let nanos = - amount.frac * 7 * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + let nanos = amount.frac * 7 * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); Self::new( self.months, @@ -1107,8 +1067,7 @@ impl Interval { )) })?; - let nanos = - amount.frac * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + let nanos = amount.frac * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); Self::new( self.months, @@ -1118,8 +1077,7 @@ impl Interval { } IntervalUnit::Hour => { let nanos_int = amount.integer.mul_checked(NANOS_PER_HOUR)?; - let nanos_frac = - amount.frac * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + let nanos_frac = amount.frac * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); let nanos = nanos_int.add_checked(nanos_frac)?; Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) @@ -1398,8 +1356,7 @@ mod tests { "2030-12-04T17:11:10.123456", ]; for case in cases { - let chrono = - NaiveDateTime::parse_from_str(case, "%Y-%m-%dT%H:%M:%S%.f").unwrap(); + let chrono = NaiveDateTime::parse_from_str(case, "%Y-%m-%dT%H:%M:%S%.f").unwrap(); let custom = string_to_datetime(&Utc, case).unwrap(); assert_eq!(chrono, custom.naive_utc()) } @@ -1431,8 +1388,7 @@ mod tests { ]; for (s, ctx) in cases { - let expected = - format!("Parser error: Error parsing timestamp from '{s}': {ctx}"); + let expected = format!("Parser error: Error parsing timestamp from '{s}': {ctx}"); let actual = string_to_datetime(&Utc, s).unwrap_err().to_string(); assert_eq!(actual, expected) } @@ -1497,8 +1453,7 @@ mod tests { assert_eq!(local, "2020-09-08 15:42:29"); let dt = - NaiveDateTime::parse_from_str("2020-09-08T13:42:29Z", "%Y-%m-%dT%H:%M:%SZ") - .unwrap(); + NaiveDateTime::parse_from_str("2020-09-08T13:42:29Z", "%Y-%m-%dT%H:%M:%SZ").unwrap(); let local: Tz = "+08:00".parse().unwrap(); // Parsed as offset from UTC @@ -1629,10 +1584,7 @@ mod tests { // custom format assert_eq!( - Time64NanosecondType::parse_formatted( - "02 - 10 - 01 - .1234567", - "%H - %M - %S - %.f" - ), + Time64NanosecondType::parse_formatted("02 - 10 - 01 - .1234567", "%H - %M - %S - %.f"), Some(7_801_123_456_700) ); } @@ -1709,10 +1661,7 @@ mod tests { // custom format assert_eq!( - Time64MicrosecondType::parse_formatted( - "02 - 10 - 01 - .1234", - "%H - %M - %S - %.f" - ), + Time64MicrosecondType::parse_formatted("02 - 10 - 01 - .1234", "%H - %M - %S - %.f"), Some(7_801_123_400) ); } @@ -1759,10 +1708,7 @@ mod tests { // custom format assert_eq!( - Time32MillisecondType::parse_formatted( - "02 - 10 - 01 - .1", - "%H - %M - %S - %.f" - ), + Time32MillisecondType::parse_formatted("02 - 10 - 01 - .1", "%H - %M - %S - %.f"), Some(7_801_100) ); } @@ -2005,8 +1951,19 @@ mod tests { ); assert_eq!( - Interval::new(-13i32, -8i32, -NANOS_PER_HOUR - NANOS_PER_MINUTE - NANOS_PER_SECOND - (1.11_f64 * NANOS_PER_MILLIS as f64) as i64), - Interval::parse("-1 year -1 month -1 week -1 day -1 hour -1 minute -1 second -1.11 millisecond", &config).unwrap(), + Interval::new( + -13i32, + -8i32, + -NANOS_PER_HOUR + - NANOS_PER_MINUTE + - NANOS_PER_SECOND + - (1.11_f64 * NANOS_PER_MILLIS as f64) as i64 + ), + Interval::parse( + "-1 year -1 month -1 week -1 day -1 hour -1 minute -1 second -1.11 millisecond", + &config + ) + .unwrap(), ); } @@ -2280,22 +2237,34 @@ mod tests { let edge_tests_256 = [ ( "9999999999999999999999999999999999999999999999999999999999999999999999999999", -i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(), + i256::from_string( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + ) + .unwrap(), 0, ), ( "999999999999999999999999999999999999999999999999999999999999999999999999.9999", - i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(), + i256::from_string( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + ) + .unwrap(), 4, ), ( "99999999999999999999999999999999999999999999999999.99999999999999999999999999", - i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(), + i256::from_string( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + ) + .unwrap(), 26, ), ( "99999999999999999999999999999999999999999999999999", - i256::from_string("9999999999999999999999999999999999999999999999999900000000000000000000000000").unwrap(), + i256::from_string( + "9999999999999999999999999999999999999999999999999900000000000000000000000000", + ) + .unwrap(), 26, ), ]; diff --git a/arrow-cast/src/pretty.rs b/arrow-cast/src/pretty.rs index 59a9f9d605e2..550afa9f739d 100644 --- a/arrow-cast/src/pretty.rs +++ b/arrow-cast/src/pretty.rs @@ -25,9 +25,7 @@ use comfy_table::{Cell, Table}; use std::fmt::Display; /// Create a visual representation of record batches -pub fn pretty_format_batches( - results: &[RecordBatch], -) -> Result { +pub fn pretty_format_batches(results: &[RecordBatch]) -> Result { let options = FormatOptions::default().with_display_error(true); pretty_format_batches_with_options(results, &options) } @@ -70,10 +68,7 @@ pub fn print_columns(col_name: &str, results: &[ArrayRef]) -> Result<(), ArrowEr } /// Convert a series of record batches into a table -fn create_table( - results: &[RecordBatch], - options: &FormatOptions, -) -> Result { +fn create_table(results: &[RecordBatch], options: &FormatOptions) -> Result { let mut table = Table::new(); table.load_preset("||--+-++| ++++++"); @@ -209,8 +204,8 @@ mod tests { let table = pretty_format_columns("a", &columns).unwrap().to_string(); let expected = vec![ - "+---+", "| a |", "+---+", "| a |", "| b |", "| |", "| d |", "| e |", - "| |", "| g |", "+---+", + "+---+", "| a |", "+---+", "| a |", "| b |", "| |", "| d |", "| e |", "| |", + "| g |", "+---+", ]; let actual: Vec<&str> = table.lines().collect(); @@ -289,10 +284,8 @@ mod tests { #[test] fn test_pretty_format_fixed_size_list() { // define a schema. - let field_type = DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Int32, true)), - 3, - ); + let field_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); let keys_builder = Int32Array::builder(3); @@ -383,10 +376,7 @@ mod tests { }; } - fn timestamp_batch( - timezone: &str, - value: T::Native, - ) -> RecordBatch { + fn timestamp_batch(timezone: &str, value: T::Native) -> RecordBatch { let mut builder = PrimitiveBuilder::::with_capacity(10); builder.append_value(value); builder.append_null(); @@ -621,8 +611,8 @@ mod tests { let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ - "+------+", "| f |", "+------+", "| 101 |", "| |", "| 200 |", - "| 3040 |", "+------+", + "+------+", "| f |", "+------+", "| 101 |", "| |", "| 200 |", "| 3040 |", + "+------+", ]; let actual: Vec<&str> = table.lines().collect(); @@ -660,16 +650,14 @@ mod tests { )), Arc::new(StructArray::from(vec![( Arc::new(Field::new("c121", DataType::Utf8, false)), - Arc::new(StringArray::from(vec![Some("e"), Some("f"), Some("g")])) - as ArrayRef, + Arc::new(StringArray::from(vec![Some("e"), Some("f"), Some("g")])) as ArrayRef, )])) as ArrayRef, ), ]); let c2 = StringArray::from(vec![Some("a"), Some("b"), Some("c")]); let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap(); let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ @@ -705,8 +693,7 @@ mod tests { UnionMode::Dense, )]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); let table = pretty_format_batches(&[batch]).unwrap().to_string(); let actual: Vec<&str> = table.lines().collect(); let expected = vec![ @@ -742,8 +729,7 @@ mod tests { UnionMode::Sparse, )]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); let table = pretty_format_batches(&[batch]).unwrap().to_string(); let actual: Vec<&str> = table.lines().collect(); let expected = vec![ @@ -799,8 +785,7 @@ mod tests { UnionMode::Sparse, )]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(outer)]).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(outer)]).unwrap(); let table = pretty_format_batches(&[batch]).unwrap().to_string(); let actual: Vec<&str> = table.lines().collect(); let expected = vec![ @@ -882,8 +867,7 @@ mod tests { let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ - "+------+", "| f16 |", "+------+", "| NaN |", "| 4 |", "| -inf |", - "+------+", + "+------+", "| f16 |", "+------+", "| NaN |", "| 4 |", "| -inf |", "+------+", ]; let actual: Vec<&str> = table.lines().collect(); @@ -986,9 +970,7 @@ mod tests { fn test_format_options() { let options = FormatOptions::default().with_null("null"); let array = Int32Array::from(vec![Some(1), Some(2), None, Some(3), Some(4)]); - let batch = - RecordBatch::try_from_iter([("my_column_name", Arc::new(array) as _)]) - .unwrap(); + let batch = RecordBatch::try_from_iter([("my_column_name", Arc::new(array) as _)]).unwrap(); let column = pretty_format_columns_with_options( "my_column_name", diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index a194b35ffa46..83c8965fdf8a 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -292,8 +292,7 @@ impl Format { let header_length = headers.len(); // keep track of inferred field types - let mut column_types: Vec = - vec![Default::default(); header_length]; + let mut column_types: Vec = vec![Default::default(); header_length]; let mut records_count = 0; @@ -307,9 +306,7 @@ impl Format { // Note since we may be looking at a sample of the data, we make the safe assumption that // they could be nullable - for (i, column_type) in - column_types.iter_mut().enumerate().take(header_length) - { + for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) { if let Some(string) = record.get(i) { if !self.null_regex.is_null(string) { column_type.update(string) @@ -606,8 +603,7 @@ impl Decoder { return Ok(bytes); } - let to_read = - self.batch_size.min(self.end - self.line_number) - self.record_decoder.len(); + let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len(); let (_, bytes) = self.record_decoder.decode(buf, to_read)?; Ok(bytes) } @@ -662,29 +658,23 @@ fn parse( let i = *i; let field = &fields[i]; match field.data_type() { - DataType::Boolean => { - build_boolean_array(line_number, rows, i, null_regex) - } - DataType::Decimal128(precision, scale) => { - build_decimal_array::( - line_number, - rows, - i, - *precision, - *scale, - null_regex, - ) - } - DataType::Decimal256(precision, scale) => { - build_decimal_array::( - line_number, - rows, - i, - *precision, - *scale, - null_regex, - ) - } + DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex), + DataType::Decimal128(precision, scale) => build_decimal_array::( + line_number, + rows, + i, + *precision, + *scale, + null_regex, + ), + DataType::Decimal256(precision, scale) => build_decimal_array::( + line_number, + rows, + i, + *precision, + *scale, + null_regex, + ), DataType::Int8 => { build_primitive_array::(line_number, rows, i, null_regex) } @@ -721,34 +711,17 @@ fn parse( DataType::Date64 => { build_primitive_array::(line_number, rows, i, null_regex) } - DataType::Time32(TimeUnit::Second) => build_primitive_array::< - Time32SecondType, - >( - line_number, rows, i, null_regex - ), + DataType::Time32(TimeUnit::Second) => { + build_primitive_array::(line_number, rows, i, null_regex) + } DataType::Time32(TimeUnit::Millisecond) => { - build_primitive_array::( - line_number, - rows, - i, - null_regex, - ) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Time64(TimeUnit::Microsecond) => { - build_primitive_array::( - line_number, - rows, - i, - null_regex, - ) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Time64(TimeUnit::Nanosecond) => { - build_primitive_array::( - line_number, - rows, - i, - null_regex, - ) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Timestamp(TimeUnit::Second, tz) => { build_timestamp_array::( @@ -786,9 +759,7 @@ fn parse( null_regex, ) } - DataType::Null => { - Ok(Arc::new(NullArray::builder(rows.len()).finish()) as ArrayRef) - } + DataType::Null => Ok(Arc::new(NullArray::builder(rows.len()).finish()) as ArrayRef), DataType::Utf8 => Ok(Arc::new( rows.iter() .map(|row| { @@ -853,8 +824,7 @@ fn parse( }) .collect(); - let projected_fields: Fields = - projection.iter().map(|i| fields[*i].clone()).collect(); + let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect(); let projected_schema = Arc::new(match metadata { None => Schema::new(projected_fields), @@ -898,8 +868,7 @@ fn build_decimal_array( // append null decimal_builder.append_null(); } else { - let decimal_value: Result = - parse_decimal::(s, precision, scale); + let decimal_value: Result = parse_decimal::(s, precision, scale); match decimal_value { Ok(v) => { decimal_builder.append_value(v); @@ -957,22 +926,10 @@ fn build_timestamp_array( Ok(Arc::new(match timezone { Some(timezone) => { let tz: Tz = timezone.parse()?; - build_timestamp_array_impl::( - line_number, - rows, - col_idx, - &tz, - null_regex, - )? - .with_timezone(timezone) + build_timestamp_array_impl::(line_number, rows, col_idx, &tz, null_regex)? + .with_timezone(timezone) } - None => build_timestamp_array_impl::( - line_number, - rows, - col_idx, - &Utc, - null_regex, - )?, + None => build_timestamp_array_impl::(line_number, rows, col_idx, &Utc, null_regex)?, })) } @@ -1169,10 +1126,7 @@ impl ReaderBuilder { } /// Create a new `BufReader` from a buffered reader - pub fn build_buffered( - self, - reader: R, - ) -> Result, ArrowError> { + pub fn build_buffered(self, reader: R) -> Result, ArrowError> { Ok(BufReader { reader, decoder: self.build_decoder(), @@ -1318,8 +1272,7 @@ mod tests { Field::new("lng", DataType::Float64, false), ]); - let file_with_headers = - File::open("test/data/uk_cities_with_headers.csv").unwrap(); + let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap(); let file_without_headers = File::open("test/data/uk_cities.csv").unwrap(); let both_files = file_with_headers .chain(Cursor::new("\n".to_string())) @@ -1642,8 +1595,7 @@ mod tests { schema.field(5).data_type() ); - let names: Vec<&str> = - schema.fields().iter().map(|x| x.name().as_str()).collect(); + let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect(); assert_eq!( names, vec![ @@ -1819,16 +1771,11 @@ mod tests { -2203932304000 ); assert_eq!( - Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S") - .unwrap(), + Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(), -2203932304000 ); assert_eq!( - Date64Type::parse_formatted( - "1900-02-28 12:34:56+0030", - "%Y-%m-%d %H:%M:%S%z" - ) - .unwrap(), + Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(), -2203932304000 - (30 * 60 * 1000) ); } @@ -1865,10 +1812,7 @@ mod tests { #[test] fn test_parse_timestamp() { - test_parse_timestamp_impl::( - None, - &[0, 0, -7_200_000_000_000], - ); + test_parse_timestamp_impl::(None, &[0, 0, -7_200_000_000_000]); test_parse_timestamp_impl::( Some("+00:00".into()), &[0, 0, -7_200_000_000_000], @@ -1885,10 +1829,7 @@ mod tests { Some("-03".into()), &[10_800_000, 0, -7_200_000], ); - test_parse_timestamp_impl::( - Some("-03".into()), - &[10_800, 0, -7_200], - ); + test_parse_timestamp_impl::(Some("-03".into()), &[10_800, 0, -7_200]); } #[test] @@ -2227,10 +2168,8 @@ mod tests { expected_rows ); - let buffered = std::io::BufReader::with_capacity( - capacity, - File::open(path).unwrap(), - ); + let buffered = + std::io::BufReader::with_capacity(capacity, File::open(path).unwrap()); let reader = ReaderBuilder::new(schema.clone()) .with_batch_size(batch_size) diff --git a/arrow-csv/src/reader/records.rs b/arrow-csv/src/reader/records.rs index a59d02e0e2d8..877cfb3ee653 100644 --- a/arrow-csv/src/reader/records.rs +++ b/arrow-csv/src/reader/records.rs @@ -76,11 +76,7 @@ impl RecordDecoder { /// Decodes records from `input` returning the number of records and bytes read /// /// Note: this expects to be called with an empty `input` to signal EOF - pub fn decode( - &mut self, - input: &[u8], - to_read: usize, - ) -> Result<(usize, usize), ArrowError> { + pub fn decode(&mut self, input: &[u8], to_read: usize) -> Result<(usize, usize), ArrowError> { if to_read == 0 { return Ok((0, 0)); } @@ -124,11 +120,17 @@ impl RecordDecoder { // Need to allocate more capacity ReadRecordResult::OutputFull => break, ReadRecordResult::OutputEndsFull => { - return Err(ArrowError::CsvError(format!("incorrect number of fields for line {}, expected {} got more than {}", self.line_number, self.num_columns, self.current_field))); + return Err(ArrowError::CsvError(format!( + "incorrect number of fields for line {}, expected {} got more than {}", + self.line_number, self.num_columns, self.current_field + ))); } ReadRecordResult::Record => { if self.current_field != self.num_columns { - return Err(ArrowError::CsvError(format!("incorrect number of fields for line {}, expected {} got {}", self.line_number, self.num_columns, self.current_field))); + return Err(ArrowError::CsvError(format!( + "incorrect number of fields for line {}, expected {} got {}", + self.line_number, self.num_columns, self.current_field + ))); } read += 1; self.current_field = 0; @@ -334,8 +336,7 @@ mod tests { let mut decoder = RecordDecoder::new(Reader::new(), 2); let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string(); - let expected = - "Csv error: incorrect number of fields for line 3, expected 2 got 1"; + let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1"; assert_eq!(err, expected); diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs index 1ca956e2c73f..0bb76e536e67 100644 --- a/arrow-csv/src/writer.rs +++ b/arrow-csv/src/writer.rs @@ -389,18 +389,12 @@ mod tests { "consectetur adipiscing elit", "sed do eiusmod tempor", ]); - let c2 = PrimitiveArray::::from(vec![ - Some(123.564532), - None, - Some(-556132.25), - ]); + let c2 = + PrimitiveArray::::from(vec![Some(123.564532), None, Some(-556132.25)]); let c3 = PrimitiveArray::::from(vec![3, 2, 1]); let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); - let c5 = TimestampMillisecondArray::from(vec![ - None, - Some(1555584887378), - Some(1555555555555), - ]); + let c5 = + TimestampMillisecondArray::from(vec![None, Some(1555584887378), Some(1555555555555)]); let c6 = Time32SecondArray::from(vec![1234, 24680, 85563]); let c7: DictionaryArray = vec!["cupcakes", "cupcakes", "foo"].into_iter().collect(); @@ -451,13 +445,11 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo Field::new("c2", DataType::Decimal256(76, 6), true), ]); - let mut c1_builder = - Decimal128Builder::new().with_data_type(DataType::Decimal128(38, 6)); + let mut c1_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(38, 6)); c1_builder.extend(vec![Some(-3335724), Some(2179404), None, Some(290472)]); let c1 = c1_builder.finish(); - let mut c2_builder = - Decimal256Builder::new().with_data_type(DataType::Decimal256(76, 6)); + let mut c2_builder = Decimal256Builder::new().with_data_type(DataType::Decimal256(76, 6)); c2_builder.extend(vec![ Some(i256::from_i128(-3335724)), Some(i256::from_i128(2179404)), @@ -467,8 +459,7 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo let c2 = c2_builder.finish(); let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap(); let mut file = tempfile::tempfile().unwrap(); @@ -512,11 +503,8 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo "consectetur adipiscing elit", "sed do eiusmod tempor", ]); - let c2 = PrimitiveArray::::from(vec![ - Some(123.564532), - None, - Some(-556132.25), - ]); + let c2 = + PrimitiveArray::::from(vec![Some(123.564532), None, Some(-556132.25)]); let c3 = PrimitiveArray::::from(vec![3, 2, 1]); let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); let c6 = Time32SecondArray::from(vec![1234, 24680, 85563]); @@ -629,8 +617,7 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo let c0 = UInt32Array::from(vec![Some(123), Some(234)]); let c1 = Date64Array::from(vec![Some(1926632005177), Some(1926632005177685347)]); let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c0), Arc::new(c1)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c0), Arc::new(c1)]).unwrap(); let mut file = tempfile::tempfile().unwrap(); let mut writer = Writer::new(&mut file); @@ -656,15 +643,9 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo Field::new("c4", DataType::Time32(TimeUnit::Second), false), ]); - let c1 = TimestampMillisecondArray::from(vec![ - Some(1555584887378), - Some(1635577147000), - ]) - .with_timezone("+00:00".to_string()); - let c2 = TimestampMillisecondArray::from(vec![ - Some(1555584887378), - Some(1635577147000), - ]); + let c1 = TimestampMillisecondArray::from(vec![Some(1555584887378), Some(1635577147000)]) + .with_timezone("+00:00".to_string()); + let c2 = TimestampMillisecondArray::from(vec![Some(1555584887378), Some(1635577147000)]); let c3 = Date32Array::from(vec![3, 2]); let c4 = Time32SecondArray::from(vec![1234, 24680]); diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index 5f87dddd4217..10c53c549e2b 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -42,9 +42,7 @@ pub(crate) fn contains_nulls( ) -> bool { match null_bit_buffer { Some(buffer) => { - match BitSliceIterator::new(buffer.validity(), buffer.offset() + offset, len) - .next() - { + match BitSliceIterator::new(buffer.validity(), buffer.offset() + offset, len).next() { Some((start, end)) => start != 0 || end != len, None => len != 0, // No non-null values } @@ -130,9 +128,9 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff MutableBuffer::new(capacity * k.primitive_width().unwrap()), empty_buffer, ], - DataType::FixedSizeList(_, _) - | DataType::Struct(_) - | DataType::RunEndEncoded(_, _) => [empty_buffer, MutableBuffer::new(0)], + DataType::FixedSizeList(_, _) | DataType::Struct(_) | DataType::RunEndEncoded(_, _) => { + [empty_buffer, MutableBuffer::new(0)] + } DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => [ MutableBuffer::new(capacity * mem::size_of::()), empty_buffer, @@ -159,10 +157,9 @@ pub(crate) fn into_buffers( ) -> Vec { match data_type { DataType::Null | DataType::Struct(_) | DataType::FixedSizeList(_, _) => vec![], - DataType::Utf8 - | DataType::Binary - | DataType::LargeUtf8 - | DataType::LargeBinary => vec![buffer1.into(), buffer2.into()], + DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => { + vec![buffer1.into(), buffer2.into()] + } DataType::Union(_, mode) => { match mode { // Based on Union's DataTypeLayout @@ -452,12 +449,11 @@ impl ArrayData { for spec in layout.buffers.iter() { match spec { BufferSpec::FixedWidth { byte_width, .. } => { - let buffer_size = - self.len.checked_mul(*byte_width).ok_or_else(|| { - ArrowError::ComputeError( - "Integer overflow computing buffer size".to_string(), - ) - })?; + let buffer_size = self.len.checked_mul(*byte_width).ok_or_else(|| { + ArrowError::ComputeError( + "Integer overflow computing buffer size".to_string(), + ) + })?; result += buffer_size; } BufferSpec::VariableWidth => { @@ -590,9 +586,7 @@ impl ArrayData { DataType::LargeBinary | DataType::LargeUtf8 => { (vec![zeroed((len + 1) * 8), zeroed(0)], vec![], true) } - DataType::FixedSizeBinary(i) => { - (vec![zeroed(*i as usize * len)], vec![], true) - } + DataType::FixedSizeBinary(i) => (vec![zeroed(*i as usize * len)], vec![], true), DataType::List(f) | DataType::Map(f, _) => ( vec![zeroed((len + 1) * 4)], vec![ArrayData::new_empty(f.data_type())], @@ -749,9 +743,7 @@ impl ArrayData { ))); } - for (i, (buffer, spec)) in - self.buffers.iter().zip(layout.buffers.iter()).enumerate() - { + for (i, (buffer, spec)) in self.buffers.iter().zip(layout.buffers.iter()).enumerate() { match spec { BufferSpec::FixedWidth { byte_width, @@ -999,10 +991,8 @@ impl ArrayData { } DataType::RunEndEncoded(run_ends_field, values_field) => { self.validate_num_child_data(2)?; - let run_ends_data = - self.get_valid_child_data(0, run_ends_field.data_type())?; - let values_data = - self.get_valid_child_data(1, values_field.data_type())?; + let run_ends_data = self.get_valid_child_data(0, run_ends_field.data_type())?; + let values_data = self.get_valid_child_data(1, values_field.data_type())?; if run_ends_data.len != values_data.len { return Err(ArrowError::InvalidArgumentError(format!( "The run_ends array length should be the same as values array length. Run_ends array length is {}, values array length is {}", @@ -1022,9 +1012,7 @@ impl ArrayData { for (i, (_, field)) in fields.iter().enumerate() { let field_data = self.get_valid_child_data(i, field.data_type())?; - if mode == &UnionMode::Sparse - && field_data.len < (self.len + self.offset) - { + if mode == &UnionMode::Sparse && field_data.len < (self.len + self.offset) { return Err(ArrowError::InvalidArgumentError(format!( "Sparse union child array #{} has length smaller than expected for union array ({} < {})", i, field_data.len, self.len + self.offset @@ -1083,14 +1071,14 @@ impl ArrayData { i: usize, expected_type: &DataType, ) -> Result<&ArrayData, ArrowError> { - let values_data = self.child_data - .get(i) - .ok_or_else(|| { - ArrowError::InvalidArgumentError(format!( - "{} did not have enough child arrays. Expected at least {} but had only {}", - self.data_type, i+1, self.child_data.len() - )) - })?; + let values_data = self.child_data.get(i).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "{} did not have enough child arrays. Expected at least {} but had only {}", + self.data_type, + i + 1, + self.child_data.len() + )) + })?; if expected_type != &values_data.data_type { return Err(ArrowError::InvalidArgumentError(format!( @@ -1160,7 +1148,8 @@ impl ArrayData { if actual != nulls.null_count() { return Err(ArrowError::InvalidArgumentError(format!( "null_count value ({}) doesn't match actual number of nulls in array ({})", - nulls.null_count(), actual + nulls.null_count(), + actual ))); } } @@ -1209,23 +1198,22 @@ impl ArrayData { ) -> Result<(), ArrowError> { let mask = match mask { Some(mask) => mask, - None => return match child.null_count() { - 0 => Ok(()), - _ => Err(ArrowError::InvalidArgumentError(format!( - "non-nullable child of type {} contains nulls not present in parent {}", - child.data_type, - self.data_type - ))), - }, + None => { + return match child.null_count() { + 0 => Ok(()), + _ => Err(ArrowError::InvalidArgumentError(format!( + "non-nullable child of type {} contains nulls not present in parent {}", + child.data_type, self.data_type + ))), + } + } }; match child.nulls() { - Some(nulls) if !mask.contains(nulls) => { - Err(ArrowError::InvalidArgumentError(format!( - "non-nullable child of type {} contains nulls not present in parent", - child.data_type - ))) - } + Some(nulls) if !mask.contains(nulls) => Err(ArrowError::InvalidArgumentError(format!( + "non-nullable child of type {} contains nulls not present in parent", + child.data_type + ))), _ => Ok(()), } } @@ -1240,9 +1228,7 @@ impl ArrayData { DataType::Utf8 => self.validate_utf8::(), DataType::LargeUtf8 => self.validate_utf8::(), DataType::Binary => self.validate_offsets_full::(self.buffers[1].len()), - DataType::LargeBinary => { - self.validate_offsets_full::(self.buffers[1].len()) - } + DataType::LargeBinary => self.validate_offsets_full::(self.buffers[1].len()), DataType::List(_) | DataType::Map(_, _) => { let child = &self.child_data[0]; self.validate_offsets_full::(child.len) @@ -1300,11 +1286,7 @@ impl ArrayData { /// /// For example, the offsets buffer contained `[1, 2, 4]`, this /// function would call `validate([1,2])`, and `validate([2,4])` - fn validate_each_offset( - &self, - offset_limit: usize, - validate: V, - ) -> Result<(), ArrowError> + fn validate_each_offset(&self, offset_limit: usize, validate: V) -> Result<(), ArrowError> where T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, V: Fn(usize, Range) -> Result<(), ArrowError>, @@ -1358,32 +1340,26 @@ impl ArrayData { let values_buffer = &self.buffers[1].as_slice(); if let Ok(values_str) = std::str::from_utf8(values_buffer) { // Validate Offsets are correct - self.validate_each_offset::( - values_buffer.len(), - |string_index, range| { - if !values_str.is_char_boundary(range.start) - || !values_str.is_char_boundary(range.end) - { - return Err(ArrowError::InvalidArgumentError(format!( - "incomplete utf-8 byte sequence from index {string_index}" - ))); - } - Ok(()) - }, - ) + self.validate_each_offset::(values_buffer.len(), |string_index, range| { + if !values_str.is_char_boundary(range.start) + || !values_str.is_char_boundary(range.end) + { + return Err(ArrowError::InvalidArgumentError(format!( + "incomplete utf-8 byte sequence from index {string_index}" + ))); + } + Ok(()) + }) } else { // find specific offset that failed utf8 validation - self.validate_each_offset::( - values_buffer.len(), - |string_index, range| { - std::str::from_utf8(&values_buffer[range.clone()]).map_err(|e| { - ArrowError::InvalidArgumentError(format!( - "Invalid UTF8 sequence at string index {string_index} ({range:?}): {e}" - )) - })?; - Ok(()) - }, - ) + self.validate_each_offset::(values_buffer.len(), |string_index, range| { + std::str::from_utf8(&values_buffer[range.clone()]).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Invalid UTF8 sequence at string index {string_index} ({range:?}): {e}" + )) + })?; + Ok(()) + }) } } @@ -1414,8 +1390,7 @@ impl ArrayData { assert!(buffer.len() / mem::size_of::() >= required_len); // Justification: buffer size was validated above - let indexes: &[T] = - &buffer.typed_data::()[self.offset..self.offset + self.len]; + let indexes: &[T] = &buffer.typed_data::()[self.offset..self.offset + self.len]; indexes.iter().enumerate().try_for_each(|(i, &dict_index)| { // Do not check the value is null (value can be arbitrary) diff --git a/arrow-data/src/decimal.rs b/arrow-data/src/decimal.rs index f74ab880d478..74279bfb9af1 100644 --- a/arrow-data/src/decimal.rs +++ b/arrow-data/src/decimal.rs @@ -19,8 +19,8 @@ use arrow_buffer::i256; use arrow_schema::ArrowError; pub use arrow_schema::{ - DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, - DECIMAL256_MAX_SCALE, DECIMAL_DEFAULT_SCALE, + DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DECIMAL_DEFAULT_SCALE, }; // MAX decimal256 value of little-endian format for each precision. @@ -28,308 +28,308 @@ pub use arrow_schema::{ // is encoded to the 32-byte width format of little-endian. pub(crate) const MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [ i256::from_le_bytes([ - 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, + 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, ]), i256::from_le_bytes([ - 99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, + 99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, ]), i256::from_le_bytes([ - 231, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, + 231, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, ]), i256::from_le_bytes([ - 15, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, + 15, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, ]), i256::from_le_bytes([ - 159, 134, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, + 159, 134, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, ]), i256::from_le_bytes([ - 63, 66, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, + 63, 66, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, ]), i256::from_le_bytes([ - 127, 150, 152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, + 127, 150, 152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 224, 245, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, + 255, 224, 245, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 201, 154, 59, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, + 255, 201, 154, 59, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 227, 11, 84, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, + 255, 227, 11, 84, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 231, 118, 72, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, + 255, 231, 118, 72, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 15, 165, 212, 232, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, + 255, 15, 165, 212, 232, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 159, 114, 78, 24, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, + 255, 159, 114, 78, 24, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 63, 122, 16, 243, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, + 255, 63, 122, 16, 243, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 127, 198, 164, 126, 141, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 127, 198, 164, 126, 141, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 192, 111, 242, 134, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 192, 111, 242, 134, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 137, 93, 120, 69, 99, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 137, 93, 120, 69, 99, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 99, 167, 179, 182, 224, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 99, 167, 179, 182, 224, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 231, 137, 4, 35, 199, 138, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 231, 137, 4, 35, 199, 138, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 15, 99, 45, 94, 199, 107, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 15, 99, 45, 94, 199, 107, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 159, 222, 197, 173, 201, 53, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 159, 222, 197, 173, 201, 53, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 63, 178, 186, 201, 224, 25, 30, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 63, 178, 186, 201, 224, 25, 30, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 127, 246, 74, 225, 199, 2, 45, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 127, 246, 74, 225, 199, 2, 45, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 160, 237, 204, 206, 27, 194, 211, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 160, 237, 204, 206, 27, 194, 211, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 73, 72, 1, 20, 22, 149, 69, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 73, 72, 1, 20, 22, 149, 69, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 227, 210, 12, 200, 220, 210, 183, 82, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 227, 210, 12, 200, 220, 210, 183, 82, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 231, 60, 128, 208, 159, 60, 46, 59, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 231, 60, 128, 208, 159, 60, 46, 59, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 15, 97, 2, 37, 62, 94, 206, 79, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 15, 97, 2, 37, 62, 94, 206, 79, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 159, 202, 23, 114, 109, 174, 15, 30, 67, 1, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 159, 202, 23, 114, 109, 174, 15, 30, 67, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 63, 234, 237, 116, 70, 208, 156, 44, 159, 12, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 63, 234, 237, 116, 70, 208, 156, 44, 159, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 127, 38, 75, 145, 192, 34, 32, 190, 55, 126, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 127, 38, 75, 145, 192, 34, 32, 190, 55, 126, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 128, 239, 172, 133, 91, 65, 109, 45, 238, 4, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 128, 239, 172, 133, 91, 65, 109, 45, 238, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 9, 91, 193, 56, 147, 141, 68, 198, 77, 49, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 9, 91, 193, 56, 147, 141, 68, 198, 77, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 99, 142, 141, 55, 192, 135, 173, 190, 9, 237, 1, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 99, 142, 141, 55, 192, 135, 173, 190, 9, 237, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 231, 143, 135, 43, 130, 77, 199, 114, 97, 66, 19, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 231, 143, 135, 43, 130, 77, 199, 114, 97, 66, 19, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 15, 159, 75, 179, 21, 7, 201, 123, 206, 151, 192, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 15, 159, 75, 179, 21, 7, 201, 123, 206, 151, 192, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 159, 54, 244, 0, 217, 70, 218, 213, 16, 238, 133, 7, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 159, 54, 244, 0, 217, 70, 218, 213, 16, 238, 133, 7, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 127, 86, 101, 95, 196, 172, 67, 137, 147, 254, 80, 240, 2, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 127, 86, 101, 95, 196, 172, 67, 137, 147, 254, 80, 240, 2, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 96, 245, 185, 171, 191, 164, 92, 195, 241, 41, 99, 29, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 96, 245, 185, 171, 191, 164, 92, 195, 241, 41, 99, 29, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 201, 149, 67, 181, 124, 111, 158, 161, 113, 163, 223, - 37, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 201, 149, 67, 181, 124, 111, 158, 161, 113, 163, 223, 37, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 227, 217, 163, 20, 223, 90, 48, 80, 112, 98, 188, 122, - 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 227, 217, 163, 20, 223, 90, 48, 80, 112, 98, 188, 122, 11, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 231, 130, 102, 206, 182, 140, 227, 33, 99, 216, 91, 203, - 114, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 231, 130, 102, 206, 182, 140, 227, 33, 99, 216, 91, 203, 114, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 15, 29, 1, 16, 36, 127, 227, 82, 223, 115, 150, 241, - 123, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 15, 29, 1, 16, 36, 127, 227, 82, 223, 115, 150, 241, 123, 4, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 159, 34, 11, 160, 104, 247, 226, 60, 185, 134, 224, 111, - 215, 44, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 159, 34, 11, 160, 104, 247, 226, 60, 185, 134, 224, 111, 215, 44, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 63, 90, 111, 64, 22, 170, 221, 96, 60, 67, 197, 94, 106, - 192, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 63, 90, 111, 64, 22, 170, 221, 96, 60, 67, 197, 94, 106, 192, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 127, 134, 89, 132, 222, 164, 168, 200, 91, 160, 180, - 179, 39, 132, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 127, 134, 89, 132, 222, 164, 168, 200, 91, 160, 180, 179, 39, 132, + 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 64, 127, 43, 177, 112, 150, 214, 149, 67, 14, 5, - 141, 41, 175, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 64, 127, 43, 177, 112, 150, 214, 149, 67, 14, 5, 141, 41, + 175, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 137, 248, 178, 235, 102, 224, 97, 218, 163, 142, - 50, 130, 159, 215, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 137, 248, 178, 235, 102, 224, 97, 218, 163, 142, 50, 130, + 159, 215, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 99, 181, 253, 52, 5, 196, 210, 135, 102, 146, 249, - 21, 59, 108, 68, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 99, 181, 253, 52, 5, 196, 210, 135, 102, 146, 249, 21, 59, + 108, 68, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 231, 21, 233, 17, 52, 168, 59, 78, 1, 184, 191, - 219, 78, 58, 172, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 231, 21, 233, 17, 52, 168, 59, 78, 1, 184, 191, 219, 78, 58, + 172, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 15, 219, 26, 179, 8, 146, 84, 14, 13, 48, 125, 149, - 20, 71, 186, 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 15, 219, 26, 179, 8, 146, 84, 14, 13, 48, 125, 149, 20, 71, + 186, 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 159, 142, 12, 255, 86, 180, 77, 143, 130, 224, 227, - 214, 205, 198, 70, 11, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 159, 142, 12, 255, 86, 180, 77, 143, 130, 224, 227, 214, 205, + 198, 70, 11, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 63, 146, 125, 246, 101, 11, 9, 153, 25, 197, 230, - 100, 10, 196, 195, 112, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 63, 146, 125, 246, 101, 11, 9, 153, 25, 197, 230, 100, 10, + 196, 195, 112, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 127, 182, 231, 160, 251, 113, 90, 250, 255, 178, 3, - 241, 103, 168, 165, 103, 104, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 127, 182, 231, 160, 251, 113, 90, 250, 255, 178, 3, 241, 103, + 168, 165, 103, 104, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 32, 13, 73, 212, 115, 136, 199, 255, 253, 36, - 106, 15, 148, 120, 12, 20, 4, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 32, 13, 73, 212, 115, 136, 199, 255, 253, 36, 106, 15, + 148, 120, 12, 20, 4, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 73, 131, 218, 74, 134, 84, 203, 253, 235, 113, - 37, 154, 200, 181, 124, 200, 40, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 73, 131, 218, 74, 134, 84, 203, 253, 235, 113, 37, 154, + 200, 181, 124, 200, 40, 0, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 227, 32, 137, 236, 62, 77, 241, 233, 55, 115, - 118, 5, 214, 25, 223, 212, 151, 1, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 227, 32, 137, 236, 62, 77, 241, 233, 55, 115, 118, 5, + 214, 25, 223, 212, 151, 1, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 231, 72, 91, 61, 117, 4, 109, 35, 47, 128, - 160, 54, 92, 2, 183, 80, 238, 15, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 231, 72, 91, 61, 117, 4, 109, 35, 47, 128, 160, 54, 92, + 2, 183, 80, 238, 15, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 15, 217, 144, 101, 148, 44, 66, 98, 215, 1, - 69, 34, 154, 23, 38, 39, 79, 159, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 15, 217, 144, 101, 148, 44, 66, 98, 215, 1, 69, 34, 154, + 23, 38, 39, 79, 159, 0, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 159, 122, 168, 247, 203, 189, 149, 214, 105, - 18, 178, 86, 5, 236, 124, 135, 23, 57, 6, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 159, 122, 168, 247, 203, 189, 149, 214, 105, 18, 178, + 86, 5, 236, 124, 135, 23, 57, 6, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 63, 202, 148, 172, 247, 105, 217, 97, 34, 184, - 244, 98, 53, 56, 225, 74, 235, 58, 62, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 63, 202, 148, 172, 247, 105, 217, 97, 34, 184, 244, 98, + 53, 56, 225, 74, 235, 58, 62, 0, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 127, 230, 207, 189, 172, 35, 126, 210, 87, 49, - 143, 221, 21, 50, 204, 236, 48, 77, 110, 2, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 127, 230, 207, 189, 172, 35, 126, 210, 87, 49, 143, 221, + 21, 50, 204, 236, 48, 77, 110, 2, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 0, 31, 106, 191, 100, 237, 56, 110, 237, - 151, 167, 218, 244, 249, 63, 233, 3, 79, 24, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 0, 31, 106, 191, 100, 237, 56, 110, 237, 151, 167, + 218, 244, 249, 63, 233, 3, 79, 24, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 9, 54, 37, 122, 239, 69, 57, 78, 70, 239, - 139, 138, 144, 195, 127, 28, 39, 22, 243, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 9, 54, 37, 122, 239, 69, 57, 78, 70, 239, 139, 138, + 144, 195, 127, 28, 39, 22, 243, 0, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 99, 28, 116, 197, 90, 187, 60, 14, 191, - 88, 119, 105, 165, 163, 253, 28, 135, 221, 126, 9, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 99, 28, 116, 197, 90, 187, 60, 14, 191, 88, 119, + 105, 165, 163, 253, 28, 135, 221, 126, 9, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 231, 27, 137, 182, 139, 81, 95, 142, 118, - 119, 169, 30, 118, 100, 232, 33, 71, 167, 244, 94, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 231, 27, 137, 182, 139, 81, 95, 142, 118, 119, 169, + 30, 118, 100, 232, 33, 71, 167, 244, 94, 0, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 15, 23, 91, 33, 117, 47, 185, 143, 161, - 170, 158, 50, 157, 236, 19, 83, 199, 136, 142, 181, 3, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 15, 23, 91, 33, 117, 47, 185, 143, 161, 170, 158, + 50, 157, 236, 19, 83, 199, 136, 142, 181, 3, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 159, 230, 142, 77, 147, 218, 59, 157, 79, - 170, 50, 250, 35, 62, 199, 62, 201, 87, 145, 23, 37, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 159, 230, 142, 77, 147, 218, 59, 157, 79, 170, 50, + 250, 35, 62, 199, 62, 201, 87, 145, 23, 37, 0, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 63, 2, 149, 7, 193, 137, 86, 36, 28, 167, - 250, 197, 103, 109, 200, 115, 220, 109, 173, 235, 114, 1, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 63, 2, 149, 7, 193, 137, 86, 36, 28, 167, 250, 197, + 103, 109, 200, 115, 220, 109, 173, 235, 114, 1, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 127, 22, 210, 75, 138, 97, 97, 107, 25, - 135, 202, 187, 13, 70, 212, 133, 156, 74, 198, 52, 125, 14, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 127, 22, 210, 75, 138, 97, 97, 107, 25, 135, 202, + 187, 13, 70, 212, 133, 156, 74, 198, 52, 125, 14, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 255, 224, 52, 246, 102, 207, 205, 49, - 254, 70, 233, 85, 137, 188, 74, 58, 29, 234, 190, 15, 228, 144, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 224, 52, 246, 102, 207, 205, 49, 254, 70, 233, + 85, 137, 188, 74, 58, 29, 234, 190, 15, 228, 144, 0, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 255, 201, 16, 158, 5, 26, 10, 242, 237, - 197, 28, 91, 93, 93, 235, 70, 36, 37, 117, 157, 232, 168, 5, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 201, 16, 158, 5, 26, 10, 242, 237, 197, 28, + 91, 93, 93, 235, 70, 36, 37, 117, 157, 232, 168, 5, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 255, 227, 167, 44, 56, 4, 101, 116, 75, - 187, 31, 143, 165, 165, 49, 197, 106, 115, 147, 38, 22, 153, 56, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 227, 167, 44, 56, 4, 101, 116, 75, 187, 31, + 143, 165, 165, 49, 197, 106, 115, 147, 38, 22, 153, 56, 0, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 255, 231, 142, 190, 49, 42, 242, 139, - 242, 80, 61, 151, 119, 120, 240, 179, 43, 130, 194, 129, 221, 250, 53, 2, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 231, 142, 190, 49, 42, 242, 139, 242, 80, 61, + 151, 119, 120, 240, 179, 43, 130, 194, 129, 221, 250, 53, 2, ]), i256::from_le_bytes([ - 255, 255, 255, 255, 255, 255, 255, 255, 255, 15, 149, 113, 241, 165, 117, 119, - 121, 41, 101, 232, 171, 180, 100, 7, 181, 21, 153, 17, 167, 204, 27, 22, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 15, 149, 113, 241, 165, 117, 119, 121, 41, + 101, 232, 171, 180, 100, 7, 181, 21, 153, 17, 167, 204, 27, 22, ]), ]; @@ -338,308 +338,308 @@ pub(crate) const MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [ // is encoded to the 76-byte width format of little-endian. pub(crate) const MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [ i256::from_le_bytes([ - 247, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 247, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 157, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 157, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 25, 252, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 25, 252, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 241, 216, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 241, 216, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 97, 121, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 97, 121, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 193, 189, 240, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 193, 189, 240, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 129, 105, 103, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 129, 105, 103, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 31, 10, 250, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 31, 10, 250, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 54, 101, 196, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 54, 101, 196, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 28, 244, 171, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 28, 244, 171, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 24, 137, 183, 232, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 24, 137, 183, 232, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 240, 90, 43, 23, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 240, 90, 43, 23, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 96, 141, 177, 231, 246, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 96, 141, 177, 231, 246, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 192, 133, 239, 12, 165, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 192, 133, 239, 12, 165, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 128, 57, 91, 129, 114, 252, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 128, 57, 91, 129, 114, 252, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 63, 144, 13, 121, 220, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 63, 144, 13, 121, 220, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 118, 162, 135, 186, 156, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 118, 162, 135, 186, 156, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 156, 88, 76, 73, 31, 242, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 156, 88, 76, 73, 31, 242, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 24, 118, 251, 220, 56, 117, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 24, 118, 251, 220, 56, 117, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 240, 156, 210, 161, 56, 148, 250, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 240, 156, 210, 161, 56, 148, 250, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 96, 33, 58, 82, 54, 202, 201, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 96, 33, 58, 82, 54, 202, 201, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 192, 77, 69, 54, 31, 230, 225, 253, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 192, 77, 69, 54, 31, 230, 225, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 128, 9, 181, 30, 56, 253, 210, 234, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 128, 9, 181, 30, 56, 253, 210, 234, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 95, 18, 51, 49, 228, 61, 44, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 95, 18, 51, 49, 228, 61, 44, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 182, 183, 254, 235, 233, 106, 186, 247, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 182, 183, 254, 235, 233, 106, 186, 247, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 28, 45, 243, 55, 35, 45, 72, 173, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 28, 45, 243, 55, 35, 45, 72, 173, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 24, 195, 127, 47, 96, 195, 209, 196, 252, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 24, 195, 127, 47, 96, 195, 209, 196, 252, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 240, 158, 253, 218, 193, 161, 49, 176, 223, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 240, 158, 253, 218, 193, 161, 49, 176, 223, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 96, 53, 232, 141, 146, 81, 240, 225, 188, 254, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 96, 53, 232, 141, 146, 81, 240, 225, 188, 254, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 192, 21, 18, 139, 185, 47, 99, 211, 96, 243, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 192, 21, 18, 139, 185, 47, 99, 211, 96, 243, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 128, 217, 180, 110, 63, 221, 223, 65, 200, 129, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 128, 217, 180, 110, 63, 221, 223, 65, 200, 129, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 127, 16, 83, 122, 164, 190, 146, 210, 17, 251, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 127, 16, 83, 122, 164, 190, 146, 210, 17, 251, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 246, 164, 62, 199, 108, 114, 187, 57, 178, 206, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 246, 164, 62, 199, 108, 114, 187, 57, 178, 206, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 156, 113, 114, 200, 63, 120, 82, 65, 246, 18, 254, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 156, 113, 114, 200, 63, 120, 82, 65, 246, 18, 254, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 24, 112, 120, 212, 125, 178, 56, 141, 158, 189, 236, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 24, 112, 120, 212, 125, 178, 56, 141, 158, 189, 236, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 240, 96, 180, 76, 234, 248, 54, 132, 49, 104, 63, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 240, 96, 180, 76, 234, 248, 54, 132, 49, 104, 63, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 96, 201, 11, 255, 38, 185, 37, 42, 239, 17, 122, 248, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 96, 201, 11, 255, 38, 185, 37, 42, 239, 17, 122, 248, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 192, 221, 117, 246, 133, 59, 121, 165, 87, 179, 196, 180, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 192, 221, 117, 246, 133, 59, 121, 165, 87, 179, 196, 180, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 128, 169, 154, 160, 59, 83, 188, 118, 108, 1, 175, 15, 253, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 128, 169, 154, 160, 59, 83, 188, 118, 108, 1, 175, 15, 253, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 159, 10, 70, 84, 64, 91, 163, 60, 14, 214, 156, 226, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 159, 10, 70, 84, 64, 91, 163, 60, 14, 214, 156, 226, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 54, 106, 188, 74, 131, 144, 97, 94, 142, 92, 32, 218, 254, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 54, 106, 188, 74, 131, 144, 97, 94, 142, 92, 32, 218, 254, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 28, 38, 92, 235, 32, 165, 207, 175, 143, 157, 67, 133, 244, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 28, 38, 92, 235, 32, 165, 207, 175, 143, 157, 67, 133, 244, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 24, 125, 153, 49, 73, 115, 28, 222, 156, 39, 164, 52, 141, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 24, 125, 153, 49, 73, 115, 28, 222, 156, 39, 164, 52, 141, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 240, 226, 254, 239, 219, 128, 28, 173, 32, 140, 105, 14, 132, 251, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 240, 226, 254, 239, 219, 128, 28, 173, 32, 140, 105, 14, 132, 251, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 96, 221, 244, 95, 151, 8, 29, 195, 70, 121, 31, 144, 40, 211, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 96, 221, 244, 95, 151, 8, 29, 195, 70, 121, 31, 144, 40, 211, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 192, 165, 144, 191, 233, 85, 34, 159, 195, 188, 58, 161, 149, 63, - 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 192, 165, 144, 191, 233, 85, 34, 159, 195, 188, 58, 161, 149, 63, 254, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 128, 121, 166, 123, 33, 91, 87, 55, 164, 95, 75, 76, 216, 123, - 238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 128, 121, 166, 123, 33, 91, 87, 55, 164, 95, 75, 76, 216, 123, 238, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 191, 128, 212, 78, 143, 105, 41, 106, 188, 241, 250, 114, 214, - 80, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 191, 128, 212, 78, 143, 105, 41, 106, 188, 241, 250, 114, 214, 80, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 118, 7, 77, 20, 153, 31, 158, 37, 92, 113, 205, 125, 96, 40, - 249, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 118, 7, 77, 20, 153, 31, 158, 37, 92, 113, 205, 125, 96, 40, 249, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 156, 74, 2, 203, 250, 59, 45, 120, 153, 109, 6, 234, 196, 147, - 187, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 156, 74, 2, 203, 250, 59, 45, 120, 153, 109, 6, 234, 196, 147, 187, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 24, 234, 22, 238, 203, 87, 196, 177, 254, 71, 64, 36, 177, 197, - 83, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 24, 234, 22, 238, 203, 87, 196, 177, 254, 71, 64, 36, 177, 197, 83, 253, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 240, 36, 229, 76, 247, 109, 171, 241, 242, 207, 130, 106, 235, - 184, 69, 229, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 240, 36, 229, 76, 247, 109, 171, 241, 242, 207, 130, 106, 235, 184, 69, + 229, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 96, 113, 243, 0, 169, 75, 178, 112, 125, 31, 28, 41, 50, 57, - 185, 244, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 96, 113, 243, 0, 169, 75, 178, 112, 125, 31, 28, 41, 50, 57, 185, 244, + 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 192, 109, 130, 9, 154, 244, 246, 102, 230, 58, 25, 155, 245, - 59, 60, 143, 245, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 192, 109, 130, 9, 154, 244, 246, 102, 230, 58, 25, 155, 245, 59, 60, 143, + 245, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 128, 73, 24, 95, 4, 142, 165, 5, 0, 77, 252, 14, 152, 87, 90, - 152, 151, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 128, 73, 24, 95, 4, 142, 165, 5, 0, 77, 252, 14, 152, 87, 90, 152, 151, + 255, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 223, 242, 182, 43, 140, 119, 56, 0, 2, 219, 149, 240, 107, - 135, 243, 235, 251, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 223, 242, 182, 43, 140, 119, 56, 0, 2, 219, 149, 240, 107, 135, 243, + 235, 251, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 182, 124, 37, 181, 121, 171, 52, 2, 20, 142, 218, 101, 55, - 74, 131, 55, 215, 255, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 182, 124, 37, 181, 121, 171, 52, 2, 20, 142, 218, 101, 55, 74, 131, + 55, 215, 255, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 28, 223, 118, 19, 193, 178, 14, 22, 200, 140, 137, 250, 41, - 230, 32, 43, 104, 254, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 28, 223, 118, 19, 193, 178, 14, 22, 200, 140, 137, 250, 41, 230, 32, + 43, 104, 254, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 24, 183, 164, 194, 138, 251, 146, 220, 208, 127, 95, 201, - 163, 253, 72, 175, 17, 240, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 24, 183, 164, 194, 138, 251, 146, 220, 208, 127, 95, 201, 163, 253, + 72, 175, 17, 240, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 240, 38, 111, 154, 107, 211, 189, 157, 40, 254, 186, 221, - 101, 232, 217, 216, 176, 96, 255, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 240, 38, 111, 154, 107, 211, 189, 157, 40, 254, 186, 221, 101, 232, + 217, 216, 176, 96, 255, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 96, 133, 87, 8, 52, 66, 106, 41, 150, 237, 77, 169, 250, 19, - 131, 120, 232, 198, 249, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 96, 133, 87, 8, 52, 66, 106, 41, 150, 237, 77, 169, 250, 19, 131, 120, + 232, 198, 249, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 192, 53, 107, 83, 8, 150, 38, 158, 221, 71, 11, 157, 202, - 199, 30, 181, 20, 197, 193, 255, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 192, 53, 107, 83, 8, 150, 38, 158, 221, 71, 11, 157, 202, 199, 30, + 181, 20, 197, 193, 255, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 128, 25, 48, 66, 83, 220, 129, 45, 168, 206, 112, 34, 234, - 205, 51, 19, 207, 178, 145, 253, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 128, 25, 48, 66, 83, 220, 129, 45, 168, 206, 112, 34, 234, 205, 51, + 19, 207, 178, 145, 253, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 255, 224, 149, 64, 155, 18, 199, 145, 18, 104, 88, 37, - 11, 6, 192, 22, 252, 176, 231, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 0, 255, 224, 149, 64, 155, 18, 199, 145, 18, 104, 88, 37, 11, 6, 192, + 22, 252, 176, 231, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 246, 201, 218, 133, 16, 186, 198, 177, 185, 16, 116, 117, - 111, 60, 128, 227, 216, 233, 12, 255, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 0, 246, 201, 218, 133, 16, 186, 198, 177, 185, 16, 116, 117, 111, 60, + 128, 227, 216, 233, 12, 255, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 156, 227, 139, 58, 165, 68, 195, 241, 64, 167, 136, 150, - 90, 92, 2, 227, 120, 34, 129, 246, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 0, 156, 227, 139, 58, 165, 68, 195, 241, 64, 167, 136, 150, 90, 92, 2, + 227, 120, 34, 129, 246, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 24, 228, 118, 73, 116, 174, 160, 113, 137, 136, 86, 225, - 137, 155, 23, 222, 184, 88, 11, 161, 255, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 0, 24, 228, 118, 73, 116, 174, 160, 113, 137, 136, 86, 225, 137, 155, + 23, 222, 184, 88, 11, 161, 255, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 240, 232, 164, 222, 138, 208, 70, 112, 94, 85, 97, 205, - 98, 19, 236, 172, 56, 119, 113, 74, 252, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 0, 240, 232, 164, 222, 138, 208, 70, 112, 94, 85, 97, 205, 98, 19, + 236, 172, 56, 119, 113, 74, 252, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 96, 25, 113, 178, 108, 37, 196, 98, 176, 85, 205, 5, 220, - 193, 56, 193, 54, 168, 110, 232, 218, 255, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 0, 96, 25, 113, 178, 108, 37, 196, 98, 176, 85, 205, 5, 220, 193, 56, + 193, 54, 168, 110, 232, 218, 255, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 192, 253, 106, 248, 62, 118, 169, 219, 227, 88, 5, 58, - 152, 146, 55, 140, 35, 146, 82, 20, 141, 254, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 0, 192, 253, 106, 248, 62, 118, 169, 219, 227, 88, 5, 58, 152, 146, + 55, 140, 35, 146, 82, 20, 141, 254, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 128, 233, 45, 180, 117, 158, 158, 148, 230, 120, 53, 68, - 242, 185, 43, 122, 99, 181, 57, 203, 130, 241, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 0, 128, 233, 45, 180, 117, 158, 158, 148, 230, 120, 53, 68, 242, 185, + 43, 122, 99, 181, 57, 203, 130, 241, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 31, 203, 9, 153, 48, 50, 206, 1, 185, 22, 170, 118, - 67, 181, 197, 226, 21, 65, 240, 27, 111, 255, 255, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 31, 203, 9, 153, 48, 50, 206, 1, 185, 22, 170, 118, 67, 181, + 197, 226, 21, 65, 240, 27, 111, 255, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 54, 239, 97, 250, 229, 245, 13, 18, 58, 227, 164, 162, - 162, 20, 185, 219, 218, 138, 98, 23, 87, 250, 255, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 54, 239, 97, 250, 229, 245, 13, 18, 58, 227, 164, 162, 162, 20, + 185, 219, 218, 138, 98, 23, 87, 250, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 28, 88, 211, 199, 251, 154, 139, 180, 68, 224, 112, - 90, 90, 206, 58, 149, 140, 108, 217, 233, 102, 199, 255, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 28, 88, 211, 199, 251, 154, 139, 180, 68, 224, 112, 90, 90, 206, + 58, 149, 140, 108, 217, 233, 102, 199, 255, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 24, 113, 65, 206, 213, 13, 116, 13, 175, 194, 104, - 136, 135, 15, 76, 212, 125, 61, 126, 34, 5, 202, 253, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 24, 113, 65, 206, 213, 13, 116, 13, 175, 194, 104, 136, 135, 15, + 76, 212, 125, 61, 126, 34, 5, 202, 253, ]), i256::from_le_bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 240, 106, 142, 14, 90, 138, 136, 134, 214, 154, 23, - 84, 75, 155, 248, 74, 234, 102, 238, 88, 51, 228, 233, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 240, 106, 142, 14, 90, 138, 136, 134, 214, 154, 23, 84, 75, 155, + 248, 74, 234, 102, 238, 88, 51, 228, 233, ]), ]; @@ -758,10 +758,7 @@ pub fn validate_decimal_precision(value: i128, precision: u8) -> Result<(), Arro /// Validates that the specified `i256` of value can be properly /// interpreted as a Decimal256 number with precision `precision` #[inline] -pub fn validate_decimal256_precision( - value: i256, - precision: u8, -) -> Result<(), ArrowError> { +pub fn validate_decimal256_precision(value: i256, precision: u8) -> Result<(), ArrowError> { if precision > DECIMAL256_MAX_PRECISION { return Err(ArrowError::InvalidArgumentError(format!( "Max precision of a Decimal256 is {DECIMAL256_MAX_PRECISION}, but got {precision}", diff --git a/arrow-data/src/equal/boolean.rs b/arrow-data/src/equal/boolean.rs index a20ca5ac0bd7..addae936f118 100644 --- a/arrow-data/src/equal/boolean.rs +++ b/arrow-data/src/equal/boolean.rs @@ -78,11 +78,10 @@ pub(super) fn boolean_equal( // get a ref of the null buffer bytes, to use in testing for nullness let lhs_nulls = lhs.nulls().unwrap(); - BitIndexIterator::new(lhs_nulls.validity(), lhs_start + lhs_nulls.offset(), len) - .all(|i| { - let lhs_pos = lhs_start + lhs.offset() + i; - let rhs_pos = rhs_start + rhs.offset() + i; - get_bit(lhs_values, lhs_pos) == get_bit(rhs_values, rhs_pos) - }) + BitIndexIterator::new(lhs_nulls.validity(), lhs_start + lhs_nulls.offset(), len).all(|i| { + let lhs_pos = lhs_start + lhs.offset() + i; + let rhs_pos = rhs_start + rhs.offset() + i; + get_bit(lhs_values, lhs_pos) == get_bit(rhs_values, rhs_pos) + }) } } diff --git a/arrow-data/src/equal/fixed_binary.rs b/arrow-data/src/equal/fixed_binary.rs index 40dacdddd3a0..0778d77e2fdd 100644 --- a/arrow-data/src/equal/fixed_binary.rs +++ b/arrow-data/src/equal/fixed_binary.rs @@ -75,20 +75,15 @@ pub(super) fn fixed_binary_equal( }) } else { let lhs_nulls = lhs.nulls().unwrap(); - let lhs_slices_iter = BitSliceIterator::new( - lhs_nulls.validity(), - lhs_start + lhs_nulls.offset(), - len, - ); + let lhs_slices_iter = + BitSliceIterator::new(lhs_nulls.validity(), lhs_start + lhs_nulls.offset(), len); let rhs_nulls = rhs.nulls().unwrap(); - let rhs_slices_iter = BitSliceIterator::new( - rhs_nulls.validity(), - rhs_start + rhs_nulls.offset(), - len, - ); + let rhs_slices_iter = + BitSliceIterator::new(rhs_nulls.validity(), rhs_start + rhs_nulls.offset(), len); - lhs_slices_iter.zip(rhs_slices_iter).all( - |((l_start, l_end), (r_start, r_end))| { + lhs_slices_iter + .zip(rhs_slices_iter) + .all(|((l_start, l_end), (r_start, r_end))| { l_start == r_start && l_end == r_end && equal_len( @@ -98,8 +93,7 @@ pub(super) fn fixed_binary_equal( (rhs_start + r_start) * size, (l_end - l_start) * size, ) - }, - ) + }) } } } diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs index fbc868d3f5c4..b279546474a0 100644 --- a/arrow-data/src/equal/mod.rs +++ b/arrow-data/src/equal/mod.rs @@ -76,24 +76,16 @@ fn equal_values( DataType::Int64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Float32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Float64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), - DataType::Decimal128(_, _) => { - primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) - } - DataType::Decimal256(_, _) => { - primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) - } - DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { + DataType::Decimal128(_, _) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Decimal256(_, _) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) } DataType::Date64 | DataType::Interval(IntervalUnit::DayTime) | DataType::Time64(_) | DataType::Timestamp(_, _) - | DataType::Duration(_) => { - primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) - } + | DataType::Duration(_) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Interval(IntervalUnit::MonthDayNano) => { primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) } @@ -103,39 +95,21 @@ fn equal_values( DataType::LargeUtf8 | DataType::LargeBinary => { variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) } - DataType::FixedSizeBinary(_) => { - fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len) - } + DataType::FixedSizeBinary(_) => fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::List(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::LargeList(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), - DataType::FixedSizeList(_, _) => { - fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len) - } + DataType::FixedSizeList(_, _) => fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Dictionary(data_type, _) => match data_type.as_ref() { DataType::Int8 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), - DataType::Int16 => { - dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) - } - DataType::Int32 => { - dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) - } - DataType::Int64 => { - dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) - } - DataType::UInt8 => { - dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) - } - DataType::UInt16 => { - dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) - } - DataType::UInt32 => { - dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) - } - DataType::UInt64 => { - dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) - } + DataType::Int16 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int32 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int64 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt8 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt16 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt32 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt64 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), _ => unreachable!(), }, DataType::Float16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), diff --git a/arrow-data/src/equal/primitive.rs b/arrow-data/src/equal/primitive.rs index 7b3cbc9eb949..e92fdd2ba23b 100644 --- a/arrow-data/src/equal/primitive.rs +++ b/arrow-data/src/equal/primitive.rs @@ -73,20 +73,15 @@ pub(super) fn primitive_equal( }) } else { let lhs_nulls = lhs.nulls().unwrap(); - let lhs_slices_iter = BitSliceIterator::new( - lhs_nulls.validity(), - lhs_start + lhs_nulls.offset(), - len, - ); + let lhs_slices_iter = + BitSliceIterator::new(lhs_nulls.validity(), lhs_start + lhs_nulls.offset(), len); let rhs_nulls = rhs.nulls().unwrap(); - let rhs_slices_iter = BitSliceIterator::new( - rhs_nulls.validity(), - rhs_start + rhs_nulls.offset(), - len, - ); + let rhs_slices_iter = + BitSliceIterator::new(rhs_nulls.validity(), rhs_start + rhs_nulls.offset(), len); - lhs_slices_iter.zip(rhs_slices_iter).all( - |((l_start, l_end), (r_start, r_end))| { + lhs_slices_iter + .zip(rhs_slices_iter) + .all(|((l_start, l_end), (r_start, r_end))| { l_start == r_start && l_end == r_end && equal_len( @@ -96,8 +91,7 @@ pub(super) fn primitive_equal( (rhs_start + r_start) * byte_width, (l_end - l_start) * byte_width, ) - }, - ) + }) } } } diff --git a/arrow-data/src/equal/union.rs b/arrow-data/src/equal/union.rs index 5869afc30dbe..62de276e507f 100644 --- a/arrow-data/src/equal/union.rs +++ b/arrow-data/src/equal/union.rs @@ -116,10 +116,7 @@ pub(super) fn union_equal( rhs_fields, ) } - ( - DataType::Union(_, UnionMode::Sparse), - DataType::Union(_, UnionMode::Sparse), - ) => { + (DataType::Union(_, UnionMode::Sparse), DataType::Union(_, UnionMode::Sparse)) => { lhs_type_id_range == rhs_type_id_range && equal_sparse(lhs, rhs, lhs_start, rhs_start, len) } diff --git a/arrow-data/src/equal/utils.rs b/arrow-data/src/equal/utils.rs index fa6211542550..cc81943756d2 100644 --- a/arrow-data/src/equal/utils.rs +++ b/arrow-data/src/equal/utils.rs @@ -73,11 +73,9 @@ pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { let r_value_field = r_fields.get(1).unwrap(); // We don't enforce the equality of field names - let data_type_equal = l_key_field.data_type() - == r_key_field.data_type() + let data_type_equal = l_key_field.data_type() == r_key_field.data_type() && l_value_field.data_type() == r_value_field.data_type(); - let nullability_equal = l_key_field.is_nullable() - == r_key_field.is_nullable() + let nullability_equal = l_key_field.is_nullable() == r_key_field.is_nullable() && l_value_field.is_nullable() == r_value_field.is_nullable(); let metadata_equal = l_key_field.metadata() == r_key_field.metadata() && l_value_field.metadata() == r_value_field.metadata(); diff --git a/arrow-data/src/transform/list.rs b/arrow-data/src/transform/list.rs index 9d5d8330cb1e..d9a1c62a8e8e 100644 --- a/arrow-data/src/transform/list.rs +++ b/arrow-data/src/transform/list.rs @@ -23,9 +23,7 @@ use crate::ArrayData; use arrow_buffer::ArrowNativeType; use num::{CheckedAdd, Integer}; -pub(super) fn build_extend( - array: &ArrayData, -) -> Extend { +pub(super) fn build_extend(array: &ArrayData) -> Extend { let offsets = array.buffer::(0); Box::new( move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { @@ -35,11 +33,7 @@ pub(super) fn build_extend( let last_offset: T = unsafe { get_last_offset(offset_buffer) }; // offsets - extend_offsets::( - offset_buffer, - last_offset, - &offsets[start..start + len + 1], - ); + extend_offsets::(offset_buffer, last_offset, &offsets[start..start + len + 1]); mutable.child_data[0].extend( index, @@ -50,10 +44,7 @@ pub(super) fn build_extend( ) } -pub(super) fn extend_nulls( - mutable: &mut _MutableArrayData, - len: usize, -) { +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { let offset_buffer = &mut mutable.buffer1; // this is safe due to how offset is built. See details on `get_last_offset` diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs index f4b2b46d1723..af25e9c7e3dc 100644 --- a/arrow-data/src/transform/mod.rs +++ b/arrow-data/src/transform/mod.rs @@ -173,11 +173,7 @@ impl<'a> std::fmt::Debug for MutableArrayData<'a> { /// Builds an extend that adds `offset` to the source primitive /// Additionally validates that `max` fits into the /// the underlying primitive returning None if not -fn build_extend_dictionary( - array: &ArrayData, - offset: usize, - max: usize, -) -> Option { +fn build_extend_dictionary(array: &ArrayData, offset: usize, max: usize) -> Option { macro_rules! validate_and_build { ($dt: ty) => {{ let _: $dt = max.try_into().ok()?; @@ -215,27 +211,19 @@ fn build_extend(array: &ArrayData) -> Extend { DataType::Int64 => primitive::build_extend::(array), DataType::Float32 => primitive::build_extend::(array), DataType::Float64 => primitive::build_extend::(array), - DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { + DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { primitive::build_extend::(array) } DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, _) | DataType::Duration(_) - | DataType::Interval(IntervalUnit::DayTime) => { - primitive::build_extend::(array) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - primitive::build_extend::(array) - } + | DataType::Interval(IntervalUnit::DayTime) => primitive::build_extend::(array), + DataType::Interval(IntervalUnit::MonthDayNano) => primitive::build_extend::(array), DataType::Decimal128(_, _) => primitive::build_extend::(array), DataType::Decimal256(_, _) => primitive::build_extend::(array), DataType::Utf8 | DataType::Binary => variable_size::build_extend::(array), - DataType::LargeUtf8 | DataType::LargeBinary => { - variable_size::build_extend::(array) - } + DataType::LargeUtf8 | DataType::LargeBinary => variable_size::build_extend::(array), DataType::Map(_, _) | DataType::List(_) => list::build_extend::(array), DataType::LargeList(_) => list::build_extend::(array), DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), @@ -265,9 +253,9 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { DataType::Int64 => primitive::extend_nulls::, DataType::Float32 => primitive::extend_nulls::, DataType::Float64 => primitive::extend_nulls::, - DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => primitive::extend_nulls::, + DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { + primitive::extend_nulls:: + } DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, _) @@ -380,10 +368,7 @@ impl<'a> MutableArrayData<'a> { array_capacity = *capacity; preallocate_offset_and_binary_buffer::(*capacity, *value_cap) } - ( - DataType::Utf8 | DataType::Binary, - Capacities::Binary(capacity, Some(value_cap)), - ) => { + (DataType::Utf8 | DataType::Binary, Capacities::Binary(capacity, Some(value_cap))) => { array_capacity = *capacity; preallocate_offset_and_binary_buffer::(*capacity, *value_cap) } @@ -391,10 +376,7 @@ impl<'a> MutableArrayData<'a> { array_capacity = *capacity; new_buffers(data_type, *capacity) } - ( - DataType::List(_) | DataType::LargeList(_), - Capacities::List(capacity, _), - ) => { + (DataType::List(_) | DataType::LargeList(_), Capacities::List(capacity, _)) => { array_capacity = *capacity; new_buffers(data_type, *capacity) } @@ -435,16 +417,15 @@ impl<'a> MutableArrayData<'a> { .map(|array| &array.child_data()[0]) .collect::>(); - let capacities = if let Capacities::List(capacity, ref child_capacities) = - capacities - { - child_capacities - .clone() - .map(|c| *c) - .unwrap_or(Capacities::Array(capacity)) - } else { - Capacities::Array(array_capacity) - }; + let capacities = + if let Capacities::List(capacity, ref child_capacities) = capacities { + child_capacities + .clone() + .map(|c| *c) + .unwrap_or(Capacities::Array(capacity)) + } else { + Capacities::Array(array_capacity) + }; vec![MutableArrayData::with_capacities( children, use_nulls, capacities, @@ -546,8 +527,7 @@ impl<'a> MutableArrayData<'a> { .collect(); let capacity = lengths.iter().sum(); - let mut mutable = - MutableArrayData::new(dictionaries, false, capacity); + let mut mutable = MutableArrayData::new(dictionaries, false, capacity); for (i, len) in lengths.iter().enumerate() { mutable.extend(i, 0, *len) diff --git a/arrow-data/src/transform/primitive.rs b/arrow-data/src/transform/primitive.rs index b5c826438bfc..627dc00de1df 100644 --- a/arrow-data/src/transform/primitive.rs +++ b/arrow-data/src/transform/primitive.rs @@ -47,9 +47,6 @@ where ) } -pub(super) fn extend_nulls( - mutable: &mut _MutableArrayData, - len: usize, -) { +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { mutable.buffer1.extend_zeros(len * size_of::()); } diff --git a/arrow-data/src/transform/utils.rs b/arrow-data/src/transform/utils.rs index 17bb87e88a5c..5407f68e0d0c 100644 --- a/arrow-data/src/transform/utils.rs +++ b/arrow-data/src/transform/utils.rs @@ -45,9 +45,7 @@ pub(super) fn extend_offsets( } #[inline] -pub(super) unsafe fn get_last_offset( - offset_buffer: &MutableBuffer, -) -> T { +pub(super) unsafe fn get_last_offset(offset_buffer: &MutableBuffer) -> T { // JUSTIFICATION // Benefit // 20% performance improvement extend of variable sized arrays (see bench `mutable_array`) diff --git a/arrow-data/src/transform/variable_size.rs b/arrow-data/src/transform/variable_size.rs index 597a8b2b6645..fa1592d973ed 100644 --- a/arrow-data/src/transform/variable_size.rs +++ b/arrow-data/src/transform/variable_size.rs @@ -39,9 +39,7 @@ fn extend_offset_values>( buffer.extend_from_slice(new_values); } -pub(super) fn build_extend< - T: ArrowNativeType + Integer + CheckedAdd + AsPrimitive, ->( +pub(super) fn build_extend>( array: &ArrayData, ) -> Extend { let offsets = array.buffer::(0); @@ -54,21 +52,14 @@ pub(super) fn build_extend< // this is safe due to how offset is built. See details on `get_last_offset` let last_offset = unsafe { get_last_offset(offset_buffer) }; - extend_offsets::( - offset_buffer, - last_offset, - &offsets[start..start + len + 1], - ); + extend_offsets::(offset_buffer, last_offset, &offsets[start..start + len + 1]); // values extend_offset_values::(values_buffer, offsets, values, start, len); }, ) } -pub(super) fn extend_nulls( - mutable: &mut _MutableArrayData, - len: usize, -) { +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { let offset_buffer = &mut mutable.buffer1; // this is safe due to how offset is built. See details on `get_last_offset` diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 013f7e7788f8..bd94d3c499ca 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -32,28 +32,26 @@ use arrow_array::builder::StringBuilder; use arrow_array::{ArrayRef, RecordBatch}; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::sql::metadata::{ - SqlInfoData, SqlInfoDataBuilder, XdbcTypeInfo, XdbcTypeInfoData, - XdbcTypeInfoDataBuilder, + SqlInfoData, SqlInfoDataBuilder, XdbcTypeInfo, XdbcTypeInfoData, XdbcTypeInfoDataBuilder, }; use arrow_flight::sql::{ server::FlightSqlService, ActionBeginSavepointRequest, ActionBeginSavepointResult, - ActionBeginTransactionRequest, ActionBeginTransactionResult, - ActionCancelQueryRequest, ActionCancelQueryResult, - ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, - ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, - ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs, - CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, - CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, - CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, + ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionCancelQueryRequest, + ActionCancelQueryResult, ActionClosePreparedStatementRequest, + ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, + ActionCreatePreparedSubstraitPlanRequest, ActionEndSavepointRequest, + ActionEndTransactionRequest, Any, CommandGetCatalogs, CommandGetCrossReference, + CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, + CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, - CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, ProstMessageExt, - Searchable, SqlInfo, TicketStatementQuery, XdbcDataType, + CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, ProstMessageExt, Searchable, + SqlInfo, TicketStatementQuery, XdbcDataType, }; use arrow_flight::utils::batches_to_flight_data; use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, - Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, - HandshakeResponse, IpcMessage, Location, SchemaAsIpc, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, + FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, + IpcMessage, Location, SchemaAsIpc, Ticket, }; use arrow_ipc::writer::IpcWriteOptions; use arrow_schema::{ArrowError, DataType, Field, Schema}; @@ -167,8 +165,7 @@ impl FlightSqlService for FlightSqlServiceImpl { let bytes = BASE64_STANDARD .decode(base64) .map_err(|e| status!("authorization not decodable", e))?; - let str = String::from_utf8(bytes) - .map_err(|e| status!("authorization not parsable", e))?; + let str = String::from_utf8(bytes).map_err(|e| status!("authorization not parsable", e))?; let parts: Vec<_> = str.split(':').collect(); let (user, pass) = match parts.as_slice() { [user, pass] => (user, pass), @@ -195,8 +192,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _message: Any, ) -> Result::DoGetStream>, Status> { self.check_token(&request)?; - let batch = - Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; + let batch = Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; let schema = batch.schema(); let batches = vec![batch]; let flight_data = batches_to_flight_data(schema.as_ref(), batches) @@ -238,8 +234,7 @@ impl FlightSqlService for FlightSqlServiceImpl { self.check_token(&request)?; let handle = std::str::from_utf8(&cmd.prepared_statement_handle) .map_err(|e| status!("Unable to parse handle", e))?; - let batch = - Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; + let batch = Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; let schema = (*batch.schema()).clone(); let num_rows = batch.num_rows(); let num_bytes = batch.get_array_memory_size(); @@ -736,8 +731,7 @@ async fn main() -> Result<(), Box> { if std::env::var("USE_TLS").ok().is_some() { let cert = std::fs::read_to_string("arrow-flight/examples/data/server.pem")?; let key = std::fs::read_to_string("arrow-flight/examples/data/server.key")?; - let client_ca = - std::fs::read_to_string("arrow-flight/examples/data/client_ca.pem")?; + let client_ca = std::fs::read_to_string("arrow-flight/examples/data/client_ca.pem")?; let tls_config = ServerTlsConfig::new() .identity(Identity::from_pem(&cert, &key)) diff --git a/arrow-flight/examples/server.rs b/arrow-flight/examples/server.rs index 1ed21acef9b8..85ac4ca1384c 100644 --- a/arrow-flight/examples/server.rs +++ b/arrow-flight/examples/server.rs @@ -20,9 +20,9 @@ use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, - Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, + ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, + HandshakeResponse, PutResult, SchemaResult, Ticket, }; #[derive(Clone)] diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index 8793f7834bfb..a264012c82ec 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -249,10 +249,7 @@ impl FlightClient { /// .expect("error fetching data"); /// # } /// ``` - pub async fn get_flight_info( - &mut self, - descriptor: FlightDescriptor, - ) -> Result { + pub async fn get_flight_info(&mut self, descriptor: FlightDescriptor) -> Result { let request = self.make_request(descriptor); let response = self.inner.get_flight_info(request).await?.into_inner(); @@ -452,10 +449,7 @@ impl FlightClient { /// .expect("error making request"); /// # } /// ``` - pub async fn get_schema( - &mut self, - flight_descriptor: FlightDescriptor, - ) -> Result { + pub async fn get_schema(&mut self, flight_descriptor: FlightDescriptor) -> Result { let request = self.make_request(flight_descriptor); let schema_result = self.inner.get_schema(request).await?.into_inner(); @@ -488,9 +482,7 @@ impl FlightClient { /// .expect("error gathering actions"); /// # } /// ``` - pub async fn list_actions( - &mut self, - ) -> Result>> { + pub async fn list_actions(&mut self) -> Result>> { let request = self.make_request(Empty {}); let action_stream = self @@ -528,10 +520,7 @@ impl FlightClient { /// .expect("error gathering action results"); /// # } /// ``` - pub async fn do_action( - &mut self, - action: Action, - ) -> Result>> { + pub async fn do_action(&mut self, action: Action) -> Result>> { let request = self.make_request(action); let result_stream = self diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs index dfcdd260602c..95bbe2b46bb2 100644 --- a/arrow-flight/src/decode.rs +++ b/arrow-flight/src/decode.rs @@ -21,9 +21,7 @@ use arrow_buffer::Buffer; use arrow_schema::{Schema, SchemaRef}; use bytes::Bytes; use futures::{ready, stream::BoxStream, Stream, StreamExt}; -use std::{ - collections::HashMap, convert::TryFrom, fmt::Debug, pin::Pin, sync::Arc, task::Poll, -}; +use std::{collections::HashMap, convert::TryFrom, fmt::Debug, pin::Pin, sync::Arc, task::Poll}; use tonic::metadata::MetadataMap; use crate::error::{FlightError, Result}; @@ -270,16 +268,14 @@ impl FlightDataDecoder { /// state as necessary. fn extract_message(&mut self, data: FlightData) -> Result> { use arrow_ipc::MessageHeader; - let message = arrow_ipc::root_as_message(&data.data_header[..]).map_err(|e| { - FlightError::DecodeError(format!("Error decoding root message: {e}")) - })?; + let message = arrow_ipc::root_as_message(&data.data_header[..]) + .map_err(|e| FlightError::DecodeError(format!("Error decoding root message: {e}")))?; match message.header_type() { MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))), MessageHeader::Schema => { - let schema = Schema::try_from(&data).map_err(|e| { - FlightError::DecodeError(format!("Error decoding schema: {e}")) - })?; + let schema = Schema::try_from(&data) + .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?; let schema = Arc::new(schema); let dictionaries_by_field = HashMap::new(); @@ -300,12 +296,11 @@ impl FlightDataDecoder { }; let buffer = Buffer::from_bytes(data.data_body.into()); - let dictionary_batch = - message.header_as_dictionary_batch().ok_or_else(|| { - FlightError::protocol( - "Could not get dictionary batch from DictionaryBatch message", - ) - })?; + let dictionary_batch = message.header_as_dictionary_batch().ok_or_else(|| { + FlightError::protocol( + "Could not get dictionary batch from DictionaryBatch message", + ) + })?; arrow_ipc::reader::read_dictionary( &buffer, @@ -315,9 +310,7 @@ impl FlightDataDecoder { &message.version(), ) .map_err(|e| { - FlightError::DecodeError(format!( - "Error decoding ipc dictionary: {e}" - )) + FlightError::DecodeError(format!("Error decoding ipc dictionary: {e}")) })?; // Updated internal state, but no decoded message @@ -338,9 +331,7 @@ impl FlightDataDecoder { &state.dictionaries_by_field, ) .map_err(|e| { - FlightError::DecodeError(format!( - "Error decoding ipc RecordBatch: {e}" - )) + FlightError::DecodeError(format!("Error decoding ipc RecordBatch: {e}")) })?; Ok(Some(DecodedFlightData::new_record_batch(data, batch))) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 9ae7f1637982..e6ef9994d487 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -159,10 +159,7 @@ impl FlightDataEncoderBuilder { } /// Set [`DictionaryHandling`] for encoder - pub fn with_dictionary_handling( - mut self, - dictionary_handling: DictionaryHandling, - ) -> Self { + pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self { self.dictionary_handling = dictionary_handling; self } @@ -191,10 +188,7 @@ impl FlightDataEncoderBuilder { } /// Specify a flight descriptor in the first FlightData message. - pub fn with_flight_descriptor( - mut self, - descriptor: Option, - ) -> Self { + pub fn with_flight_descriptor(mut self, descriptor: Option) -> Self { self.descriptor = descriptor; self } @@ -334,8 +328,7 @@ impl FlightDataEncoder { let batch = prepare_batch_for_flight(&batch, schema, send_dictionaries)?; for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) { - let (flight_dictionaries, flight_batch) = - self.encoder.encode_batch(&batch)?; + let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?; self.queue_messages(flight_dictionaries); self.queue_message(flight_batch); @@ -460,9 +453,8 @@ fn split_batch_for_grpc_response( .map(|col| col.get_buffer_memory_size()) .sum::(); - let n_batches = (size / max_flight_data_size - + usize::from(size % max_flight_data_size != 0)) - .max(1); + let n_batches = + (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1); let rows_per_batch = (batch.num_rows() / n_batches).max(1); let mut out = Vec::with_capacity(n_batches + 1); @@ -505,18 +497,12 @@ impl FlightIpcEncoder { /// Convert a `RecordBatch` to a Vec of `FlightData` representing /// dictionaries and a `FlightData` representing the batch - fn encode_batch( - &mut self, - batch: &RecordBatch, - ) -> Result<(Vec, FlightData)> { - let (encoded_dictionaries, encoded_batch) = self.data_gen.encoded_batch( - batch, - &mut self.dictionary_tracker, - &self.options, - )?; - - let flight_dictionaries = - encoded_dictionaries.into_iter().map(Into::into).collect(); + fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec, FlightData)> { + let (encoded_dictionaries, encoded_batch) = + self.data_gen + .encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?; + + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); let flight_batch = encoded_batch.into(); Ok((flight_dictionaries, flight_batch)) @@ -553,9 +539,7 @@ fn prepare_batch_for_flight( /// but does enable sending DictionaryArray's via Flight. fn hydrate_dictionary(array: &ArrayRef, send_dictionaries: bool) -> Result { let arr = match array.data_type() { - DataType::Dictionary(_, value) if !send_dictionaries => { - arrow_cast::cast(array, value)? - } + DataType::Dictionary(_, value) if !send_dictionaries => arrow_cast::cast(array, value)?, _ => Arc::clone(array), }; Ok(arr) @@ -586,11 +570,9 @@ mod tests { let (_, baseline_flight_batch) = make_flight_data(&batch, &options); let big_batch = batch.slice(0, batch.num_rows() - 1); - let optimized_big_batch = - prepare_batch_for_flight(&big_batch, Arc::clone(&schema), false) - .expect("failed to optimize"); - let (_, optimized_big_flight_batch) = - make_flight_data(&optimized_big_batch, &options); + let optimized_big_batch = prepare_batch_for_flight(&big_batch, Arc::clone(&schema), false) + .expect("failed to optimize"); + let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options); assert_eq!( baseline_flight_batch.data_body.len(), @@ -601,12 +583,10 @@ mod tests { let optimized_small_batch = prepare_batch_for_flight(&small_batch, Arc::clone(&schema), false) .expect("failed to optimize"); - let (_, optimized_small_flight_batch) = - make_flight_data(&optimized_small_batch, &options); + let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options); assert!( - baseline_flight_batch.data_body.len() - > optimized_small_flight_batch.data_body.len() + baseline_flight_batch.data_body.len() > optimized_small_flight_batch.data_body.len() ); } @@ -620,11 +600,10 @@ mod tests { false, )])); let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)]).unwrap(); - let encoder = FlightDataEncoderBuilder::default() - .build(futures::stream::once(async { Ok(batch) })); + let encoder = + FlightDataEncoderBuilder::default().build(futures::stream::once(async { Ok(batch) })); let mut decoder = FlightDataDecoder::new(encoder); - let expected_schema = - Schema::new(vec![Field::new("dict", DataType::Utf8, false)]); + let expected_schema = Schema::new(vec![Field::new("dict", DataType::Utf8, false)]); let expected_schema = Arc::new(expected_schema); while let Some(decoded) = decoder.next().await { let decoded = decoded.unwrap(); @@ -656,10 +635,8 @@ mod tests { Arc::new(vec!["a", "a", "b"].into_iter().collect()); let arr_two: Arc> = Arc::new(vec!["b", "a", "c"].into_iter().collect()); - let batch_one = - RecordBatch::try_new(schema.clone(), vec![arr_one.clone()]).unwrap(); - let batch_two = - RecordBatch::try_new(schema.clone(), vec![arr_two.clone()]).unwrap(); + let batch_one = RecordBatch::try_new(schema.clone(), vec![arr_one.clone()]).unwrap(); + let batch_two = RecordBatch::try_new(schema.clone(), vec![arr_two.clone()]).unwrap(); let encoder = FlightDataEncoderBuilder::default() .with_dictionary_handling(DictionaryHandling::Resend) @@ -675,10 +652,9 @@ mod tests { DecodedPayload::RecordBatch(b) => { assert_eq!(b.schema(), schema); - let actual_array = - Arc::new(downcast_array::>( - b.column_by_name("dict").unwrap(), - )); + let actual_array = Arc::new(downcast_array::>( + b.column_by_name("dict").unwrap(), + )); assert_eq!(actual_array, expected_array); @@ -690,10 +666,9 @@ mod tests { #[test] fn test_schema_metadata_encoded() { - let schema = - Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata( - HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), - ); + let schema = Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata( + HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), + ); let got = prepare_schema_for_flight(&schema, false); assert!(got.metadata().contains_key("some_key")); @@ -708,8 +683,7 @@ mod tests { ) .expect("cannot create record batch"); - prepare_batch_for_flight(&batch, batch.schema(), false) - .expect("failed to optimize"); + prepare_batch_for_flight(&batch, batch.schema(), false).expect("failed to optimize"); } pub fn make_flight_data( @@ -723,8 +697,7 @@ mod tests { .encoded_batch(batch, &mut dictionary_tracker, options) .expect("DictionaryTracker configured above to not error on replacement"); - let flight_dictionaries = - encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); let flight_batch = encoded_batch.into(); (flight_dictionaries, flight_batch) @@ -745,8 +718,7 @@ mod tests { // split once let n_rows = max_flight_data_size + 1; assert!(n_rows % 2 == 1, "should be an odd number"); - let c = - UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::>()); + let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::>()); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) .expect("cannot create record batch"); let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size); @@ -793,8 +765,7 @@ mod tests { let input_rows = batch.num_rows(); - let split = - split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes); + let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes); let sizes: Vec<_> = split.iter().map(|batch| batch.num_rows()).collect(); let output_rows: usize = sizes.iter().sum(); @@ -807,8 +778,7 @@ mod tests { #[tokio::test] async fn flight_data_size_even() { - let s1 = - StringArray::from_iter_values(std::iter::repeat(".10 bytes.").take(1024)); + let s1 = StringArray::from_iter_values(std::iter::repeat(".10 bytes.").take(1024)); let i1 = Int16Array::from_iter_values(0..1024); let s2 = StringArray::from_iter_values(std::iter::repeat("6bytes").take(1024)); let i2 = Int64Array::from_iter_values(0..1024); @@ -828,8 +798,7 @@ mod tests { async fn flight_data_size_uneven_variable_lengths() { // each row has a longer string than the last with increasing lengths 0 --> 1024 let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i))); - let batch = - RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap(); + let batch = RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap(); // overage is much higher than ideal // https://github.com/apache/arrow-rs/issues/3478 @@ -883,8 +852,7 @@ mod tests { }) .collect(); - let batch = - RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); + let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); verify_encoded_split(batch, 160).await; } @@ -894,11 +862,9 @@ mod tests { // large dictionary (all distinct values ==> 1024 entries in dictionary) let values: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect(); - let array: DictionaryArray = - values.iter().map(|s| Some(s.as_str())).collect(); + let array: DictionaryArray = values.iter().map(|s| Some(s.as_str())).collect(); - let batch = - RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); + let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); // overage is much higher than ideal // https://github.com/apache/arrow-rs/issues/3478 @@ -912,8 +878,7 @@ mod tests { let keys = Int32Array::from_iter_values((0..3000).map(|i| (3000 - i) % 1024)); let array = DictionaryArray::new(keys, Arc::new(values)); - let batch = - RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); + let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); // overage is much higher than ideal // https://github.com/apache/arrow-rs/issues/3478 @@ -929,12 +894,9 @@ mod tests { // medium cardinality let values3: Vec<_> = (1..1024).map(|i| "**".repeat(i % 100)).collect(); - let array1: DictionaryArray = - values1.iter().map(|s| Some(s.as_str())).collect(); - let array2: DictionaryArray = - values2.iter().map(|s| Some(s.as_str())).collect(); - let array3: DictionaryArray = - values3.iter().map(|s| Some(s.as_str())).collect(); + let array1: DictionaryArray = values1.iter().map(|s| Some(s.as_str())).collect(); + let array2: DictionaryArray = values2.iter().map(|s| Some(s.as_str())).collect(); + let array3: DictionaryArray = values3.iter().map(|s| Some(s.as_str())).collect(); let batch = RecordBatch::try_from_iter(vec![ ("a1", Arc::new(array1) as _), @@ -954,17 +916,13 @@ mod tests { .flight_descriptor .as_ref() .map(|descriptor| { - let path_len: usize = - descriptor.path.iter().map(|p| p.as_bytes().len()).sum(); + let path_len: usize = descriptor.path.iter().map(|p| p.as_bytes().len()).sum(); std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len }) .unwrap_or(0); - flight_descriptor_size - + d.app_metadata.len() - + d.data_body.len() - + d.data_header.len() + flight_descriptor_size + d.app_metadata.len() + d.data_body.len() + d.data_header.len() } /// Coverage for diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index 3035f109c685..8d05f658703a 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -133,10 +133,7 @@ pub struct IpcMessage(pub Bytes); // Useful conversion functions -fn flight_schema_as_encoded_data( - arrow_schema: &Schema, - options: &IpcWriteOptions, -) -> EncodedData { +fn flight_schema_as_encoded_data(arrow_schema: &Schema, options: &IpcWriteOptions) -> EncodedData { let data_gen = writer::IpcDataGenerator::default(); data_gen.schema_to_bytes(arrow_schema, options) } diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index 7685813ff844..133df5b044cf 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -31,17 +31,16 @@ use crate::flight_service_client::FlightServiceClient; use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT}; use crate::sql::{ ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, - ActionCreatePreparedStatementResult, Any, CommandGetCatalogs, - CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, - CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, - CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, + ActionCreatePreparedStatementResult, Any, CommandGetCatalogs, CommandGetCrossReference, + CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, + CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo, }; use crate::trailers::extract_lazy_trailers; use crate::{ - Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, - HandshakeResponse, IpcMessage, PutResult, Ticket, + Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, + IpcMessage, PutResult, Ticket, }; use arrow_array::RecordBatch; use arrow_buffer::Buffer; @@ -134,11 +133,7 @@ impl FlightSqlServiceClient { /// Perform a `handshake` with the server, passing credentials and establishing a session /// Returns arbitrary auth/handshake info binary blob - pub async fn handshake( - &mut self, - username: &str, - password: &str, - ) -> Result { + pub async fn handshake(&mut self, username: &str, password: &str) -> Result { let cmd = HandshakeRequest { protocol_version: 0, payload: Default::default(), @@ -156,9 +151,9 @@ impl FlightSqlServiceClient { .await .map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?; if let Some(auth) = resp.metadata().get("authorization") { - let auth = auth.to_str().map_err(|_| { - ArrowError::ParseError("Can't read auth header".to_string()) - })?; + let auth = auth + .to_str() + .map_err(|_| ArrowError::ParseError("Can't read auth header".to_string()))?; let bearer = "Bearer "; if !auth.starts_with(bearer) { Err(ArrowError::ParseError("Invalid auth header!".to_string()))?; @@ -166,10 +161,11 @@ impl FlightSqlServiceClient { let auth = auth[bearer.len()..].to_string(); self.token = Some(auth); } - let responses: Vec = - resp.into_inner().try_collect().await.map_err(|_| { - ArrowError::ParseError("Can't collect responses".to_string()) - })?; + let responses: Vec = resp + .into_inner() + .try_collect() + .await + .map_err(|_| ArrowError::ParseError("Can't collect responses".to_string()))?; let resp = match responses.as_slice() { [resp] => resp.payload.clone(), [] => Bytes::new(), @@ -209,8 +205,7 @@ impl FlightSqlServiceClient { .await .map_err(status_to_arrow_error)? .unwrap(); - let any = - Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; let result: DoPutUpdateResult = any.unpack()?.unwrap(); Ok(result.record_count) } @@ -405,17 +400,13 @@ impl FlightSqlServiceClient { ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}")) })?; let v = v.parse().map_err(|e| { - ArrowError::ParseError(format!( - "Cannot convert header value \"{v}\": {e}" - )) + ArrowError::ParseError(format!("Cannot convert header value \"{v}\": {e}")) })?; req.metadata_mut().insert(k, v); } if let Some(token) = &self.token { let val = format!("Bearer {token}").parse().map_err(|e| { - ArrowError::ParseError(format!( - "Cannot convert token to header value: {e}" - )) + ArrowError::ParseError(format!("Cannot convert token to header value: {e}")) })?; req.metadata_mut().insert("authorization", val); } @@ -484,8 +475,7 @@ impl PreparedStatement { .await .map_err(status_to_arrow_error)? .unwrap(); - let any = - Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; let result: DoPutUpdateResult = any.unpack()?.unwrap(); Ok(result.record_count) } @@ -501,10 +491,7 @@ impl PreparedStatement { } /// Set a RecordBatch that contains the parameters that will be bind. - pub fn set_parameters( - &mut self, - parameter_binding: RecordBatch, - ) -> Result<(), ArrowError> { + pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> Result<(), ArrowError> { self.parameter_binding = Some(parameter_binding); Ok(()) } @@ -580,19 +567,16 @@ pub fn arrow_data_from_flight_data( flight_data: FlightData, arrow_schema_ref: &SchemaRef, ) -> Result { - let ipc_message = root_as_message(&flight_data.data_header[..]).map_err(|err| { - ArrowError::ParseError(format!("Unable to get root as message: {err:?}")) - })?; + let ipc_message = root_as_message(&flight_data.data_header[..]) + .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?; match ipc_message.header_type() { MessageHeader::RecordBatch => { - let ipc_record_batch = - ipc_message.header_as_record_batch().ok_or_else(|| { - ArrowError::ComputeError( - "Unable to convert flight data header to a record batch" - .to_string(), - ) - })?; + let ipc_record_batch = ipc_message.header_as_record_batch().ok_or_else(|| { + ArrowError::ComputeError( + "Unable to convert flight data header to a record batch".to_string(), + ) + })?; let dictionaries_by_field = HashMap::new(); let record_batch = read_record_batch( @@ -618,13 +602,11 @@ pub fn arrow_data_from_flight_data( MessageHeader::DictionaryBatch => { let _ = ipc_message.header_as_dictionary_batch().ok_or_else(|| { ArrowError::ComputeError( - "Unable to convert flight data header to a dictionary batch" - .to_string(), + "Unable to convert flight data header to a dictionary batch".to_string(), ) })?; Err(ArrowError::NotYetImplemented( - "no idea on how to convert an ipc dictionary batch to an arrow type" - .to_string(), + "no idea on how to convert an ipc dictionary batch to an arrow type".to_string(), )) } MessageHeader::Tensor => { @@ -644,8 +626,7 @@ pub fn arrow_data_from_flight_data( ) })?; Err(ArrowError::NotYetImplemented( - "no idea on how to convert an ipc sparse tensor to an arrow type" - .to_string(), + "no idea on how to convert an ipc sparse tensor to an arrow type".to_string(), )) } _ => Err(ArrowError::ComputeError(format!( diff --git a/arrow-flight/src/sql/metadata/db_schemas.rs b/arrow-flight/src/sql/metadata/db_schemas.rs index 642802b058d5..303d11cd74ca 100644 --- a/arrow-flight/src/sql/metadata/db_schemas.rs +++ b/arrow-flight/src/sql/metadata/db_schemas.rs @@ -95,11 +95,7 @@ impl GetDbSchemasBuilder { /// Append a row /// /// In case the catalog should be considered as empty, pass in an empty string '""'. - pub fn append( - &mut self, - catalog_name: impl AsRef, - schema_name: impl AsRef, - ) { + pub fn append(&mut self, catalog_name: impl AsRef, schema_name: impl AsRef) { self.catalog_name.append_value(catalog_name); self.db_schema_name.append_value(schema_name); } diff --git a/arrow-flight/src/sql/metadata/sql_info.rs b/arrow-flight/src/sql/metadata/sql_info.rs index 88c97227814d..d4584f4a6827 100644 --- a/arrow-flight/src/sql/metadata/sql_info.rs +++ b/arrow-flight/src/sql/metadata/sql_info.rs @@ -30,8 +30,8 @@ use std::sync::Arc; use arrow_arith::boolean::or; use arrow_array::array::{Array, UInt32Array, UnionArray}; use arrow_array::builder::{ - ArrayBuilder, BooleanBuilder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, - MapBuilder, StringBuilder, UInt32Builder, + ArrayBuilder, BooleanBuilder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, MapBuilder, + StringBuilder, UInt32Builder, }; use arrow_array::{RecordBatch, Scalar}; use arrow_data::ArrayData; @@ -184,11 +184,7 @@ static UNION_TYPE: Lazy = Lazy::new(|| { Field::new("keys", DataType::Int32, false), Field::new( "values", - DataType::List(Arc::new(Field::new( - "item", - DataType::Int32, - true, - ))), + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), true, ), ])), @@ -420,10 +416,7 @@ pub struct SqlInfoData { impl SqlInfoData { /// Return a [`RecordBatch`] containing only the requested `u32`, if any /// from [`CommandGetSqlInfo`] - pub fn record_batch( - &self, - info: impl IntoIterator, - ) -> Result { + pub fn record_batch(&self, info: impl IntoIterator) -> Result { let arr = self.batch.column(0); let type_filter = info .into_iter() @@ -493,9 +486,7 @@ mod tests { use super::SqlInfoDataBuilder; use crate::sql::metadata::tests::assert_batches_eq; - use crate::sql::{ - SqlInfo, SqlNullOrdering, SqlSupportedTransaction, SqlSupportsConvert, - }; + use crate::sql::{SqlInfo, SqlNullOrdering, SqlSupportedTransaction, SqlSupportsConvert}; #[test] fn test_sql_infos() { diff --git a/arrow-flight/src/sql/metadata/tables.rs b/arrow-flight/src/sql/metadata/tables.rs index 00502a76db53..7ffb76fa1d5f 100644 --- a/arrow-flight/src/sql/metadata/tables.rs +++ b/arrow-flight/src/sql/metadata/tables.rs @@ -329,12 +329,12 @@ mod tests { "b_catalog", ])) as ArrayRef, Arc::new(StringArray::from(vec![ - "a_schema", "a_schema", "b_schema", "b_schema", "a_schema", - "a_schema", "b_schema", "b_schema", + "a_schema", "a_schema", "b_schema", "b_schema", "a_schema", "a_schema", + "b_schema", "b_schema", ])) as ArrayRef, Arc::new(StringArray::from(vec![ - "a_table", "b_table", "a_table", "b_table", "a_table", "a_table", - "b_table", "b_table", + "a_table", "b_table", "a_table", "b_table", "a_table", "a_table", "b_table", + "b_table", ])) as ArrayRef, Arc::new(StringArray::from(vec![ "TABLE", "TABLE", "TABLE", "TABLE", "TABLE", "VIEW", "TABLE", "VIEW", diff --git a/arrow-flight/src/sql/metadata/xdbc_info.rs b/arrow-flight/src/sql/metadata/xdbc_info.rs index 8212c847a4fa..2e635d3037bc 100644 --- a/arrow-flight/src/sql/metadata/xdbc_info.rs +++ b/arrow-flight/src/sql/metadata/xdbc_info.rs @@ -36,9 +36,7 @@ use once_cell::sync::Lazy; use super::lexsort_to_indices; use crate::error::*; -use crate::sql::{ - CommandGetXdbcTypeInfo, Nullable, Searchable, XdbcDataType, XdbcDatetimeSubcode, -}; +use crate::sql::{CommandGetXdbcTypeInfo, Nullable, Searchable, XdbcDataType, XdbcDatetimeSubcode}; /// Data structure representing type information for xdbc types. #[derive(Debug, Clone, Default)] @@ -201,8 +199,7 @@ impl XdbcTypeInfoDataBuilder { minimum_scale_builder.append_option(info.minimum_scale); maximum_scale_builder.append_option(info.maximum_scale); sql_data_type_builder.append_value(info.sql_data_type as i32); - datetime_subcode_builder - .append_option(info.datetime_subcode.map(|code| code as i32)); + datetime_subcode_builder.append_option(info.datetime_subcode.map(|code| code as i32)); num_prec_radix_builder.append_option(info.num_prec_radix); interval_precision_builder.append_option(info.interval_precision); }); @@ -215,8 +212,7 @@ impl XdbcTypeInfoDataBuilder { let (field, offsets, values, nulls) = create_params_builder.finish().into_parts(); // Re-defined the field to be non-nullable let new_field = Arc::new(field.as_ref().clone().with_nullable(false)); - let create_params = - Arc::new(ListArray::new(new_field, offsets, values, nulls)) as ArrayRef; + let create_params = Arc::new(ListArray::new(new_field, offsets, values, nulls)) as ArrayRef; let nullable = Arc::new(nullable_builder.finish()); let case_sensitive = Arc::new(case_sensitive_builder.finish()); let searchable = Arc::new(searchable_builder.finish()); diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 4042ce8efc46..97645ae7840d 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -295,9 +295,8 @@ impl Any { if !self.is::() { return Ok(None); } - let m = Message::decode(&*self.value).map_err(|err| { - ArrowError::ParseError(format!("Unable to decode Any value: {err}")) - })?; + let m = Message::decode(&*self.value) + .map_err(|err| ArrowError::ParseError(format!("Unable to decode Any value: {err}")))?; Ok(Some(m)) } diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index a158ed77f54d..14ab7d81b4f3 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -24,23 +24,21 @@ use prost::Message; use tonic::{Request, Response, Status, Streaming}; use super::{ - ActionBeginSavepointRequest, ActionBeginSavepointResult, - ActionBeginTransactionRequest, ActionBeginTransactionResult, - ActionCancelQueryRequest, ActionCancelQueryResult, + ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest, + ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult, ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, - ActionEndSavepointRequest, ActionEndTransactionRequest, Any, Command, - CommandGetCatalogs, CommandGetCrossReference, CommandGetDbSchemas, - CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, - CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, - CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, - CommandStatementSubstraitPlan, CommandStatementUpdate, DoPutUpdateResult, - ProstMessageExt, SqlInfo, TicketStatementQuery, + ActionEndSavepointRequest, ActionEndTransactionRequest, Any, Command, CommandGetCatalogs, + CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, + CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, + CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, + CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, + DoPutUpdateResult, ProstMessageExt, SqlInfo, TicketStatementQuery, }; use crate::{ - flight_service_server::FlightService, Action, ActionType, Criteria, Empty, - FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, - PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, Action, ActionType, Criteria, Empty, FlightData, + FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, + Ticket, }; pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement"; @@ -549,13 +547,10 @@ where Pin> + Send + 'static>>; type ListFlightsStream = Pin> + Send + 'static>>; - type DoGetStream = - Pin> + Send + 'static>>; - type DoPutStream = - Pin> + Send + 'static>>; - type DoActionStream = Pin< - Box> + Send + 'static>, - >; + type DoGetStream = Pin> + Send + 'static>>; + type DoPutStream = Pin> + Send + 'static>>; + type DoActionStream = + Pin> + Send + 'static>>; type ListActionsStream = Pin> + Send + 'static>>; type DoExchangeStream = @@ -580,8 +575,7 @@ where &self, request: Request, ) -> Result, Status> { - let message = - Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?; + let message = Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?; match Command::try_from(message).map_err(arrow_error_to_status)? { Command::CommandStatementQuery(token) => { @@ -600,9 +594,7 @@ where Command::CommandGetDbSchemas(token) => { return self.get_flight_info_schemas(token, request).await } - Command::CommandGetTables(token) => { - self.get_flight_info_tables(token, request).await - } + Command::CommandGetTables(token) => self.get_flight_info_tables(token, request).await, Command::CommandGetTableTypes(token) => { self.get_flight_info_table_types(token, request).await } @@ -642,31 +634,21 @@ where &self, request: Request, ) -> Result, Status> { - let msg: Any = Message::decode(&*request.get_ref().ticket) - .map_err(decode_error_to_status)?; + let msg: Any = + Message::decode(&*request.get_ref().ticket).map_err(decode_error_to_status)?; match Command::try_from(msg).map_err(arrow_error_to_status)? { - Command::TicketStatementQuery(command) => { - self.do_get_statement(command, request).await - } + Command::TicketStatementQuery(command) => self.do_get_statement(command, request).await, Command::CommandPreparedStatementQuery(command) => { self.do_get_prepared_statement(command, request).await } - Command::CommandGetCatalogs(command) => { - self.do_get_catalogs(command, request).await - } - Command::CommandGetDbSchemas(command) => { - self.do_get_schemas(command, request).await - } - Command::CommandGetTables(command) => { - self.do_get_tables(command, request).await - } + Command::CommandGetCatalogs(command) => self.do_get_catalogs(command, request).await, + Command::CommandGetDbSchemas(command) => self.do_get_schemas(command, request).await, + Command::CommandGetTables(command) => self.do_get_tables(command, request).await, Command::CommandGetTableTypes(command) => { self.do_get_table_types(command, request).await } - Command::CommandGetSqlInfo(command) => { - self.do_get_sql_info(command, request).await - } + Command::CommandGetSqlInfo(command) => self.do_get_sql_info(command, request).await, Command::CommandGetPrimaryKeys(command) => { self.do_get_primary_keys(command, request).await } @@ -699,8 +681,8 @@ where let mut request = request.map(PeekableFlightDataStream::new); let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?; - let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd) - .map_err(decode_error_to_status)?; + let message = + Any::decode(&*cmd.flight_descriptor.unwrap().cmd).map_err(decode_error_to_status)?; match Command::try_from(message).map_err(arrow_error_to_status)? { Command::CommandStatementUpdate(command) => { let record_count = self.do_put_statement_update(command, request).await?; @@ -755,11 +737,10 @@ where }; let create_prepared_substrait_plan_action_type = ActionType { r#type: CREATE_PREPARED_SUBSTRAIT_PLAN.to_string(), - description: - "Creates a reusable prepared substrait plan resource on the server.\n + description: "Creates a reusable prepared substrait plan resource on the server.\n Request Message: ActionCreatePreparedSubstraitPlanRequest\n Response Message: ActionCreatePreparedStatementResult" - .into(), + .into(), }; let begin_transaction_action_type = ActionType { r#type: BEGIN_TRANSACTION.to_string(), @@ -820,8 +801,7 @@ where request: Request, ) -> Result, Status> { if request.get_ref().r#type == CREATE_PREPARED_STATEMENT { - let any = - Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionCreatePreparedStatementRequest = any .unpack() @@ -839,8 +819,7 @@ where })]); return Ok(Response::new(Box::pin(output))); } else if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT { - let any = - Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionClosePreparedStatementRequest = any .unpack() @@ -854,8 +833,7 @@ where .await?; return Ok(Response::new(Box::pin(futures::stream::empty()))); } else if request.get_ref().r#type == CREATE_PREPARED_SUBSTRAIT_PLAN { - let any = - Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionCreatePreparedSubstraitPlanRequest = any .unpack() @@ -869,47 +847,38 @@ where .await?; return Ok(Response::new(Box::pin(futures::stream::empty()))); } else if request.get_ref().r#type == BEGIN_TRANSACTION { - let any = - Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionBeginTransactionRequest = any .unpack() .map_err(arrow_error_to_status)? .ok_or_else(|| { - Status::invalid_argument( - "Unable to unpack ActionBeginTransactionRequest.", - ) - })?; + Status::invalid_argument("Unable to unpack ActionBeginTransactionRequest.") + })?; let stmt = self.do_action_begin_transaction(cmd, request).await?; let output = futures::stream::iter(vec![Ok(super::super::gen::Result { body: stmt.as_any().encode_to_vec().into(), })]); return Ok(Response::new(Box::pin(output))); } else if request.get_ref().r#type == END_TRANSACTION { - let any = - Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionEndTransactionRequest = any .unpack() .map_err(arrow_error_to_status)? .ok_or_else(|| { - Status::invalid_argument( - "Unable to unpack ActionEndTransactionRequest.", - ) + Status::invalid_argument("Unable to unpack ActionEndTransactionRequest.") })?; self.do_action_end_transaction(cmd, request).await?; return Ok(Response::new(Box::pin(futures::stream::empty()))); } else if request.get_ref().r#type == BEGIN_SAVEPOINT { - let any = - Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionBeginSavepointRequest = any .unpack() .map_err(arrow_error_to_status)? .ok_or_else(|| { - Status::invalid_argument( - "Unable to unpack ActionBeginSavepointRequest.", - ) + Status::invalid_argument("Unable to unpack ActionBeginSavepointRequest.") })?; let stmt = self.do_action_begin_savepoint(cmd, request).await?; let output = futures::stream::iter(vec![Ok(super::super::gen::Result { @@ -917,22 +886,18 @@ where })]); return Ok(Response::new(Box::pin(output))); } else if request.get_ref().r#type == END_SAVEPOINT { - let any = - Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionEndSavepointRequest = any .unpack() .map_err(arrow_error_to_status)? .ok_or_else(|| { - Status::invalid_argument( - "Unable to unpack ActionEndSavepointRequest.", - ) + Status::invalid_argument("Unable to unpack ActionEndSavepointRequest.") })?; self.do_action_end_savepoint(cmd, request).await?; return Ok(Response::new(Box::pin(futures::stream::empty()))); } else if request.get_ref().r#type == CANCEL_QUERY { - let any = - Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionCancelQueryRequest = any .unpack() diff --git a/arrow-flight/src/trailers.rs b/arrow-flight/src/trailers.rs index d652542da779..73136379d69f 100644 --- a/arrow-flight/src/trailers.rs +++ b/arrow-flight/src/trailers.rs @@ -28,9 +28,7 @@ use tonic::{metadata::MetadataMap, Status, Streaming}; /// /// Note that [`LazyTrailers`] has inner mutability and will only hold actual data after [`ExtractTrailersStream`] is /// fully consumed (dropping it is not required though). -pub fn extract_lazy_trailers( - s: Streaming, -) -> (ExtractTrailersStream, LazyTrailers) { +pub fn extract_lazy_trailers(s: Streaming) -> (ExtractTrailersStream, LazyTrailers) { let trailers: SharedTrailers = Default::default(); let stream = ExtractTrailersStream { inner: s, @@ -54,10 +52,7 @@ pub struct ExtractTrailersStream { impl Stream for ExtractTrailersStream { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let res = ready!(self.inner.poll_next_unpin(cx)); if res.is_none() { diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 145626b6608f..b75d61d200cb 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -52,26 +52,23 @@ pub fn flight_data_from_arrow_batch( } /// Convert a slice of wire protocol `FlightData`s into a vector of `RecordBatch`es -pub fn flight_data_to_batches( - flight_data: &[FlightData], -) -> Result, ArrowError> { +pub fn flight_data_to_batches(flight_data: &[FlightData]) -> Result, ArrowError> { let schema = flight_data.get(0).ok_or_else(|| { ArrowError::CastError("Need at least one FlightData for schema".to_string()) })?; let message = root_as_message(&schema.data_header[..]) .map_err(|_| ArrowError::CastError("Cannot get root as message".to_string()))?; - let ipc_schema: arrow_ipc::Schema = message.header_as_schema().ok_or_else(|| { - ArrowError::CastError("Cannot get header as Schema".to_string()) - })?; + let ipc_schema: arrow_ipc::Schema = message + .header_as_schema() + .ok_or_else(|| ArrowError::CastError("Cannot get header as Schema".to_string()))?; let schema = fb_to_schema(ipc_schema); let schema = Arc::new(schema); let mut batches = vec![]; let dictionaries_by_id = HashMap::new(); for datum in flight_data[1..].iter() { - let batch = - flight_data_to_arrow_batch(datum, schema.clone(), &dictionaries_by_id)?; + let batch = flight_data_to_arrow_batch(datum, schema.clone(), &dictionaries_by_id)?; batches.push(batch); } Ok(batches) @@ -84,9 +81,8 @@ pub fn flight_data_to_arrow_batch( dictionaries_by_id: &HashMap, ) -> Result { // check that the data_header is a record batch message - let message = arrow_ipc::root_as_message(&data.data_header[..]).map_err(|err| { - ArrowError::ParseError(format!("Unable to get root as message: {err:?}")) - })?; + let message = arrow_ipc::root_as_message(&data.data_header[..]) + .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?; message .header_as_record_batch() @@ -124,10 +120,7 @@ pub fn flight_schema_from_arrow_schema( since = "4.4.0", note = "Use From trait, e.g.: SchemaAsIpc::new(schema, options).into()" )] -pub fn flight_data_from_arrow_schema( - schema: &Schema, - options: &IpcWriteOptions, -) -> FlightData { +pub fn flight_data_from_arrow_schema(schema: &Schema, options: &IpcWriteOptions) -> FlightData { SchemaAsIpc::new(schema, options).into() } diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs index 1b9891e121fa..3ad9ee7a45ca 100644 --- a/arrow-flight/tests/client.rs +++ b/arrow-flight/tests/client.rs @@ -23,9 +23,9 @@ mod common { } use arrow_array::{RecordBatch, UInt64Array}; use arrow_flight::{ - decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, - error::FlightError, Action, ActionType, Criteria, Empty, FlightClient, FlightData, - FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, Ticket, + decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, error::FlightError, Action, + ActionType, Criteria, Empty, FlightClient, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PutResult, Ticket, }; use arrow_schema::{DataType, Field, Schema}; use bytes::Bytes; @@ -271,8 +271,7 @@ async fn test_do_put() { }, ]; - test_server - .set_do_put_response(expected_response.clone().into_iter().map(Ok).collect()); + test_server.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect()); let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok); @@ -446,9 +445,8 @@ async fn test_do_exchange() { let input_flight_data = test_flight_data().await; let output_flight_data = test_flight_data2().await; - test_server.set_do_exchange_response( - output_flight_data.clone().into_iter().map(Ok).collect(), - ); + test_server + .set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect()); let response_stream = client .do_exchange(futures::stream::iter(input_flight_data.clone())) diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs index c575d12bbf52..8b162d398c4b 100644 --- a/arrow-flight/tests/common/server.rs +++ b/arrow-flight/tests/common/server.rs @@ -174,10 +174,7 @@ impl TestFlightServer { } /// Specify the response returned from the next call to `do_action` - pub fn set_do_action_response( - &self, - response: Vec>, - ) { + pub fn set_do_action_response(&self, response: Vec>) { let mut state = self.state.lock().expect("mutex not poisoned"); state.do_action_response.replace(response); } @@ -278,9 +275,10 @@ impl FlightService for TestFlightServer { let mut state = self.state.lock().expect("mutex not poisoned"); state.handshake_request = Some(handshake_request); - let response = state.handshake_response.take().unwrap_or_else(|| { - Err(Status::internal("No handshake response configured")) - })?; + let response = state + .handshake_response + .take() + .unwrap_or_else(|| Err(Status::internal("No handshake response configured")))?; // turn into a streaming response let output = futures::stream::iter(std::iter::once(Ok(response))); @@ -313,9 +311,10 @@ impl FlightService for TestFlightServer { self.save_metadata(&request); let mut state = self.state.lock().expect("mutex not poisoned"); state.get_flight_info_request = Some(request.into_inner()); - let response = state.get_flight_info_response.take().unwrap_or_else(|| { - Err(Status::internal("No get_flight_info response configured")) - })?; + let response = state + .get_flight_info_response + .take() + .unwrap_or_else(|| Err(Status::internal("No get_flight_info response configured")))?; Ok(Response::new(response)) } @@ -326,9 +325,10 @@ impl FlightService for TestFlightServer { self.save_metadata(&request); let mut state = self.state.lock().expect("mutex not poisoned"); state.get_schema_request = Some(request.into_inner()); - let schema = state.get_schema_response.take().unwrap_or_else(|| { - Err(Status::internal("No get_schema response configured")) - })?; + let schema = state + .get_schema_response + .take() + .unwrap_or_else(|| Err(Status::internal("No get_schema response configured")))?; // encode the schema let options = arrow_ipc::writer::IpcWriteOptions::default(); diff --git a/arrow-flight/tests/common/trailers_layer.rs b/arrow-flight/tests/common/trailers_layer.rs index 9e6be0dcf0da..b2ab74f7d925 100644 --- a/arrow-flight/tests/common/trailers_layer.rs +++ b/arrow-flight/tests/common/trailers_layer.rs @@ -81,9 +81,7 @@ where ready!(self.as_mut().project().inner.poll(cx)); match result { - Ok(response) => { - Poll::Ready(Ok(response.map(|body| WrappedBody { inner: body }))) - } + Ok(response) => Poll::Ready(Ok(response.map(|body| WrappedBody { inner: body }))), Err(e) => Poll::Ready(Err(e)), } } diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index 71bcf4e0521a..f4741d743e57 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -195,8 +195,7 @@ async fn test_app_metadata() { let encode_stream = encoder.build(input_batch_stream); // use lower level stream to get access to app metadata - let decode_stream = - FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner(); + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner(); let mut messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails"); @@ -225,8 +224,7 @@ async fn test_max_message_size() { let encode_stream = encoder.build(input_batch_stream); // use lower level stream to get access to app metadata - let decode_stream = - FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner(); + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner(); let messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails"); @@ -254,8 +252,8 @@ async fn test_max_message_size_fuzz() { ]; for max_message_size_bytes in [10, 1024, 2048, 6400, 3211212] { - let encoder = FlightDataEncoderBuilder::default() - .with_max_flight_data_size(max_message_size_bytes); + let encoder = + FlightDataEncoderBuilder::default().with_max_flight_data_size(max_message_size_bytes); let input_batch_stream = futures::stream::iter(input.clone()).map(Ok); @@ -299,10 +297,10 @@ async fn test_chained_streams_batch_decoder() { let batch2 = make_dictionary_batch(3); // Model sending two flight streams back to back, with different schemas - let encode_stream1 = FlightDataEncoderBuilder::default() - .build(futures::stream::iter(vec![Ok(batch1.clone())])); - let encode_stream2 = FlightDataEncoderBuilder::default() - .build(futures::stream::iter(vec![Ok(batch2.clone())])); + let encode_stream1 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch1.clone())])); + let encode_stream2 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch2.clone())])); // append the two streams (so they will have two different schema messages) let encode_stream = encode_stream1.chain(encode_stream2); @@ -324,10 +322,10 @@ async fn test_chained_streams_data_decoder() { let batch2 = make_dictionary_batch(3); // Model sending two flight streams back to back, with different schemas - let encode_stream1 = FlightDataEncoderBuilder::default() - .build(futures::stream::iter(vec![Ok(batch1.clone())])); - let encode_stream2 = FlightDataEncoderBuilder::default() - .build(futures::stream::iter(vec![Ok(batch2.clone())])); + let encode_stream1 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch1.clone())])); + let encode_stream2 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch2.clone())])); // append the two streams (so they will have two different schema messages) let encode_stream = encode_stream1.chain(encode_stream2); @@ -335,8 +333,7 @@ async fn test_chained_streams_data_decoder() { // lower level decode stream can handle multiple schema messages let decode_stream = FlightDataDecoder::new(encode_stream); - let decoded_data: Vec<_> = - decode_stream.try_collect().await.expect("encode / decode"); + let decoded_data: Vec<_> = decode_stream.try_collect().await.expect("encode / decode"); println!("decoded data: {decoded_data:#?}"); @@ -425,8 +422,7 @@ fn make_primitive_batch(num_rows: usize) -> RecordBatch { }) .collect(); - RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]) - .unwrap() + RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]).unwrap() } /// Make a dictionary batch for testing @@ -459,8 +455,7 @@ fn make_dictionary_batch(num_rows: usize) -> RecordBatch { /// match the input. async fn roundtrip(input: Vec) { let expected_output = input.clone(); - roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output) - .await + roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output).await } /// Encodes input as a FlightData stream, and then decodes it using @@ -475,8 +470,7 @@ async fn roundtrip_dictionary(input: Vec) { .iter() .map(|batch| prepare_batch_for_flight(batch, schema.clone()).unwrap()) .collect(); - roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output) - .await + roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output).await } async fn roundtrip_with_encoder( @@ -491,8 +485,7 @@ async fn roundtrip_with_encoder( let encode_stream = encoder.build(input_batch_stream); let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); - let output_batches: Vec<_> = - decode_stream.try_collect().await.expect("encode / decode"); + let output_batches: Vec<_> = decode_stream.try_collect().await.expect("encode / decode"); // remove any empty batches from input as they are not transmitted let expected_batches: Vec<_> = expected_batches diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index 221e776218c3..a28080450bc2 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -23,18 +23,16 @@ use arrow_flight::{ flight_service_server::{FlightService, FlightServiceServer}, sql::{ server::{FlightSqlService, PeekableFlightDataStream}, - ActionBeginSavepointRequest, ActionBeginSavepointResult, - ActionBeginTransactionRequest, ActionBeginTransactionResult, - ActionCancelQueryRequest, ActionCancelQueryResult, + ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest, + ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult, ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs, CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, - CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, - CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, - CommandPreparedStatementQuery, CommandPreparedStatementUpdate, - CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, - ProstMessageExt, SqlInfo, TicketStatementQuery, + CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, + CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, + CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan, + CommandStatementUpdate, ProstMessageExt, SqlInfo, TicketStatementQuery, }, utils::batches_to_flight_data, Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, @@ -168,8 +166,7 @@ impl FlightSqlServiceImpl { RecordBatch::try_new(Arc::new(schema), cols) } - fn create_fake_prepared_stmt( - ) -> Result { + fn create_fake_prepared_stmt() -> Result { let handle = PREPARED_STATEMENT_HANDLE.to_string(); let schema = Schema::new(vec![ Field::new("field_string", DataType::Utf8, false), diff --git a/arrow-integration-test/src/datatype.rs b/arrow-integration-test/src/datatype.rs index 47bacc7cc74b..42ac71fbbd7e 100644 --- a/arrow-integration-test/src/datatype.rs +++ b/arrow-integration-test/src/datatype.rs @@ -124,26 +124,16 @@ pub fn data_type_from_json(json: &serde_json::Value) -> Result { } Some(s) if s == "duration" => match map.get("unit") { Some(p) if p == "SECOND" => Ok(DataType::Duration(TimeUnit::Second)), - Some(p) if p == "MILLISECOND" => { - Ok(DataType::Duration(TimeUnit::Millisecond)) - } - Some(p) if p == "MICROSECOND" => { - Ok(DataType::Duration(TimeUnit::Microsecond)) - } - Some(p) if p == "NANOSECOND" => { - Ok(DataType::Duration(TimeUnit::Nanosecond)) - } + Some(p) if p == "MILLISECOND" => Ok(DataType::Duration(TimeUnit::Millisecond)), + Some(p) if p == "MICROSECOND" => Ok(DataType::Duration(TimeUnit::Microsecond)), + Some(p) if p == "NANOSECOND" => Ok(DataType::Duration(TimeUnit::Nanosecond)), _ => Err(ArrowError::ParseError( "time unit missing or invalid".to_string(), )), }, Some(s) if s == "interval" => match map.get("unit") { - Some(p) if p == "DAY_TIME" => { - Ok(DataType::Interval(IntervalUnit::DayTime)) - } - Some(p) if p == "YEAR_MONTH" => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } + Some(p) if p == "DAY_TIME" => Ok(DataType::Interval(IntervalUnit::DayTime)), + Some(p) if p == "YEAR_MONTH" => Ok(DataType::Interval(IntervalUnit::YearMonth)), Some(p) if p == "MONTH_DAY_NANO" => { Ok(DataType::Interval(IntervalUnit::MonthDayNano)) } diff --git a/arrow-integration-test/src/field.rs b/arrow-integration-test/src/field.rs index f59314ca02db..32edc4165938 100644 --- a/arrow-integration-test/src/field.rs +++ b/arrow-integration-test/src/field.rs @@ -63,18 +63,17 @@ pub fn field_from_json(json: &serde_json::Value) -> Result { "Field 'metadata' must have exact two entries for each key-value map".to_string(), )); } - if let (Some(k), Some(v)) = - (map.get("key"), map.get("value")) - { - if let (Some(k_str), Some(v_str)) = - (k.as_str(), v.as_str()) - { + if let (Some(k), Some(v)) = (map.get("key"), map.get("value")) { + if let (Some(k_str), Some(v_str)) = (k.as_str(), v.as_str()) { res.insert( k_str.to_string().clone(), v_str.to_string().clone(), ); } else { - return Err(ArrowError::ParseError("Field 'metadata' must have map value of string type".to_string())); + return Err(ArrowError::ParseError( + "Field 'metadata' must have map value of string type" + .to_string(), + )); } } else { return Err(ArrowError::ParseError("Field 'metadata' lacks map keys named \"key\" or \"value\"".to_string())); @@ -115,46 +114,47 @@ pub fn field_from_json(json: &serde_json::Value) -> Result { // if data_type is a struct or list, get its children let data_type = match data_type { - DataType::List(_) - | DataType::LargeList(_) - | DataType::FixedSizeList(_, _) => match map.get("children") { - Some(Value::Array(values)) => { - if values.len() != 1 { + DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => { + match map.get("children") { + Some(Value::Array(values)) => { + if values.len() != 1 { + return Err(ArrowError::ParseError( + "Field 'children' must have one element for a list data type" + .to_string(), + )); + } + match data_type { + DataType::List(_) => { + DataType::List(Arc::new(field_from_json(&values[0])?)) + } + DataType::LargeList(_) => { + DataType::LargeList(Arc::new(field_from_json(&values[0])?)) + } + DataType::FixedSizeList(_, int) => DataType::FixedSizeList( + Arc::new(field_from_json(&values[0])?), + int, + ), + _ => unreachable!( + "Data type should be a list, largelist or fixedsizelist" + ), + } + } + Some(_) => { return Err(ArrowError::ParseError( - "Field 'children' must have one element for a list data type".to_string(), - )); + "Field 'children' must be an array".to_string(), + )) } - match data_type { - DataType::List(_) => { - DataType::List(Arc::new(field_from_json(&values[0])?)) - } - DataType::LargeList(_) => DataType::LargeList(Arc::new( - field_from_json(&values[0])?, - )), - DataType::FixedSizeList(_, int) => DataType::FixedSizeList( - Arc::new(field_from_json(&values[0])?), - int, - ), - _ => unreachable!( - "Data type should be a list, largelist or fixedsizelist" - ), + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); } } - Some(_) => { - return Err(ArrowError::ParseError( - "Field 'children' must be an array".to_string(), - )) - } - None => { - return Err(ArrowError::ParseError( - "Field missing 'children' attribute".to_string(), - )); - } - }, + } DataType::Struct(_) => match map.get("children") { - Some(Value::Array(values)) => DataType::Struct( - values.iter().map(field_from_json).collect::>()?, - ), + Some(Value::Array(values)) => { + DataType::Struct(values.iter().map(field_from_json).collect::>()?) + } Some(_) => { return Err(ArrowError::ParseError( "Field 'children' must be an array".to_string(), @@ -175,17 +175,16 @@ pub fn field_from_json(json: &serde_json::Value) -> Result { DataType::Struct(map_fields) if map_fields.len() == 2 => { DataType::Map(Arc::new(child), keys_sorted) } - t => { - return Err(ArrowError::ParseError( - format!("Map children should be a struct with 2 fields, found {t:?}") - )) + t => { + return Err(ArrowError::ParseError(format!( + "Map children should be a struct with 2 fields, found {t:?}" + ))) } } } Some(_) => { return Err(ArrowError::ParseError( - "Field 'children' must be an array with 1 element" - .to_string(), + "Field 'children' must be an array with 1 element".to_string(), )) } None => { @@ -200,9 +199,7 @@ pub fn field_from_json(json: &serde_json::Value) -> Result { let fields = fields .iter() .zip(values) - .map(|((id, _), value)| { - Ok((id, Arc::new(field_from_json(value)?))) - }) + .map(|((id, _), value)| Ok((id, Arc::new(field_from_json(value)?)))) .collect::>()?; DataType::Union(fields, mode) @@ -255,8 +252,7 @@ pub fn field_from_json(json: &serde_json::Value) -> Result { _ => data_type, }; - let mut field = - Field::new_dict(name, data_type, nullable, dict_id, dict_is_ordered); + let mut field = Field::new_dict(name, data_type, nullable, dict_id, dict_is_ordered); field.set_metadata(metadata); Ok(field) } @@ -269,9 +265,7 @@ pub fn field_from_json(json: &serde_json::Value) -> Result { /// Generate a JSON representation of the `Field`. pub fn field_to_json(field: &Field) -> serde_json::Value { let children: Vec = match field.data_type() { - DataType::Struct(fields) => { - fields.iter().map(|x| field_to_json(x.as_ref())).collect() - } + DataType::Struct(fields) => fields.iter().map(|x| field_to_json(x.as_ref())).collect(), DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) diff --git a/arrow-integration-test/src/lib.rs b/arrow-integration-test/src/lib.rs index 07b69bffd07d..7b797aa07061 100644 --- a/arrow-integration-test/src/lib.rs +++ b/arrow-integration-test/src/lib.rs @@ -261,9 +261,7 @@ impl ArrowJsonField { true } Err(e) => { - eprintln!( - "Encountered error while converting JSON field to Arrow field: {e:?}" - ); + eprintln!("Encountered error while converting JSON field to Arrow field: {e:?}"); false } } @@ -273,8 +271,8 @@ impl ArrowJsonField { /// TODO: convert to use an Into fn to_arrow_field(&self) -> Result { // a bit regressive, but we have to convert the field to JSON in order to convert it - let field = serde_json::to_value(self) - .map_err(|error| ArrowError::JsonError(error.to_string()))?; + let field = + serde_json::to_value(self).map_err(|error| ArrowError::JsonError(error.to_string()))?; field_from_json(&field) } } @@ -389,12 +387,9 @@ pub fn array_from_json( match is_valid { 1 => b.append_value(match value { Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => { - s.parse().expect("Unable to parse string as i64") - } + Value::String(s) => s.parse().expect("Unable to parse string as i64"), Value::Object(ref map) - if map.contains_key("days") - && map.contains_key("milliseconds") => + if map.contains_key("days") && map.contains_key("milliseconds") => { match field.data_type() { DataType::Interval(IntervalUnit::DayTime) => { @@ -404,23 +399,19 @@ pub fn array_from_json( match (days, milliseconds) { (Value::Number(d), Value::Number(m)) => { let mut bytes = [0_u8; 8]; - let m = (m.as_i64().unwrap() as i32) - .to_le_bytes(); - let d = (d.as_i64().unwrap() as i32) - .to_le_bytes(); + let m = (m.as_i64().unwrap() as i32).to_le_bytes(); + let d = (d.as_i64().unwrap() as i32).to_le_bytes(); let c = [d, m].concat(); bytes.copy_from_slice(c.as_slice()); i64::from_le_bytes(bytes) } - _ => panic!( - "Unable to parse {value:?} as interval daytime" - ), + _ => { + panic!("Unable to parse {value:?} as interval daytime") + } } } - _ => panic!( - "Unable to parse {value:?} as interval daytime" - ), + _ => panic!("Unable to parse {value:?} as interval daytime"), } } _ => panic!("Unable to parse {value:?} as number"), @@ -499,9 +490,7 @@ pub fn array_from_json( .expect("Unable to parse string as u64"), ) } else if value.is_number() { - b.append_value( - value.as_u64().expect("Unable to read number as u64"), - ) + b.append_value(value.as_u64().expect("Unable to read number as u64")) } else { panic!("Unable to parse value {value:?} as u64") } @@ -535,11 +524,10 @@ pub fn array_from_json( let months = months.as_i64().unwrap() as i32; let days = days.as_i64().unwrap() as i32; let nanoseconds = nanoseconds.as_i64().unwrap(); - let months_days_ns: i128 = ((nanoseconds as i128) - & 0xFFFFFFFFFFFFFFFF) - << 64 - | ((days as i128) & 0xFFFFFFFF) << 32 - | ((months as i128) & 0xFFFFFFFF); + let months_days_ns: i128 = + ((nanoseconds as i128) & 0xFFFFFFFFFFFFFFFF) << 64 + | ((days as i128) & 0xFFFFFFFF) << 32 + | ((months as i128) & 0xFFFFFFFF); months_days_ns } (_, _, _) => { @@ -678,11 +666,8 @@ pub fn array_from_json( DataType::List(child_field) => { let null_buf = create_null_buf(&json_col); let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = + array_from_json(child_field, children.get(0).unwrap().clone(), dictionaries)?; let offsets: Vec = json_col .offset .unwrap() @@ -702,11 +687,8 @@ pub fn array_from_json( DataType::LargeList(child_field) => { let null_buf = create_null_buf(&json_col); let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = + array_from_json(child_field, children.get(0).unwrap().clone(), dictionaries)?; let offsets: Vec = json_col .offset .unwrap() @@ -729,11 +711,8 @@ pub fn array_from_json( } DataType::FixedSizeList(child_field, _) => { let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = + array_from_json(child_field, children.get(0).unwrap().clone(), dictionaries)?; let null_buf = create_null_buf(&json_col); let list_data = ArrayData::builder(field.data_type().clone()) .len(json_col.count) @@ -760,9 +739,7 @@ pub fn array_from_json( } DataType::Dictionary(key_type, value_type) => { let dict_id = field.dict_id().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find dict_id for field {field:?}" - )) + ArrowError::JsonError(format!("Unable to find dict_id for field {field:?}")) })?; // find dictionary let dictionary = dictionaries @@ -823,8 +800,7 @@ pub fn array_from_json( } else { [255_u8; 32] }; - bytes[0..integer_bytes.len()] - .copy_from_slice(integer_bytes.as_slice()); + bytes[0..integer_bytes.len()].copy_from_slice(integer_bytes.as_slice()); b.append_value(i256::from_le_bytes(bytes)); } _ => b.append_null(), @@ -837,11 +813,8 @@ pub fn array_from_json( DataType::Map(child_field, _) => { let null_buf = create_null_buf(&json_col); let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = + array_from_json(child_field, children.get(0).unwrap().clone(), dictionaries)?; let offsets: Vec = json_col .offset .unwrap() @@ -946,9 +919,7 @@ pub fn dictionary_array_from_json( .unwrap(); let array = match dict_key { - DataType::Int8 => { - Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef - } + DataType::Int8 => Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef, DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), @@ -1099,11 +1070,7 @@ mod tests { Field::new("c3", DataType::Utf8, true), Field::new( "c4", - DataType::List(Arc::new(Field::new( - "custom_item", - DataType::Int32, - false, - ))), + DataType::List(Arc::new(Field::new("custom_item", DataType::Int32, false))), true, ), ]); @@ -1199,10 +1166,8 @@ mod tests { ), ]); - let bools_with_metadata_map = - BooleanArray::from(vec![Some(true), None, Some(false)]); - let bools_with_metadata_vec = - BooleanArray::from(vec![Some(true), None, Some(false)]); + let bools_with_metadata_map = BooleanArray::from(vec![Some(true), None, Some(false)]); + let bools_with_metadata_vec = BooleanArray::from(vec![Some(true), None, Some(false)]); let bools = BooleanArray::from(vec![Some(true), None, Some(false)]); let int8s = Int8Array::from(vec![Some(1), None, Some(3)]); let int16s = Int16Array::from(vec![Some(1), None, Some(3)]); @@ -1220,39 +1185,24 @@ mod tests { Some(29923997007884), Some(30612271819236), ]); - let time_secs = - Time32SecondArray::from(vec![Some(27974), Some(78592), Some(43207)]); - let time_millis = Time32MillisecondArray::from(vec![ - Some(6613125), - Some(74667230), - Some(52260079), - ]); - let time_micros = - Time64MicrosecondArray::from(vec![Some(62522958593), None, None]); - let time_nanos = Time64NanosecondArray::from(vec![ - Some(73380123595985), - None, - Some(16584393546415), - ]); + let time_secs = Time32SecondArray::from(vec![Some(27974), Some(78592), Some(43207)]); + let time_millis = + Time32MillisecondArray::from(vec![Some(6613125), Some(74667230), Some(52260079)]); + let time_micros = Time64MicrosecondArray::from(vec![Some(62522958593), None, None]); + let time_nanos = + Time64NanosecondArray::from(vec![Some(73380123595985), None, Some(16584393546415)]); let ts_secs = TimestampSecondArray::from(vec![None, Some(193438817552), None]); - let ts_millis = TimestampMillisecondArray::from(vec![ - None, - Some(38606916383008), - Some(58113709376587), - ]); + let ts_millis = + TimestampMillisecondArray::from(vec![None, Some(38606916383008), Some(58113709376587)]); let ts_micros = TimestampMicrosecondArray::from(vec![None, None, None]); - let ts_nanos = - TimestampNanosecondArray::from(vec![None, None, Some(-6473623571954960143)]); + let ts_nanos = TimestampNanosecondArray::from(vec![None, None, Some(-6473623571954960143)]); let ts_secs_tz = TimestampSecondArray::from(vec![None, Some(193438817552), None]) .with_timezone_opt(secs_tz); - let ts_millis_tz = TimestampMillisecondArray::from(vec![ - None, - Some(38606916383008), - Some(58113709376587), - ]) - .with_timezone_opt(millis_tz); - let ts_micros_tz = TimestampMicrosecondArray::from(vec![None, None, None]) - .with_timezone_opt(micros_tz); + let ts_millis_tz = + TimestampMillisecondArray::from(vec![None, Some(38606916383008), Some(58113709376587)]) + .with_timezone_opt(millis_tz); + let ts_micros_tz = + TimestampMicrosecondArray::from(vec![None, None, None]).with_timezone_opt(micros_tz); let ts_nanos_tz = TimestampNanosecondArray::from(vec![None, None, Some(-6473623571954960143)]) .with_timezone_opt(nanos_tz); @@ -1260,8 +1210,7 @@ mod tests { let value_data = Int32Array::from(vec![None, Some(2), None, None]); let value_offsets = Buffer::from_slice_ref([0, 3, 4, 4]); - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) diff --git a/arrow-integration-test/src/schema.rs b/arrow-integration-test/src/schema.rs index 6e143c2838d9..b5f6c5e86b38 100644 --- a/arrow-integration-test/src/schema.rs +++ b/arrow-integration-test/src/schema.rs @@ -65,11 +65,9 @@ fn from_metadata(json: &serde_json::Value) -> Result> { match json { Value::Array(_) => { let mut hashmap = HashMap::new(); - let values: Vec = serde_json::from_value(json.clone()) - .map_err(|_| { - ArrowError::JsonError( - "Unable to parse object into key-value pair".to_string(), - ) + let values: Vec = + serde_json::from_value(json.clone()).map_err(|_| { + ArrowError::JsonError("Unable to parse object into key-value pair".to_string()) })?; for meta in values { hashmap.insert(meta.key.clone(), meta.value); @@ -110,11 +108,10 @@ mod tests { #[test] fn schema_json() { // Add some custom metadata - let metadata: HashMap = - [("Key".to_string(), "Value".to_string())] - .iter() - .cloned() - .collect(); + let metadata: HashMap = [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); let schema = Schema::new_with_metadata( vec![ @@ -140,10 +137,7 @@ mod tests { ), Field::new( "c17", - DataType::Timestamp( - TimeUnit::Microsecond, - Some("Africa/Johannesburg".into()), - ), + DataType::Timestamp(TimeUnit::Microsecond, Some("Africa/Johannesburg".into())), false, ), Field::new( @@ -197,10 +191,7 @@ mod tests { Field::new("c32", DataType::Duration(TimeUnit::Nanosecond), false), Field::new_dict( "c33", - DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - ), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), true, 123, true, diff --git a/arrow-integration-testing/src/bin/arrow-json-integration-test.rs b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs index db5df8b58a6f..187d987a5a0a 100644 --- a/arrow-integration-testing/src/bin/arrow-json-integration-test.rs +++ b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs @@ -121,11 +121,8 @@ fn canonicalize_schema(schema: &Schema) -> Schema { DataType::Map(child_field, sorted) => match child_field.data_type() { DataType::Struct(fields) if fields.len() == 2 => { let first_field = fields.get(0).unwrap(); - let key_field = Arc::new(Field::new( - "key", - first_field.data_type().clone(), - false, - )); + let key_field = + Arc::new(Field::new("key", first_field.data_type().clone(), false)); let second_field = fields.get(1).unwrap(); let value_field = Arc::new(Field::new( "value", @@ -143,9 +140,7 @@ fn canonicalize_schema(schema: &Schema) -> Schema { field.is_nullable(), )) } - _ => panic!( - "The child field of Map type should be Struct type with 2 fields." - ), + _ => panic!("The child field of Map type should be Struct type with 2 fields."), }, _ => field.clone(), }) diff --git a/arrow-integration-testing/src/bin/flight-test-integration-client.rs b/arrow-integration-testing/src/bin/flight-test-integration-client.rs index d46b4fac759e..b8bbb952837b 100644 --- a/arrow-integration-testing/src/bin/flight-test-integration-client.rs +++ b/arrow-integration-testing/src/bin/flight-test-integration-client.rs @@ -62,8 +62,7 @@ async fn main() -> Result { } None => { let path = args.path.expect("No path is given"); - flight_client_scenarios::integration_test::run_scenario(&host, port, &path) - .await?; + flight_client_scenarios::integration_test::run_scenario(&host, port, &path).await?; } } diff --git a/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs b/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs index 9f66abf50106..376e31e15553 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs @@ -17,9 +17,7 @@ use crate::{AUTH_PASSWORD, AUTH_USERNAME}; -use arrow_flight::{ - flight_service_client::FlightServiceClient, BasicAuth, HandshakeRequest, -}; +use arrow_flight::{flight_service_client::FlightServiceClient, BasicAuth, HandshakeRequest}; use futures::{stream, StreamExt}; use prost::Message; use tonic::{metadata::MetadataValue, Request, Status}; @@ -78,11 +76,7 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { Ok(()) } -async fn authenticate( - client: &mut Client, - username: &str, - password: &str, -) -> Result { +async fn authenticate(client: &mut Client, username: &str, password: &str) -> Result { let auth = BasicAuth { username: username.into(), password: password.into(), diff --git a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index a55c2dec0580..81cc4bbe8ed2 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -27,8 +27,7 @@ use arrow::{ }; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, - utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, Location, - SchemaAsIpc, Ticket, + utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, Location, SchemaAsIpc, Ticket, }; use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; use tonic::{Request, Streaming}; @@ -203,19 +202,16 @@ async fn consume_flight_location( let mut dictionaries_by_id = HashMap::new(); for (counter, expected_batch) in expected_data.iter().enumerate() { - let data = receive_batch_flight_data( - &mut resp, - actual_schema.clone(), - &mut dictionaries_by_id, - ) - .await - .unwrap_or_else(|| { - panic!( - "Got fewer batches than expected, received so far: {} expected: {}", - counter, - expected_data.len(), - ) - }); + let data = + receive_batch_flight_data(&mut resp, actual_schema.clone(), &mut dictionaries_by_id) + .await + .unwrap_or_else(|| { + panic!( + "Got fewer batches than expected, received so far: {} expected: {}", + counter, + expected_data.len(), + ) + }); let metadata = counter.to_string().into_bytes(); assert_eq!(metadata, data.app_metadata); @@ -250,8 +246,8 @@ async fn consume_flight_location( async fn receive_schema_flight_data(resp: &mut Streaming) -> Option { let data = resp.next().await?.ok()?; - let message = arrow::ipc::root_as_message(&data.data_header[..]) - .expect("Error parsing message"); + let message = + arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message"); // message header is a Schema, so read it let ipc_schema: ipc::Schema = message @@ -268,8 +264,8 @@ async fn receive_batch_flight_data( dictionaries_by_id: &mut HashMap, ) -> Option { let mut data = resp.next().await?.ok()?; - let mut message = arrow::ipc::root_as_message(&data.data_header[..]) - .expect("Error parsing first message"); + let mut message = + arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing first message"); while message.header_type() == ipc::MessageHeader::DictionaryBatch { reader::read_dictionary( @@ -284,8 +280,8 @@ async fn receive_batch_flight_data( .expect("Error reading dictionary"); data = resp.next().await?.ok()?; - message = arrow::ipc::root_as_message(&data.data_header[..]) - .expect("Error parsing message"); + message = + arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message"); } Some(data) diff --git a/arrow-integration-testing/src/flight_client_scenarios/middleware.rs b/arrow-integration-testing/src/flight_client_scenarios/middleware.rs index 773919ff72af..3b71edf446a3 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/middleware.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/middleware.rs @@ -16,8 +16,7 @@ // under the License. use arrow_flight::{ - flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, - FlightDescriptor, + flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, FlightDescriptor, }; use prost::bytes::Bytes; use tonic::{Request, Status}; diff --git a/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs b/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs index 72d47b1391ee..ff4fc12f2523 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs @@ -19,15 +19,13 @@ use std::pin::Pin; use std::sync::Arc; use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, - Action, ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, - FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, + ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, }; use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; use tokio::sync::Mutex; -use tonic::{ - metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming, -}; +use tonic::{metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming}; type TonicStream = Pin + Send + Sync + 'static>>; type Error = Box; @@ -63,10 +61,7 @@ pub struct AuthBasicProtoScenarioImpl { } impl AuthBasicProtoScenarioImpl { - async fn check_auth( - &self, - metadata: &MetadataMap, - ) -> Result { + async fn check_auth(&self, metadata: &MetadataMap) -> Result { let token = metadata .get_bin("auth-token-bin") .and_then(|v| v.to_bytes().ok()) @@ -74,10 +69,7 @@ impl AuthBasicProtoScenarioImpl { self.is_valid(token).await } - async fn is_valid( - &self, - token: Option, - ) -> Result { + async fn is_valid(&self, token: Option) -> Result { match token { Some(t) if t == *self.username => Ok(GrpcServerCallContext { peer_identity: self.username.to_string(), @@ -142,12 +134,10 @@ impl FlightService for AuthBasicProtoScenarioImpl { let req = req.expect("Error reading handshake request"); let HandshakeRequest { payload, .. } = req; - let auth = BasicAuth::decode(&*payload) - .expect("Error parsing handshake request"); + let auth = + BasicAuth::decode(&*payload).expect("Error parsing handshake request"); - let resp = if *auth.username == *username - && *auth.password == *password - { + let resp = if *auth.username == *username && *auth.password == *password { Ok(HandshakeResponse { payload: username.as_bytes().to_vec().into(), ..HandshakeResponse::default() diff --git a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs index e2c4cb5d88f3..2011031e921a 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs @@ -30,9 +30,9 @@ use arrow::{ }; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_server::FlightService, - flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, - FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, - HandshakeResponse, IpcMessage, PutResult, SchemaAsIpc, SchemaResult, Ticket, + flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData, + FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, IpcMessage, + PutResult, SchemaAsIpc, SchemaResult, Ticket, }; use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; use std::convert::TryInto; @@ -113,8 +113,7 @@ impl FlightService for FlightServiceImpl { let options = arrow::ipc::writer::IpcWriteOptions::default(); - let schema = - std::iter::once(Ok(SchemaAsIpc::new(&flight.schema, &options).into())); + let schema = std::iter::once(Ok(SchemaAsIpc::new(&flight.schema, &options).into())); let batches = flight .chunks @@ -126,12 +125,9 @@ impl FlightService for FlightServiceImpl { let (encoded_dictionaries, encoded_batch) = data_gen .encoded_batch(batch, &mut dictionary_tracker, &options) - .expect( - "DictionaryTracker configured above to not error on replacement", - ); + .expect("DictionaryTracker configured above to not error on replacement"); - let dictionary_flight_data = - encoded_dictionaries.into_iter().map(Into::into); + let dictionary_flight_data = encoded_dictionaries.into_iter().map(Into::into); let mut batch_flight_data: FlightData = encoded_batch.into(); // Only the record batch's FlightData gets app_metadata @@ -182,8 +178,7 @@ impl FlightService for FlightServiceImpl { let endpoint = self.endpoint_from_path(&path[0]); - let total_records: usize = - flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); + let total_records: usize = flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); let options = arrow::ipc::writer::IpcWriteOptions::default(); let message = SchemaAsIpc::new(&flight.schema, &options) @@ -224,8 +219,7 @@ impl FlightService for FlightServiceImpl { .clone() .ok_or_else(|| Status::invalid_argument("Must have a descriptor"))?; - if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() - { + if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() { return Err(Status::invalid_argument("Must specify a path")); } @@ -297,9 +291,9 @@ async fn record_batch_from_message( schema_ref: SchemaRef, dictionaries_by_id: &HashMap, ) -> Result { - let ipc_batch = message.header_as_record_batch().ok_or_else(|| { - Status::internal("Could not parse message header as record batch") - })?; + let ipc_batch = message + .header_as_record_batch() + .ok_or_else(|| Status::internal("Could not parse message header as record batch"))?; let arrow_batch_result = reader::read_record_batch( data_body, @@ -320,9 +314,9 @@ async fn dictionary_from_message( schema_ref: SchemaRef, dictionaries_by_id: &mut HashMap, ) -> Result<(), Status> { - let ipc_batch = message.header_as_dictionary_batch().ok_or_else(|| { - Status::internal("Could not parse message header as dictionary batch") - })?; + let ipc_batch = message + .header_as_dictionary_batch() + .ok_or_else(|| Status::internal("Could not parse message header as dictionary batch"))?; let dictionary_batch_result = reader::read_dictionary( data_body, diff --git a/arrow-integration-testing/src/flight_server_scenarios/middleware.rs b/arrow-integration-testing/src/flight_server_scenarios/middleware.rs index 9b1c84b57119..68d871b528a6 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/middleware.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/middleware.rs @@ -19,9 +19,9 @@ use std::pin::Pin; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_server::FlightService, - flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, - FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, - PutResult, SchemaResult, Ticket, + flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData, + FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, + Ticket, }; use futures::Stream; use tonic::{transport::Server, Request, Response, Status, Streaming}; @@ -93,8 +93,7 @@ impl FlightService for MiddlewareScenarioImpl { let descriptor = request.into_inner(); - if descriptor.r#type == DescriptorType::Cmd as i32 - && descriptor.cmd.as_ref() == b"success" + if descriptor.r#type == DescriptorType::Cmd as i32 && descriptor.cmd.as_ref() == b"success" { // Return a fake location - the test doesn't read it let endpoint = super::endpoint("foo", "grpc+tcp://localhost:10010"); diff --git a/arrow-integration-testing/src/lib.rs b/arrow-integration-testing/src/lib.rs index fe0cc68a4205..2d76be3495c8 100644 --- a/arrow-integration-testing/src/lib.rs +++ b/arrow-integration-testing/src/lib.rs @@ -56,8 +56,8 @@ pub fn read_json_file(json_name: &str) -> Result { .as_array() .expect("Unable to get dictionaries as array") { - let json_dict: ArrowJsonDictionaryBatch = serde_json::from_value(d.clone()) - .expect("Unable to get dictionary from JSON"); + let json_dict: ArrowJsonDictionaryBatch = + serde_json::from_value(d.clone()).expect("Unable to get dictionary from JSON"); // TODO: convert to a concrete Arrow type dictionaries.insert(json_dict.id, json_dict); } diff --git a/arrow-integration-testing/tests/ipc_reader.rs b/arrow-integration-testing/tests/ipc_reader.rs index 696ab6e6053a..11b8fa84534e 100644 --- a/arrow-integration-testing/tests/ipc_reader.rs +++ b/arrow-integration-testing/tests/ipc_reader.rs @@ -63,9 +63,7 @@ fn read_1_0_0_bigendian_decimal_should_panic() { } #[test] -#[should_panic( - expected = "Last offset 687865856 of Utf8 is larger than values length 41" -)] +#[should_panic(expected = "Last offset 687865856 of Utf8 is larger than values length 41")] fn read_1_0_0_bigendian_dictionary_should_panic() { // The offsets are not translated for big-endian files // https://github.com/apache/arrow-rs/issues/859 @@ -160,8 +158,7 @@ fn read_2_0_0_compression() { /// Verification json file /// `arrow-ipc-stream/integration//.json.gz fn verify_arrow_file(testdata: &str, version: &str, path: &str) { - let filename = - format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.arrow_file"); + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.arrow_file"); println!("Verifying {filename}"); // Compare contents to the expected output format in JSON @@ -197,8 +194,7 @@ fn verify_arrow_file(testdata: &str, version: &str, path: &str) { /// Verification json file /// `arrow-ipc-stream/integration//.json.gz fn verify_arrow_stream(testdata: &str, version: &str, path: &str) { - let filename = - format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.stream"); + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.stream"); println!("Verifying {filename}"); // Compare contents to the expected output format in JSON diff --git a/arrow-integration-testing/tests/ipc_writer.rs b/arrow-integration-testing/tests/ipc_writer.rs index 11707d935540..d780eb2ee0b5 100644 --- a/arrow-integration-testing/tests/ipc_writer.rs +++ b/arrow-integration-testing/tests/ipc_writer.rs @@ -113,12 +113,7 @@ fn write_2_0_0_compression() { for options in &all_options { println!("Using options {options:?}"); roundtrip_arrow_file_with_options(&testdata, version, path, options.clone()); - roundtrip_arrow_stream_with_options( - &testdata, - version, - path, - options.clone(), - ); + roundtrip_arrow_stream_with_options(&testdata, version, path, options.clone()); } }); } @@ -143,8 +138,7 @@ fn roundtrip_arrow_file_with_options( path: &str, options: IpcWriteOptions, ) { - let filename = - format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.arrow_file"); + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.arrow_file"); println!("Verifying {filename}"); let mut tempfile = tempfile::tempfile().unwrap(); @@ -156,12 +150,8 @@ fn roundtrip_arrow_file_with_options( // read and rewrite the file to a temp location { - let mut writer = FileWriter::try_new_with_options( - &mut tempfile, - &reader.schema(), - options, - ) - .unwrap(); + let mut writer = + FileWriter::try_new_with_options(&mut tempfile, &reader.schema(), options).unwrap(); while let Some(Ok(batch)) = reader.next() { writer.write(&batch).unwrap(); } @@ -207,12 +197,7 @@ fn roundtrip_arrow_file_with_options( /// Verification json file /// `arrow-ipc-stream/integration//.json.gz fn roundtrip_arrow_stream(testdata: &str, version: &str, path: &str) { - roundtrip_arrow_stream_with_options( - testdata, - version, - path, - IpcWriteOptions::default(), - ) + roundtrip_arrow_stream_with_options(testdata, version, path, IpcWriteOptions::default()) } fn roundtrip_arrow_stream_with_options( @@ -221,8 +206,7 @@ fn roundtrip_arrow_stream_with_options( path: &str, options: IpcWriteOptions, ) { - let filename = - format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.stream"); + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.stream"); println!("Verifying {filename}"); let mut tempfile = tempfile::tempfile().unwrap(); @@ -234,12 +218,9 @@ fn roundtrip_arrow_stream_with_options( // read and rewrite the file to a temp location { - let mut writer = StreamWriter::try_new_with_options( - &mut tempfile, - &reader.schema(), - options, - ) - .unwrap(); + let mut writer = + StreamWriter::try_new_with_options(&mut tempfile, &reader.schema(), options) + .unwrap(); while let Some(Ok(batch)) = reader.next() { writer.write(&batch).unwrap(); } diff --git a/arrow-ipc/src/compression.rs b/arrow-ipc/src/compression.rs index fafc2c5c9b6d..0d8b7b4c1bd4 100644 --- a/arrow-ipc/src/compression.rs +++ b/arrow-ipc/src/compression.rs @@ -90,10 +90,7 @@ impl CompressionCodec { /// [8 bytes]: uncompressed length /// [remaining bytes]: compressed data stream /// ``` - pub(crate) fn decompress_to_buffer( - &self, - input: &Buffer, - ) -> Result { + pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result { // read the first 8 bytes to determine if the data is // compressed let decompressed_length = read_uncompressed_size(input); @@ -127,11 +124,7 @@ impl CompressionCodec { /// Decompress the data in input buffer and write to output buffer /// using the specified compression - fn decompress( - &self, - input: &[u8], - decompressed_size: usize, - ) -> Result, ArrowError> { + fn decompress(&self, input: &[u8], decompressed_size: usize) -> Result, ArrowError> { let ret = match self { CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?, CompressionCodec::Zstd => decompress_zstd(input, decompressed_size)?, @@ -175,10 +168,7 @@ fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result, Arr #[cfg(not(feature = "lz4"))] #[allow(clippy::ptr_arg)] -fn decompress_lz4( - _input: &[u8], - _decompressed_size: usize, -) -> Result, ArrowError> { +fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result, ArrowError> { Err(ArrowError::InvalidArgumentError( "lz4 IPC decompression requires the lz4 feature".to_string(), )) @@ -202,10 +192,7 @@ fn compress_zstd(_input: &[u8], _output: &mut Vec) -> Result<(), ArrowError> } #[cfg(feature = "zstd")] -fn decompress_zstd( - input: &[u8], - decompressed_size: usize, -) -> Result, ArrowError> { +fn decompress_zstd(input: &[u8], decompressed_size: usize) -> Result, ArrowError> { use std::io::Read; let mut output = Vec::with_capacity(decompressed_size); zstd::Decoder::with_buffer(input)?.read_to_end(&mut output)?; @@ -214,10 +201,7 @@ fn decompress_zstd( #[cfg(not(feature = "zstd"))] #[allow(clippy::ptr_arg)] -fn decompress_zstd( - _input: &[u8], - _decompressed_size: usize, -) -> Result, ArrowError> { +fn decompress_zstd(_input: &[u8], _decompressed_size: usize) -> Result, ArrowError> { Err(ArrowError::InvalidArgumentError( "zstd IPC decompression requires the zstd feature".to_string(), )) diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index a78ccde6e169..b290a09acf5d 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -18,9 +18,7 @@ //! Utilities for converting between IPC types and native Arrow types use arrow_schema::*; -use flatbuffers::{ - FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, WIPOffset, -}; +use flatbuffers::{FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, WIPOffset}; use std::collections::HashMap; use std::sync::Arc; @@ -186,16 +184,11 @@ pub fn try_schema_from_ipc_buffer(buffer: &[u8]) -> Result { // buffer 0 }; - let msg = - size_prefixed_root_as_message(&buffer[begin_offset..]).map_err(|err| { - ArrowError::ParseError(format!( - "Unable to convert flight info to a message: {err}" - )) - })?; + let msg = size_prefixed_root_as_message(&buffer[begin_offset..]).map_err(|err| { + ArrowError::ParseError(format!("Unable to convert flight info to a message: {err}")) + })?; let ipc_schema = msg.header_as_schema().ok_or_else(|| { - ArrowError::ParseError( - "Unable to convert flight info to a schema".to_string(), - ) + ArrowError::ParseError("Unable to convert flight info to a schema".to_string()) })?; Ok(fb_to_schema(ipc_schema)) } else { @@ -277,15 +270,9 @@ pub(crate) fn get_data_type(field: crate::Field, may_be_dictionary: bool) -> Dat let time = field.type_as_time().unwrap(); match (time.bitWidth(), time.unit()) { (32, crate::TimeUnit::SECOND) => DataType::Time32(TimeUnit::Second), - (32, crate::TimeUnit::MILLISECOND) => { - DataType::Time32(TimeUnit::Millisecond) - } - (64, crate::TimeUnit::MICROSECOND) => { - DataType::Time64(TimeUnit::Microsecond) - } - (64, crate::TimeUnit::NANOSECOND) => { - DataType::Time64(TimeUnit::Nanosecond) - } + (32, crate::TimeUnit::MILLISECOND) => DataType::Time32(TimeUnit::Millisecond), + (64, crate::TimeUnit::MICROSECOND) => DataType::Time64(TimeUnit::Microsecond), + (64, crate::TimeUnit::NANOSECOND) => DataType::Time64(TimeUnit::Nanosecond), z => panic!( "Time type with bit width of {} and unit of {:?} not supported", z.0, z.1 @@ -296,30 +283,22 @@ pub(crate) fn get_data_type(field: crate::Field, may_be_dictionary: bool) -> Dat let timestamp = field.type_as_timestamp().unwrap(); let timezone: Option<_> = timestamp.timezone().map(|tz| tz.into()); match timestamp.unit() { - crate::TimeUnit::SECOND => { - DataType::Timestamp(TimeUnit::Second, timezone) - } + crate::TimeUnit::SECOND => DataType::Timestamp(TimeUnit::Second, timezone), crate::TimeUnit::MILLISECOND => { DataType::Timestamp(TimeUnit::Millisecond, timezone) } crate::TimeUnit::MICROSECOND => { DataType::Timestamp(TimeUnit::Microsecond, timezone) } - crate::TimeUnit::NANOSECOND => { - DataType::Timestamp(TimeUnit::Nanosecond, timezone) - } + crate::TimeUnit::NANOSECOND => DataType::Timestamp(TimeUnit::Nanosecond, timezone), z => panic!("Timestamp type with unit of {z:?} not supported"), } } crate::Type::Interval => { let interval = field.type_as_interval().unwrap(); match interval.unit() { - crate::IntervalUnit::YEAR_MONTH => { - DataType::Interval(IntervalUnit::YearMonth) - } - crate::IntervalUnit::DAY_TIME => { - DataType::Interval(IntervalUnit::DayTime) - } + crate::IntervalUnit::YEAR_MONTH => DataType::Interval(IntervalUnit::YearMonth), + crate::IntervalUnit::DAY_TIME => DataType::Interval(IntervalUnit::DayTime), crate::IntervalUnit::MONTH_DAY_NANO => { DataType::Interval(IntervalUnit::MonthDayNano) } @@ -775,8 +754,8 @@ pub(crate) fn get_fb_field_type<'a>( UnionMode::Dense => crate::UnionMode::Dense, }; - let fbb_type_ids = fbb - .create_vector(&fields.iter().map(|(t, _)| t as i32).collect::>()); + let fbb_type_ids = + fbb.create_vector(&fields.iter().map(|(t, _)| t as i32).collect::>()); let mut builder = crate::UnionBuilder::new(fbb); builder.add_mode(union_mode); builder.add_typeIds(fbb_type_ids); @@ -872,10 +851,7 @@ mod tests { ), Field::new( "timestamp[us]", - DataType::Timestamp( - TimeUnit::Microsecond, - Some("Africa/Johannesburg".into()), - ), + DataType::Timestamp(TimeUnit::Microsecond, Some("Africa/Johannesburg".into())), false, ), Field::new( @@ -900,11 +876,7 @@ mod tests { ), Field::new("utf8", DataType::Utf8, false), Field::new("binary", DataType::Binary, false), - Field::new_list( - "list[u8]", - Field::new("item", DataType::UInt8, false), - true, - ), + Field::new_list("list[u8]", Field::new("item", DataType::UInt8, false), true), Field::new_list( "list[struct]", Field::new_struct( @@ -1013,20 +985,14 @@ mod tests { ), Field::new_dict( "dictionary", - DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - ), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), true, 123, true, ), Field::new_dict( "dictionary", - DataType::Dictionary( - Box::new(DataType::UInt8), - Box::new(DataType::UInt32), - ), + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt32)), true, 123, true, @@ -1056,20 +1022,18 @@ mod tests { // # stripping continuation & length prefix & suffix bytes to get only schema bytes // [x for x in sink.getvalue().to_pybytes()][8:-8] let bytes: Vec = vec![ - 16, 0, 0, 0, 0, 0, 10, 0, 12, 0, 6, 0, 5, 0, 8, 0, 10, 0, 0, 0, 0, 1, 4, 0, - 12, 0, 0, 0, 8, 0, 8, 0, 0, 0, 4, 0, 8, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 20, - 0, 0, 0, 16, 0, 20, 0, 8, 0, 0, 0, 7, 0, 12, 0, 0, 0, 16, 0, 16, 0, 0, 0, 0, - 0, 0, 2, 16, 0, 0, 0, 32, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 102, - 105, 101, 108, 100, 49, 0, 0, 0, 0, 6, 0, 8, 0, 4, 0, 6, 0, 0, 0, 32, 0, 0, - 0, + 16, 0, 0, 0, 0, 0, 10, 0, 12, 0, 6, 0, 5, 0, 8, 0, 10, 0, 0, 0, 0, 1, 4, 0, 12, 0, 0, + 0, 8, 0, 8, 0, 0, 0, 4, 0, 8, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 20, 0, 0, 0, 16, 0, 20, + 0, 8, 0, 0, 0, 7, 0, 12, 0, 0, 0, 16, 0, 16, 0, 0, 0, 0, 0, 0, 2, 16, 0, 0, 0, 32, 0, + 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 102, 105, 101, 108, 100, 49, 0, 0, 0, 0, 6, + 0, 8, 0, 4, 0, 6, 0, 0, 0, 32, 0, 0, 0, ]; let ipc = crate::root_as_message(&bytes).unwrap(); let schema = ipc.header_as_schema().unwrap(); // generate same message with Rust let data_gen = crate::writer::IpcDataGenerator::default(); - let arrow_schema = - Schema::new(vec![Field::new("field1", DataType::UInt32, false)]); + let arrow_schema = Schema::new(vec![Field::new("field1", DataType::UInt32, false)]); let bytes = data_gen .schema_to_bytes(&arrow_schema, &crate::writer::IpcWriteOptions::default()) .ipc_message; diff --git a/arrow-ipc/src/gen/File.rs b/arrow-ipc/src/gen/File.rs index 0e9427813788..c0c2fb183237 100644 --- a/arrow-ipc/src/gen/File.rs +++ b/arrow-ipc/src/gen/File.rs @@ -61,10 +61,7 @@ impl<'b> flatbuffers::Push for Block { type Output = Block; #[inline] unsafe fn push(&self, dst: &mut [u8], _written_len: usize) { - let src = ::core::slice::from_raw_parts( - self as *const Block as *const u8, - Self::size(), - ); + let src = ::core::slice::from_raw_parts(self as *const Block as *const u8, Self::size()); dst.copy_from_slice(src); } } @@ -307,11 +304,7 @@ impl flatbuffers::Verifiable for Footer<'_> { use flatbuffers::Verifiable; v.visit_table(pos)? .visit_field::("version", Self::VT_VERSION, false)? - .visit_field::>( - "schema", - Self::VT_SCHEMA, - false, - )? + .visit_field::>("schema", Self::VT_SCHEMA, false)? .visit_field::>>( "dictionaries", Self::VT_DICTIONARIES, @@ -335,9 +328,7 @@ pub struct FooterArgs<'a> { pub dictionaries: Option>>, pub recordBatches: Option>>, pub custom_metadata: Option< - flatbuffers::WIPOffset< - flatbuffers::Vector<'a, flatbuffers::ForwardsUOffset>>, - >, + flatbuffers::WIPOffset>>>, >, } impl<'a> Default for FooterArgs<'a> { @@ -360,39 +351,29 @@ pub struct FooterBuilder<'a: 'b, 'b> { impl<'a: 'b, 'b> FooterBuilder<'a, 'b> { #[inline] pub fn add_version(&mut self, version: MetadataVersion) { - self.fbb_.push_slot::( - Footer::VT_VERSION, - version, - MetadataVersion::V1, - ); + self.fbb_ + .push_slot::(Footer::VT_VERSION, version, MetadataVersion::V1); } #[inline] pub fn add_schema(&mut self, schema: flatbuffers::WIPOffset>) { self.fbb_ - .push_slot_always::>( - Footer::VT_SCHEMA, - schema, - ); + .push_slot_always::>(Footer::VT_SCHEMA, schema); } #[inline] pub fn add_dictionaries( &mut self, dictionaries: flatbuffers::WIPOffset>, ) { - self.fbb_.push_slot_always::>( - Footer::VT_DICTIONARIES, - dictionaries, - ); + self.fbb_ + .push_slot_always::>(Footer::VT_DICTIONARIES, dictionaries); } #[inline] pub fn add_recordBatches( &mut self, recordBatches: flatbuffers::WIPOffset>, ) { - self.fbb_.push_slot_always::>( - Footer::VT_RECORDBATCHES, - recordBatches, - ); + self.fbb_ + .push_slot_always::>(Footer::VT_RECORDBATCHES, recordBatches); } #[inline] pub fn add_custom_metadata( @@ -407,9 +388,7 @@ impl<'a: 'b, 'b> FooterBuilder<'a, 'b> { ); } #[inline] - pub fn new( - _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>, - ) -> FooterBuilder<'a, 'b> { + pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>) -> FooterBuilder<'a, 'b> { let start = _fbb.start_table(); FooterBuilder { fbb_: _fbb, @@ -451,9 +430,7 @@ pub fn root_as_footer(buf: &[u8]) -> Result Result { +pub fn size_prefixed_root_as_footer(buf: &[u8]) -> Result { flatbuffers::size_prefixed_root::