From 1e2e7210551e9426d025f1498078bd6ac30df5e7 Mon Sep 17 00:00:00 2001 From: JamesMurkin Date: Mon, 10 Jul 2023 21:16:46 +0100 Subject: [PATCH] Implement tls into our grpc servers (#2655) * Implement tls into our grpc servers * Fix imports * Update helm charts to support tls * Remove test config * Add smarter cert refreshing * Formatting * write tests * Run test * Small refactor * Formatting * Fix tests ignoring err * Add log line on refresh * Add log line on refresh * format * rewording * rewording --- config/armada/config.yaml | 2 + config/binoculars/config.yaml | 4 +- config/jobservice/config.yaml | 2 + config/scheduler/config.yaml | 2 + deployment/armada/templates/deployment.yaml | 10 + deployment/armada/templates/ingress.yaml | 5 + deployment/armada/templates/ingressrest.yaml | 4 + deployment/armada/values.yaml | 5 + .../binoculars/templates/deployment.yaml | 10 + deployment/binoculars/templates/ingress.yaml | 5 + .../binoculars/templates/ingressrest.yaml | 4 + deployment/binoculars/values.yaml | 7 +- deployment/jobservice/templates/ingress.yaml | 5 + .../jobservice/templates/statefulset.yaml | 10 + deployment/jobservice/values.yaml | 7 +- .../templates/scheduler-ingress.yaml | 5 + .../templates/scheduler-statefulset.yaml | 10 + deployment/scheduler/values.yaml | 4 + internal/armada/server.go | 2 +- internal/binoculars/server.go | 2 +- internal/common/certs/cached_certificate.go | 121 +++++++++++ .../common/certs/cached_certificate_test.go | 194 ++++++++++++++++++ internal/common/grpc/configuration/types.go | 11 +- internal/common/grpc/grpc.go | 33 ++- internal/jobservice/application.go | 1 + internal/lookout/application.go | 1 + internal/scheduler/schedulerapp.go | 2 +- 27 files changed, 455 insertions(+), 13 deletions(-) create mode 100644 internal/common/certs/cached_certificate.go create mode 100644 internal/common/certs/cached_certificate_test.go diff --git a/config/armada/config.yaml b/config/armada/config.yaml index d13f36e7e4e..bc3cf4d55eb 100644 --- a/config/armada/config.yaml +++ b/config/armada/config.yaml @@ -19,6 +19,8 @@ grpc: keepaliveEnforcementPolicy: minTime: 10s permitWithoutStream: true + tls: + enabled: false redis: addrs: - redis:6379 diff --git a/config/binoculars/config.yaml b/config/binoculars/config.yaml index 6bbe73f87c9..33c17f2af88 100644 --- a/config/binoculars/config.yaml +++ b/config/binoculars/config.yaml @@ -1,7 +1,7 @@ grpcPort: 50051 httpPort: 8080 metricsPort: 9000 -corsAllowedOrigins: +corsAllowedOrigins: - http://localhost:3000 - http://localhost:8080 cordon: @@ -24,3 +24,5 @@ grpc: keepaliveEnforcementPolicy: minTime: 5m permitWithoutStream: false + tls: + enabled: false diff --git a/config/jobservice/config.yaml b/config/jobservice/config.yaml index 969e62b5ba6..58a6bd8e72d 100644 --- a/config/jobservice/config.yaml +++ b/config/jobservice/config.yaml @@ -28,6 +28,8 @@ grpc: keepaliveEnforcementPolicy: minTime: 5m permitWithoutStream: false + tls: + enabled: false # gRPC connection pool to armada server configuration. grpcPool: initialConnections: 5 diff --git a/config/scheduler/config.yaml b/config/scheduler/config.yaml index e77da47ff15..fb6856a35f7 100644 --- a/config/scheduler/config.yaml +++ b/config/scheduler/config.yaml @@ -50,6 +50,8 @@ grpc: keepaliveEnforcementPolicy: minTime: 10s permitWithoutStream: true + tls: + enabled: false scheduling: executorTimeout: 10m enableAssertions: true diff --git a/deployment/armada/templates/deployment.yaml b/deployment/armada/templates/deployment.yaml index b0e7d30280e..459a6587554 100644 --- a/deployment/armada/templates/deployment.yaml +++ b/deployment/armada/templates/deployment.yaml @@ -76,6 +76,11 @@ spec: mountPath: "/pulsar/ca" readOnly: true {{- end }} + {{- if .Values.applicationConfig.grpc.tls.enabled }} + - name: tls-certs + mountPath: /certs + readOnly: true + {{- end }} {{- if .Values.additionalVolumeMounts }} {{- toYaml .Values.additionalVolumeMounts | nindent 12 -}} {{- end }} @@ -129,6 +134,11 @@ spec: - key: ca.crt path: ca.crt {{- end }} + {{- if .Values.applicationConfig.grpc.tls.enabled }} + - name: tls-certs + secret: + secretName: armada-service-tls + {{- end }} {{- if .Values.additionalVolumes }} {{- toYaml .Values.additionalVolumes | nindent 8 }} {{- end }} diff --git a/deployment/armada/templates/ingress.yaml b/deployment/armada/templates/ingress.yaml index a9c380719bb..a68a43ae575 100644 --- a/deployment/armada/templates/ingress.yaml +++ b/deployment/armada/templates/ingress.yaml @@ -7,7 +7,12 @@ metadata: annotations: kubernetes.io/ingress.class: {{ required "A value is required for .Values.ingressClass" .Values.ingressClass }} nginx.ingress.kubernetes.io/ssl-redirect: "true" + {{- if .Values.applicationConfig.grpc.tls.enabled }} + nginx.ingress.kubernetes.io/backend-protocol: "GRPCS" + nginx.ingress.kubernetes.io/ssl-passthrough: "true" + {{- else }} nginx.ingress.kubernetes.io/backend-protocol: "GRPC" + {{- end }} certmanager.k8s.io/cluster-issuer: {{ required "A value is required for .Values.clusterIssuer" .Values.clusterIssuer }} cert-manager.io/cluster-issuer: {{ required "A value is required for .Values.clusterIssuer" .Values.clusterIssuer }} {{- if .Values.ingress.annotations }} diff --git a/deployment/armada/templates/ingressrest.yaml b/deployment/armada/templates/ingressrest.yaml index 927c00c12ea..de1e0ec6006 100644 --- a/deployment/armada/templates/ingressrest.yaml +++ b/deployment/armada/templates/ingressrest.yaml @@ -7,6 +7,10 @@ metadata: annotations: kubernetes.io/ingress.class: {{ required "A value is required for .Values.ingressClass" .Values.ingressClass }} nginx.ingress.kubernetes.io/ssl-redirect: "true" + {{- if .Values.applicationConfig.grpc.tls.enabled }} + nginx.ingress.kubernetes.io/backend-protocol: "HTTPS" + nginx.ingress.kubernetes.io/ssl-passthrough: "true" + {{- end }} certmanager.k8s.io/cluster-issuer: {{ required "A value is required for .Values.clusterIssuer" .Values.clusterIssuer }} cert-manager.io/cluster-issuer: {{ required "A value is required for .Values.clusterIssuer" .Values.clusterIssuer }} nginx.ingress.kubernetes.io/rewrite-target: /$2 diff --git a/deployment/armada/values.yaml b/deployment/armada/values.yaml index a19b0c3eec3..384bdac3131 100644 --- a/deployment/armada/values.yaml +++ b/deployment/armada/values.yaml @@ -84,6 +84,11 @@ serviceAccount: {} applicationConfig: # -- Armada Server gRPC port grpcPort: 50051 + grpc: + tls: + enabled: false + certPath: /certs/tls.crt + keyPath: /certs/tls.key # -- Armada Server REST port httpPort: 8080 pulsar: diff --git a/deployment/binoculars/templates/deployment.yaml b/deployment/binoculars/templates/deployment.yaml index f07f097fb53..85f9f9b99c3 100644 --- a/deployment/binoculars/templates/deployment.yaml +++ b/deployment/binoculars/templates/deployment.yaml @@ -59,6 +59,11 @@ spec: mountPath: /config/application_config.yaml subPath: {{ include "binoculars.config.filename" . }} readOnly: true + {{- if .Values.applicationConfig.grpc.tls.enabled }} + - name: tls-certs + mountPath: /certs + readOnly: true + {{- end }} {{- if .Values.additionalVolumeMounts }} {{- toYaml .Values.additionalVolumeMounts | nindent 12 -}} {{- end }} @@ -94,6 +99,11 @@ spec: - name: user-config secret: secretName: {{ include "binoculars.config.name" . }} + {{- if .Values.applicationConfig.grpc.tls.enabled }} + - name: tls-certs + secret: + secretName: binoculars-service-tls + {{- end }} {{- if .Values.additionalVolumes }} {{- toYaml .Values.additionalVolumes | nindent 8 }} {{- end }} diff --git a/deployment/binoculars/templates/ingress.yaml b/deployment/binoculars/templates/ingress.yaml index 13b71aa97d2..61763ab06bc 100644 --- a/deployment/binoculars/templates/ingress.yaml +++ b/deployment/binoculars/templates/ingress.yaml @@ -6,7 +6,12 @@ metadata: annotations: kubernetes.io/ingress.class: {{ required "A value is required for .Values.ingressClass" .Values.ingressClass }} nginx.ingress.kubernetes.io/ssl-redirect: "true" + {{- if .Values.applicationConfig.grpc.tls.enabled }} + nginx.ingress.kubernetes.io/backend-protocol: "GRPCS" + nginx.ingress.kubernetes.io/ssl-passthrough: "true" + {{- else }} nginx.ingress.kubernetes.io/backend-protocol: "GRPC" + {{- end }} certmanager.k8s.io/cluster-issuer: {{ required "A value is required for .Values.clusterIssuer" .Values.clusterIssuer }} cert-manager.io/cluster-issuer: {{ required "A value is required for .Values.clusterIssuer" .Values.clusterIssuer }} {{- if .Values.ingress.annotations }} diff --git a/deployment/binoculars/templates/ingressrest.yaml b/deployment/binoculars/templates/ingressrest.yaml index e1ebe1cfa3b..bb3f1b2b26f 100644 --- a/deployment/binoculars/templates/ingressrest.yaml +++ b/deployment/binoculars/templates/ingressrest.yaml @@ -6,6 +6,10 @@ metadata: annotations: kubernetes.io/ingress.class: {{ required "A value is required for .Values.ingressClass" .Values.ingressClass }} nginx.ingress.kubernetes.io/ssl-redirect: "true" + {{- if .Values.applicationConfig.grpc.tls.enabled }} + nginx.ingress.kubernetes.io/backend-protocol: "HTTPS" + nginx.ingress.kubernetes.io/ssl-passthrough: "true" + {{- end }} certmanager.k8s.io/cluster-issuer: {{ required "A value is required for .Values.clusterIssuer" .Values.clusterIssuer }} cert-manager.io/cluster-issuer: {{ required "A value is required for .Values.clusterIssuer" .Values.clusterIssuer }} nginx.ingress.kubernetes.io/rewrite-target: /$2 diff --git a/deployment/binoculars/values.yaml b/deployment/binoculars/values.yaml index e390d409750..95d5b09c957 100644 --- a/deployment/binoculars/values.yaml +++ b/deployment/binoculars/values.yaml @@ -9,7 +9,7 @@ resources: memory: 512Mi cpu: 200m # -- Tolerations -tolerations: [] +tolerations: [] additionalLabels: {} additionalClusterRoleBindings: [] additionalVolumeMounts: [] @@ -32,5 +32,10 @@ serviceAccount: null applicationConfig: grpcPort: 50051 + grpc: + tls: + enabled: false + certPath: /certs/tls.crt + keyPath: /certs/tls.key httpPort: 8080 metricsPort: 9000 diff --git a/deployment/jobservice/templates/ingress.yaml b/deployment/jobservice/templates/ingress.yaml index cf3c74061f1..cd698029ddd 100644 --- a/deployment/jobservice/templates/ingress.yaml +++ b/deployment/jobservice/templates/ingress.yaml @@ -6,7 +6,12 @@ metadata: annotations: kubernetes.io/ingress.class: {{ required "A value is required for .Values.ingressClass" .Values.ingressClass }} nginx.ingress.kubernetes.io/ssl-redirect: "true" + {{- if .Values.applicationConfig.grpc.tls.enabled }} + nginx.ingress.kubernetes.io/backend-protocol: "GRPCS" + nginx.ingress.kubernetes.io/ssl-passthrough: "true" + {{- else }} nginx.ingress.kubernetes.io/backend-protocol: "GRPC" + {{- end }} certmanager.k8s.io/cluster-issuer: {{ required "A value is required for .Values.clusterIssuer" .Values.clusterIssuer }} cert-manager.io/cluster-issuer: {{ required "A value is required for .Values.clusterIssuer" .Values.clusterIssuer }} {{- if .Values.ingress.annotations }} diff --git a/deployment/jobservice/templates/statefulset.yaml b/deployment/jobservice/templates/statefulset.yaml index fc67feefddb..2519f1b69c0 100644 --- a/deployment/jobservice/templates/statefulset.yaml +++ b/deployment/jobservice/templates/statefulset.yaml @@ -49,6 +49,11 @@ spec: mountPath: /config/application_config.yaml subPath: {{ include "jobservice.config.filename" . }} readOnly: true + {{- if .Values.applicationConfig.grpc.tls.enabled }} + - name: tls-certs + mountPath: /certs + readOnly: true + {{- end }} {{- if .Values.additionalVolumeMounts }} {{- toYaml .Values.additionalVolumeMounts | nindent 12 -}} {{- end }} @@ -71,6 +76,11 @@ spec: - name: user-config secret: secretName: {{ include "jobservice.config.name" . }} + {{- if .Values.applicationConfig.grpc.tls.enabled }} + - name: tls-certs + secret: + secretName: jobservice-service-tls + {{- end }} {{- if .Values.additionalVolumes }} {{- toYaml .Values.additionalVolumes | nindent 8 }} {{- end }} diff --git a/deployment/jobservice/values.yaml b/deployment/jobservice/values.yaml index 56d801d7f6c..27ce94010b2 100644 --- a/deployment/jobservice/values.yaml +++ b/deployment/jobservice/values.yaml @@ -9,7 +9,7 @@ resources: memory: 512Mi cpu: 200m # -- Tolerations -tolerations: [] +tolerations: [] additionalLabels: {} terminationGracePeriodSeconds: 30 replicas: 1 @@ -30,3 +30,8 @@ serviceAccount: null applicationConfig: grpcPort: 60063 + grpc: + tls: + enabled: false + certPath: /certs/tls.crt + keyPath: /certs/tls.key diff --git a/deployment/scheduler/templates/scheduler-ingress.yaml b/deployment/scheduler/templates/scheduler-ingress.yaml index c7b1b48781f..21f20ae632d 100644 --- a/deployment/scheduler/templates/scheduler-ingress.yaml +++ b/deployment/scheduler/templates/scheduler-ingress.yaml @@ -6,7 +6,12 @@ metadata: annotations: kubernetes.io/ingress.class: {{ required "A value is required for .Values.scheduler.ingressClass" .Values.scheduler.ingressClass }} nginx.ingress.kubernetes.io/ssl-redirect: "true" + {{- if .Values.scheduler.applicationConfig.grpc.tls.enabled }} + nginx.ingress.kubernetes.io/backend-protocol: "GRPCS" + nginx.ingress.kubernetes.io/ssl-passthrough: "true" + {{- else }} nginx.ingress.kubernetes.io/backend-protocol: "GRPC" + {{- end }} certmanager.k8s.io/cluster-issuer: {{ required "A value is required for .Values.scheduler.clusterIssuer" .Values.scheduler.clusterIssuer }} cert-manager.io/cluster-issuer: {{ required "A value is required for .Values.scheduler.clusterIssuer" .Values.scheduler.clusterIssuer }} {{- if .Values.scheduler.ingress.annotations }} diff --git a/deployment/scheduler/templates/scheduler-statefulset.yaml b/deployment/scheduler/templates/scheduler-statefulset.yaml index b2293104843..08c3d7edf1c 100644 --- a/deployment/scheduler/templates/scheduler-statefulset.yaml +++ b/deployment/scheduler/templates/scheduler-statefulset.yaml @@ -83,6 +83,11 @@ spec: mountPath: "/pulsar/ca" readOnly: true {{- end }} + {{- if .Values.scheduler.applicationConfig.grpc.tls.enabled }} + - name: tls-certs + mountPath: /certs + readOnly: true + {{- end }} {{- if .Values.scheduler.additionalVolumeMounts }} {{- toYaml .Values.scheduler.additionalVolumeMounts | nindent 12 -}} {{- end }} @@ -108,6 +113,11 @@ spec: - {{ include "armada-scheduler.name" . }} topologyKey: kubernetes.io/hostname volumes: + {{- if .Values.scheduler.applicationConfig.grpc.tls.enabled }} + - name: tls-certs + secret: + secretName: armada-scheduler-service-tls + {{- end}} - name: user-config secret: secretName: {{ include "armada-scheduler.config.name" . }} diff --git a/deployment/scheduler/values.yaml b/deployment/scheduler/values.yaml index d6ead90ea66..606f8355e28 100644 --- a/deployment/scheduler/values.yaml +++ b/deployment/scheduler/values.yaml @@ -17,6 +17,10 @@ scheduler: applicationConfig: grpc: port: 50051 + tls: + enabled: false + certPath: /certs/tls.crt + keyPath: /certs/tls.key metrics: port: 9001 http: diff --git a/internal/armada/server.go b/internal/armada/server.go index a61f9402c20..64ae58a0eb3 100644 --- a/internal/armada/server.go +++ b/internal/armada/server.go @@ -77,7 +77,7 @@ func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks if err != nil { return err } - grpcServer := grpcCommon.CreateGrpcServer(config.Grpc.KeepaliveParams, config.Grpc.KeepaliveEnforcementPolicy, authServices) + grpcServer := grpcCommon.CreateGrpcServer(config.Grpc.KeepaliveParams, config.Grpc.KeepaliveEnforcementPolicy, authServices, config.Grpc.Tls) // Shut down grpcServer if the context is cancelled. // Give the server 5 seconds to shut down gracefully. diff --git a/internal/binoculars/server.go b/internal/binoculars/server.go index d2db416fb85..07ded516fa9 100644 --- a/internal/binoculars/server.go +++ b/internal/binoculars/server.go @@ -37,7 +37,7 @@ func StartUp(config *configuration.BinocularsConfig) (func(), *sync.WaitGroup) { os.Exit(-1) } - grpcServer := grpcCommon.CreateGrpcServer(config.Grpc.KeepaliveParams, config.Grpc.KeepaliveEnforcementPolicy, authServices) + grpcServer := grpcCommon.CreateGrpcServer(config.Grpc.KeepaliveParams, config.Grpc.KeepaliveEnforcementPolicy, authServices, config.Grpc.Tls) permissionsChecker := authorization.NewPrincipalPermissionChecker( config.Auth.PermissionGroupMapping, diff --git a/internal/common/certs/cached_certificate.go b/internal/common/certs/cached_certificate.go new file mode 100644 index 00000000000..2588d0f5b50 --- /dev/null +++ b/internal/common/certs/cached_certificate.go @@ -0,0 +1,121 @@ +package certs + +import ( + "context" + "crypto/tls" + "os" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +type CachedCertificateService struct { + certPath string + keyPath string + + fileInfoLock sync.Mutex + certFileInfo os.FileInfo + keyFileInfo os.FileInfo + + certificateLock sync.Mutex + certificate *tls.Certificate + + refreshInterval time.Duration +} + +func NewCachedCertificateService(certPath string, keyPath string, refreshInternal time.Duration) *CachedCertificateService { + cert := &CachedCertificateService{ + certPath: certPath, + keyPath: keyPath, + certificateLock: sync.Mutex{}, + fileInfoLock: sync.Mutex{}, + refreshInterval: refreshInternal, + } + // Initialise the certificate + err := cert.refresh() + if err != nil { + panic(err) + } + return cert +} + +func (c *CachedCertificateService) GetCertificate() *tls.Certificate { + c.certificateLock.Lock() + defer c.certificateLock.Unlock() + return c.certificate +} + +func (c *CachedCertificateService) updateCertificate(certificate *tls.Certificate) { + c.certificateLock.Lock() + defer c.certificateLock.Unlock() + c.certificate = certificate +} + +func (c *CachedCertificateService) Run(ctx context.Context) { + ticker := time.NewTicker(c.refreshInterval) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := c.refresh() + if err != nil { + log.WithError(err).Errorf("failed refreshing certificate from files cert: %s key: %s", c.certPath, c.keyPath) + } + } + } +} + +func (c *CachedCertificateService) refresh() error { + updatedCertFileInfo, err := os.Stat(c.certPath) + if err != nil { + return err + } + + updatedKeyFileInfo, err := os.Stat(c.keyPath) + if err != nil { + return err + } + + modified := false + + if c.certFileInfo == nil || updatedCertFileInfo.ModTime().After(c.certFileInfo.ModTime()) { + modified = true + } + + if c.keyFileInfo == nil || updatedKeyFileInfo.ModTime().After(c.keyFileInfo.ModTime()) { + modified = true + } + + if modified { + log.Infof("refreshing certificate from files cert: %s key: %s", c.certPath, c.keyPath) + certFileData, err := os.ReadFile(c.certPath) + if err != nil { + return err + } + + keyFileData, err := os.ReadFile(c.keyPath) + if err != nil { + return err + } + + cert, err := tls.X509KeyPair(certFileData, keyFileData) + if err != nil { + return err + } + + c.updateData(updatedCertFileInfo, updatedKeyFileInfo, &cert) + } + + return nil +} + +func (c *CachedCertificateService) updateData(certFileInfo os.FileInfo, keyFileInfo os.FileInfo, newCert *tls.Certificate) { + c.fileInfoLock.Lock() + defer c.fileInfoLock.Unlock() + c.certFileInfo = certFileInfo + c.keyFileInfo = keyFileInfo + + c.updateCertificate(newCert) +} diff --git a/internal/common/certs/cached_certificate_test.go b/internal/common/certs/cached_certificate_test.go new file mode 100644 index 00000000000..651794a543a --- /dev/null +++ b/internal/common/certs/cached_certificate_test.go @@ -0,0 +1,194 @@ +package certs + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "math/big" + "net" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + certFilePath = "testdata/tls.crt" + keyFilePath = "testdata/tls.key" +) + +func TestCachedCertificateService_LoadsCertificateOnStartup(t *testing.T) { + defer cleanup() + cert, certData, keyData := createCerts(t) + writeCerts(t, certData, keyData) + + cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath, time.Second) + + result := cachedCertService.GetCertificate() + + assert.Equal(t, cert, result) +} + +func TestCachedCertificateService_PanicIfInitialLoadFails(t *testing.T) { + defer cleanup() + + assert.Panics(t, func() { NewCachedCertificateService(certFilePath, keyFilePath, time.Second) }) +} + +func TestCachedCertificateService_ReloadsCert_IfFileOnDiskChanges(t *testing.T) { + defer cleanup() + cert, certData, keyData := createCerts(t) + writeCerts(t, certData, keyData) + cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath, time.Second) + + assert.Equal(t, cert, cachedCertService.GetCertificate()) + + newCert, certData, keyData := createCerts(t) + + // Update files on disk + writeCerts(t, certData, keyData) + // Certificate won't change until refresh is called + assert.NotEqual(t, newCert, cachedCertService.GetCertificate()) + + err := cachedCertService.refresh() + assert.NoError(t, err) + assert.Equal(t, newCert, cachedCertService.GetCertificate()) +} + +func TestCachedCertificateService_HandlesPartialUpdates(t *testing.T) { + defer cleanup() + originalCert, certData, keyData := createCerts(t) + writeCerts(t, certData, keyData) + cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath, time.Second) + + assert.Equal(t, originalCert, cachedCertService.GetCertificate()) + + newCert, certData, keyData := createCerts(t) + + // Update only 1 file on disk - which leaves the representation on disk in an invalid state + writeCerts(t, certData, nil) + err := cachedCertService.refresh() + assert.Error(t, err) + + // Certificate provided should not change, as there is no valid new cert yet + assert.Equal(t, originalCert, cachedCertService.GetCertificate()) + + // Update the other file, so now files on disk are now both updated and consistent + writeCerts(t, nil, keyData) + err = cachedCertService.refresh() + assert.NoError(t, err) + + assert.Equal(t, newCert, cachedCertService.GetCertificate()) +} + +func TestCachedCertificateService_ReloadsCertPeriodically_WhenUsingRun(t *testing.T) { + defer cleanup() + cert, certData, keyData := createCerts(t) + writeCerts(t, certData, keyData) + cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath, time.Second) + assert.Equal(t, cert, cachedCertService.GetCertificate()) + + go func() { + cachedCertService.Run(context.Background()) + }() + + newCert, certData, keyData := createCerts(t) + writeCerts(t, certData, keyData) + time.Sleep(time.Second * 2) + assert.Equal(t, newCert, cachedCertService.GetCertificate()) +} + +func writeCerts(t *testing.T, certData *bytes.Buffer, keyData *bytes.Buffer) { + if certData != nil { + err := os.WriteFile(certFilePath, certData.Bytes(), 0o644) + require.NoError(t, err) + } + + if keyData != nil { + err := os.WriteFile(keyFilePath, keyData.Bytes(), 0o644) + require.NoError(t, err) + } +} + +func cleanup() { + os.Remove(certFilePath) + os.Remove(keyFilePath) +} + +func createCerts(t *testing.T) (*tls.Certificate, *bytes.Buffer, *bytes.Buffer) { + // set up our CA certificate + ca := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + // create our private and public key + caPrivKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + + // create the CA + caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) + require.NoError(t, err) + + // pem encode + caPEM := new(bytes.Buffer) + err = pem.Encode(caPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + }) + require.NoError(t, err) + + caPrivKeyPEM := new(bytes.Buffer) + err = pem.Encode(caPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey), + }) + require.NoError(t, err) + + // set up our server certificate + cert := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + certPrivKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + + certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey) + require.NoError(t, err) + + certPEM := new(bytes.Buffer) + err = pem.Encode(certPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + require.NoError(t, err) + + certPrivKeyPEM := new(bytes.Buffer) + err = pem.Encode(certPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), + }) + require.NoError(t, err) + + certificate, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes()) + require.NoError(t, err) + + return &certificate, certPEM, certPrivKeyPEM +} diff --git a/internal/common/grpc/configuration/types.go b/internal/common/grpc/configuration/types.go index 489403ad667..8b4f1857300 100644 --- a/internal/common/grpc/configuration/types.go +++ b/internal/common/grpc/configuration/types.go @@ -1,14 +1,23 @@ package configuration -import "google.golang.org/grpc/keepalive" +import ( + "google.golang.org/grpc/keepalive" +) type GrpcConfig struct { Port int `validate:"required"` KeepaliveParams keepalive.ServerParameters KeepaliveEnforcementPolicy keepalive.EnforcementPolicy + Tls TlsConfig } type GrpcPoolConfig struct { InitialConnections int Capacity int } + +type TlsConfig struct { + Enabled bool + KeyPath string + CertPath string +} diff --git a/internal/common/grpc/grpc.go b/internal/common/grpc/grpc.go index 89ed83e398f..5f73c3801c0 100644 --- a/internal/common/grpc/grpc.go +++ b/internal/common/grpc/grpc.go @@ -2,6 +2,7 @@ package grpc import ( "context" + "crypto/tls" "fmt" "net" "runtime/debug" @@ -17,13 +18,15 @@ import ( log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" _ "google.golang.org/grpc/encoding/gzip" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/authorization" - "github.com/armadaproject/armada/internal/common/logging" + "github.com/armadaproject/armada/internal/common/certs" + "github.com/armadaproject/armada/internal/common/grpc/configuration" "github.com/armadaproject/armada/internal/common/requestid" ) @@ -33,6 +36,7 @@ func CreateGrpcServer( keepaliveParams keepalive.ServerParameters, keepaliveEnforcementPolicy keepalive.EnforcementPolicy, authServices []authorization.AuthService, + tlsConfig configuration.TlsConfig, ) *grpc.Server { // Logging, authentication, etc. are implemented via gRPC interceptors // (i.e., via functions that are called before handling the actual request). @@ -57,14 +61,12 @@ func CreateGrpcServer( requestid.UnaryServerInterceptor(false), armadaerrors.UnaryServerInterceptor(2000), grpc_logrus.UnaryServerInterceptor(messageDefault), - logging.UnaryServerInterceptor(), ) streamInterceptors = append(streamInterceptors, grpc_ctxtags.StreamServerInterceptor(tagsExtractor), requestid.StreamServerInterceptor(false), armadaerrors.StreamServerInterceptor(2000), grpc_logrus.StreamServerInterceptor(messageDefault), - logging.StreamServerInterceptor(), ) // Authentication @@ -79,13 +81,32 @@ func CreateGrpcServer( unaryInterceptors = append(unaryInterceptors, grpc_prometheus.UnaryServerInterceptor) streamInterceptors = append(streamInterceptors, grpc_prometheus.StreamServerInterceptor) - // Interceptors are registered at server creation - return grpc.NewServer( + serverOptions := []grpc.ServerOption{ grpc.KeepaliveParams(keepaliveParams), grpc.KeepaliveEnforcementPolicy(keepaliveEnforcementPolicy), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(streamInterceptors...)), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(unaryInterceptors...)), - ) + } + + if tlsConfig.Enabled { + cachedCertificateService := certs.NewCachedCertificateService(tlsConfig.CertPath, tlsConfig.KeyPath, time.Minute) + go func() { + cachedCertificateService.Run(context.Background()) + }() + tlsCreds := credentials.NewTLS(&tls.Config{ + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert := cachedCertificateService.GetCertificate() + if cert == nil { + return nil, fmt.Errorf("unexpectedly received nil from certificate cache") + } + return cert, nil + }, + }) + serverOptions = append(serverOptions, grpc.Creds(tlsCreds)) + } + + // Interceptors are registered at server creation + return grpc.NewServer(serverOptions...) } // TODO We don't need this function. Just do this at the caller. diff --git a/internal/jobservice/application.go b/internal/jobservice/application.go index c5c172d6ba1..11efc3db63a 100644 --- a/internal/jobservice/application.go +++ b/internal/jobservice/application.go @@ -93,6 +93,7 @@ func (a *App) StartUp(ctx context.Context, config *configuration.JobServiceConfi config.Grpc.KeepaliveParams, config.Grpc.KeepaliveEnforcementPolicy, []authorization.AuthService{&authorization.AnonymousAuthService{}}, + config.Grpc.Tls, ) err, sqlJobRepo, dbCallbackFn := repository.NewSQLJobService(config, log) diff --git a/internal/lookout/application.go b/internal/lookout/application.go index 31da0fd8e8f..4bf68458faa 100644 --- a/internal/lookout/application.go +++ b/internal/lookout/application.go @@ -33,6 +33,7 @@ func StartUp(config configuration.LookoutConfiguration, healthChecks *health.Mul config.Grpc.KeepaliveParams, config.Grpc.KeepaliveEnforcementPolicy, []authorization.AuthService{&authorization.AnonymousAuthService{}}, + config.Grpc.Tls, ) db, err := postgres.Open(config.Postgres) diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go index bcdd896ff00..83ed80abde5 100644 --- a/internal/scheduler/schedulerapp.go +++ b/internal/scheduler/schedulerapp.go @@ -125,7 +125,7 @@ func Run(config schedulerconfig.Configuration) error { if err != nil { return errors.WithMessage(err, "error creating auth services") } - grpcServer := grpcCommon.CreateGrpcServer(config.Grpc.KeepaliveParams, config.Grpc.KeepaliveEnforcementPolicy, authServices) + grpcServer := grpcCommon.CreateGrpcServer(config.Grpc.KeepaliveParams, config.Grpc.KeepaliveEnforcementPolicy, authServices, config.Grpc.Tls) defer grpcServer.GracefulStop() lis, err := net.Listen("tcp", fmt.Sprintf(":%d", config.Grpc.Port)) if err != nil {