diff --git a/go.mod b/go.mod index b5e35d0..d95e908 100644 --- a/go.mod +++ b/go.mod @@ -58,4 +58,4 @@ require ( google.golang.org/grpc v1.46.0 // indirect ) -replace github.com/flyteorg/stow => github.com/ddl-rliu/stow v0.0.16 +replace github.com/flyteorg/stow => github.com/ddl-rliu/stow v0.0.17 diff --git a/s3/config.go b/s3/config.go index b0e67c2..88b9c50 100644 --- a/s3/config.go +++ b/s3/config.go @@ -51,7 +51,7 @@ const ( // This feature is useful for s3-compatible blob stores -- ie minio. ConfigV2Signing = "v2_signing" - // ConfigV2Signing is an optional config value for extra arguments passed to S3 upload, + // ConfigExtraArgs is an optional config value for extra arguments passed to S3 upload, // a string representing a JSON object of key/value pairs // This feature is useful for setting server-side encryption headers. ConfigExtraArgs = "extra_args" diff --git a/s3/container.go b/s3/container.go index 6a11e4d..0dd0946 100644 --- a/s3/container.go +++ b/s3/container.go @@ -56,22 +56,19 @@ func (c *container) PreSignRequest(ctx context.Context, clientMethod stow.Client Key: aws.String(id), ContentMD5: contentMD5, } - log.Printf("bucket: %s // %s", c.name, id) - log.Printf("extra args: %s", c.extraArgs) // First, try to set SSE using stow.config var extraArgs S3ExtraArgs json.Unmarshal([]byte(c.extraArgs), &extraArgs) - log.Printf("extra args: %s // %s", extraArgs.ServerSideEncryption, extraArgs.SSEKMSKeyId) if extraArgs.ServerSideEncryption == "" { // As backup, try to set SSE using s3.GetBucketEncryption if bucketEncrypted, sseAlgortihm, encryptionKey := getKmsMasterKeyId(c.client, c.name); bucketEncrypted { - log.Printf("sse: %s // %s", sseAlgortihm, encryptionKey) extraArgs.ServerSideEncryption, extraArgs.SSEKMSKeyId = sseAlgortihm, encryptionKey } } + // SSE info goes in headers, so that a valid signature is generated switch extraArgs.ServerSideEncryption { case s3.ServerSideEncryptionAes256: params.ServerSideEncryption = aws.String(extraArgs.ServerSideEncryption) @@ -83,6 +80,20 @@ func (c *container) PreSignRequest(ctx context.Context, clientMethod stow.Client } req, _ = c.client.PutObjectRequest(params) + q := req.HTTPRequest.URL.Query() + + // SSE info also goes in query string, so that the pre-signed URL is self-contained + // i.e. works without headers + switch extraArgs.ServerSideEncryption { + case s3.ServerSideEncryptionAes256: + q.Add("x-amz-server-side-encryption", extraArgs.ServerSideEncryption) + case s3.ServerSideEncryptionAwsKms: + q.Add("x-amz-server-side-encryption", extraArgs.ServerSideEncryption) + if extraArgs.SSEKMSKeyId != "" { + q.Add("x-amz-server-side-encryption-aws-kms-key-id", extraArgs.SSEKMSKeyId) + } + } + req.HTTPRequest.URL.RawQuery = q.Encode() default: return "", fmt.Errorf("unsupported client method [%v]", clientMethod.String()) } diff --git a/s3/stow_test.go b/s3/stow_test.go index 6e7a670..39f397b 100644 --- a/s3/stow_test.go +++ b/s3/stow_test.go @@ -74,6 +74,42 @@ func TestPreSignedURL(t *testing.T) { assert.NotEmpty(t, res) } +func TestPreSignedURLSSEBucket(t *testing.T) { + is := is.New(t) + accessKeyId := os.Getenv("S3ACCESSKEYID") + secretKey := os.Getenv("S3SECRETKEY") + region := os.Getenv("S3REGION") + + if accessKeyId == "" || secretKey == "" || region == "" { + t.Skip("skipping test because missing one or more of S3ACCESSKEYID S3SECRETKEY S3REGION") + } + + config := stow.ConfigMap{ + "access_key_id": accessKeyId, + "secret_key": secretKey, + "region": region, + "extra_args": "{\"ServerSideEncryption\": \"aws:kms\", \"SSEKMSKeyId\": \"kmsId\"}", + } + + location, err := stow.Dial("s3", config) + is.NoErr(err) + + container, err := location.Container("flyte-demo") + if err != nil { + t.Skip(err) + } + ctx := context.Background() + res, err := container.PreSignRequest(ctx, stow.ClientMethodPut, "blah/bloh/fileon", stow.PresignRequestParams{ + ExpiresIn: time.Hour, + }) + + is.NoErr(err) + t.Log(res) + assert.NotEmpty(t, res) + assert.Contains(t, res, "x-amz-server-side-encryption=aws%3Akms") + assert.Contains(t, res, "x-amz-server-side-encryption-aws-kms-key-id=kmsId") +} + func TestEtagCleanup(t *testing.T) { etagValue := "9c51403a2255f766891a1382288dece4" permutations := []string{