From 02b1817e00732b62283c8838fa28f233d3ad6697 Mon Sep 17 00:00:00 2001 From: JamesMurkin Date: Mon, 10 Jul 2023 13:39:54 +0100 Subject: [PATCH] write tests --- internal/armada/server.go | 6 +- internal/binoculars/server.go | 6 +- .../cached_certificate.go | 10 +- .../common/certs/cached_certificate_test.go | 180 ++++++++++++++++++ internal/common/grpc/grpc.go | 4 +- internal/jobservice/application.go | 6 +- internal/lookout/application.go | 6 +- internal/scheduler/schedulerapp.go | 6 +- 8 files changed, 204 insertions(+), 20 deletions(-) rename internal/common/{fileutils => certs}/cached_certificate.go (92%) create mode 100644 internal/common/certs/cached_certificate_test.go diff --git a/internal/armada/server.go b/internal/armada/server.go index f37366e658b..70e8176e314 100644 --- a/internal/armada/server.go +++ b/internal/armada/server.go @@ -7,6 +7,7 @@ import ( "time" "github.com/apache/pulsar-client-go/pulsar" + "github.com/armadaproject/armada/internal/common/certs" "github.com/go-redis/redis" "github.com/google/uuid" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" @@ -25,7 +26,6 @@ import ( "github.com/armadaproject/armada/internal/common/auth" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/database" - "github.com/armadaproject/armada/internal/common/fileutils" grpcCommon "github.com/armadaproject/armada/internal/common/grpc" "github.com/armadaproject/armada/internal/common/health" commonmetrics "github.com/armadaproject/armada/internal/common/metrics" @@ -78,9 +78,9 @@ func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks if err != nil { return err } - var cachedCertificateService *fileutils.CachedCertificateService + var cachedCertificateService *certs.CachedCertificateService if config.Grpc.Tls.Enabled { - cachedCertificateService = fileutils.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath) + cachedCertificateService = certs.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath) services = append(services, func() error { return cachedCertificateService.Run(ctx) }) diff --git a/internal/binoculars/server.go b/internal/binoculars/server.go index b367b770d57..62c10321afb 100644 --- a/internal/binoculars/server.go +++ b/internal/binoculars/server.go @@ -5,6 +5,7 @@ import ( "os" "sync" + "github.com/armadaproject/armada/internal/common/certs" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" log "github.com/sirupsen/logrus" @@ -14,7 +15,6 @@ import ( "github.com/armadaproject/armada/internal/common/auth" "github.com/armadaproject/armada/internal/common/auth/authorization" "github.com/armadaproject/armada/internal/common/cluster" - "github.com/armadaproject/armada/internal/common/fileutils" grpcCommon "github.com/armadaproject/armada/internal/common/grpc" "github.com/armadaproject/armada/pkg/api/binoculars" ) @@ -39,9 +39,9 @@ func StartUp(config *configuration.BinocularsConfig) (func(), *sync.WaitGroup) { os.Exit(-1) } - var cachedCertificateService *fileutils.CachedCertificateService + var cachedCertificateService *certs.CachedCertificateService if config.Grpc.Tls.Enabled { - cachedCertificateService = fileutils.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath) + cachedCertificateService = certs.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath) go func() { err := func() error { return cachedCertificateService.Run(context.Background()) diff --git a/internal/common/fileutils/cached_certificate.go b/internal/common/certs/cached_certificate.go similarity index 92% rename from internal/common/fileutils/cached_certificate.go rename to internal/common/certs/cached_certificate.go index c02cd9b7c19..41de764d307 100644 --- a/internal/common/fileutils/cached_certificate.go +++ b/internal/common/certs/cached_certificate.go @@ -1,4 +1,4 @@ -package fileutils +package certs import ( "context" @@ -20,6 +20,8 @@ type CachedCertificateService struct { certificateLock sync.Mutex certificate *tls.Certificate + + refreshInterval time.Duration } func NewCachedCertificateService(certPath string, keyPath string) *CachedCertificateService { @@ -27,6 +29,8 @@ func NewCachedCertificateService(certPath string, keyPath string) *CachedCertifi certPath: certPath, keyPath: keyPath, certificateLock: sync.Mutex{}, + fileInfoLock: sync.Mutex{}, + refreshInterval: time.Minute, } // Initialise the certificate err := cert.refresh() @@ -49,7 +53,7 @@ func (c *CachedCertificateService) updateCertificate(certificate *tls.Certificat } func (c *CachedCertificateService) Run(ctx context.Context) error { - ticker := time.NewTicker(1 * time.Minute) + ticker := time.NewTicker(c.refreshInterval) for { select { case <-ctx.Done(): @@ -106,7 +110,7 @@ func (c *CachedCertificateService) refresh() error { func (c *CachedCertificateService) updateData(certFileInfo os.FileInfo, keyFileInfo os.FileInfo, newCert *tls.Certificate) { c.fileInfoLock.Lock() - defer c.certificateLock.Lock() + defer c.fileInfoLock.Unlock() c.certFileInfo = certFileInfo c.keyFileInfo = keyFileInfo diff --git a/internal/common/certs/cached_certificate_test.go b/internal/common/certs/cached_certificate_test.go new file mode 100644 index 00000000000..97aa8393257 --- /dev/null +++ b/internal/common/certs/cached_certificate_test.go @@ -0,0 +1,180 @@ +package certs + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "math/big" + "net" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const certFilePath = "testdata/tls.crt" +const keyFilePath = "testdata/tls.key" + +func TestCachedCertificateService_LoadsCertificateOnStartup(t *testing.T) { + defer cleanup() + cert, certData, keyData := createCerts() + writeCerts(t, certData, keyData) + + cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath) + + result := cachedCertService.GetCertificate() + + assert.Equal(t, cert, result) +} + +func TestCachedCertificateService_PanicIfInitialLoadFails(t *testing.T) { + defer cleanup() + + assert.Panics(t, func() { NewCachedCertificateService(certFilePath, keyFilePath) }) +} + +func TestCachedCertificateService_ReloadsCert_IfFileOnDiskChanges(t *testing.T) { + defer cleanup() + cert, certData, keyData := createCerts() + writeCerts(t, certData, keyData) + cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath) + + assert.Equal(t, cert, cachedCertService.GetCertificate()) + + newCert, certData, keyData := createCerts() + + // Update files on disk + writeCerts(t, certData, keyData) + // Certificate won't change until refresh is called + assert.NotEqual(t, newCert, cachedCertService.GetCertificate()) + + err := cachedCertService.refresh() + assert.NoError(t, err) + assert.Equal(t, newCert, cachedCertService.GetCertificate()) +} + +func TestCachedCertificateService_HandlesPartialUpdates(t *testing.T) { + defer cleanup() + originalCert, certData, keyData := createCerts() + writeCerts(t, certData, keyData) + cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath) + + assert.Equal(t, originalCert, cachedCertService.GetCertificate()) + + newCert, certData, keyData := createCerts() + + // Update only 1 file on disk - which leaves the representation on disk in an invalid state + writeCerts(t, certData, nil) + err := cachedCertService.refresh() + assert.Error(t, err) + + // Certificate provided should not change, as there is no valid new cert yet + assert.Equal(t, originalCert, cachedCertService.GetCertificate()) + + // Update the other file, so now files on disk are now both updated and consistent + writeCerts(t, nil, keyData) + err = cachedCertService.refresh() + assert.NoError(t, err) + + assert.Equal(t, newCert, cachedCertService.GetCertificate()) +} + +func writeCerts(t *testing.T, certData *bytes.Buffer, keyData *bytes.Buffer) { + if certData != nil { + err := os.WriteFile(certFilePath, certData.Bytes(), 0644) + require.NoError(t, err) + } + + if keyData != nil { + err := os.WriteFile(keyFilePath, keyData.Bytes(), 0644) + require.NoError(t, err) + } +} + +func cleanup() { + os.Remove(certFilePath) + os.Remove(keyFilePath) +} + +func createCerts() (*tls.Certificate, *bytes.Buffer, *bytes.Buffer) { + // set up our CA certificate + ca := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + // create our private and public key + caPrivKey, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + + // create the CA + caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) + if err != nil { + panic(err) + } + + // pem encode + caPEM := new(bytes.Buffer) + pem.Encode(caPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + }) + + caPrivKeyPEM := new(bytes.Buffer) + pem.Encode(caPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey), + }) + + // set up our server certificate + cert := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + certPrivKey, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + + certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey) + if err != nil { + panic(err) + } + + certPEM := new(bytes.Buffer) + pem.Encode(certPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + + certPrivKeyPEM := new(bytes.Buffer) + pem.Encode(certPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), + }) + + certificate, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes()) + if err != nil { + panic(err) + } + + return &certificate, certPEM, certPrivKeyPEM +} diff --git a/internal/common/grpc/grpc.go b/internal/common/grpc/grpc.go index 28bd61bf6d6..5f9564ae181 100644 --- a/internal/common/grpc/grpc.go +++ b/internal/common/grpc/grpc.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/armadaproject/armada/internal/common/certs" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" @@ -25,7 +26,6 @@ import ( "github.com/armadaproject/armada/internal/common/armadaerrors" "github.com/armadaproject/armada/internal/common/auth/authorization" - "github.com/armadaproject/armada/internal/common/fileutils" "github.com/armadaproject/armada/internal/common/requestid" ) @@ -35,7 +35,7 @@ func CreateGrpcServer( keepaliveParams keepalive.ServerParameters, keepaliveEnforcementPolicy keepalive.EnforcementPolicy, authServices []authorization.AuthService, - tlsCertService *fileutils.CachedCertificateService, + tlsCertService *certs.CachedCertificateService, ) *grpc.Server { // Logging, authentication, etc. are implemented via gRPC interceptors // (i.e., via functions that are called before handling the actual request). diff --git a/internal/jobservice/application.go b/internal/jobservice/application.go index 686bce1e05a..a71bbced7c2 100644 --- a/internal/jobservice/application.go +++ b/internal/jobservice/application.go @@ -7,12 +7,12 @@ import ( "os" "time" + "github.com/armadaproject/armada/internal/common/certs" log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "github.com/armadaproject/armada/internal/common/auth/authorization" - "github.com/armadaproject/armada/internal/common/fileutils" grpcCommon "github.com/armadaproject/armada/internal/common/grpc" grpcconfig "github.com/armadaproject/armada/internal/common/grpc/configuration" "github.com/armadaproject/armada/internal/common/grpc/grpcpool" @@ -90,9 +90,9 @@ func (a *App) StartUp(ctx context.Context, config *configuration.JobServiceConfi } log := log.WithField("JobService", "Startup") - var cachedCertificateService *fileutils.CachedCertificateService + var cachedCertificateService *certs.CachedCertificateService if config.Grpc.Tls.Enabled { - cachedCertificateService = fileutils.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath) + cachedCertificateService = certs.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath) g.Go(func() error { return cachedCertificateService.Run(ctx) }) diff --git a/internal/lookout/application.go b/internal/lookout/application.go index b8b66d48960..859cc70aa8f 100644 --- a/internal/lookout/application.go +++ b/internal/lookout/application.go @@ -4,12 +4,12 @@ import ( "context" "sync" + "github.com/armadaproject/armada/internal/common/certs" "github.com/doug-martin/goqu/v9" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" log "github.com/sirupsen/logrus" "github.com/armadaproject/armada/internal/common/auth/authorization" - "github.com/armadaproject/armada/internal/common/fileutils" "github.com/armadaproject/armada/internal/common/grpc" "github.com/armadaproject/armada/internal/common/health" "github.com/armadaproject/armada/internal/common/util" @@ -31,9 +31,9 @@ func StartUp(config configuration.LookoutConfiguration, healthChecks *health.Mul wg := &sync.WaitGroup{} wg.Add(1) - var cachedCertificateService *fileutils.CachedCertificateService + var cachedCertificateService *certs.CachedCertificateService if config.Grpc.Tls.Enabled { - cachedCertificateService = fileutils.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath) + cachedCertificateService = certs.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath) go func() { err := func() error { return cachedCertificateService.Run(context.Background()) diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go index e00a4a7beeb..9f3f8b78d26 100644 --- a/internal/scheduler/schedulerapp.go +++ b/internal/scheduler/schedulerapp.go @@ -8,6 +8,7 @@ import ( "time" "github.com/apache/pulsar-client-go/pulsar" + "github.com/armadaproject/armada/internal/common/certs" "github.com/go-redis/redis" "github.com/google/uuid" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" @@ -23,7 +24,6 @@ import ( "github.com/armadaproject/armada/internal/common/app" "github.com/armadaproject/armada/internal/common/auth" dbcommon "github.com/armadaproject/armada/internal/common/database" - "github.com/armadaproject/armada/internal/common/fileutils" grpcCommon "github.com/armadaproject/armada/internal/common/grpc" "github.com/armadaproject/armada/internal/common/health" "github.com/armadaproject/armada/internal/common/pulsarutils" @@ -126,9 +126,9 @@ func Run(config schedulerconfig.Configuration) error { if err != nil { return errors.WithMessage(err, "error creating auth services") } - var cachedCertificateService *fileutils.CachedCertificateService + var cachedCertificateService *certs.CachedCertificateService if config.Grpc.Tls.Enabled { - cachedCertificateService = fileutils.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath) + cachedCertificateService = certs.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath) services = append(services, func() error { return cachedCertificateService.Run(ctx) })