Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving SSH port forwarding audit logs #50932

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions lib/events/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,13 @@ const (
X11ForwardErr = "error"

// Port forwarding event
PortForwardEvent = "port"
PortForwardAddr = "addr"
PortForwardSuccess = "success"
PortForwardErr = "error"
PortForwardEvent = "port"
PortForwardLocalEvent = "port.local"
PortForwardRemoteEvent = "port.remote"
PortForwardRemoteConnEvent = "port.remote_conn"
PortForwardAddr = "addr"
PortForwardSuccess = "success"
PortForwardErr = "error"

// AuthAttemptEvent is authentication attempt that either
// succeeded or failed based on event status
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1267,12 +1267,12 @@ func (c *ServerContext) GetSessionMetadata() apievents.SessionMetadata {
}
}

func (c *ServerContext) GetPortForwardEvent() apievents.PortForward {
sconn := c.ConnectionContext.ServerConn
func (c *ServerContext) GetPortForwardEvent(evType, code string) apievents.PortForward {
sconn := c.ServerConn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Prefer being explicit over relying on the fact that ConnectionContext is embedded in ServerContext here

Suggested change
sconn := c.ServerConn
sconn := c. ConnectionContext.ServerConn

return apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardCode,
Type: evType,
Code: code,
},
UserMetadata: c.Identity.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
Expand Down
72 changes: 66 additions & 6 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,47 @@ func (s *Server) Serve() {

succeeded = true

if err := s.EmitAuditEvent(ctx, &apievents.PortForward{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right place to emit this event? Isn't this the entrypoint for all connections established to the forward server? Won't this cause the event to be emitted any time a user connects to an agentless host or a host in proxy recording mode?

Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardCode,
},
UserMetadata: s.identityContext.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: sconn.LocalAddr().String(),
RemoteAddr: sconn.RemoteAddr().String(),
},
Addr: s.targetAddr,
Status: apievents.Status{
Success: true,
},
}); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}

s.connectionContext.AddCloser(utils.CloseFunc(func() error {
if err := s.EmitAuditEvent(ctx, &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardStopCode,
},
UserMetadata: s.identityContext.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: sconn.LocalAddr().String(),
RemoteAddr: sconn.RemoteAddr().String(),
},
Addr: s.targetAddr,
Status: apievents.Status{
Success: true,
},
}); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
return err
}

return nil
}))

// Add channel handlers immediately to avoid rejecting a channel.
forwardedTCPIP := s.remoteClient.HandleChannelOpen(teleport.ChanForwardedTCPIP)

Expand All @@ -662,7 +703,9 @@ func (s *Server) Serve() {
Interval: netConfig.GetKeepAliveInterval(),
MaxCount: netConfig.GetKeepAliveCountMax(),
CloseContext: ctx,
CloseCancel: func() { s.connectionContext.Close() },
CloseCancel: func() {
s.connectionContext.Close()
},
Comment on lines +706 to +708
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: revert changes to preserve the original blame here

})

go s.handleClientChannels(ctx, forwardedTCPIP)
Expand Down Expand Up @@ -922,6 +965,11 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha
scx.DstAddr = sshutils.JoinHostPort(req.Host, req.Port)
defer scx.Close()

event := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err)
}

// Open a forwarding channel on the client.
outCh, outRequests, err := scx.ServerConn.OpenChannel(nch.ChannelType(), nch.ExtraData())
if err != nil {
Expand All @@ -941,10 +989,12 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha
go io.Copy(io.Discard, ch.Stderr())
ch = scx.TrackActivity(ch)

event := scx.GetPortForwardEvent()
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err)
}
defer func() {
stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(ctx, &stopEvent); err != nil {
s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err)
}
}()

return trace.Wrap(utils.ProxyConn(ctx, ch, outCh))
}
Expand Down Expand Up @@ -1120,13 +1170,23 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r
}
defer conn.Close()

event := scx.GetPortForwardEvent()
event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardCode)
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}

if err := utils.ProxyConn(ctx, ch, conn); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
s.logger.WarnContext(ctx, "Failed proxying data for port forwarding connection", "error", err)

event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode)
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}

event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}

Expand Down
60 changes: 50 additions & 10 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1489,13 +1489,9 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con
return
}

if err := utils.ProxyConn(ctx, conn, channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
scx.Logger.WarnContext(ctx, "Connection problem in direct-tcpip channel", "error", err)
}

if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Type: events.PortForwardLocalEvent,
Code: events.PortForwardCode,
},
UserMetadata: scx.Identity.GetUserMetadata(),
Expand All @@ -1510,6 +1506,19 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con
}); err != nil {
scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}

if err := utils.ProxyConn(ctx, conn, channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode)
if err := s.EmitAuditEvent(s.ctx, &event); err != nil {
scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
scx.Logger.WarnContext(ctx, "Connection problem in direct-tcpip channel", "error", err)
}

event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(s.ctx, &event); err != nil {
scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}

// handleSessionRequests handles out of band session requests once the session
Expand Down Expand Up @@ -2201,14 +2210,40 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co
}
scx.SrcAddr = sshutils.JoinHostPort(srcHost, listenPort)

event := scx.GetPortForwardEvent()
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
// pregenerate audit events since ServerContext may be closed before they're used
startEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode)
stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode)
errEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardFailureCode)
proxyWithAudit := func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser) {
startEvent.RemoteAddr = remoteAddr
if err := s.EmitAuditEvent(ctx, &startEvent); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}

if err := utils.ProxyConn(ctx, client, server); err != nil {
s.logger.ErrorContext(ctx, "PROXY CONN FAILURE", "error", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: remove the CAPS?

errEvent.RemoteAddr = remoteAddr
if err := s.EmitAuditEvent(ctx, &errEvent); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}
}

s.logger.ErrorContext(ctx, "PROXY CONN DONE")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: remove the CAPS?

stopEvent.RemoteAddr = remoteAddr
if err := s.EmitAuditEvent(ctx, &stopEvent); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}
}
if err := sshutils.StartRemoteListener(ctx, scx.ConnectionContext.ServerConn, scx.SrcAddr, listener); err != nil {

if err := sshutils.StartRemoteListener(ctx, scx.ServerConn, scx.SrcAddr, listener, proxyWithAudit); err != nil {
return trace.Wrap(err)
}

event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardCode)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}

// Report addr back to the client.
if r.WantReply {
var payload []byte
Expand All @@ -2232,6 +2267,11 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co
// Close the listener once the connection is closed, if it hasn't
// been closed already via a cancel-tcpip-forward request.
ccx.AddCloser(utils.CloseFunc(func() error {
event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(context.Background(), &event); err != nil {
s.logger.WarnContext(context.Background(), "Failed to emit audit event", "error", err)
}

listener, ok := s.remoteForwardingMap.LoadAndDelete(scx.SrcAddr)
if ok {
return trace.Wrap(listener.Close())
Expand All @@ -2250,14 +2290,14 @@ func (s *Server) handleCancelTCPIPForwardRequest(ctx context.Context, ccx *sshut
return trace.Wrap(err)
}
defer scx.Close()

listener, ok := s.remoteForwardingMap.LoadAndDelete(scx.SrcAddr)
if !ok {
return trace.NotFound("no remote forwarding listener at %v", scx.SrcAddr)
}
if err := r.Reply(true, nil); err != nil {
s.logger.WarnContext(ctx, "Failed to reply to request", "request_type", r.Type, "error", err)
}

return trace.Wrap(listener.Close())
}

Expand Down
11 changes: 9 additions & 2 deletions lib/sshutils/tcpip.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type channelOpener interface {

// StartRemoteListener listens on the given listener and forwards any accepted
// connections over a new "forwarded-tcpip" channel.
func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr string, listener net.Listener) error {
func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr string, listener net.Listener, proxyFn func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser)) error {
srcHost, srcPort, err := SplitHostPort(srcAddr)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -127,7 +127,14 @@ func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr str
}
go ssh.DiscardRequests(rch)
go io.Copy(io.Discard, ch.Stderr())
go utils.ProxyConn(ctx, conn, ch)
go func() {
if proxyFn != nil {
proxyFn(ctx, conn.RemoteAddr().String(), conn, ch)
return
}

utils.ProxyConn(ctx, conn, ch)
}()
}
}()

Expand Down
42 changes: 41 additions & 1 deletion lib/sshutils/tcpip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
"time"

"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/utils"
)

func TestStartRemoteListener(t *testing.T) {
Expand All @@ -51,7 +53,44 @@ func TestStartRemoteListener(t *testing.T) {
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener))
require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, nil))

// Check that dialing listener makes it all the way to the test http server.
resp, err := http.Get("http://" + listener.Addr().String())
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Hello, world", string(body))
}

func TestStartRemoteListenerWithCustomProxy(t *testing.T) {
// Create a test server to act as the other side of the channel.
tsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Hello, world")
}))
t.Cleanup(tsrv.Close)
testSrvConn, err := net.Dial("tcp", tsrv.Listener.Addr().String())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a conn that needs to be closed on cleanup as well?

require.NoError(t, err)

sshConn := &mockSSHConn{
mockChan: &mockChannel{
ReadWriter: testSrvConn,
},
}

proxied := false
proxyFn := func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser) {
proxied = true
_ = utils.ProxyConn(ctx, client, server)
}

// Start the remote listener.
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, proxyFn))

// Check that dialing listener makes it all the way to the test http server.
resp, err := http.Get("http://" + listener.Addr().String())
Expand All @@ -60,4 +99,5 @@ func TestStartRemoteListener(t *testing.T) {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Hello, world", string(body))
require.True(t, proxied)
}
Loading