diff --git a/README.md b/README.md index c4374877..37298732 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,7 @@ tls-client-key-file | REDIS_EXPORTER_TLS_CLIENT_KEY_FILE | Name tls-client-cert-file | REDIS_EXPORTER_TLS_CLIENT_CERT_FILE | Name the client cert file (including full path) if the server requires TLS client authentication tls-server-key-file | REDIS_EXPORTER_TLS_SERVER_KEY_FILE | Name of the server key file (including full path) if the web interface and telemetry should use TLS tls-server-cert-file | REDIS_EXPORTER_TLS_SERVER_CERT_FILE | Name of the server certificate file (including full path) if the web interface and telemetry should use TLS +tls-server-ca-cert-file | REDIS_EXPORTER_TLS_SERVER_CA_CERT_FILE | Name of the CA certificate file (including full path) if the web interface and telemetry should require TLS client authentication tls-ca-cert-file | REDIS_EXPORTER_TLS_CA_CERT_FILE | Name of the CA certificate file (including full path) if the server requires TLS client authentication set-client-name | REDIS_EXPORTER_SET_CLIENT_NAME | Whether to set client name to redis_exporter, defaults to true. check-key-groups | REDIS_EXPORTER_CHECK_KEY_GROUPS | Comma separated list of [LUA regexes](https://www.lua.org/pil/20.1.html) for classifying keys into groups. The regexes are applied in specified order to individual keys, and the group name is generated by concatenating all capture groups of the first regex that matches a key. A key will be tracked under the `unclassified` group if none of the specified regexes matches it. diff --git a/exporter/tls.go b/exporter/tls.go index 258ba636..d1857a87 100644 --- a/exporter/tls.go +++ b/exporter/tls.go @@ -23,19 +23,40 @@ func (e *Exporter) CreateClientTLSConfig() (*tls.Config, error) { } if e.options.CaCertFile != "" { - log.Debugf("Load CA cert: %s", e.options.CaCertFile) - caCert, err := ioutil.ReadFile(e.options.CaCertFile) + certificates, err := LoadCAFile(e.options.CaCertFile) if err != nil { return nil, err } - certificates := x509.NewCertPool() - certificates.AppendCertsFromPEM(caCert) tlsConfig.RootCAs = certificates } return &tlsConfig, nil } +// CreateServerTLSConfig verifies configured files and return a prepared tls.Config +func (e *Exporter) CreateServerTLSConfig(certFile, keyFile, caCertFile string) (*tls.Config, error) { + // Verify that the initial key pair is accepted + _, err := LoadKeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + tlsConfig := tls.Config{ + GetCertificate: GetServerCertificateFunc(certFile, keyFile), + } + + if caCertFile != "" { + // Verify that the initial CA file is accepted when configured + _, err := LoadCAFile(caCertFile) + if err != nil { + return nil, err + } + tlsConfig.GetConfigForClient = GetConfigForClientFunc(certFile, keyFile, caCertFile) + } + + return &tlsConfig, nil +} + // GetServerCertificateFunc returns a function for tls.Config.GetCertificate func GetServerCertificateFunc(certFile, keyFile string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) { return func(*tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -43,6 +64,23 @@ func GetServerCertificateFunc(certFile, keyFile string) func(*tls.ClientHelloInf } } +// GetConfigForClientFunc returns a function for tls.Config.GetConfigForClient +func GetConfigForClientFunc(certFile, keyFile, caCertFile string) func(*tls.ClientHelloInfo) (*tls.Config, error) { + return func(*tls.ClientHelloInfo) (*tls.Config, error) { + certificates, err := LoadCAFile(caCertFile) + if err != nil { + return nil, err + } + + tlsConfig := tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: certificates, + GetCertificate: GetServerCertificateFunc(certFile, keyFile), + } + return &tlsConfig, nil + } +} + // LoadKeyPair reads and parses a public/private key pair from a pair of files. // The files must contain PEM encoded data. func LoadKeyPair(certFile, keyFile string) (*tls.Certificate, error) { @@ -53,3 +91,16 @@ func LoadKeyPair(certFile, keyFile string) (*tls.Certificate, error) { } return &cert, nil } + +// LoadCAFile reads and parses CA certificates from a file into a pool. +// The file must contain PEM encoded data. +func LoadCAFile(caFile string) (*x509.CertPool, error) { + log.Debugf("Load CA cert file: %s", caFile) + pemCerts, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + pool := x509.NewCertPool() + pool.AppendCertsFromPEM(pemCerts) + return pool, nil +} diff --git a/exporter/tls_test.go b/exporter/tls_test.go index 6b989618..cdeabf65 100644 --- a/exporter/tls_test.go +++ b/exporter/tls_test.go @@ -39,6 +39,34 @@ func TestCreateClientTLSConfig(t *testing.T) { } } +func TestCreateServerTLSConfig(t *testing.T) { + e := getTestExporter() + + // positive tests + _, err := e.CreateServerTLSConfig("../contrib/tls/redis.crt", "../contrib/tls/redis.key", "") + if err != nil { + t.Errorf("CreateServerTLSConfig() err: %s", err) + } + _, err = e.CreateServerTLSConfig("../contrib/tls/redis.crt", "../contrib/tls/redis.key", "../contrib/tls/ca.crt") + if err != nil { + t.Errorf("CreateServerTLSConfig() err: %s", err) + } + + // negative tests + _, err = e.CreateServerTLSConfig("/nonexisting/file", "/nonexisting/file", "") + if err == nil { + t.Errorf("Expected CreateServerTLSConfig() to fail") + } + _, err = e.CreateServerTLSConfig("/nonexisting/file", "/nonexisting/file", "/nonexisting/file") + if err == nil { + t.Errorf("Expected CreateServerTLSConfig() to fail") + } + _, err = e.CreateServerTLSConfig("../contrib/tls/redis.crt", "../contrib/tls/redis.key", "/nonexisting/file") + if err == nil { + t.Errorf("Expected CreateServerTLSConfig() to fail") + } +} + func TestGetServerCertificateFunc(t *testing.T) { // positive test _, err := GetServerCertificateFunc("../contrib/tls/ca.crt", "../contrib/tls/ca.key")(nil) @@ -52,3 +80,17 @@ func TestGetServerCertificateFunc(t *testing.T) { t.Errorf("Expected GetServerCertificateFunc() to fail") } } + +func TestGetConfigForClientFunc(t *testing.T) { + // positive test + _, err := GetConfigForClientFunc("../contrib/tls/redis.crt", "../contrib/tls/redis.key", "../contrib/tls/ca.crt")(nil) + if err != nil { + t.Errorf("GetConfigForClientFunc() err: %s", err) + } + + // negative test + _, err = GetConfigForClientFunc("/nonexisting/file", "/nonexisting/file", "/nonexisting/file")(nil) + if err == nil { + t.Errorf("Expected GetConfigForClientFunc() to fail") + } +} diff --git a/main.go b/main.go index e5019f3e..22f6bcb2 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "crypto/tls" "flag" "io/ioutil" "net/http" @@ -77,6 +76,7 @@ func main() { tlsCaCertFile = flag.String("tls-ca-cert-file", getEnv("REDIS_EXPORTER_TLS_CA_CERT_FILE", ""), "Name of the CA certificate file (including full path) if the server requires TLS client authentication") tlsServerKeyFile = flag.String("tls-server-key-file", getEnv("REDIS_EXPORTER_TLS_SERVER_KEY_FILE", ""), "Name of the server key file (including full path) if the web interface and telemetry should use TLS") tlsServerCertFile = flag.String("tls-server-cert-file", getEnv("REDIS_EXPORTER_TLS_SERVER_CERT_FILE", ""), "Name of the server certificate file (including full path) if the web interface and telemetry should use TLS") + tlsServerCaCertFile = flag.String("tls-server-ca-cert-file", getEnv("REDIS_EXPORTER_TLS_SERVER_CA_CERT_FILE", ""), "Name of the CA certificate file (including full path) if the web interface and telemetry should require TLS client authentication") maxDistinctKeyGroups = flag.Int64("max-distinct-key-groups", getEnvInt64("REDIS_EXPORTER_MAX_DISTINCT_KEY_GROUPS", 100), "The maximum number of distinct key groups with the most memory utilization to present as distinct metrics per database, the leftover key groups will be aggregated in the 'overflow' bucket") isDebug = flag.Bool("debug", getEnvBool("REDIS_EXPORTER_DEBUG", false), "Output verbose debug information") setClientName = flag.Bool("set-client-name", getEnvBool("REDIS_EXPORTER_SET_CLIENT_NAME", true), "Whether to set client name to redis_exporter") @@ -197,14 +197,14 @@ func main() { if *tlsServerCertFile != "" && *tlsServerKeyFile != "" { log.Debugf("Bind as TLS using cert %s and key %s", *tlsServerCertFile, *tlsServerKeyFile) - // Verify that the initial key pair is accepted - _, err := exporter.LoadKeyPair(*tlsServerCertFile, *tlsServerKeyFile) + tlsConfig, err := exp.CreateServerTLSConfig(*tlsServerCertFile, *tlsServerKeyFile, *tlsServerCaCertFile) if err != nil { - log.Fatalf("Couldn't load TLS server key pair, err: %s", err) + log.Fatal(err) } + server := &http.Server{ Addr: *listenAddress, - TLSConfig: &tls.Config{GetCertificate: exporter.GetServerCertificateFunc(*tlsServerCertFile, *tlsServerKeyFile)}, + TLSConfig: tlsConfig, Handler: exp} log.Fatal(server.ListenAndServeTLS("", "")) } else {