diff --git a/lib/events/api.go b/lib/events/api.go index b1dda9a2b6db7..055e060f57ce7 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -264,10 +264,13 @@ const ( X11ForwardErr = "error" // Port forwarding event - PortForwardEvent = "port" - PortForwardAddr = "addr" - PortForwardSuccess = "success" - PortForwardErr = "error" + PortForwardEvent = "port" + PortForwardLocalEvent = "port.local" + PortForwardRemoteEvent = "port.remote" + PortForwardRemoteConnEvent = "port.remote_conn" + PortForwardAddr = "addr" + PortForwardSuccess = "success" + PortForwardErr = "error" // AuthAttemptEvent is authentication attempt that either // succeeded or failed based on event status diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 3318b3755f92c..d81c6db748957 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -1267,12 +1267,12 @@ func (c *ServerContext) GetSessionMetadata() apievents.SessionMetadata { } } -func (c *ServerContext) GetPortForwardEvent() apievents.PortForward { - sconn := c.ConnectionContext.ServerConn +func (c *ServerContext) GetPortForwardEvent(evType, code string) apievents.PortForward { + sconn := c.ServerConn return apievents.PortForward{ Metadata: apievents.Metadata{ - Type: events.PortForwardEvent, - Code: events.PortForwardCode, + Type: evType, + Code: code, }, UserMetadata: c.Identity.GetUserMetadata(), ConnectionMetadata: apievents.ConnectionMetadata{ diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 849bc0a228c53..fae1ccdbb4196 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -648,6 +648,47 @@ func (s *Server) Serve() { succeeded = true + if err := s.EmitAuditEvent(ctx, &apievents.PortForward{ + Metadata: apievents.Metadata{ + Type: events.PortForwardEvent, + Code: events.PortForwardCode, + }, + UserMetadata: s.identityContext.GetUserMetadata(), + ConnectionMetadata: apievents.ConnectionMetadata{ + LocalAddr: sconn.LocalAddr().String(), + RemoteAddr: sconn.RemoteAddr().String(), + }, + Addr: s.targetAddr, + Status: apievents.Status{ + Success: true, + }, + }); err != nil { + s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) + } + + s.connectionContext.AddCloser(utils.CloseFunc(func() error { + if err := s.EmitAuditEvent(ctx, &apievents.PortForward{ + Metadata: apievents.Metadata{ + Type: events.PortForwardEvent, + Code: events.PortForwardStopCode, + }, + UserMetadata: s.identityContext.GetUserMetadata(), + ConnectionMetadata: apievents.ConnectionMetadata{ + LocalAddr: sconn.LocalAddr().String(), + RemoteAddr: sconn.RemoteAddr().String(), + }, + Addr: s.targetAddr, + Status: apievents.Status{ + Success: true, + }, + }); err != nil { + s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) + return err + } + + return nil + })) + // Add channel handlers immediately to avoid rejecting a channel. forwardedTCPIP := s.remoteClient.HandleChannelOpen(teleport.ChanForwardedTCPIP) @@ -662,7 +703,9 @@ func (s *Server) Serve() { Interval: netConfig.GetKeepAliveInterval(), MaxCount: netConfig.GetKeepAliveCountMax(), CloseContext: ctx, - CloseCancel: func() { s.connectionContext.Close() }, + CloseCancel: func() { + s.connectionContext.Close() + }, }) go s.handleClientChannels(ctx, forwardedTCPIP) @@ -922,6 +965,11 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha scx.DstAddr = sshutils.JoinHostPort(req.Host, req.Port) defer scx.Close() + event := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode) + if err := s.EmitAuditEvent(ctx, &event); err != nil { + s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err) + } + // Open a forwarding channel on the client. outCh, outRequests, err := scx.ServerConn.OpenChannel(nch.ChannelType(), nch.ExtraData()) if err != nil { @@ -941,10 +989,12 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha go io.Copy(io.Discard, ch.Stderr()) ch = scx.TrackActivity(ch) - event := scx.GetPortForwardEvent() - if err := s.EmitAuditEvent(ctx, &event); err != nil { - s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err) - } + defer func() { + stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode) + if err := s.EmitAuditEvent(ctx, &stopEvent); err != nil { + s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err) + } + }() return trace.Wrap(utils.ProxyConn(ctx, ch, outCh)) } @@ -1120,13 +1170,23 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r } defer conn.Close() - event := scx.GetPortForwardEvent() + event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardCode) if err := s.EmitAuditEvent(s.closeContext, &event); err != nil { s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) } if err := utils.ProxyConn(ctx, ch, conn); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) { s.logger.WarnContext(ctx, "Failed proxying data for port forwarding connection", "error", err) + + event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode) + if err := s.EmitAuditEvent(s.closeContext, &event); err != nil { + s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) + } + } + + event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode) + if err := s.EmitAuditEvent(s.closeContext, &event); err != nil { + s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) } } diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index d68b7367cbc8f..1ada2cfaed8ad 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1489,13 +1489,9 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con return } - if err := utils.ProxyConn(ctx, conn, channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) { - scx.Logger.WarnContext(ctx, "Connection problem in direct-tcpip channel", "error", err) - } - if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{ Metadata: apievents.Metadata{ - Type: events.PortForwardEvent, + Type: events.PortForwardLocalEvent, Code: events.PortForwardCode, }, UserMetadata: scx.Identity.GetUserMetadata(), @@ -1510,6 +1506,19 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con }); err != nil { scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) } + + if err := utils.ProxyConn(ctx, conn, channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) { + event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode) + if err := s.EmitAuditEvent(s.ctx, &event); err != nil { + scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) + } + scx.Logger.WarnContext(ctx, "Connection problem in direct-tcpip channel", "error", err) + } + + event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode) + if err := s.EmitAuditEvent(s.ctx, &event); err != nil { + scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) + } } // handleSessionRequests handles out of band session requests once the session @@ -2201,14 +2210,40 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co } scx.SrcAddr = sshutils.JoinHostPort(srcHost, listenPort) - event := scx.GetPortForwardEvent() - if err := s.EmitAuditEvent(ctx, &event); err != nil { - s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err) + // pregenerate audit events since ServerContext may be closed before they're used + startEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode) + stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode) + errEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardFailureCode) + proxyWithAudit := func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser) { + startEvent.RemoteAddr = remoteAddr + if err := s.EmitAuditEvent(ctx, &startEvent); err != nil { + s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err) + } + + if err := utils.ProxyConn(ctx, client, server); err != nil { + s.logger.ErrorContext(ctx, "PROXY CONN FAILURE", "error", err) + errEvent.RemoteAddr = remoteAddr + if err := s.EmitAuditEvent(ctx, &errEvent); err != nil { + s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err) + } + } + + s.logger.ErrorContext(ctx, "PROXY CONN DONE") + stopEvent.RemoteAddr = remoteAddr + if err := s.EmitAuditEvent(ctx, &stopEvent); err != nil { + s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err) + } } - if err := sshutils.StartRemoteListener(ctx, scx.ConnectionContext.ServerConn, scx.SrcAddr, listener); err != nil { + + if err := sshutils.StartRemoteListener(ctx, scx.ServerConn, scx.SrcAddr, listener, proxyWithAudit); err != nil { return trace.Wrap(err) } + event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardCode) + if err := s.EmitAuditEvent(ctx, &event); err != nil { + s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err) + } + // Report addr back to the client. if r.WantReply { var payload []byte @@ -2232,6 +2267,11 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co // Close the listener once the connection is closed, if it hasn't // been closed already via a cancel-tcpip-forward request. ccx.AddCloser(utils.CloseFunc(func() error { + event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardStopCode) + if err := s.EmitAuditEvent(context.Background(), &event); err != nil { + s.logger.WarnContext(context.Background(), "Failed to emit audit event", "error", err) + } + listener, ok := s.remoteForwardingMap.LoadAndDelete(scx.SrcAddr) if ok { return trace.Wrap(listener.Close()) @@ -2250,7 +2290,6 @@ func (s *Server) handleCancelTCPIPForwardRequest(ctx context.Context, ccx *sshut return trace.Wrap(err) } defer scx.Close() - listener, ok := s.remoteForwardingMap.LoadAndDelete(scx.SrcAddr) if !ok { return trace.NotFound("no remote forwarding listener at %v", scx.SrcAddr) @@ -2258,6 +2297,7 @@ func (s *Server) handleCancelTCPIPForwardRequest(ctx context.Context, ccx *sshut if err := r.Reply(true, nil); err != nil { s.logger.WarnContext(ctx, "Failed to reply to request", "request_type", r.Type, "error", err) } + return trace.Wrap(listener.Close()) } diff --git a/lib/sshutils/tcpip.go b/lib/sshutils/tcpip.go index a7308f010db7c..3358fff88d015 100644 --- a/lib/sshutils/tcpip.go +++ b/lib/sshutils/tcpip.go @@ -79,7 +79,7 @@ type channelOpener interface { // StartRemoteListener listens on the given listener and forwards any accepted // connections over a new "forwarded-tcpip" channel. -func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr string, listener net.Listener) error { +func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr string, listener net.Listener, proxyFn func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser)) error { srcHost, srcPort, err := SplitHostPort(srcAddr) if err != nil { return trace.Wrap(err) @@ -127,7 +127,14 @@ func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr str } go ssh.DiscardRequests(rch) go io.Copy(io.Discard, ch.Stderr()) - go utils.ProxyConn(ctx, conn, ch) + go func() { + if proxyFn != nil { + proxyFn(ctx, conn.RemoteAddr().String(), conn, ch) + return + } + + utils.ProxyConn(ctx, conn, ch) + }() } }() diff --git a/lib/sshutils/tcpip_test.go b/lib/sshutils/tcpip_test.go index 5a59b5a64e57f..a39dfc6ad06a7 100644 --- a/lib/sshutils/tcpip_test.go +++ b/lib/sshutils/tcpip_test.go @@ -29,6 +29,8 @@ import ( "time" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/utils" ) func TestStartRemoteListener(t *testing.T) { @@ -51,7 +53,44 @@ func TestStartRemoteListener(t *testing.T) { require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) t.Cleanup(cancel) - require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener)) + require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, nil)) + + // Check that dialing listener makes it all the way to the test http server. + resp, err := http.Get("http://" + listener.Addr().String()) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "Hello, world", string(body)) +} + +func TestStartRemoteListenerWithCustomProxy(t *testing.T) { + // Create a test server to act as the other side of the channel. + tsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Hello, world") + })) + t.Cleanup(tsrv.Close) + testSrvConn, err := net.Dial("tcp", tsrv.Listener.Addr().String()) + require.NoError(t, err) + + sshConn := &mockSSHConn{ + mockChan: &mockChannel{ + ReadWriter: testSrvConn, + }, + } + + proxied := false + proxyFn := func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser) { + proxied = true + _ = utils.ProxyConn(ctx, client, server) + } + + // Start the remote listener. + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, proxyFn)) // Check that dialing listener makes it all the way to the test http server. resp, err := http.Get("http://" + listener.Addr().String()) @@ -60,4 +99,5 @@ func TestStartRemoteListener(t *testing.T) { body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Hello, world", string(body)) + require.True(t, proxied) }