Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

COCOS-160 - Enable mTLS when using aTLS #172

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 109 additions & 43 deletions internal/server/grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"math/big"
"net"
"os"
"strings"
"time"

"github.com/google/go-sev-guest/client"
Expand Down Expand Up @@ -92,7 +93,7 @@ func (s *Server) Start() error {
creds := grpc.Creds(insecure.NewCredentials())

switch {
case s.Config.AttestedTLS:
case s.Config.AttestedTLS && (s.Config.ClientCAFile != "" || s.Config.ServerCAFile != ""):
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.quoteProvider)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
Expand All @@ -103,25 +104,11 @@ func (s *Server) Start() error {
return fmt.Errorf("falied due to invalid key pair: %w", err)
}

tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
}

creds = grpc.Creds(credentials.NewTLS(tlsConfig))
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
case s.Config.CertFile != "" || s.Config.KeyFile != "":
certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile)
if err != nil {
return fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
}

var mtlsCA string
// Loading Server CA file
rootCA, err := loadCertFile(s.Config.ServerCAFile)
if err != nil {
return fmt.Errorf("failed to load root ca file: %w", err)
Expand All @@ -133,7 +120,6 @@ func (s *Server) Start() error {
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return fmt.Errorf("failed to append root ca to tls.Config")
}
mtlsCA = fmt.Sprintf("root ca %s", s.Config.ServerCAFile)
}

// Loading Client CA File
Expand All @@ -148,15 +134,41 @@ func (s *Server) Start() error {
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return fmt.Errorf("failed to append client ca to tls.Config")
}
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile)
}

if err != nil {
return err
}

Comment on lines +139 to +142
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the purpose of this error check

tlsConfig.Certificates = append(tlsConfig.Certificates, certificate)
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
switch {
case mtlsCA != "":
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile, mtlsCA))
default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile))
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distinguish this with line 163


case s.Config.AttestedTLS:
certificateBytes, privateKeyBytes, err := generateCertificatesForATLS(s.quoteProvider)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
}

certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes)
if err != nil {
return fmt.Errorf("falied due to invalid key pair: %w", err)
}

tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
}

creds = grpc.Creds(credentials.NewTLS(tlsConfig))
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
case s.Config.CertFile != "" && s.Config.KeyFile != "":
tlsConfig, err := s.setupTLSConfig()
if err != nil {
return err
}
creds = grpc.Creds(credentials.NewTLS(tlsConfig))

default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
}
Expand Down Expand Up @@ -196,35 +208,39 @@ func (s *Server) Stop() error {
}

func loadCertFile(certFile string) ([]byte, error) {
if certFile != "" {
return os.ReadFile(certFile)
if len(certFile) < 1000 && !strings.Contains(certFile, "\n") {
data, err := os.ReadFile(certFile)
if err == nil {
return data, nil
}
}
return []byte{}, nil
return []byte(certFile), nil
}

func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
func loadX509KeyPair(certFile, keyFile string) (tls.Certificate, error) {
var cert, key []byte
var err error
if _, err = os.Stat(certfile); err == nil {
cert, err = os.ReadFile(certfile)
if err != nil {
return tls.Certificate{}, err

readFileOrData := func(input string) ([]byte, error) {
if len(input) < 1000 && !strings.Contains(input, "\n") {
data, err := os.ReadFile(input)
if err == nil {
return data, nil
}
}
} else if os.IsNotExist(err) {
cert = []byte(certfile)
} else {
return tls.Certificate{}, err
return []byte(input), nil
}
if _, err := os.Stat(keyfile); err == nil {
key, err = os.ReadFile(keyfile)
if err != nil {
return tls.Certificate{}, err
}
} else if os.IsNotExist(err) {
key = []byte(keyfile)
} else {
return tls.Certificate{}, err

cert, err = readFileOrData(certFile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read cert: %v", err)
}

key, err = readFileOrData(keyFile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read key: %v", err)
}

return tls.X509KeyPair(cert, key)
}

Expand Down Expand Up @@ -292,3 +308,53 @@ func generateCertificatesForATLS(qp client.QuoteProvider) ([]byte, []byte, error

return certBytes, keyBytes, nil
}

func (s *Server) setupTLSConfig() (*tls.Config, error) {
certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
}

var mtlsCA string
// Loading Server CA file
rootCA, err := loadCertFile(s.Config.ServerCAFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return &tls.Config{}, fmt.Errorf("failed to append root ca to tls.Config")
}
mtlsCA = fmt.Sprintf("root ca %s", s.Config.ServerCAFile)
}

// Loading Client CA File
clientCA, err := loadCertFile(s.Config.ClientCAFile)
if err != nil {
return &tls.Config{}, fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return &tls.Config{}, fmt.Errorf("failed to append client ca to tls.Config")
}
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile)
}
switch {
case mtlsCA != "":
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS", s.Name, s.Address))
default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS", s.Name, s.Address))
}

return tlsConfig, nil
}
85 changes: 59 additions & 26 deletions pkg/clients/grpc/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ const (
withoutTLS security = iota
withTLS
withmTLS
withaTLS
withmaTLS
)

const (
Expand Down Expand Up @@ -122,6 +124,10 @@ func (c *client) Secure() string {
return "with TLS"
case withmTLS:
return "with mTLS"
case withmaTLS:
return "with maTLS"
case withaTLS:
return "with mTLS"
case withoutTLS:
fallthrough
default:
Expand All @@ -141,7 +147,35 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) {
secure := withoutTLS
tc := insecure.NewCredentials()

if cfg.AttestedTLS {
switch {
case cfg.AttestedTLS && cfg.ServerCAFile != "":
err := ReadManifest(cfg.Manifest, &attestationConfiguration)
if err != nil {
return nil, secure, fmt.Errorf("failed to read Manifest %w", err)
}

tlsConfig := &tls.Config{
InsecureSkipVerify: false,
VerifyPeerCertificate: verifyAttestationReportTLS,
}

// Loading root ca certificates file
rootCA, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, secure, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
}
tlsConfig.RootCAs = capool
secure = withmaTLS
}

tc = credentials.NewTLS(tlsConfig)

case cfg.AttestedTLS:
err := ReadManifest(cfg.Manifest, &attestationConfiguration)
if err != nil {
return nil, secure, fmt.Errorf("failed to read Manifest %w", err)
Expand All @@ -152,36 +186,35 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) {
VerifyPeerCertificate: verifyAttestationReportTLS,
}
tc = credentials.NewTLS(tlsConfig)
} else {
if cfg.ServerCAFile != "" {
tlsConfig := &tls.Config{}

// Loading root ca certificates file
rootCA, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, secure, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
}
tlsConfig.RootCAs = capool
secure = withTLS
}
case cfg.ServerCAFile != "":
tlsConfig := &tls.Config{}

// Loading mTLS certificates file
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, secure, fmt.Errorf("failed to client certificate and key %w", err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
secure = withmTLS
// Loading root ca certificates file
rootCA, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, secure, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
}
tlsConfig.RootCAs = capool
secure = withTLS
}

tc = credentials.NewTLS(tlsConfig)
// Loading mTLS certificates file
if cfg.ClientCert != "" && cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, secure, fmt.Errorf("failed to load client certificate and key %w", err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
secure = withmTLS
}
tc = credentials.NewTLS(tlsConfig)
default:
}

opts = append(opts, grpc.WithTransportCredentials(tc))
Expand Down