diff --git a/wrappers/transit/transit.go b/wrappers/transit/transit.go index fdef5aa..bd9c44b 100644 --- a/wrappers/transit/transit.go +++ b/wrappers/transit/transit.go @@ -80,8 +80,8 @@ func (s *Wrapper) KeyId(_ context.Context) (string, error) { } // Encrypt is used to encrypt using Vault's Transit engine -func (s *Wrapper) Encrypt(_ context.Context, plaintext []byte, _ ...wrapping.Option) (*wrapping.BlobInfo, error) { - ciphertext, err := s.client.Encrypt(plaintext) +func (s *Wrapper) Encrypt(ctx context.Context, plaintext []byte, _ ...wrapping.Option) (*wrapping.BlobInfo, error) { + ciphertext, err := s.client.Encrypt(ctx, plaintext) if err != nil { return nil, err } @@ -103,8 +103,8 @@ func (s *Wrapper) Encrypt(_ context.Context, plaintext []byte, _ ...wrapping.Opt } // Decrypt is used to decrypt the ciphertext -func (s *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, _ ...wrapping.Option) ([]byte, error) { - plaintext, err := s.client.Decrypt(in.Ciphertext) +func (s *Wrapper) Decrypt(ctx context.Context, in *wrapping.BlobInfo, _ ...wrapping.Option) ([]byte, error) { + plaintext, err := s.client.Decrypt(ctx, in.Ciphertext) if err != nil { return nil, err } diff --git a/wrappers/transit/transit_client.go b/wrappers/transit/transit_client.go index 1192083..c14b3a3 100644 --- a/wrappers/transit/transit_client.go +++ b/wrappers/transit/transit_client.go @@ -4,6 +4,7 @@ package transit import ( + "context" "encoding/base64" "errors" "fmt" @@ -29,8 +30,8 @@ const ( type transitClientEncryptor interface { Close() - Encrypt(plaintext []byte) (ciphertext []byte, err error) - Decrypt(ciphertext []byte) (plaintext []byte, err error) + Encrypt(ctx context.Context, plaintext []byte) (ciphertext []byte, err error) + Decrypt(ctx context.Context, ciphertext []byte) (plaintext []byte, err error) } type TransitClient struct { @@ -197,10 +198,10 @@ func (c *TransitClient) Close() { } } -func (c *TransitClient) Encrypt(plaintext []byte) ([]byte, error) { +func (c *TransitClient) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) { encPlaintext := base64.StdEncoding.EncodeToString(plaintext) path := path.Join(c.mountPath, "encrypt", c.keyName) - secret, err := c.client.Logical().Write(path, map[string]interface{}{ + secret, err := c.client.Logical().WriteWithContext(ctx, path, map[string]interface{}{ "plaintext": encPlaintext, }) if err != nil { @@ -224,9 +225,9 @@ func (c *TransitClient) Encrypt(plaintext []byte) ([]byte, error) { return []byte(ctStr), nil } -func (c *TransitClient) Decrypt(ciphertext []byte) ([]byte, error) { +func (c *TransitClient) Decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) { path := path.Join(c.mountPath, "decrypt", c.keyName) - secret, err := c.client.Logical().Write(path, map[string]interface{}{ + secret, err := c.client.Logical().WriteWithContext(ctx, path, map[string]interface{}{ "ciphertext": string(ciphertext), }) if err != nil { diff --git a/wrappers/transit/transit_test.go b/wrappers/transit/transit_test.go index ed95f8b..96f14dd 100644 --- a/wrappers/transit/transit_test.go +++ b/wrappers/transit/transit_test.go @@ -42,8 +42,8 @@ func newTestTransitClient(keyID string) *testTransitClient { func (m *testTransitClient) Close() {} -func (m *testTransitClient) Encrypt(plaintext []byte) ([]byte, error) { - v, err := m.wrap.Encrypt(context.Background(), plaintext, nil) +func (m *testTransitClient) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) { + v, err := m.wrap.Encrypt(ctx, plaintext, nil) if err != nil { return nil, err } @@ -51,7 +51,7 @@ func (m *testTransitClient) Encrypt(plaintext []byte) ([]byte, error) { return []byte(fmt.Sprintf("v1:%s:%s", m.keyID, string(v.Ciphertext))), nil } -func (m *testTransitClient) Decrypt(ciphertext []byte) ([]byte, error) { +func (m *testTransitClient) Decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) { splitKey := strings.Split(string(ciphertext), ":") if len(splitKey) != 3 { return nil, errors.New("invalid ciphertext returned") @@ -60,7 +60,7 @@ func (m *testTransitClient) Decrypt(ciphertext []byte) ([]byte, error) { data := &wrapping.BlobInfo{ Ciphertext: []byte(splitKey[2]), } - v, err := m.wrap.Decrypt(context.Background(), data, nil) + v, err := m.wrap.Decrypt(ctx, data, nil) if err != nil { return nil, err } @@ -343,3 +343,51 @@ func TestSetConfig(t *testing.T) { }) } } + +func TestContextCancellation(t *testing.T) { + t.Parallel() + t.Run("Encrypt stops when the context is cancelled", func(t *testing.T) { + t.Parallel() + _, require := assert.New(t), require.New(t) + w := NewWrapper() + _, err := w.SetConfig( + context.Background(), + WithAddress(testWithAddress), + WithToken(testWithToken), + WithMountPath(testWithMountPath), + WithKeyName(testWithKeyName), + WithNamespace(testWithNamespace), + WithKeyIdPrefix("test/"), + ) + require.NoError(err) + testPt := []byte("test-plaintext") + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = w.Encrypt(canceledCtx, testPt) + require.Error(err) + require.ErrorIs(err, context.Canceled) + }) + t.Run("Decrypt stops when the context is cancelled", func(t *testing.T) { + t.Parallel() + _, require := assert.New(t), require.New(t) + w := NewWrapper() + _, err := w.SetConfig( + context.Background(), + WithAddress(testWithAddress), + WithToken(testWithToken), + WithMountPath(testWithMountPath), + WithKeyName(testWithKeyName), + WithNamespace(testWithNamespace), + WithKeyIdPrefix("test/"), + ) + require.NoError(err) + testPt := []byte("test-plaintext") + blob, err := w.Encrypt(context.Background(), testPt) + require.NoError(err) + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = w.Decrypt(canceledCtx, blob) + require.Error(err) + require.ErrorIs(err, context.Canceled) + }) +}