diff --git a/forward/fwd.go b/forward/fwd.go index 8f1cb1ce..3337f149 100644 --- a/forward/fwd.go +++ b/forward/fwd.go @@ -5,6 +5,7 @@ package forward import ( "bytes" + "context" "crypto/tls" "errors" "io" @@ -42,6 +43,9 @@ type ReqRewriter interface { Rewrite(r *http.Request) } +// WsHook websocket message hook called when message is received or sent +type WsHook func(req *http.Request, messageType int, reader io.Reader) (io.Reader, error) + type optSetter func(f *Forwarder) error // PassHostHeader specifies if a client's Host header field should be delegated. @@ -69,10 +73,21 @@ func Rewriter(r ReqRewriter) optSetter { } } -// WebsocketTLSClientConfig define the websocker client TLS configuration. +// WebsocketTLSClientConfig define the websocket client TLS configuration. func WebsocketTLSClientConfig(tcc *tls.Config) optSetter { return func(f *Forwarder) error { - f.httpForwarder.tlsClientConfig = tcc + f.websocketDialer.TLSClientConfig = tcc.Clone() + // WebSocket is only in http/1.1 + f.websocketDialer.TLSClientConfig.NextProtos = []string{"http/1.1"} + + return nil + } +} + +// WebsocketNetDialContext define the websocket client DialContext function +func WebsocketNetDialContext(dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)) optSetter { + return func(f *Forwarder) error { + f.websocketDialer.NetDialContext = dialContext return nil } } @@ -136,7 +151,23 @@ func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) } } -// ResponseModifier defines a response modifier for the HTTP forwarder. +// WebsocketMessageReceivedHook defines a hook called when websocket message is received. +func WebsocketMessageReceivedHook(hook WsHook) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.websocketMessageReceivedHook = hook + return nil + } +} + +// WebsocketMessageSentHook defines a hook called when websocket message is sent. +func WebsocketMessageSentHook(hook WsHook) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.websocketMessageSentHook = hook + return nil + } +} + +// ResponseModifier defines a response modifier for the HTTP forwarder func ResponseModifier(responseModifier func(*http.Response) error) optSetter { return func(f *Forwarder) error { f.httpForwarder.modifyResponse = responseModifier @@ -180,6 +211,9 @@ type httpForwarder struct { bufferPool httputil.BufferPool websocketConnectionClosedHook func(req *http.Request, conn net.Conn) + websocketMessageReceivedHook WsHook + websocketMessageSentHook WsHook + websocketDialer *websocket.Dialer } const defaultFlushInterval = 100 * clock.Millisecond @@ -203,6 +237,12 @@ func New(setters ...optSetter) (*Forwarder, error) { httpForwarder: &httpForwarder{log: &internalLogger{Logger: log.StandardLogger()}}, handlerContext: &handlerContext{}, } + + f.websocketDialer = &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + } + for _, s := range setters { if err := s(f); err != nil { return nil, err @@ -234,6 +274,9 @@ func New(setters ...optSetter) (*Forwarder, error) { if f.tlsClientConfig == nil { if ht, ok := f.httpForwarder.roundTripper.(*http.Transport); ok { f.tlsClientConfig = ht.TLSClientConfig + if f.websocketDialer.TLSClientConfig == nil && ht.TLSClientConfig != nil { + _ = WebsocketTLSClientConfig(ht.TLSClientConfig)(f) + } } } @@ -315,14 +358,7 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, outReq := f.copyWebSocketRequest(req) - dialer := websocket.DefaultDialer - if outReq.URL.Scheme == "wss" && f.tlsClientConfig != nil { - dialer.TLSClientConfig = f.tlsClientConfig.Clone() - // WebSocket is only in http/1.1 - dialer.TLSClientConfig.NextProtos = []string{"http/1.1"} - } - - targetConn, resp, err := dialer.DialContext(outReq.Context(), outReq.URL.String(), outReq.Header) + targetConn, resp, err := f.websocketDialer.DialContext(outReq.Context(), outReq.URL.String(), outReq.Header) if err != nil { if resp == nil { ctx.errHandler.ServeHTTP(w, req, err) @@ -383,7 +419,8 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, errClient := make(chan error, 1) errBackend := make(chan error, 1) - replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { + replicateWebsocketConn := func(dst, src *websocket.Conn, websocketMessageHook WsHook, errc chan error) { + forward := func(messageType int, reader io.Reader) error { writer, err := dst.NextWriter(messageType) if err != nil { @@ -424,6 +461,12 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, } break } + if websocketMessageHook != nil { + if reader, err = websocketMessageHook(req, msgType, reader); err != nil { + errc <- err + break + } + } err = forward(msgType, reader) if err != nil { errc <- err @@ -432,8 +475,8 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, } } - go replicateWebsocketConn(underlyingConn, targetConn, errClient) - go replicateWebsocketConn(targetConn, underlyingConn, errBackend) + go replicateWebsocketConn(underlyingConn, targetConn, f.websocketMessageSentHook, errClient) + go replicateWebsocketConn(targetConn, underlyingConn, f.websocketMessageReceivedHook, errBackend) var message string select {