Skip to content

Commit

Permalink
wrappers/transit: support context cancelation (#259)
Browse files Browse the repository at this point in the history
This makes the transit client respect context cancelation,
which is a critical feature of any I/O API.
  • Loading branch information
johanbrandhorst authored May 10, 2024
1 parent 933ad6c commit 05c77e8
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 14 deletions.
8 changes: 4 additions & 4 deletions wrappers/transit/transit.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ func (s *Wrapper) KeyId(_ context.Context) (string, error) {
}

// Encrypt is used to encrypt using Vault's Transit engine
func (s *Wrapper) Encrypt(_ context.Context, plaintext []byte, _ ...wrapping.Option) (*wrapping.BlobInfo, error) {
ciphertext, err := s.client.Encrypt(plaintext)
func (s *Wrapper) Encrypt(ctx context.Context, plaintext []byte, _ ...wrapping.Option) (*wrapping.BlobInfo, error) {
ciphertext, err := s.client.Encrypt(ctx, plaintext)
if err != nil {
return nil, err
}
Expand All @@ -103,8 +103,8 @@ func (s *Wrapper) Encrypt(_ context.Context, plaintext []byte, _ ...wrapping.Opt
}

// Decrypt is used to decrypt the ciphertext
func (s *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, _ ...wrapping.Option) ([]byte, error) {
plaintext, err := s.client.Decrypt(in.Ciphertext)
func (s *Wrapper) Decrypt(ctx context.Context, in *wrapping.BlobInfo, _ ...wrapping.Option) ([]byte, error) {
plaintext, err := s.client.Decrypt(ctx, in.Ciphertext)
if err != nil {
return nil, err
}
Expand Down
13 changes: 7 additions & 6 deletions wrappers/transit/transit_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package transit

import (
"context"
"encoding/base64"
"errors"
"fmt"
Expand All @@ -29,8 +30,8 @@ const (

type transitClientEncryptor interface {
Close()
Encrypt(plaintext []byte) (ciphertext []byte, err error)
Decrypt(ciphertext []byte) (plaintext []byte, err error)
Encrypt(ctx context.Context, plaintext []byte) (ciphertext []byte, err error)
Decrypt(ctx context.Context, ciphertext []byte) (plaintext []byte, err error)
}

type TransitClient struct {
Expand Down Expand Up @@ -197,10 +198,10 @@ func (c *TransitClient) Close() {
}
}

func (c *TransitClient) Encrypt(plaintext []byte) ([]byte, error) {
func (c *TransitClient) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) {
encPlaintext := base64.StdEncoding.EncodeToString(plaintext)
path := path.Join(c.mountPath, "encrypt", c.keyName)
secret, err := c.client.Logical().Write(path, map[string]interface{}{
secret, err := c.client.Logical().WriteWithContext(ctx, path, map[string]interface{}{
"plaintext": encPlaintext,
})
if err != nil {
Expand All @@ -224,9 +225,9 @@ func (c *TransitClient) Encrypt(plaintext []byte) ([]byte, error) {
return []byte(ctStr), nil
}

func (c *TransitClient) Decrypt(ciphertext []byte) ([]byte, error) {
func (c *TransitClient) Decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) {
path := path.Join(c.mountPath, "decrypt", c.keyName)
secret, err := c.client.Logical().Write(path, map[string]interface{}{
secret, err := c.client.Logical().WriteWithContext(ctx, path, map[string]interface{}{
"ciphertext": string(ciphertext),
})
if err != nil {
Expand Down
56 changes: 52 additions & 4 deletions wrappers/transit/transit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ func newTestTransitClient(keyID string) *testTransitClient {

func (m *testTransitClient) Close() {}

func (m *testTransitClient) Encrypt(plaintext []byte) ([]byte, error) {
v, err := m.wrap.Encrypt(context.Background(), plaintext, nil)
func (m *testTransitClient) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) {
v, err := m.wrap.Encrypt(ctx, plaintext, nil)
if err != nil {
return nil, err
}

return []byte(fmt.Sprintf("v1:%s:%s", m.keyID, string(v.Ciphertext))), nil
}

func (m *testTransitClient) Decrypt(ciphertext []byte) ([]byte, error) {
func (m *testTransitClient) Decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) {
splitKey := strings.Split(string(ciphertext), ":")
if len(splitKey) != 3 {
return nil, errors.New("invalid ciphertext returned")
Expand All @@ -60,7 +60,7 @@ func (m *testTransitClient) Decrypt(ciphertext []byte) ([]byte, error) {
data := &wrapping.BlobInfo{
Ciphertext: []byte(splitKey[2]),
}
v, err := m.wrap.Decrypt(context.Background(), data, nil)
v, err := m.wrap.Decrypt(ctx, data, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -343,3 +343,51 @@ func TestSetConfig(t *testing.T) {
})
}
}

func TestContextCancellation(t *testing.T) {
t.Parallel()
t.Run("Encrypt stops when the context is cancelled", func(t *testing.T) {
t.Parallel()
_, require := assert.New(t), require.New(t)
w := NewWrapper()
_, err := w.SetConfig(
context.Background(),
WithAddress(testWithAddress),
WithToken(testWithToken),
WithMountPath(testWithMountPath),
WithKeyName(testWithKeyName),
WithNamespace(testWithNamespace),
WithKeyIdPrefix("test/"),
)
require.NoError(err)
testPt := []byte("test-plaintext")
canceledCtx, cancel := context.WithCancel(context.Background())
cancel()
_, err = w.Encrypt(canceledCtx, testPt)
require.Error(err)
require.ErrorIs(err, context.Canceled)
})
t.Run("Decrypt stops when the context is cancelled", func(t *testing.T) {
t.Parallel()
_, require := assert.New(t), require.New(t)
w := NewWrapper()
_, err := w.SetConfig(
context.Background(),
WithAddress(testWithAddress),
WithToken(testWithToken),
WithMountPath(testWithMountPath),
WithKeyName(testWithKeyName),
WithNamespace(testWithNamespace),
WithKeyIdPrefix("test/"),
)
require.NoError(err)
testPt := []byte("test-plaintext")
blob, err := w.Encrypt(context.Background(), testPt)
require.NoError(err)
canceledCtx, cancel := context.WithCancel(context.Background())
cancel()
_, err = w.Decrypt(canceledCtx, blob)
require.Error(err)
require.ErrorIs(err, context.Canceled)
})
}

0 comments on commit 05c77e8

Please sign in to comment.