Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Nilekh Chaudhari <1626598+nilekhc@users.noreply.github.com>
  • Loading branch information
nilekhc committed Apr 18, 2023
1 parent 08b47a1 commit b319d0d
Show file tree
Hide file tree
Showing 13 changed files with 262 additions and 173 deletions.
4 changes: 2 additions & 2 deletions .pipelines/templates/e2e-kind-template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ jobs:
- group: kubernetes-kms
strategy:
matrix:
kmsv2_kind_v1_27_0:
KUBERNETES_VERSION: v1.27.0
kmsv2_kind_v1_27_1:
KUBERNETES_VERSION: v1.27.1
steps:
- task: GoTool@0
inputs:
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export DOCKER_BUILDKIT

# Testing var
KIND_VERSION ?= 0.18.0
KUBERNETES_VERSION ?= v1.27.0
KUBERNETES_VERSION ?= v1.27.1
BATS_VERSION ?= 1.4.1

## --------------------------------------
Expand Down Expand Up @@ -130,12 +130,12 @@ install-soak-prerequisites: e2e-install-prerequisites
e2e-setup-kind: setup-local-registry
./scripts/setup-kind-cluster.sh &
./scripts/connect-registry.sh &
sleep 90s
wait

e2e-kmsv2-setup-kind: setup-local-registry
./scripts/setup-kmsv2-kind-cluster.sh &
./scripts/connect-registry.sh &
sleep 90s
wait

.PHONY: setup-local-registry
setup-local-registry:
Expand Down
84 changes: 24 additions & 60 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,12 @@ import (
)

var (
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
kmsV2ListenAddr = flag.String("kms-v2-listen-addr", "unix:///opt/azurekmsv2.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
// TODO remove this flag in future release.
_ = flag.String("configFilePath", "/etc/kubernetes/azure.json", "[DEPRECATED] Path for Azure Cloud Provider config file")
configFilePath = flag.String("config-file-path", "/etc/kubernetes/azure.json", "Path for Azure Cloud Provider config file")
Expand Down Expand Up @@ -112,38 +111,41 @@ func main() {
os.Exit(1)
}

s := grpc.NewServer(utils.GetDefaultGRPCServerOption()...)
opts := []grpc.ServerOption{
grpc.UnaryInterceptor(utils.UnaryServerInterceptor),
}

// register kms v1 server
s := grpc.NewServer(opts...)
pb.RegisterKeyManagementServiceServer(s, kmsServer)

// register kms v2 server
kmsV2Server, err := plugin.NewV2Server(pc)
if err != nil {
klog.ErrorS(err, "failed to create kms V2 server")
os.Exit(1)
}
kmsv2.RegisterKeyManagementServiceServer(s, kmsV2Server)

klog.InfoS("Listening for connections", "addr", listener.Addr().String())
go func() {
if err := s.Serve(listener); err != nil {
klog.ErrorS(err, "failed to serve kms v2 server")
klog.ErrorS(err, "failed to serve kms server")
}
}()

// initialize kms v2 server
unixSocketPathForV2, kmsV2Server, err := setupKMSV2(pc)
if err != nil {
klog.ErrorS(err, "failed to create kms v2 server")
os.Exit(1)
}

// Health check for KMS v1 and v2
klog.InfoS("Starting healthz server", "addr", listener.Addr().String())
// Health check for kms v1 and v2
healthz := &plugin.HealthZ{
KMSServer: kmsServer,
KMSV2Server: kmsV2Server,
HealthCheckURL: &url.URL{
Host: net.JoinHostPort("", strconv.FormatUint(uint64(*healthzPort), 10)),
Path: *healthzPath,
},
UnixSocketPath: listener.Addr().String(),
UnixSocketPathForV2: unixSocketPathForV2,
RPCTimeout: *healthzTimeout,
UnixSocketPath: listener.Addr().String(),
RPCTimeout: *healthzTimeout,
}
go healthz.Serve()
klog.InfoS("Healthz server started", "addr", listener.Addr().String())

<-ctx.Done()
// gracefully stop the grpc server
Expand All @@ -170,41 +172,3 @@ func withShutdownSignal(ctx context.Context) context.Context {
}()
return nctx
}

func setupKMSV2(pluginConfig *plugin.Config) (string, *plugin.KeyManagementServiceV2Server, error) {
kmsV2Server, err := plugin.NewV2Server(pluginConfig)
if err != nil {
klog.ErrorS(err, "failed to create kms V2 server")
return "", nil, err
}

// Initialize and run the GRPC server
proto, addr, err := utils.ParseEndpoint(*kmsV2ListenAddr)
if err != nil {
klog.ErrorS(err, "failed to parse kms v2 unix domain socket endpoint")
return "", nil, err
}
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
klog.ErrorS(err, "failed to remove socket file", "addr", addr)
return "", nil, err
}

listener, err := net.Listen(proto, addr)
if err != nil {
klog.ErrorS(err, "failed to listen", "addr", addr, "proto", proto)
return "", nil, err
}

klog.InfoS("Starting GRPC server", "addr", listener.Addr().String())
s := grpc.NewServer(utils.GetDefaultGRPCServerOption()...)

kmsv2.RegisterKeyManagementServiceServer(s, kmsV2Server)
klog.InfoS("Listening for connections", "addr", listener.Addr().String())
go func() {
if err := s.Serve(listener); err != nil {
klog.ErrorS(err, "failed to serve kms v2 server")
}
}()

return listener.Addr().String(), kmsV2Server, nil
}
40 changes: 14 additions & 26 deletions pkg/plugin/healthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ const (
healthCheckPlainText = "healthcheck"
)

// HealthZ is the health check server for the KMS plugin
type HealthZ struct {
KMSServer *KeyManagementServiceServer
KMSV2Server *KeyManagementServiceV2Server
HealthCheckURL *url.URL
UnixSocketPath string
UnixSocketPathForV2 string
RPCTimeout time.Duration
KMSServer *KeyManagementServiceServer
KMSV2Server *KeyManagementServiceV2Server
HealthCheckURL *url.URL
UnixSocketPath string
RPCTimeout time.Duration
}

// Serve creates the http handler for serving health requests
Expand All @@ -57,25 +57,18 @@ func (h *HealthZ) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
ctx, cancel := context.WithTimeout(context.Background(), h.RPCTimeout)
defer cancel()

conn, err := h.dialUnixSocket("v1")
conn, err := h.dialUnixSocket()
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer conn.Close()

v2Conn, err := h.dialUnixSocket("v2")
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer v2Conn.Close()

// create the kms client for v1
kmsClient := pb.NewKeyManagementServiceClient(conn)

// create the kms client for v2
kmsV2Client := kmsv2.NewKeyManagementServiceClient(v2Conn)
kmsV2Client := kmsv2.NewKeyManagementServiceClient(conn)

// check version response against KMS-Plugin's gRPC endpoint.
err = h.checkRPC(ctx, kmsClient, kmsV2Client)
Expand Down Expand Up @@ -123,9 +116,10 @@ func (h *HealthZ) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
}

v2DecryptResponse, err := h.KMSV2Server.Decrypt(ctx, &kmsv2.DecryptRequest{
Ciphertext: v2EncryptResponse.Ciphertext,
KeyId: v2EncryptResponse.KeyId,
Uid: uid, // passing the same uid to track roundtrip encrypt/decrypt calls
Ciphertext: v2EncryptResponse.Ciphertext,
KeyId: v2EncryptResponse.KeyId,
Uid: uid, // passing the same uid to track roundtrip encrypt/decrypt calls
Annotations: v2EncryptResponse.Annotations,
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -170,15 +164,9 @@ func (h *HealthZ) checkRPC(
return nil
}

func (h *HealthZ) dialUnixSocket(apiVersion string) (*grpc.ClientConn, error) {
socketPath := h.UnixSocketPath

if apiVersion == "v2" {
socketPath = h.UnixSocketPathForV2
}

func (h *HealthZ) dialUnixSocket() (*grpc.ClientConn, error) {
return grpc.Dial(
socketPath,
h.UnixSocketPath,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, "unix", target)
Expand Down
61 changes: 22 additions & 39 deletions pkg/plugin/healthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,7 @@ func TestServe(t *testing.T) {
socketPath := fmt.Sprintf("%s/kms.sock", getTempTestDir(t))
defer os.Remove(socketPath)

socketPathForV2 := fmt.Sprintf("%s/kmsv2.sock", getTempTestDir(t))
defer os.Remove(socketPathForV2)

fakeKMSServer, mockKVClient, err := setupFakeKMSServer(socketPath)
if err != nil {
t.Fatalf("failed to create fake kms server, err: %+v", err)
}

fakeKMSV2Server, mockKVClient, err := setupFakeKMSV2Server(socketPathForV2)
fakeKMSServer, fakeKMSV2Server, mockKVClient, err := setupFakeKMSServer(socketPath)
if err != nil {
t.Fatalf("failed to create fake kms server, err: %+v", err)
}
Expand All @@ -86,11 +78,10 @@ func TestServe(t *testing.T) {
mockKVClient.SetDecryptResponse([]byte(test.setDecryptResponse), test.setDecryptError)

healthz := &HealthZ{
KMSServer: fakeKMSServer,
KMSV2Server: fakeKMSV2Server,
UnixSocketPath: socketPath,
UnixSocketPathForV2: socketPathForV2,
RPCTimeout: 20 * time.Second,
KMSServer: fakeKMSServer,
KMSV2Server: fakeKMSV2Server,
UnixSocketPath: socketPath,
RPCTimeout: 20 * time.Second,
HealthCheckURL: &url.URL{
Scheme: "http",
Host: net.JoinHostPort("localhost", "8080"),
Expand All @@ -116,29 +107,25 @@ func TestCheckRPC(t *testing.T) {
socketPath := fmt.Sprintf("%s/kms.sock", getTempTestDir(t))
defer os.Remove(socketPath)

fakeKMSServer, _, err := setupFakeKMSServer(socketPath)
fakeKMSServer, fakeKMSV2Server, _, err := setupFakeKMSServer(socketPath)
if err != nil {
t.Fatalf("failed to create fake kms server, err: %+v", err)
}
healthz := &HealthZ{
KMSServer: fakeKMSServer,
KMSV2Server: fakeKMSV2Server,
UnixSocketPath: socketPath,
}

conn, err := healthz.dialUnixSocket("v1")
if err != nil {
t.Fatalf("failed to create connection, err: %+v", err)
}

v2Conn, err := healthz.dialUnixSocket("v2")
conn, err := healthz.dialUnixSocket()
if err != nil {
t.Fatalf("failed to create connection, err: %+v", err)
}

err = healthz.checkRPC(
context.TODO(),
pb.NewKeyManagementServiceClient(conn),
kmsv2.NewKeyManagementServiceClient(v2Conn),
kmsv2.NewKeyManagementServiceClient(conn),
)
if err != nil {
t.Fatalf("expected err to be nil, got: %+v", err)
Expand All @@ -153,38 +140,34 @@ func getTempTestDir(t *testing.T) string {
return tmpDir
}

func setupFakeKMSServer(socketPath string) (*KeyManagementServiceServer, *mockkeyvault.KeyVaultClient, error) {
func setupFakeKMSServer(socketPath string) (
*KeyManagementServiceServer,
*KeyManagementServiceV2Server,
*mockkeyvault.KeyVaultClient,
error,
) {
listener, err := net.Listen("unix", socketPath)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

kvClient := &mockkeyvault.KeyVaultClient{}
fakeKMSServer := &KeyManagementServiceServer{
kvClient: kvClient,
reporter: metrics.NewStatsReporter(),
}
s := grpc.NewServer()
pb.RegisterKeyManagementServiceServer(s, fakeKMSServer)
go s.Serve(listener)

return fakeKMSServer, kvClient, nil
}

func setupFakeKMSV2Server(socketPath string) (*KeyManagementServiceV2Server, *mockkeyvault.KeyVaultClient, error) {
listener, err := net.Listen("unix", socketPath)
if err != nil {
return nil, nil, err
}
kvClient := &mockkeyvault.KeyVaultClient{}
fakeKMSServer := &KeyManagementServiceV2Server{
fakeKMSV2Server := &KeyManagementServiceV2Server{
kvClient: kvClient,
reporter: metrics.NewStatsReporter(),
}

s := grpc.NewServer()
kmsv2.RegisterKeyManagementServiceServer(s, fakeKMSServer)
pb.RegisterKeyManagementServiceServer(s, fakeKMSServer)
kmsv2.RegisterKeyManagementServiceServer(s, fakeKMSV2Server)
go s.Serve(listener)

return fakeKMSServer, kvClient, nil
return fakeKMSServer, fakeKMSV2Server, kvClient, nil
}

func doHealthCheck(t *testing.T, url string) (int, []byte) {
Expand Down
Loading

0 comments on commit b319d0d

Please sign in to comment.