Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to use custom TLS certificates instead of ACME #17

Merged
merged 11 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ applications. To enable this, add the `--tls` flag when deploying an instance:
kamal-proxy deploy service1 --target web-1:3000 --host app1.example.com --tls


### Custom TLS certificate

When you obtained your TLS certificate manually, manage your own certificate authority,
or need to install Cloudflare origin certificate, you can manually specify path to
your certificate file and the corresponding private key:

kamal-proxy deploy service1 --target web-1:3000 --host app1.example.com --tls --tls-certificate-path cert.pem --tls-private-key-path key.pem


## Specifying `run` options with environment variables

In some environments, like when running a Docker container, it can be convenient
Expand Down
3 changes: 3 additions & 0 deletions internal/cmd/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ func newDeployCommand() *deployCommand {

deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.TLSEnabled, "tls", false, "Configure TLS for this target (requires a non-empty host)")
deployCommand.cmd.Flags().BoolVar(&deployCommand.tlsStaging, "tls-staging", false, "Use Let's Encrypt staging environment for certificate provisioning")
deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSCertificatePath, "tls-certificate-path", "", "Configure custom TLS certificate path (PEM format)")
deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSPrivateKeyPath, "tls-private-key-path", "", "Configure custom TLS private key path (PEM format)")

deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DeployTimeout, "deploy-timeout", server.DefaultDeployTimeout, "Maximum time to wait for the new target to become healthy")
deployCommand.cmd.Flags().DurationVar(&deployCommand.args.DrainTimeout, "drain-timeout", server.DefaultDrainTimeout, "Maximum time to allow existing connections to drain before removing old target")
Expand All @@ -53,6 +55,7 @@ func newDeployCommand() *deployCommand {
deployCommand.cmd.Flags().BoolVar(&deployCommand.args.TargetOptions.ForwardHeaders, "forward-headers", false, "Forward X-Forwarded headers to target (default false if TLS enabled; otherwise true)")

deployCommand.cmd.MarkFlagRequired("target")
deployCommand.cmd.MarkFlagsRequiredTogether("tls-certificate-path", "tls-private-key-path")

return deployCommand
}
Expand Down
61 changes: 61 additions & 0 deletions internal/server/cert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package server

import (
"crypto/tls"
"log/slog"
"net/http"
"sync"
)

type CertManager interface {
GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
HTTPHandler(handler http.Handler) http.Handler
}

// StaticCertManager is a certificate manager that loads certificates from disk.
type StaticCertManager struct {
tlsCertificateFilePath string
tlsPrivateKeyFilePath string
cert *tls.Certificate
lock sync.RWMutex
}

func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string) *StaticCertManager {
return &StaticCertManager{
tlsCertificateFilePath: tlsCertificateFilePath,
tlsPrivateKeyFilePath: tlsPrivateKeyFilePath,
}
}

func (m *StaticCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
m.lock.RLock()
if m.cert != nil {
defer m.lock.RUnlock()
return m.cert, nil
}
m.lock.RUnlock()

m.lock.Lock()
defer m.lock.Unlock()
if m.cert != nil { // Double-check locking
return m.cert, nil
}

slog.Info(
"Loading custom TLS certificate",
"tls-certificate-path", m.tlsCertificateFilePath,
"tls-private-key-path", m.tlsPrivateKeyFilePath,
)

cert, err := tls.LoadX509KeyPair(m.tlsCertificateFilePath, m.tlsPrivateKeyFilePath)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be called for every TLS handshake, which means the certificate files will be loaded multiple times. Better to load the certificate when the StaticCertManager is established, and reuse it for the lifetime of the deployment.

That would also let us catch a loading error early on, and fail the deployment at that point, which would be safer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would require more changes, as we right now do not expect createCertManager() to ever return an error.

Another option would be to cache the certificate on first use:

diff --git a/internal/server/cert.go b/internal/server/cert.go
index bb1580d..afaf740 100644
--- a/internal/server/cert.go
+++ b/internal/server/cert.go
@@ -9,9 +9,11 @@ type CertManager interface {
 	GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
 }
 
+// StaticCertManager is a certificate manager that loads certificates from disk.
 type StaticCertManager struct {
 	tlsCertificateFilePath string
 	tlsPrivateKeyFilePath  string
+	cert                   *tls.Certificate
 }
 
 func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string) *StaticCertManager {
@@ -22,6 +24,10 @@ func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string)
 }
 
 func (m *StaticCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+	if m.cert != nil {
+		return m.cert, nil
+	}
+
 	slog.Info(
 		"Loading custom TLS certificate",
 		"tls-certificate-path", m.tlsCertificateFilePath,
@@ -32,6 +38,7 @@ func (m *StaticCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certifica
 	if err != nil {
 		return nil, err
 	}
+	m.cert = &cert
 
-	return &cert, nil
+	return m.cert, nil
 }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied the patch and added a test. Let me know what you think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this, might need a mutex since there might be multiple procs loading certs at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed, will work on a fix:

	manager := NewStaticCertManager(certPath, keyPath)
	go func() {
		manager.GetCertificate(&tls.ClientHelloInfo{})
	}()
	cert, err := manager.GetCertificate(&tls.ClientHelloInfo{})

Output of go test ./... -race:

WARNING: DATA RACE
Write at 0x00c0001e8e00 by goroutine 23:
  github.com/basecamp/kamal-proxy/internal/server.(*StaticCertManager).GetCertificate()
      /Users/dmytro/work/github/kamal-proxy/internal/server/cert.go:41 +0x254
  github.com/basecamp/kamal-proxy/internal/server.TestCertificateLoadingRaceCondition.func1()
      /Users/dmytro/work/github/kamal-proxy/internal/server/cert_test.go:49 +0x7c

Previous write at 0x00c0001e8e00 by goroutine 22:
  github.com/basecamp/kamal-proxy/internal/server.(*StaticCertManager).GetCertificate()
      /Users/dmytro/work/github/kamal-proxy/internal/server/cert.go:41 +0x254
  github.com/basecamp/kamal-proxy/internal/server.TestCertificateLoadingRaceCondition()
      /Users/dmytro/work/github/kamal-proxy/internal/server/cert_test.go:51 +0x270
  testing.tRunner()
      /Users/dmytro/go/pkg/mod/golang.org/toolchain@v0.0.1-go1.23.1.darwin-arm64/src/testing/testing.go:1690 +0x184
  testing.(*T).Run.gowrap1()
      /Users/dmytro/go/pkg/mod/golang.org/toolchain@v0.0.1-go1.23.1.darwin-arm64/src/testing/testing.go:1743 +0x40

Copy link
Contributor Author

@kpumuk kpumuk Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduced an RWMutex in f1df35f to address it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option to avoid the mutex would be to attempt the file load in the NewStaticCertManager, and then save both the error and the response in the struct. GetCertificate would just return both, avoiding a delayed load, the need for mutex, etc.

This design might limit the options for (potential) future changes, like reloading the certificate on file change.

if err != nil {
return nil, err
}
m.cert = &cert

return m.cert, nil
}

func (m *StaticCertManager) HTTPHandler(handler http.Handler) http.Handler {
return handler
}
105 changes: 105 additions & 0 deletions internal/server/cert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package server

import (
"crypto/tls"
"os"
"path"
"testing"

"github.com/stretchr/testify/require"
)

const certPem = `-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`

const keyPem = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`

func TestCertificateLoading(t *testing.T) {
certPath, keyPath, err := prepareTestCertificateFiles(t)
require.NoError(t, err)

manager := NewStaticCertManager(certPath, keyPath)
cert, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.NoError(t, err)
require.NotNil(t, cert)
}

func TestCertificateLoadingRaceCondition(t *testing.T) {
certPath, keyPath, err := prepareTestCertificateFiles(t)
require.NoError(t, err)

manager := NewStaticCertManager(certPath, keyPath)
go func() {
_, err2 := manager.GetCertificate(&tls.ClientHelloInfo{})
require.NoError(t, err2)
}()
cert, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.NoError(t, err)
require.NotNil(t, cert)
}

func TestCachesLoadedCertificate(t *testing.T) {
certPath, keyPath, err := prepareTestCertificateFiles(t)
require.NoError(t, err)

manager := NewStaticCertManager(certPath, keyPath)
cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.NoError(t, err)
require.NotNil(t, cert1)

require.Nil(t, os.Remove(certPath))
require.Nil(t, os.Remove(keyPath))

cert2, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.Equal(t, cert1, cert2)
}

func TestErrorWhenFileDoesNotExist(t *testing.T) {
manager := NewStaticCertManager("testdata/cert.pem", "testdata/key.pem")
cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.ErrorContains(t, err, "no such file or directory")
require.Nil(t, cert1)
}

func TestErrorWhenKeyFormatIsInvalid(t *testing.T) {
certPath, keyPath, err := prepareTestCertificateFiles(t)
require.NoError(t, err)

manager := NewStaticCertManager(keyPath, certPath)
cert1, err := manager.GetCertificate(&tls.ClientHelloInfo{})
require.ErrorContains(t, err, "failed to find certificate PEM data in certificate input")
require.Nil(t, cert1)
}

func prepareTestCertificateFiles(t *testing.T) (string, string, error) {
t.Helper()

dir := t.TempDir()
certFile := path.Join(dir, "example-cert.pem")
keyFile := path.Join(dir, "example-key.pem")

err := os.WriteFile(certFile, []byte(certPem), 0644)
if err != nil {
return "", "", err
}

err = os.WriteFile(keyFile, []byte(keyPem), 0644)
if err != nil {
return "", "", err
}

return certFile, keyFile, nil
}
18 changes: 12 additions & 6 deletions internal/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ type HealthCheckConfig struct {
}

type ServiceOptions struct {
TLSEnabled bool `json:"tls_enabled"`
ACMEDirectory string `json:"acme_directory"`
ACMECachePath string `json:"acme_cache_path"`
ErrorPagePath string `json:"error_page_path"`
TLSEnabled bool `json:"tls_enabled"`
TLSCertificatePath string `json:"tls_certificate_path"`
TLSPrivateKeyPath string `json:"tls_private_key_path"`
ACMEDirectory string `json:"acme_directory"`
ACMECachePath string `json:"acme_cache_path"`
ErrorPagePath string `json:"error_page_path"`
}

func (so ServiceOptions) ScopedCachePath() string {
Expand All @@ -90,7 +92,7 @@ type Service struct {

pauseController *PauseController
rolloutController *RolloutController
certManager *autocert.Manager
certManager CertManager
middleware http.Handler
}

Expand Down Expand Up @@ -284,11 +286,15 @@ func (s *Service) initialize() {
s.middleware = s.createMiddleware()
}

func (s *Service) createCertManager() *autocert.Manager {
func (s *Service) createCertManager() CertManager {
if !s.options.TLSEnabled {
return nil
}

if s.options.TLSCertificatePath != "" && s.options.TLSPrivateKeyPath != "" {
return NewStaticCertManager(s.options.TLSCertificatePath, s.options.TLSPrivateKeyPath)
}

return &autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(s.options.ScopedCachePath()),
Expand Down
15 changes: 15 additions & 0 deletions internal/server/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@ func TestService_RedirectToHTTPWhenTLSRequired(t *testing.T) {
require.Equal(t, http.StatusOK, w.Result().StatusCode)
}

func TestService_UseStaticTLSCertificateWhenConfigured(t *testing.T) {
service := testCreateService(
t,
[]string{"example.com"},
ServiceOptions{
TLSEnabled: true,
TLSCertificatePath: "cert.pem",
TLSPrivateKeyPath: "key.pem",
},
defaultTargetOptions,
)

require.IsType(t, &StaticCertManager{}, service.certManager)
}

func TestService_RejectTLSRequestsWhenNotConfigured(t *testing.T) {
service := testCreateService(t, defaultEmptyHosts, defaultServiceOptions, defaultTargetOptions)

Expand Down