From c1d76aa2ed685cbafd23ab81e1f39fee43b5a5a9 Mon Sep 17 00:00:00 2001 From: Smuu <18609909+Smuu@users.noreply.github.com> Date: Mon, 30 Oct 2023 16:11:01 +0100 Subject: [PATCH] feat: handle websocket connections sperately Signed-off-by: Smuu <18609909+Smuu@users.noreply.github.com> --- go.mod | 1 + go.sum | 4 ++-- main.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 71 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 13b48f8..f8ef008 100644 --- a/go.mod +++ b/go.mod @@ -4,4 +4,5 @@ go 1.21.1 require ( github.com/andybalholm/brotli v1.0.6 + github.com/gorilla/websocket v1.5.0 ) diff --git a/go.sum b/go.sum index 035290a..93968f7 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/google/brotli/go/cbrotli v0.0.0-20231026090320-9b83be233e0e h1:q7hxkkc0sikXStQES2u4sYPVuNGMj8NUagcKAZp0sDI= -github.com/google/brotli/go/cbrotli v0.0.0-20231026090320-9b83be233e0e/go.mod h1:nOPhAkwVliJdNTkj3gXpljmWhjc4wCaVqbMJcPKWP4s= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/main.go b/main.go index 7580793..546b049 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/andybalholm/brotli" + "github.com/gorilla/websocket" ) var ( @@ -30,9 +31,6 @@ func replaceDomainInResponse(originalSubdomain, replaceSubdomain, originalDomain replacedBody := strings.ReplaceAll(body, fullReplace, fullOriginal) buffer.Reset() buffer.WriteString(replacedBody) - - // Logging for troubleshooting - debugLog.Printf("fullReplace: %s, fullOriginal: %s", fullReplace, fullOriginal) } func proxyRequest(fullSubdomain, path string, buffer *bytes.Buffer, r *http.Request) (int, map[string]string, error) { @@ -70,6 +68,7 @@ func proxyRequest(fullSubdomain, path string, buffer *bytes.Buffer, r *http.Requ func handleHttpRequest(w http.ResponseWriter, r *http.Request) { infoLog.Printf("Received request from %s", r.Host) + hostParts := strings.Split(r.Host, ".") if len(hostParts) < 3 { errorLog.Printf("Invalid domain: %s", r.Host) @@ -80,6 +79,13 @@ func handleHttpRequest(w http.ResponseWriter, r *http.Request) { subdomain := hostParts[0] // Extract original domain originalDomain := strings.Join(hostParts[1:], ".") + // Check for WebSocket upgrade headers + if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" { + // Handle WebSocket requests by proxying to snapscale + proxyWebSocketRequest(subdomain, w, r) + return + } + buffer := new(bytes.Buffer) backupBuffer := new(bytes.Buffer) @@ -117,6 +123,65 @@ func handleHttpRequest(w http.ResponseWriter, r *http.Request) { io.Copy(w, buffer) } +var upgrader = websocket.Upgrader{} // use default options + +func proxyWebSocketRequest(subdomain string, w http.ResponseWriter, r *http.Request) { + // Build target URL + fullSubdomain := subdomain + "-snapscale" + target := "wss://" + fullSubdomain + ".lunaroasis.net" + r.RequestURI + + // Create a new WebSocket connection to the target + dialer := websocket.Dialer{} + targetConn, _, err := dialer.Dial(target, nil) + if err != nil { + errorLog.Printf("Failed to connect to target: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + defer targetConn.Close() + + // Upgrade the client connection to a WebSocket connection + clientConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + errorLog.Printf("Failed to upgrade client connection: %v", err) + return // No need to send an error response, Upgrade already did if there was an error + } + defer clientConn.Close() + + // Start goroutines to copy data between the client and target + go func() { + for { + messageType, message, err := targetConn.ReadMessage() + if err != nil { + errorLog.Printf("Failed to read from target: %v", err) + return + } + err = clientConn.WriteMessage(messageType, message) + if err != nil { + errorLog.Printf("Failed to write to client: %v", err) + return + } + } + }() + go func() { + for { + messageType, message, err := clientConn.ReadMessage() + if err != nil { + errorLog.Printf("Failed to read from client: %v", err) + return + } + err = targetConn.WriteMessage(messageType, message) + if err != nil { + errorLog.Printf("Failed to write to target: %v", err) + return + } + } + }() + + // The goroutines will run until one of the connections is closed + select {} +} + func compressBrotli(buffer *bytes.Buffer) *bytes.Buffer { var compressedData bytes.Buffer writer := brotli.NewWriterLevel(&compressedData, brotli.DefaultCompression)