diff --git a/sshmuxer/handle.go b/sshmuxer/handle.go index 0fda661..fa9027a 100644 --- a/sshmuxer/handle.go +++ b/sshmuxer/handle.go @@ -26,6 +26,8 @@ func handleRequest(newRequest *ssh.Request, sshConn *utils.SSHConnection, state case "tcpip-forward": go checkSession(newRequest, sshConn, state) handleRemoteForward(newRequest, sshConn, state) + case "cancel-tcpip-forward": + handleCancelRemoteForward(newRequest, sshConn, state) case "keepalive@openssh.com": err := newRequest.Reply(true, nil) if err != nil { diff --git a/sshmuxer/requests.go b/sshmuxer/requests.go index f87f64b..3e2f88c 100644 --- a/sshmuxer/requests.go +++ b/sshmuxer/requests.go @@ -41,6 +41,55 @@ type forwardedTCPPayload struct { OriginPort uint32 } +// handleCancelRemoteForward will handle a remote forward cancellation +// request and remove the relevant listeners. +func handleCancelRemoteForward(newRequest *ssh.Request, sshConn *utils.SSHConnection, state *utils.State) { + check := &channelForwardMsg{} + + err := ssh.Unmarshal(newRequest.Payload, check) + if err != nil { + log.Println("Error unmarshaling remote forward payload:", err) + err = newRequest.Reply(false, nil) + if err != nil { + log.Println("Error replying to request:", err) + } + return + } + + closed := false + + sshConn.Listeners.Range(func(remoteAddr string, listener net.Listener) bool { + holder, ok := listener.(*utils.ListenerHolder) + if !ok { + return false + } + + if holder.OriginalAddr == check.Addr && holder.OriginalPort == check.Rport { + closed = true + holder.Close() + return false + } + + return true + }) + + if !closed { + log.Println("Unable to close tunnel") + + err = newRequest.Reply(false, nil) + if err != nil { + log.Println("Error replying to request:", err) + } + + return + } + + err = newRequest.Reply(true, nil) + if err != nil { + log.Println("Error replying to request:", err) + } +} + // handleRemoteForward will handle a remote forward request // and stand up the relevant listeners. func handleRemoteForward(newRequest *ssh.Request, sshConn *utils.SSHConnection, state *utils.State) { @@ -56,6 +105,17 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *utils.SSHConnection, err := ssh.Unmarshal(newRequest.Payload, check) if err != nil { log.Println("Error unmarshaling remote forward payload:", err) + + err = newRequest.Reply(false, nil) + if err != nil { + log.Println("Error replying to socket request:", err) + } + return + } + + originalCheck := &channelForwardMsg{ + Addr: check.Addr, + Rport: check.Rport, } originalAddress := check.Addr @@ -131,10 +191,12 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *utils.SSHConnection, } listenerHolder := &utils.ListenerHolder{ - ListenAddr: listenAddr, - Listener: chanListener, - Type: listenerType, - SSHConn: sshConn, + ListenAddr: listenAddr, + Listener: chanListener, + Type: listenerType, + SSHConn: sshConn, + OriginalAddr: originalCheck.Addr, + OriginalPort: originalCheck.Rport, } state.Listeners.Store(listenAddr, listenerHolder) diff --git a/utils/state.go b/utils/state.go index 35b7a78..8d66b40 100644 --- a/utils/state.go +++ b/utils/state.go @@ -48,9 +48,11 @@ func (w LogWriter) Write(bytes []byte) (int, error) { // ListenerHolder represents a generic listener. type ListenerHolder struct { net.Listener - ListenAddr string - Type ListenerType - SSHConn *SSHConnection + ListenAddr string + Type ListenerType + SSHConn *SSHConnection + OriginalAddr string + OriginalPort uint32 } // HTTPHolder holds proxy and connection info.