diff --git a/internal/common/certs/cached_certificate.go b/internal/common/certs/cached_certificate.go index 41de764d307..4e64dbc3bb5 100644 --- a/internal/common/certs/cached_certificate.go +++ b/internal/common/certs/cached_certificate.go @@ -24,13 +24,13 @@ type CachedCertificateService struct { refreshInterval time.Duration } -func NewCachedCertificateService(certPath string, keyPath string) *CachedCertificateService { +func NewCachedCertificateService(certPath string, keyPath string, refreshInternal time.Duration) *CachedCertificateService { cert := &CachedCertificateService{ certPath: certPath, keyPath: keyPath, certificateLock: sync.Mutex{}, fileInfoLock: sync.Mutex{}, - refreshInterval: time.Minute, + refreshInterval: refreshInternal, } // Initialise the certificate err := cert.refresh() @@ -60,7 +60,9 @@ func (c *CachedCertificateService) Run(ctx context.Context) error { return nil case <-ticker.C: err := c.refresh() - log.WithError(err).Errorf("failed refreshing tls cert for key %s cert %s", c.keyPath, c.certPath) + if err != nil { + log.WithError(err).Errorf("failed refreshing tls cert for key %s cert %s", c.keyPath, c.certPath) + } } } } diff --git a/internal/common/certs/cached_certificate_test.go b/internal/common/certs/cached_certificate_test.go index 97aa8393257..87b54fae9ee 100644 --- a/internal/common/certs/cached_certificate_test.go +++ b/internal/common/certs/cached_certificate_test.go @@ -2,6 +2,7 @@ package certs import ( "bytes" + "context" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -25,7 +26,7 @@ func TestCachedCertificateService_LoadsCertificateOnStartup(t *testing.T) { cert, certData, keyData := createCerts() writeCerts(t, certData, keyData) - cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath) + cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath, time.Second) result := cachedCertService.GetCertificate() @@ -35,14 +36,14 @@ func TestCachedCertificateService_LoadsCertificateOnStartup(t *testing.T) { func TestCachedCertificateService_PanicIfInitialLoadFails(t *testing.T) { defer cleanup() - assert.Panics(t, func() { NewCachedCertificateService(certFilePath, keyFilePath) }) + assert.Panics(t, func() { NewCachedCertificateService(certFilePath, keyFilePath, time.Second) }) } func TestCachedCertificateService_ReloadsCert_IfFileOnDiskChanges(t *testing.T) { defer cleanup() cert, certData, keyData := createCerts() writeCerts(t, certData, keyData) - cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath) + cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath, time.Second) assert.Equal(t, cert, cachedCertService.GetCertificate()) @@ -62,10 +63,10 @@ func TestCachedCertificateService_HandlesPartialUpdates(t *testing.T) { defer cleanup() originalCert, certData, keyData := createCerts() writeCerts(t, certData, keyData) - cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath) + cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath, time.Second) assert.Equal(t, originalCert, cachedCertService.GetCertificate()) - + newCert, certData, keyData := createCerts() // Update only 1 file on disk - which leaves the representation on disk in an invalid state @@ -84,6 +85,24 @@ func TestCachedCertificateService_HandlesPartialUpdates(t *testing.T) { assert.Equal(t, newCert, cachedCertService.GetCertificate()) } +func TestCachedCertificateService_ReloadsCertPeriodically_WhenUsingRun(t *testing.T) { + defer cleanup() + cert, certData, keyData := createCerts() + writeCerts(t, certData, keyData) + cachedCertService := NewCachedCertificateService(certFilePath, keyFilePath, time.Second) + assert.Equal(t, cert, cachedCertService.GetCertificate()) + + go func() { + err := cachedCertService.Run(context.Background()) + require.NoError(t, err) + }() + + newCert, certData, keyData := createCerts() + writeCerts(t, certData, keyData) + time.Sleep(time.Second * 2) + assert.Equal(t, newCert, cachedCertService.GetCertificate()) +} + func writeCerts(t *testing.T, certData *bytes.Buffer, keyData *bytes.Buffer) { if certData != nil { err := os.WriteFile(certFilePath, certData.Bytes(), 0644)