Skip to content

Commit

Permalink
write tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesMurkin committed Jul 10, 2023
1 parent 787feb5 commit 02b1817
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 20 deletions.
6 changes: 3 additions & 3 deletions internal/armada/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/apache/pulsar-client-go/pulsar"
"github.com/armadaproject/armada/internal/common/certs"
"github.com/go-redis/redis"
"github.com/google/uuid"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
Expand All @@ -25,7 +26,6 @@ import (
"github.com/armadaproject/armada/internal/common/auth"
"github.com/armadaproject/armada/internal/common/auth/authorization"
"github.com/armadaproject/armada/internal/common/database"
"github.com/armadaproject/armada/internal/common/fileutils"
grpcCommon "github.com/armadaproject/armada/internal/common/grpc"
"github.com/armadaproject/armada/internal/common/health"
commonmetrics "github.com/armadaproject/armada/internal/common/metrics"
Expand Down Expand Up @@ -78,9 +78,9 @@ func Serve(ctx context.Context, config *configuration.ArmadaConfig, healthChecks
if err != nil {
return err
}
var cachedCertificateService *fileutils.CachedCertificateService
var cachedCertificateService *certs.CachedCertificateService
if config.Grpc.Tls.Enabled {
cachedCertificateService = fileutils.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath)
cachedCertificateService = certs.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath)
services = append(services, func() error {
return cachedCertificateService.Run(ctx)
})
Expand Down
6 changes: 3 additions & 3 deletions internal/binoculars/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"sync"

"github.com/armadaproject/armada/internal/common/certs"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
log "github.com/sirupsen/logrus"

Expand All @@ -14,7 +15,6 @@ import (
"github.com/armadaproject/armada/internal/common/auth"
"github.com/armadaproject/armada/internal/common/auth/authorization"
"github.com/armadaproject/armada/internal/common/cluster"
"github.com/armadaproject/armada/internal/common/fileutils"
grpcCommon "github.com/armadaproject/armada/internal/common/grpc"
"github.com/armadaproject/armada/pkg/api/binoculars"
)
Expand All @@ -39,9 +39,9 @@ func StartUp(config *configuration.BinocularsConfig) (func(), *sync.WaitGroup) {
os.Exit(-1)
}

var cachedCertificateService *fileutils.CachedCertificateService
var cachedCertificateService *certs.CachedCertificateService
if config.Grpc.Tls.Enabled {
cachedCertificateService = fileutils.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath)
cachedCertificateService = certs.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath)
go func() {
err := func() error {
return cachedCertificateService.Run(context.Background())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package fileutils
package certs

import (
"context"
Expand All @@ -20,13 +20,17 @@ type CachedCertificateService struct {

certificateLock sync.Mutex
certificate *tls.Certificate

refreshInterval time.Duration
}

func NewCachedCertificateService(certPath string, keyPath string) *CachedCertificateService {
cert := &CachedCertificateService{
certPath: certPath,
keyPath: keyPath,
certificateLock: sync.Mutex{},
fileInfoLock: sync.Mutex{},
refreshInterval: time.Minute,
}
// Initialise the certificate
err := cert.refresh()
Expand All @@ -49,7 +53,7 @@ func (c *CachedCertificateService) updateCertificate(certificate *tls.Certificat
}

func (c *CachedCertificateService) Run(ctx context.Context) error {
ticker := time.NewTicker(1 * time.Minute)
ticker := time.NewTicker(c.refreshInterval)
for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -106,7 +110,7 @@ func (c *CachedCertificateService) refresh() error {

func (c *CachedCertificateService) updateData(certFileInfo os.FileInfo, keyFileInfo os.FileInfo, newCert *tls.Certificate) {
c.fileInfoLock.Lock()
defer c.certificateLock.Lock()
defer c.fileInfoLock.Unlock()
c.certFileInfo = certFileInfo
c.keyFileInfo = keyFileInfo

Expand Down
180 changes: 180 additions & 0 deletions internal/common/certs/cached_certificate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package certs

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"math/big"
"net"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const certFilePath = "testdata/tls.crt"
const keyFilePath = "testdata/tls.key"

func TestCachedCertificateService_LoadsCertificateOnStartup(t *testing.T) {
defer cleanup()
cert, certData, keyData := createCerts()
writeCerts(t, certData, keyData)

cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath)

result := cachedCertService.GetCertificate()

assert.Equal(t, cert, result)
}

func TestCachedCertificateService_PanicIfInitialLoadFails(t *testing.T) {
defer cleanup()

assert.Panics(t, func() { NewCachedCertificateService(certFilePath, keyFilePath) })
}

func TestCachedCertificateService_ReloadsCert_IfFileOnDiskChanges(t *testing.T) {
defer cleanup()
cert, certData, keyData := createCerts()
writeCerts(t, certData, keyData)
cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath)

assert.Equal(t, cert, cachedCertService.GetCertificate())

newCert, certData, keyData := createCerts()

// Update files on disk
writeCerts(t, certData, keyData)
// Certificate won't change until refresh is called
assert.NotEqual(t, newCert, cachedCertService.GetCertificate())

err := cachedCertService.refresh()
assert.NoError(t, err)
assert.Equal(t, newCert, cachedCertService.GetCertificate())
}

func TestCachedCertificateService_HandlesPartialUpdates(t *testing.T) {
defer cleanup()
originalCert, certData, keyData := createCerts()
writeCerts(t, certData, keyData)
cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath)

assert.Equal(t, originalCert, cachedCertService.GetCertificate())

newCert, certData, keyData := createCerts()

// Update only 1 file on disk - which leaves the representation on disk in an invalid state
writeCerts(t, certData, nil)
err := cachedCertService.refresh()
assert.Error(t, err)

// Certificate provided should not change, as there is no valid new cert yet
assert.Equal(t, originalCert, cachedCertService.GetCertificate())

// Update the other file, so now files on disk are now both updated and consistent
writeCerts(t, nil, keyData)
err = cachedCertService.refresh()
assert.NoError(t, err)

assert.Equal(t, newCert, cachedCertService.GetCertificate())
}

func writeCerts(t *testing.T, certData *bytes.Buffer, keyData *bytes.Buffer) {
if certData != nil {
err := os.WriteFile(certFilePath, certData.Bytes(), 0644)
require.NoError(t, err)
}

if keyData != nil {
err := os.WriteFile(keyFilePath, keyData.Bytes(), 0644)
require.NoError(t, err)
}
}

func cleanup() {
os.Remove(certFilePath)
os.Remove(keyFilePath)
}

func createCerts() (*tls.Certificate, *bytes.Buffer, *bytes.Buffer) {
// set up our CA certificate
ca := &x509.Certificate{
SerialNumber: big.NewInt(2019),
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}

// create our private and public key
caPrivKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}

// create the CA
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
if err != nil {
panic(err)
}

// pem encode
caPEM := new(bytes.Buffer)
pem.Encode(caPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
})

caPrivKeyPEM := new(bytes.Buffer)
pem.Encode(caPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
})

// set up our server certificate
cert := &x509.Certificate{
SerialNumber: big.NewInt(2019),
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}

certPrivKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}

certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
if err != nil {
panic(err)
}

certPEM := new(bytes.Buffer)
pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})

certPrivKeyPEM := new(bytes.Buffer)
pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})

certificate, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
if err != nil {
panic(err)
}

return &certificate, certPEM, certPrivKeyPEM
}
4 changes: 2 additions & 2 deletions internal/common/grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"github.com/armadaproject/armada/internal/common/certs"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
Expand All @@ -25,7 +26,6 @@ import (

"github.com/armadaproject/armada/internal/common/armadaerrors"
"github.com/armadaproject/armada/internal/common/auth/authorization"
"github.com/armadaproject/armada/internal/common/fileutils"
"github.com/armadaproject/armada/internal/common/requestid"
)

Expand All @@ -35,7 +35,7 @@ func CreateGrpcServer(
keepaliveParams keepalive.ServerParameters,
keepaliveEnforcementPolicy keepalive.EnforcementPolicy,
authServices []authorization.AuthService,
tlsCertService *fileutils.CachedCertificateService,
tlsCertService *certs.CachedCertificateService,
) *grpc.Server {
// Logging, authentication, etc. are implemented via gRPC interceptors
// (i.e., via functions that are called before handling the actual request).
Expand Down
6 changes: 3 additions & 3 deletions internal/jobservice/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import (
"os"
"time"

"github.com/armadaproject/armada/internal/common/certs"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"

"github.com/armadaproject/armada/internal/common/auth/authorization"
"github.com/armadaproject/armada/internal/common/fileutils"
grpcCommon "github.com/armadaproject/armada/internal/common/grpc"
grpcconfig "github.com/armadaproject/armada/internal/common/grpc/configuration"
"github.com/armadaproject/armada/internal/common/grpc/grpcpool"
Expand Down Expand Up @@ -90,9 +90,9 @@ func (a *App) StartUp(ctx context.Context, config *configuration.JobServiceConfi
}

log := log.WithField("JobService", "Startup")
var cachedCertificateService *fileutils.CachedCertificateService
var cachedCertificateService *certs.CachedCertificateService
if config.Grpc.Tls.Enabled {
cachedCertificateService = fileutils.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath)
cachedCertificateService = certs.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath)
g.Go(func() error {
return cachedCertificateService.Run(ctx)
})
Expand Down
6 changes: 3 additions & 3 deletions internal/lookout/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import (
"context"
"sync"

"github.com/armadaproject/armada/internal/common/certs"
"github.com/doug-martin/goqu/v9"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
log "github.com/sirupsen/logrus"

"github.com/armadaproject/armada/internal/common/auth/authorization"
"github.com/armadaproject/armada/internal/common/fileutils"
"github.com/armadaproject/armada/internal/common/grpc"
"github.com/armadaproject/armada/internal/common/health"
"github.com/armadaproject/armada/internal/common/util"
Expand All @@ -31,9 +31,9 @@ func StartUp(config configuration.LookoutConfiguration, healthChecks *health.Mul
wg := &sync.WaitGroup{}
wg.Add(1)

var cachedCertificateService *fileutils.CachedCertificateService
var cachedCertificateService *certs.CachedCertificateService
if config.Grpc.Tls.Enabled {
cachedCertificateService = fileutils.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath)
cachedCertificateService = certs.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath)
go func() {
err := func() error {
return cachedCertificateService.Run(context.Background())
Expand Down
6 changes: 3 additions & 3 deletions internal/scheduler/schedulerapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/apache/pulsar-client-go/pulsar"
"github.com/armadaproject/armada/internal/common/certs"
"github.com/go-redis/redis"
"github.com/google/uuid"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
Expand All @@ -23,7 +24,6 @@ import (
"github.com/armadaproject/armada/internal/common/app"
"github.com/armadaproject/armada/internal/common/auth"
dbcommon "github.com/armadaproject/armada/internal/common/database"
"github.com/armadaproject/armada/internal/common/fileutils"
grpcCommon "github.com/armadaproject/armada/internal/common/grpc"
"github.com/armadaproject/armada/internal/common/health"
"github.com/armadaproject/armada/internal/common/pulsarutils"
Expand Down Expand Up @@ -126,9 +126,9 @@ func Run(config schedulerconfig.Configuration) error {
if err != nil {
return errors.WithMessage(err, "error creating auth services")
}
var cachedCertificateService *fileutils.CachedCertificateService
var cachedCertificateService *certs.CachedCertificateService
if config.Grpc.Tls.Enabled {
cachedCertificateService = fileutils.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath)
cachedCertificateService = certs.NewCachedCertificateService(config.Grpc.Tls.CertPath, config.Grpc.Tls.KeyPath)
services = append(services, func() error {
return cachedCertificateService.Run(ctx)
})
Expand Down

0 comments on commit 02b1817

Please sign in to comment.