From 3c3530849ddc13ccb69feade4cd9cef9523e4d97 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Tue, 3 Dec 2024 10:10:59 -0700 Subject: [PATCH] squash --- .github/workflows/object_store.yml | 1 + object_store/src/aws/client.rs | 32 ++++++++++++---- object_store/src/aws/mod.rs | 60 ++++++++++++++++++++++++++++++ object_store/src/client/s3.rs | 27 ++++++++++++-- 4 files changed, 109 insertions(+), 11 deletions(-) diff --git a/.github/workflows/object_store.yml b/.github/workflows/object_store.yml index bdbfc0bec4bb..2e84bf050f0e 100644 --- a/.github/workflows/object_store.yml +++ b/.github/workflows/object_store.yml @@ -141,6 +141,7 @@ jobs: echo "LOCALSTACK_CONTAINER=$(docker run -d -p 4566:4566 localstack/localstack:3.3.0)" >> $GITHUB_ENV echo "EC2_METADATA_CONTAINER=$(docker run -d -p 1338:1338 amazon/amazon-ec2-metadata-mock:v1.9.2 --imdsv2)" >> $GITHUB_ENV aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket + aws --endpoint-url=http://localhost:4566 s3api create-bucket --bucket test-object-lock --object-lock-enabled-for-bucket aws --endpoint-url=http://localhost:4566 dynamodb create-table --table-name test-table --key-schema AttributeName=path,KeyType=HASH AttributeName=etag,KeyType=RANGE --attribute-definitions AttributeName=path,AttributeType=S AttributeName=etag,AttributeType=S --provisioned-throughput ReadCapacityUnits=5,WriteCapacityUnits=5 KMS_KEY=$(aws --endpoint-url=http://localhost:4566 kms create-key --description "test key") diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index ab4da86f504b..faa9d36fd6a9 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -29,7 +29,7 @@ use crate::client::list::ListClient; use crate::client::retry::RetryExt; use crate::client::s3::{ CompleteMultipartUpload, CompleteMultipartUploadResult, InitiateMultipartUploadResult, - ListResponse, + ListResponse, PartMetadata, }; use crate::client::GetOptionsExt; use crate::multipart::PartId; @@ -62,6 +62,7 @@ use std::sync::Arc; const VERSION_HEADER: &str = "x-amz-version-id"; const SHA256_CHECKSUM: &str = "x-amz-checksum-sha256"; const USER_DEFINED_METADATA_HEADER_PREFIX: &str = "x-amz-meta-"; +const ALGORITHM: &str = "x-amz-checksum-algorithm"; /// A specialized `Error` for object store-related errors #[derive(Debug, Snafu)] @@ -349,10 +350,9 @@ impl<'a> Request<'a> { let payload_sha256 = sha256.finish(); if let Some(Checksum::SHA256) = self.config.checksum { - self.builder = self.builder.header( - "x-amz-checksum-sha256", - BASE64_STANDARD.encode(payload_sha256), - ); + self.builder = self + .builder + .header(SHA256_CHECKSUM, BASE64_STANDARD.encode(payload_sha256)); } self.payload_sha256 = Some(payload_sha256); } @@ -534,8 +534,11 @@ impl S3Client { location: &Path, opts: PutMultipartOpts, ) -> Result { - let response = self - .request(Method::POST, location) + let mut reqquest = self.request(Method::POST, location); + if let Some(algorithm) = self.config.checksum { + reqquest = reqquest.header(ALGORITHM, &algorithm.to_string().to_uppercase()); + } + let response = reqquest .query(&[("uploads", "")]) .with_encryption_headers() .with_attributes(opts.attributes) @@ -569,8 +572,21 @@ impl S3Client { .idempotent(true) .send() .await?; + let checksum = response + .headers() + .get(SHA256_CHECKSUM) + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()); + + let e_tag = get_etag(response.headers()).context(MetadataSnafu)?; + + let content_id = if self.config.checksum == Some(Checksum::SHA256) { + let meta = PartMetadata { e_tag, checksum }; + quick_xml::se::to_string(&meta).unwrap() + } else { + e_tag + }; - let content_id = get_etag(response.headers()).context(MetadataSnafu)?; Ok(PartId { content_id }) } diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index f5204a5365ed..caf42150cd29 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -416,6 +416,66 @@ mod tests { const NON_EXISTENT_NAME: &str = "nonexistentname"; + #[tokio::test] + async fn write_multipart_file_with_signature() { + maybe_skip_integration!(); + + let store = AmazonS3Builder::from_env() + .with_checksum_algorithm(Checksum::SHA256) + .build() + .unwrap(); + + let str = "test.bin"; + let path = Path::parse(str).unwrap(); + let opts = PutMultipartOpts::default(); + let mut upload = store.put_multipart_opts(&path, opts).await.unwrap(); + + upload + .put_part(PutPayload::from(vec![0u8; 10_000_000])) + .await + .unwrap(); + upload + .put_part(PutPayload::from(vec![0u8; 5_000_000])) + .await + .unwrap(); + + let res = upload.complete().await.unwrap(); + assert!(res.e_tag.is_some(), "Should have valid etag"); + + store.delete(&path).await.unwrap(); + } + + #[tokio::test] + async fn write_multipart_file_with_signature_object_lock() { + maybe_skip_integration!(); + + let bucket = "test-object-lock"; + let store = AmazonS3Builder::from_env() + .with_bucket_name(bucket) + .with_checksum_algorithm(Checksum::SHA256) + .build() + .unwrap(); + + let str = "test.bin"; + let path = Path::parse(str).unwrap(); + let opts = PutMultipartOpts::default(); + let mut upload = store.put_multipart_opts(&path, opts).await.unwrap(); + + upload + .put_part(PutPayload::from(vec![0u8; 10_000_000])) + .await + .unwrap(); + upload + .put_part(PutPayload::from(vec![0u8; 5_000_000])) + .await + .unwrap(); + + let res = upload.complete().await.unwrap(); + assert!(res.e_tag.is_some(), "Should have valid etag"); + + store.delete(&path).await.unwrap(); + } + #[tokio::test] async fn s3_test() { maybe_skip_integration!(); diff --git a/object_store/src/client/s3.rs b/object_store/src/client/s3.rs index 61237dc4beab..e091086aabb3 100644 --- a/object_store/src/client/s3.rs +++ b/object_store/src/client/s3.rs @@ -98,14 +98,32 @@ pub struct CompleteMultipartUpload { pub part: Vec, } +#[derive(Serialize, Deserialize)] +pub(crate) struct PartMetadata { + pub e_tag: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub checksum: Option, +} + impl From> for CompleteMultipartUpload { fn from(value: Vec) -> Self { let part = value .into_iter() .enumerate() - .map(|(part_number, part)| MultipartPart { - e_tag: part.content_id, - part_number: part_number + 1, + .map(|(part_idx, part)| { + let md = match quick_xml::de::from_str::(&part.content_id) { + Ok(md) => md, + // fallback to old way + Err(_) => PartMetadata { + e_tag: part.content_id.clone(), + checksum: None, + }, + }; + MultipartPart { + e_tag: md.e_tag, + part_number: part_idx + 1, + checksum_sha256: md.checksum, + } }) .collect(); Self { part } @@ -118,6 +136,9 @@ pub struct MultipartPart { pub e_tag: String, #[serde(rename = "PartNumber")] pub part_number: usize, + #[serde(rename = "ChecksumSHA256")] + #[serde(skip_serializing_if = "Option::is_none")] + pub checksum_sha256: Option, } #[derive(Debug, Deserialize)]