From e6834cdc98ed0ec4b388fff61283b2d43a16319d Mon Sep 17 00:00:00 2001 From: Paul Lorenz Date: Thu, 14 Nov 2024 15:23:20 -0500 Subject: [PATCH] Check peer certs --- controller/raft/mesh/mesh.go | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/controller/raft/mesh/mesh.go b/controller/raft/mesh/mesh.go index 2ef286e0b..bde9945e4 100644 --- a/controller/raft/mesh/mesh.go +++ b/controller/raft/mesh/mesh.go @@ -460,7 +460,7 @@ func (self *impl) GetOrConnectPeer(address string, timeout time.Duration) (*Peer peer.Channel = binding.GetChannel() - if err = self.checkClusterIds(peer.Channel); err != nil { + if err = self.validateConnection(peer.Channel); err != nil { return err } @@ -510,6 +510,14 @@ func (self *impl) GetOrConnectPeer(address string, timeout time.Duration) (*Peer return peer, nil } +func (self *impl) validateConnection(ch channel.Channel) error { + if err := self.checkClusterIds(ch); err != nil { + return err + } + + return self.checkCerts(ch) +} + func (self *impl) checkClusterIds(ch channel.Channel) error { clusterId := string(ch.Underlay().Headers()[ClusterIdHeader]) if clusterId != "" && self.env.GetClusterId() != "" && clusterId != self.env.GetClusterId() { @@ -518,6 +526,21 @@ func (self *impl) checkClusterIds(ch channel.Channel) error { return nil } +func (self *impl) checkCerts(ch channel.Channel) error { + certs := ch.Underlay().Certificates() + if len(certs) == 0 { + return errors.New("unable to validate peer connection, no certs presented") + } + + for _, cert := range ch.Underlay().Certificates() { + if _, err := self.env.GetNodeId().CaPool().VerifyToRoot(cert); err == nil { + return nil + } + } + + return errors.New("unable to validate peer connection, no certs presented matched the CA for this node") +} + func (self *impl) GetPeerInfo(address string, timeout time.Duration) (raft.ServerID, raft.ServerAddress, error) { log := pfxlog.Logger().WithField("address", address) addr, err := transport.ParseAddress(address) @@ -560,7 +583,7 @@ func (self *impl) GetPeerInfo(address string, timeout time.Duration) (raft.Serve return err } - if err = self.checkClusterIds(binding.GetChannel()); err != nil { + if err = self.validateConnection(binding.GetChannel()); err != nil { return err } @@ -794,7 +817,7 @@ func (self *impl) AcceptUnderlay(underlay channel.Underlay) error { } } - if err = self.checkClusterIds(peer.Channel); err != nil { + if err = self.validateConnection(peer.Channel); err != nil { return err }