diff --git a/cmd/server/main.go b/cmd/server/main.go index 14307730..4bdf5292 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -16,6 +16,7 @@ import ( "syscall" "time" + "github.com/Azure/kubernetes-kms/pkg/config" "github.com/Azure/kubernetes-kms/pkg/metrics" "github.com/Azure/kubernetes-kms/pkg/plugin" "github.com/Azure/kubernetes-kms/pkg/utils" @@ -92,6 +93,27 @@ func main() { ConfigFilePath: *configFilePath, } + azureConfig, err := config.GetAzureConfig(pluginConfig.ConfigFilePath) + if err != nil { + klog.ErrorS(err, "failed to get azure config") + os.Exit(1) + } + + kvClient, err := plugin.NewKeyVaultClient( + azureConfig, + pluginConfig.KeyVaultName, + pluginConfig.KeyName, + pluginConfig.KeyVersion, + pluginConfig.ProxyMode, + pluginConfig.ProxyAddress, + pluginConfig.ProxyPort, + pluginConfig.ManagedHSM, + ) + if err != nil { + klog.ErrorS(err, "failed to create key vault client") + os.Exit(1) + } + // Initialize and run the GRPC server proto, addr, err := utils.ParseEndpoint(*listenAddr) if err != nil { @@ -116,7 +138,7 @@ func main() { s := grpc.NewServer(opts...) // register kms v1 server - kmsV1Server, err := plugin.NewKMSv1Server(pluginConfig) + kmsV1Server, err := plugin.NewKMSv1Server(kvClient) if err != nil { klog.ErrorS(err, "failed to create server") os.Exit(1) @@ -124,7 +146,7 @@ func main() { kmsv1.RegisterKeyManagementServiceServer(s, kmsV1Server) // register kms v2 server - kmsV2Server, err := plugin.NewKMSv2Server(pluginConfig) + kmsV2Server, err := plugin.NewKMSv2Server(kvClient) if err != nil { klog.ErrorS(err, "failed to create kms V2 server") os.Exit(1) diff --git a/go.mod b/go.mod index 307907ef..a0a82242 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,6 @@ require ( google.golang.org/grpc v1.54.0 gopkg.in/yaml.v3 v3.0.1 k8s.io/apimachinery v0.27.1 - k8s.io/apiserver v0.25.8 k8s.io/component-base v0.25.8 k8s.io/klog/v2 v2.90.1 k8s.io/kms v0.27.1 diff --git a/go.sum b/go.sum index 13932ecd..be54f043 100644 --- a/go.sum +++ b/go.sum @@ -800,8 +800,6 @@ honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9 honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= k8s.io/apimachinery v0.27.1 h1:EGuZiLI95UQQcClhanryclaQE6xjg1Bts6/L3cD7zyc= k8s.io/apimachinery v0.27.1/go.mod h1:5ikh59fK3AJ287GUvpUsryoMFtH9zj/ARfWCo3AyXTM= -k8s.io/apiserver v0.25.8 h1:ZTYdLdouAu8D6h9QavMaQZiAV+EfWK87VGdOyb6RZMQ= -k8s.io/apiserver v0.25.8/go.mod h1:IJ1r0vqXxwa+3QbrxAHWqdmoGZnVDDMzWtIK9ju3maI= k8s.io/component-base v0.25.8 h1:lQ5Ouw7lupdpXn5slRjAeHnlMK/aAEbPf9jjSWbOD3c= k8s.io/component-base v0.25.8/go.mod h1:MkC9Lz4fXoGOgB2WhFBU4zjiviIEeJS3sVhTxX9vt6s= k8s.io/klog/v2 v2.90.1 h1:m4bYOKall2MmOiRaR1J+We67Do7vm9KiQVlT96lnHUw= diff --git a/pkg/plugin/healthz.go b/pkg/plugin/healthz.go index 061b9cb6..c81f4fe3 100644 --- a/pkg/plugin/healthz.go +++ b/pkg/plugin/healthz.go @@ -99,8 +99,9 @@ func (h *HealthZ) ServeHTTP(w http.ResponseWriter, _ *http.Request) { return } - // v2 checks - uid := string(uuid.NewUUID()) + // v2 checks. + // appending a string to UUID allows us to differentiate the UUIDs generated by us from those generated by the API server. + uid := "local-healthz-check-" + string(uuid.NewUUID()) v2EncryptResponse, err := h.KMSv2Server.Encrypt( ctx, diff --git a/pkg/plugin/keyvault.go b/pkg/plugin/keyvault.go index f816cc68..fb703fd9 100644 --- a/pkg/plugin/keyvault.go +++ b/pkg/plugin/keyvault.go @@ -36,7 +36,6 @@ const ( keyvaultRegionAnnotationKey = "x-ms-keyvault-region.azure.akv.io" versionAnnotationKey = "version.azure.akv.io" algorithmAnnotationKey = "algorithm.azure.akv.io" - keyIDAnnotationKey = "key-id.azure.akv.io" dateAnnotationValue = "Date" requestIDAnnotationValue = "X-Ms-Request-Id" keyvaultRegionAnnotationValue = "X-Ms-Keyvault-Region" @@ -46,10 +45,11 @@ const ( type Client interface { Encrypt(ctx context.Context, cipher []byte) (*service.EncryptResponse, error) Decrypt(ctx context.Context, plain []byte) ([]byte, error) - ValidateAnnotations(annotations map[string][]byte) error + ValidateAnnotations(annotations map[string][]byte, keyID string) error } -type keyVaultClient struct { +// KeyVaultClient is a client for interacting with Keyvault. +type KeyVaultClient struct { baseClient kv.BaseClient config *config.AzureConfig vaultName string @@ -62,14 +62,14 @@ type keyVaultClient struct { } // NewKeyVaultClient returns a new key vault client to use for kms operations. -func newKeyVaultClient( +func NewKeyVaultClient( config *config.AzureConfig, vaultName, keyName, keyVersion string, proxyMode bool, proxyAddress string, proxyPort int, managedHSM bool, -) (*keyVaultClient, error) { +) (*KeyVaultClient, error) { // Sanitize vaultName, keyName, keyVersion. (https://github.com/Azure/kubernetes-kms/issues/85) vaultName = utils.SanitizeString(vaultName) keyName = utils.SanitizeString(keyName) @@ -115,7 +115,7 @@ func newKeyVaultClient( klog.InfoS("using kms key for encrypt/decrypt", "vaultURL", *vaultURL, "keyName", keyName, "keyVersion", keyVersion) - client := &keyVaultClient{ + client := &KeyVaultClient{ baseClient: kvClient, config: config, vaultName: vaultName, @@ -123,14 +123,15 @@ func newKeyVaultClient( keyVersion: keyVersion, vaultURL: *vaultURL, azureEnvironment: env, - keyID: fmt.Sprintf("%x", sha256.Sum256([]byte(fmt.Sprintf("%s/keys/%s/%s", *vaultURL, keyName, keyVersion)))), + keyID: fmt.Sprintf("%x", sha256.Sum256([]byte(fmt.Sprintf("%skeys/%s/%s", *vaultURL, keyName, keyVersion)))), algorithm: kv.RSA15, } return client, nil } -func (kvc *keyVaultClient) Encrypt(ctx context.Context, cipher []byte) (*service.EncryptResponse, error) { - value := base64.RawURLEncoding.EncodeToString(cipher) +// Encrypt encrypts the given plain text using the keyvault key. +func (kvc *KeyVaultClient) Encrypt(ctx context.Context, plain []byte) (*service.EncryptResponse, error) { + value := base64.RawURLEncoding.EncodeToString(plain) params := kv.KeyOperationsParameters{ Algorithm: kvc.algorithm, @@ -147,7 +148,6 @@ func (kvc *keyVaultClient) Encrypt(ctx context.Context, cipher []byte) (*service keyvaultRegionAnnotationKey: []byte(result.Header.Get(keyvaultRegionAnnotationValue)), versionAnnotationKey: []byte(encryptionResponseVersion), algorithmAnnotationKey: []byte(kvc.algorithm), - keyIDAnnotationKey: []byte(kvc.keyID), } return &service.EncryptResponse{ @@ -157,8 +157,9 @@ func (kvc *keyVaultClient) Encrypt(ctx context.Context, cipher []byte) (*service }, nil } -func (kvc *keyVaultClient) Decrypt(ctx context.Context, plain []byte) ([]byte, error) { - value := string(plain) +// Decrypt decrypts the given cipher text using the keyvault key. +func (kvc *KeyVaultClient) Decrypt(ctx context.Context, cipher []byte) ([]byte, error) { + value := string(cipher) params := kv.KeyOperationsParameters{ Algorithm: kvc.algorithm, @@ -177,15 +178,14 @@ func (kvc *keyVaultClient) Decrypt(ctx context.Context, plain []byte) ([]byte, e } // ValidateAnnotations validates following annotations before decryption: -// - Key ID. // - Algorithm. // - Version. -func (kvc *keyVaultClient) ValidateAnnotations(annotations map[string][]byte) error { +// It also validates keyID that the API server checks. +func (kvc *KeyVaultClient) ValidateAnnotations(annotations map[string][]byte, keyID string) error { if len(annotations) == 0 { return fmt.Errorf("invalid annotations, annotations cannot be empty") } - keyID := string(annotations[keyIDAnnotationKey]) if keyID != kvc.keyID { return fmt.Errorf( "key id %s does not match expected key id %s used for encryption", diff --git a/pkg/plugin/keyvault_test.go b/pkg/plugin/keyvault_test.go index cedb5210..a4896074 100644 --- a/pkg/plugin/keyvault_test.go +++ b/pkg/plugin/keyvault_test.go @@ -68,7 +68,7 @@ func TestNewKeyVaultClientError(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - if _, err := newKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM); err == nil { + if _, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM); err == nil { t.Fatalf("newKeyVaultClient() expected error, got nil") } }) @@ -131,7 +131,7 @@ func TestNewKeyVaultClient(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - kvClient, err := newKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM) + kvClient, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM) if err != nil { t.Fatalf("newKeyVaultClient() failed with error: %v", err) } diff --git a/pkg/plugin/kms_v2_server.go b/pkg/plugin/kms_v2_server.go index 927ed78f..96e99ae7 100644 --- a/pkg/plugin/kms_v2_server.go +++ b/pkg/plugin/kms_v2_server.go @@ -10,7 +10,6 @@ import ( "fmt" "time" - "github.com/Azure/kubernetes-kms/pkg/config" "github.com/Azure/kubernetes-kms/pkg/metrics" "github.com/Azure/kubernetes-kms/pkg/version" @@ -25,26 +24,7 @@ type KeyManagementServiceV2Server struct { } // NewKMSv2Server creates an instance of the KMS Service Server with v2 apis. -func NewKMSv2Server(pluginConfig *Config) (*KeyManagementServiceV2Server, error) { - cfg, err := config.GetAzureConfig(pluginConfig.ConfigFilePath) - if err != nil { - return nil, err - } - - kvClient, err := newKeyVaultClient( - cfg, - pluginConfig.KeyVaultName, - pluginConfig.KeyName, - pluginConfig.KeyVersion, - pluginConfig.ProxyMode, - pluginConfig.ProxyAddress, - pluginConfig.ProxyPort, - pluginConfig.ManagedHSM, - ) - if err != nil { - return nil, err - } - +func NewKMSv2Server(kvClient *KeyVaultClient) (*KeyManagementServiceV2Server, error) { return &KeyManagementServiceV2Server{ kvClient: kvClient, reporter: metrics.NewStatsReporter(), @@ -135,7 +115,7 @@ func (s *KeyManagementServiceV2Server) Decrypt(ctx context.Context, request *kms }() klog.V(2).InfoS("decrypt request started", "uid", request.Uid) - err = s.kvClient.ValidateAnnotations(request.Annotations) + err = s.kvClient.ValidateAnnotations(request.Annotations, request.KeyId) if err != nil { klog.ErrorS(err, "failed to decrypt", "uid", request.Uid) return &kmsv2.DecryptResponse{}, err diff --git a/pkg/plugin/kms_v2_server_test.go b/pkg/plugin/kms_v2_server_test.go index bd23ee10..d6393f3c 100644 --- a/pkg/plugin/kms_v2_server_test.go +++ b/pkg/plugin/kms_v2_server_test.go @@ -92,7 +92,6 @@ func TestV2Decrypt(t *testing.T) { output: []byte{}, err: fmt.Errorf("key id \"invalid-key-id\" does not match expected key id \"mock-key-id\" used for encryption"), annotations: map[string][]byte{ - keyIDAnnotationKey: []byte("invalid-key-id"), algorithmAnnotationKey: []byte(keyvault.RSA15), versionAnnotationKey: []byte("1"), }, @@ -103,7 +102,6 @@ func TestV2Decrypt(t *testing.T) { output: []byte{}, err: fmt.Errorf("algorithm \"insecure-algorithm\" does not match expected algorithm \"RSA1_5\" used for encryption"), annotations: map[string][]byte{ - keyIDAnnotationKey: []byte("mock-key-id"), algorithmAnnotationKey: []byte("insecure-algorithm"), versionAnnotationKey: []byte("1"), }, @@ -114,7 +112,6 @@ func TestV2Decrypt(t *testing.T) { output: []byte{}, err: fmt.Errorf("version \"10\" does not match expected version \"1\" used for encryption"), annotations: map[string][]byte{ - keyIDAnnotationKey: []byte("mock-key-id"), algorithmAnnotationKey: []byte(keyvault.RSA15), versionAnnotationKey: []byte("10"), }, @@ -125,7 +122,6 @@ func TestV2Decrypt(t *testing.T) { output: []byte{}, err: fmt.Errorf("failed to decrypt"), annotations: map[string][]byte{ - keyIDAnnotationKey: []byte("mock-key-id"), algorithmAnnotationKey: []byte(keyvault.RSA15), versionAnnotationKey: []byte("1"), }, @@ -136,7 +132,6 @@ func TestV2Decrypt(t *testing.T) { output: []byte("foo"), err: nil, annotations: map[string][]byte{ - keyIDAnnotationKey: []byte("mock-key-id"), algorithmAnnotationKey: []byte(keyvault.RSA15), versionAnnotationKey: []byte("1"), }, @@ -159,6 +154,7 @@ func TestV2Decrypt(t *testing.T) { out, err := kmsV2Server.Decrypt(context.TODO(), &kmsv2.DecryptRequest{ Ciphertext: test.input, Annotations: test.annotations, + KeyId: "mock-key-id", }) if err != nil && (err.Error() != test.err.Error()) { t.Fatalf("expected err: %v, got: %v", test.err, err) diff --git a/pkg/plugin/mock_keyvault/keyvault_mock.go b/pkg/plugin/mock_keyvault/keyvault_mock.go index e8d35631..d1080def 100644 --- a/pkg/plugin/mock_keyvault/keyvault_mock.go +++ b/pkg/plugin/mock_keyvault/keyvault_mock.go @@ -59,13 +59,13 @@ func (kvc *KeyVaultClient) SetDecryptResponse(decryptOut []byte, err error) { kvc.decryptErr = err } -func (kvc *KeyVaultClient) ValidateAnnotations(annotations map[string][]byte) error { +func (kvc *KeyVaultClient) ValidateAnnotations(annotations map[string][]byte, keyID string) error { if len(annotations) == 0 { return fmt.Errorf("invalid annotations, annotations cannot be empty") } // validate key id - if string(annotations["key-id.azure.akv.io"]) != kvc.KeyID { + if keyID != kvc.KeyID { return fmt.Errorf( "key id %q does not match expected key id %q used for encryption", string(annotations["key-id.azure.akv.io"]), diff --git a/pkg/plugin/server.go b/pkg/plugin/server.go index c72b254e..cff93737 100644 --- a/pkg/plugin/server.go +++ b/pkg/plugin/server.go @@ -9,7 +9,6 @@ import ( "context" "time" - "github.com/Azure/kubernetes-kms/pkg/config" "github.com/Azure/kubernetes-kms/pkg/metrics" "github.com/Azure/kubernetes-kms/pkg/version" @@ -36,15 +35,7 @@ type Config struct { } // NewKMSv1Server creates an instance of the KMS Service Server. -func NewKMSv1Server(pc *Config) (*KeyManagementServiceServer, error) { - cfg, err := config.GetAzureConfig(pc.ConfigFilePath) - if err != nil { - return nil, err - } - kvClient, err := newKeyVaultClient(cfg, pc.KeyVaultName, pc.KeyName, pc.KeyVersion, pc.ProxyMode, pc.ProxyAddress, pc.ProxyPort, pc.ManagedHSM) - if err != nil { - return nil, err - } +func NewKMSv1Server(kvClient *KeyVaultClient) (*KeyManagementServiceServer, error) { return &KeyManagementServiceServer{ kvClient: kvClient, reporter: metrics.NewStatsReporter(), diff --git a/tests/client/client_test.go b/tests/client/client_test.go index fb28fb3b..90ebc8cc 100644 --- a/tests/client/client_test.go +++ b/tests/client/client_test.go @@ -9,7 +9,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - k8spb "k8s.io/apiserver/pkg/storage/value/encrypt/envelope/v1beta1" + kmsv1 "k8s.io/kms/apis/v1beta1" ) const ( @@ -19,7 +19,7 @@ const ( ) var ( - client k8spb.KeyManagementServiceClient + client kmsv1.KeyManagementServiceClient connection *grpc.ClientConn err error ) @@ -30,7 +30,7 @@ func setupTestCase(t *testing.T) func(t *testing.T) { if err != nil { fmt.Printf("%s", err) } - client = k8spb.NewKeyManagementServiceClient(connection) + client = kmsv1.NewKeyManagementServiceClient(connection) return func(t *testing.T) { t.Log("teardown test case") connection.Close() @@ -54,13 +54,13 @@ func TestEncryptDecrypt(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - encryptRequest := k8spb.EncryptRequest{Version: version, Plain: tc.want} + encryptRequest := kmsv1.EncryptRequest{Version: version, Plain: tc.want} encryptResponse, err := client.Encrypt(context.Background(), &encryptRequest) if err != nil { t.Fatalf("encrypt request failed with error: %+v", err) } - decryptRequest := k8spb.DecryptRequest{Version: version, Cipher: encryptResponse.Cipher} + decryptRequest := kmsv1.DecryptRequest{Version: version, Cipher: encryptResponse.Cipher} decryptResponse, err := client.Decrypt(context.Background(), &decryptRequest) if !bytes.Equal(decryptResponse.Plain, tc.want) { t.Fatalf("Expected secret, but got %s - %v", string(decryptResponse.Plain), err) @@ -85,7 +85,7 @@ func TestVersion(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - request := &k8spb.VersionRequest{Version: tc.want} + request := &kmsv1.VersionRequest{Version: tc.want} response, err := client.Version(context.Background(), request) if err != nil { t.Fatalf("failed get version from remote KMS provider: %v", err) diff --git a/tests/e2e/testkmsv2.bats b/tests/e2e/testkmsv2.bats index 2575a17e..e81f091b 100644 --- a/tests/e2e/testkmsv2.bats +++ b/tests/e2e/testkmsv2.bats @@ -30,6 +30,23 @@ export ETCD_KEY=/etc/kubernetes/pki/etcd/server.key assert_success } +@test "check encryption count" { + # The expected_encryption_count value is set to 2 because the Kind creates bootstrap secret in addition to the one we create above for testing purposes. + local expected_encyption_count="2" + local metrics=$(kubectl get --raw /metrics) + encyption_count=$(echo "${metrics}" | grep -oP 'apiserver_envelope_encryption_key_id_hash_total\{[^\}]*transformation_type="to_storage"[^\}]*\}\s+\K\d+') + [[ "${encyption_count}" == "${expected_encyption_count}" ]] +} + +@test "check keyID hash used for encrypt/decrypt" { + # expected_hash value is computed based on key used in CI. + # this needs to be updated when we rotate the key used in CI. + local expected_hash="cbda52be2f8c13d323a3b17c4679118a60b91d29454305e02ee485185b6e386f" + local metrics=$(kubectl get --raw /metrics) + hash=$(echo "${metrics}" | grep -oP 'sha256:\K[a-f0-9]+' | head -n 1) + [[ "${hash}" == "${expected_hash}" ]] +} + @test "check if metrics endpoint works" { local curl_pod_name=curl-$(openssl rand -hex 5) kubectl run ${curl_pod_name} --image=curlimages/curl:7.75.0 --labels="test=metrics_test" -- tail -f /dev/null