Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] [kube] add server_id to targets when monitoring exec/portforward connections #48076

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2255,7 +2255,7 @@ func (s *clusterSession) close() {
}
}

func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error) {
func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (net.Conn, error) {
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -2270,10 +2270,18 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error)
s.connMonitorCancel(err)
return nil, trace.Wrap(err)
}

lockTargets := s.LockTargets()
// when the target is not a kubernetes_service instance, we don't need to lock it.
// the target could be a remote cluster or a local Kubernetes API server. In both cases,
// hostID is empty.
if hostID != "" {
lockTargets = append(lockTargets, types.LockTarget{
ServerID: hostID,
})
}
err = srv.StartMonitor(srv.MonitorConfig{
LockWatcher: s.parent.cfg.LockWatcher,
LockTargets: s.LockTargets(),
LockTargets: lockTargets,
DisconnectExpiredCert: s.disconnectExpiredCert,
ClientIdleTimeout: s.clientIdleTimeout,
Clock: s.parent.cfg.Clock,
Expand Down Expand Up @@ -2305,12 +2313,16 @@ func (s *clusterSession) getServerMetadata() apievents.ServerMetadata {
}

func (s *clusterSession) Dial(network, addr string) (net.Conn, error) {
return s.monitorConn(s.dial(s.requestContext, network, addr))
var hostID string
conn, err := s.dial(s.requestContext, network, addr, withHostIDCollection(&hostID))
return s.monitorConn(conn, err, hostID)
}

func (s *clusterSession) DialWithContext(opts ...contextDialerOption) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
return s.monitorConn(s.dial(ctx, network, addr, opts...))
var hostID string
conn, err := s.dial(ctx, network, addr, append(opts, withHostIDCollection(&hostID))...)
return s.monitorConn(conn, err, hostID)
}
}

Expand Down
1 change: 0 additions & 1 deletion lib/kube/proxy/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) {
if err != nil {
return nil, err
}

if err := req.Write(conn); err != nil {
conn.Close()
return nil, err
Expand Down
23 changes: 20 additions & 3 deletions lib/kube/proxy/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ func (f *Forwarder) localClusterDialer(kubeClusterName string, opts ...contextDi
ProxyIDs: s.GetProxyIDs(),
})
if err == nil {
opt.collect(s.GetHostID())
return conn, nil
}
errs = append(errs, trace.Wrap(err))
Expand Down Expand Up @@ -415,13 +416,21 @@ func (f *Forwarder) getContextDialerFunc(s *clusterSession, opts ...contextDiale
// contextDialerOptions is a set of options that can be used to filter
// the hosts that the dialer connects to.
type contextDialerOptions struct {
hostID string
hostIDFilter string
collectHostID *string
}

// matches returns true if the host matches the hostID of the dialer options or
// if the dialer hostID is empty.
func (c *contextDialerOptions) matches(hostID string) bool {
return c.hostID == "" || c.hostID == hostID
return c.hostIDFilter == "" || c.hostIDFilter == hostID
}

// collect sets the hostID that the dialer connected to if collectHostID is not nil.
func (c *contextDialerOptions) collect(hostID string) {
if c.collectHostID != nil {
*c.collectHostID = hostID
}
}

// contextDialerOption is a functional option for the contextDialerOptions.
Expand All @@ -434,6 +443,14 @@ type contextDialerOption func(*contextDialerOptions)
// error.
func withTargetHostID(hostID string) contextDialerOption {
return func(o *contextDialerOptions) {
o.hostID = hostID
o.hostIDFilter = hostID
}
}

// withHostIDCollection is a functional option that sets the hostID of the dialer
// to the provided pointer.
func withHostIDCollection(hostID *string) contextDialerOption {
return func(o *contextDialerOptions) {
o.collectHostID = hostID
}
}
Loading