Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Nilekh Chaudhari <1626598+nilekhc@users.noreply.github.com>
  • Loading branch information
nilekhc committed May 8, 2023
1 parent f589a39 commit a209f44
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 69 deletions.
26 changes: 24 additions & 2 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"syscall"
"time"

"github.com/Azure/kubernetes-kms/pkg/config"
"github.com/Azure/kubernetes-kms/pkg/metrics"
"github.com/Azure/kubernetes-kms/pkg/plugin"
"github.com/Azure/kubernetes-kms/pkg/utils"
Expand Down Expand Up @@ -92,6 +93,27 @@ func main() {
ConfigFilePath: *configFilePath,
}

azureConfig, err := config.GetAzureConfig(pluginConfig.ConfigFilePath)
if err != nil {
klog.ErrorS(err, "failed to get azure config")
os.Exit(1)
}

kvClient, err := plugin.NewKeyVaultClient(
azureConfig,
pluginConfig.KeyVaultName,
pluginConfig.KeyName,
pluginConfig.KeyVersion,
pluginConfig.ProxyMode,
pluginConfig.ProxyAddress,
pluginConfig.ProxyPort,
pluginConfig.ManagedHSM,
)
if err != nil {
klog.ErrorS(err, "failed to create key vault client")
os.Exit(1)
}

// Initialize and run the GRPC server
proto, addr, err := utils.ParseEndpoint(*listenAddr)
if err != nil {
Expand All @@ -116,15 +138,15 @@ func main() {
s := grpc.NewServer(opts...)

// register kms v1 server
kmsV1Server, err := plugin.NewKMSv1Server(pluginConfig)
kmsV1Server, err := plugin.NewKMSv1Server(kvClient)
if err != nil {
klog.ErrorS(err, "failed to create server")
os.Exit(1)
}
kmsv1.RegisterKeyManagementServiceServer(s, kmsV1Server)

// register kms v2 server
kmsV2Server, err := plugin.NewKMSv2Server(pluginConfig)
kmsV2Server, err := plugin.NewKMSv2Server(kvClient)
if err != nil {
klog.ErrorS(err, "failed to create kms V2 server")
os.Exit(1)
Expand Down
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ require (
google.golang.org/grpc v1.54.0
gopkg.in/yaml.v3 v3.0.1
k8s.io/apimachinery v0.27.1
k8s.io/apiserver v0.25.8
k8s.io/component-base v0.25.8
k8s.io/klog/v2 v2.90.1
k8s.io/kms v0.27.1
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -800,8 +800,6 @@ honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
k8s.io/apimachinery v0.27.1 h1:EGuZiLI95UQQcClhanryclaQE6xjg1Bts6/L3cD7zyc=
k8s.io/apimachinery v0.27.1/go.mod h1:5ikh59fK3AJ287GUvpUsryoMFtH9zj/ARfWCo3AyXTM=
k8s.io/apiserver v0.25.8 h1:ZTYdLdouAu8D6h9QavMaQZiAV+EfWK87VGdOyb6RZMQ=
k8s.io/apiserver v0.25.8/go.mod h1:IJ1r0vqXxwa+3QbrxAHWqdmoGZnVDDMzWtIK9ju3maI=
k8s.io/component-base v0.25.8 h1:lQ5Ouw7lupdpXn5slRjAeHnlMK/aAEbPf9jjSWbOD3c=
k8s.io/component-base v0.25.8/go.mod h1:MkC9Lz4fXoGOgB2WhFBU4zjiviIEeJS3sVhTxX9vt6s=
k8s.io/klog/v2 v2.90.1 h1:m4bYOKall2MmOiRaR1J+We67Do7vm9KiQVlT96lnHUw=
Expand Down
5 changes: 3 additions & 2 deletions pkg/plugin/healthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ func (h *HealthZ) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
return
}

// v2 checks
uid := string(uuid.NewUUID())
// v2 checks.
// appending a string to UUID allows us to differentiate the UUIDs generated by us from those generated by the API server.
uid := "local-healthz-check-" + string(uuid.NewUUID())

v2EncryptResponse, err := h.KMSv2Server.Encrypt(
ctx,
Expand Down
30 changes: 15 additions & 15 deletions pkg/plugin/keyvault.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ const (
keyvaultRegionAnnotationKey = "x-ms-keyvault-region.azure.akv.io"
versionAnnotationKey = "version.azure.akv.io"
algorithmAnnotationKey = "algorithm.azure.akv.io"
keyIDAnnotationKey = "key-id.azure.akv.io"
dateAnnotationValue = "Date"
requestIDAnnotationValue = "X-Ms-Request-Id"
keyvaultRegionAnnotationValue = "X-Ms-Keyvault-Region"
Expand All @@ -46,10 +45,11 @@ const (
type Client interface {
Encrypt(ctx context.Context, cipher []byte) (*service.EncryptResponse, error)
Decrypt(ctx context.Context, plain []byte) ([]byte, error)
ValidateAnnotations(annotations map[string][]byte) error
ValidateAnnotations(annotations map[string][]byte, keyID string) error
}

type keyVaultClient struct {
// KeyVaultClient is a client for interacting with Keyvault.
type KeyVaultClient struct {
baseClient kv.BaseClient
config *config.AzureConfig
vaultName string
Expand All @@ -62,14 +62,14 @@ type keyVaultClient struct {
}

// NewKeyVaultClient returns a new key vault client to use for kms operations.
func newKeyVaultClient(
func NewKeyVaultClient(
config *config.AzureConfig,
vaultName, keyName, keyVersion string,
proxyMode bool,
proxyAddress string,
proxyPort int,
managedHSM bool,
) (*keyVaultClient, error) {
) (*KeyVaultClient, error) {
// Sanitize vaultName, keyName, keyVersion. (https://github.com/Azure/kubernetes-kms/issues/85)
vaultName = utils.SanitizeString(vaultName)
keyName = utils.SanitizeString(keyName)
Expand Down Expand Up @@ -115,22 +115,23 @@ func newKeyVaultClient(

klog.InfoS("using kms key for encrypt/decrypt", "vaultURL", *vaultURL, "keyName", keyName, "keyVersion", keyVersion)

client := &keyVaultClient{
client := &KeyVaultClient{
baseClient: kvClient,
config: config,
vaultName: vaultName,
keyName: keyName,
keyVersion: keyVersion,
vaultURL: *vaultURL,
azureEnvironment: env,
keyID: fmt.Sprintf("%x", sha256.Sum256([]byte(fmt.Sprintf("%s/keys/%s/%s", *vaultURL, keyName, keyVersion)))),
keyID: fmt.Sprintf("%x", sha256.Sum256([]byte(fmt.Sprintf("%skeys/%s/%s", *vaultURL, keyName, keyVersion)))),
algorithm: kv.RSA15,
}
return client, nil
}

func (kvc *keyVaultClient) Encrypt(ctx context.Context, cipher []byte) (*service.EncryptResponse, error) {
value := base64.RawURLEncoding.EncodeToString(cipher)
// Encrypt encrypts the given plain text using the keyvault key.
func (kvc *KeyVaultClient) Encrypt(ctx context.Context, plain []byte) (*service.EncryptResponse, error) {
value := base64.RawURLEncoding.EncodeToString(plain)

params := kv.KeyOperationsParameters{
Algorithm: kvc.algorithm,
Expand All @@ -147,7 +148,6 @@ func (kvc *keyVaultClient) Encrypt(ctx context.Context, cipher []byte) (*service
keyvaultRegionAnnotationKey: []byte(result.Header.Get(keyvaultRegionAnnotationValue)),
versionAnnotationKey: []byte(encryptionResponseVersion),
algorithmAnnotationKey: []byte(kvc.algorithm),
keyIDAnnotationKey: []byte(kvc.keyID),
}

return &service.EncryptResponse{
Expand All @@ -157,8 +157,9 @@ func (kvc *keyVaultClient) Encrypt(ctx context.Context, cipher []byte) (*service
}, nil
}

func (kvc *keyVaultClient) Decrypt(ctx context.Context, plain []byte) ([]byte, error) {
value := string(plain)
// Decrypt decrypts the given cipher text using the keyvault key.
func (kvc *KeyVaultClient) Decrypt(ctx context.Context, cipher []byte) ([]byte, error) {
value := string(cipher)

params := kv.KeyOperationsParameters{
Algorithm: kvc.algorithm,
Expand All @@ -177,15 +178,14 @@ func (kvc *keyVaultClient) Decrypt(ctx context.Context, plain []byte) ([]byte, e
}

// ValidateAnnotations validates following annotations before decryption:
// - Key ID.
// - Algorithm.
// - Version.
func (kvc *keyVaultClient) ValidateAnnotations(annotations map[string][]byte) error {
// It also validates keyID that the API server checks.
func (kvc *KeyVaultClient) ValidateAnnotations(annotations map[string][]byte, keyID string) error {
if len(annotations) == 0 {
return fmt.Errorf("invalid annotations, annotations cannot be empty")
}

keyID := string(annotations[keyIDAnnotationKey])
if keyID != kvc.keyID {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
Expand Down
4 changes: 2 additions & 2 deletions pkg/plugin/keyvault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestNewKeyVaultClientError(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
if _, err := newKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM); err == nil {
if _, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM); err == nil {
t.Fatalf("newKeyVaultClient() expected error, got nil")
}
})
Expand Down Expand Up @@ -131,7 +131,7 @@ func TestNewKeyVaultClient(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
kvClient, err := newKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM)
kvClient, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM)
if err != nil {
t.Fatalf("newKeyVaultClient() failed with error: %v", err)
}
Expand Down
24 changes: 2 additions & 22 deletions pkg/plugin/kms_v2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"fmt"
"time"

"github.com/Azure/kubernetes-kms/pkg/config"
"github.com/Azure/kubernetes-kms/pkg/metrics"
"github.com/Azure/kubernetes-kms/pkg/version"

Expand All @@ -25,26 +24,7 @@ type KeyManagementServiceV2Server struct {
}

// NewKMSv2Server creates an instance of the KMS Service Server with v2 apis.
func NewKMSv2Server(pluginConfig *Config) (*KeyManagementServiceV2Server, error) {
cfg, err := config.GetAzureConfig(pluginConfig.ConfigFilePath)
if err != nil {
return nil, err
}

kvClient, err := newKeyVaultClient(
cfg,
pluginConfig.KeyVaultName,
pluginConfig.KeyName,
pluginConfig.KeyVersion,
pluginConfig.ProxyMode,
pluginConfig.ProxyAddress,
pluginConfig.ProxyPort,
pluginConfig.ManagedHSM,
)
if err != nil {
return nil, err
}

func NewKMSv2Server(kvClient *KeyVaultClient) (*KeyManagementServiceV2Server, error) {
return &KeyManagementServiceV2Server{
kvClient: kvClient,
reporter: metrics.NewStatsReporter(),
Expand Down Expand Up @@ -135,7 +115,7 @@ func (s *KeyManagementServiceV2Server) Decrypt(ctx context.Context, request *kms
}()

klog.V(2).InfoS("decrypt request started", "uid", request.Uid)
err = s.kvClient.ValidateAnnotations(request.Annotations)
err = s.kvClient.ValidateAnnotations(request.Annotations, request.KeyId)
if err != nil {
klog.ErrorS(err, "failed to decrypt", "uid", request.Uid)
return &kmsv2.DecryptResponse{}, err
Expand Down
6 changes: 1 addition & 5 deletions pkg/plugin/kms_v2_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ func TestV2Decrypt(t *testing.T) {
output: []byte{},
err: fmt.Errorf("key id \"invalid-key-id\" does not match expected key id \"mock-key-id\" used for encryption"),
annotations: map[string][]byte{
keyIDAnnotationKey: []byte("invalid-key-id"),
algorithmAnnotationKey: []byte(keyvault.RSA15),
versionAnnotationKey: []byte("1"),
},
Expand All @@ -103,7 +102,6 @@ func TestV2Decrypt(t *testing.T) {
output: []byte{},
err: fmt.Errorf("algorithm \"insecure-algorithm\" does not match expected algorithm \"RSA1_5\" used for encryption"),
annotations: map[string][]byte{
keyIDAnnotationKey: []byte("mock-key-id"),
algorithmAnnotationKey: []byte("insecure-algorithm"),
versionAnnotationKey: []byte("1"),
},
Expand All @@ -114,7 +112,6 @@ func TestV2Decrypt(t *testing.T) {
output: []byte{},
err: fmt.Errorf("version \"10\" does not match expected version \"1\" used for encryption"),
annotations: map[string][]byte{
keyIDAnnotationKey: []byte("mock-key-id"),
algorithmAnnotationKey: []byte(keyvault.RSA15),
versionAnnotationKey: []byte("10"),
},
Expand All @@ -125,7 +122,6 @@ func TestV2Decrypt(t *testing.T) {
output: []byte{},
err: fmt.Errorf("failed to decrypt"),
annotations: map[string][]byte{
keyIDAnnotationKey: []byte("mock-key-id"),
algorithmAnnotationKey: []byte(keyvault.RSA15),
versionAnnotationKey: []byte("1"),
},
Expand All @@ -136,7 +132,6 @@ func TestV2Decrypt(t *testing.T) {
output: []byte("foo"),
err: nil,
annotations: map[string][]byte{
keyIDAnnotationKey: []byte("mock-key-id"),
algorithmAnnotationKey: []byte(keyvault.RSA15),
versionAnnotationKey: []byte("1"),
},
Expand All @@ -159,6 +154,7 @@ func TestV2Decrypt(t *testing.T) {
out, err := kmsV2Server.Decrypt(context.TODO(), &kmsv2.DecryptRequest{
Ciphertext: test.input,
Annotations: test.annotations,
KeyId: "mock-key-id",
})
if err != nil && (err.Error() != test.err.Error()) {
t.Fatalf("expected err: %v, got: %v", test.err, err)
Expand Down
4 changes: 2 additions & 2 deletions pkg/plugin/mock_keyvault/keyvault_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ func (kvc *KeyVaultClient) SetDecryptResponse(decryptOut []byte, err error) {
kvc.decryptErr = err
}

func (kvc *KeyVaultClient) ValidateAnnotations(annotations map[string][]byte) error {
func (kvc *KeyVaultClient) ValidateAnnotations(annotations map[string][]byte, keyID string) error {
if len(annotations) == 0 {
return fmt.Errorf("invalid annotations, annotations cannot be empty")
}

// validate key id
if string(annotations["key-id.azure.akv.io"]) != kvc.KeyID {
if keyID != kvc.KeyID {
return fmt.Errorf(
"key id %q does not match expected key id %q used for encryption",
string(annotations["key-id.azure.akv.io"]),
Expand Down
11 changes: 1 addition & 10 deletions pkg/plugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"context"
"time"

"github.com/Azure/kubernetes-kms/pkg/config"
"github.com/Azure/kubernetes-kms/pkg/metrics"
"github.com/Azure/kubernetes-kms/pkg/version"

Expand All @@ -36,15 +35,7 @@ type Config struct {
}

// NewKMSv1Server creates an instance of the KMS Service Server.
func NewKMSv1Server(pc *Config) (*KeyManagementServiceServer, error) {
cfg, err := config.GetAzureConfig(pc.ConfigFilePath)
if err != nil {
return nil, err
}
kvClient, err := newKeyVaultClient(cfg, pc.KeyVaultName, pc.KeyName, pc.KeyVersion, pc.ProxyMode, pc.ProxyAddress, pc.ProxyPort, pc.ManagedHSM)
if err != nil {
return nil, err
}
func NewKMSv1Server(kvClient *KeyVaultClient) (*KeyManagementServiceServer, error) {
return &KeyManagementServiceServer{
kvClient: kvClient,
reporter: metrics.NewStatsReporter(),
Expand Down
12 changes: 6 additions & 6 deletions tests/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
k8spb "k8s.io/apiserver/pkg/storage/value/encrypt/envelope/v1beta1"
kmsv1 "k8s.io/kms/apis/v1beta1"
)

const (
Expand All @@ -19,7 +19,7 @@ const (
)

var (
client k8spb.KeyManagementServiceClient
client kmsv1.KeyManagementServiceClient
connection *grpc.ClientConn
err error
)
Expand All @@ -30,7 +30,7 @@ func setupTestCase(t *testing.T) func(t *testing.T) {
if err != nil {
fmt.Printf("%s", err)
}
client = k8spb.NewKeyManagementServiceClient(connection)
client = kmsv1.NewKeyManagementServiceClient(connection)
return func(t *testing.T) {
t.Log("teardown test case")
connection.Close()
Expand All @@ -54,13 +54,13 @@ func TestEncryptDecrypt(t *testing.T) {

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
encryptRequest := k8spb.EncryptRequest{Version: version, Plain: tc.want}
encryptRequest := kmsv1.EncryptRequest{Version: version, Plain: tc.want}
encryptResponse, err := client.Encrypt(context.Background(), &encryptRequest)
if err != nil {
t.Fatalf("encrypt request failed with error: %+v", err)
}

decryptRequest := k8spb.DecryptRequest{Version: version, Cipher: encryptResponse.Cipher}
decryptRequest := kmsv1.DecryptRequest{Version: version, Cipher: encryptResponse.Cipher}
decryptResponse, err := client.Decrypt(context.Background(), &decryptRequest)
if !bytes.Equal(decryptResponse.Plain, tc.want) {
t.Fatalf("Expected secret, but got %s - %v", string(decryptResponse.Plain), err)
Expand All @@ -85,7 +85,7 @@ func TestVersion(t *testing.T) {

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
request := &k8spb.VersionRequest{Version: tc.want}
request := &kmsv1.VersionRequest{Version: tc.want}
response, err := client.Version(context.Background(), request)
if err != nil {
t.Fatalf("failed get version from remote KMS provider: %v", err)
Expand Down
Loading

0 comments on commit a209f44

Please sign in to comment.