Skip to content

Commit

Permalink
feat: handle websocket connections sperately
Browse files Browse the repository at this point in the history
Signed-off-by: Smuu <18609909+Smuu@users.noreply.github.com>
  • Loading branch information
smuu committed Oct 30, 2023
1 parent 44438ce commit c1d76aa
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 5 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ go 1.21.1

require (
github.com/andybalholm/brotli v1.0.6
github.com/gorilla/websocket v1.5.0
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/google/brotli/go/cbrotli v0.0.0-20231026090320-9b83be233e0e h1:q7hxkkc0sikXStQES2u4sYPVuNGMj8NUagcKAZp0sDI=
github.com/google/brotli/go/cbrotli v0.0.0-20231026090320-9b83be233e0e/go.mod h1:nOPhAkwVliJdNTkj3gXpljmWhjc4wCaVqbMJcPKWP4s=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
71 changes: 68 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"

"github.com/andybalholm/brotli"
"github.com/gorilla/websocket"
)

var (
Expand All @@ -30,9 +31,6 @@ func replaceDomainInResponse(originalSubdomain, replaceSubdomain, originalDomain
replacedBody := strings.ReplaceAll(body, fullReplace, fullOriginal)
buffer.Reset()
buffer.WriteString(replacedBody)

// Logging for troubleshooting
debugLog.Printf("fullReplace: %s, fullOriginal: %s", fullReplace, fullOriginal)
}

func proxyRequest(fullSubdomain, path string, buffer *bytes.Buffer, r *http.Request) (int, map[string]string, error) {
Expand Down Expand Up @@ -70,6 +68,7 @@ func proxyRequest(fullSubdomain, path string, buffer *bytes.Buffer, r *http.Requ

func handleHttpRequest(w http.ResponseWriter, r *http.Request) {
infoLog.Printf("Received request from %s", r.Host)

hostParts := strings.Split(r.Host, ".")
if len(hostParts) < 3 {
errorLog.Printf("Invalid domain: %s", r.Host)
Expand All @@ -80,6 +79,13 @@ func handleHttpRequest(w http.ResponseWriter, r *http.Request) {
subdomain := hostParts[0] // Extract original domain
originalDomain := strings.Join(hostParts[1:], ".")

// Check for WebSocket upgrade headers
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
// Handle WebSocket requests by proxying to snapscale
proxyWebSocketRequest(subdomain, w, r)
return
}

buffer := new(bytes.Buffer)
backupBuffer := new(bytes.Buffer)

Expand Down Expand Up @@ -117,6 +123,65 @@ func handleHttpRequest(w http.ResponseWriter, r *http.Request) {
io.Copy(w, buffer)
}

var upgrader = websocket.Upgrader{} // use default options

func proxyWebSocketRequest(subdomain string, w http.ResponseWriter, r *http.Request) {
// Build target URL
fullSubdomain := subdomain + "-snapscale"
target := "wss://" + fullSubdomain + ".lunaroasis.net" + r.RequestURI

// Create a new WebSocket connection to the target
dialer := websocket.Dialer{}
targetConn, _, err := dialer.Dial(target, nil)
if err != nil {
errorLog.Printf("Failed to connect to target: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
defer targetConn.Close()

// Upgrade the client connection to a WebSocket connection
clientConn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
errorLog.Printf("Failed to upgrade client connection: %v", err)
return // No need to send an error response, Upgrade already did if there was an error
}
defer clientConn.Close()

// Start goroutines to copy data between the client and target
go func() {
for {
messageType, message, err := targetConn.ReadMessage()
if err != nil {
errorLog.Printf("Failed to read from target: %v", err)
return
}
err = clientConn.WriteMessage(messageType, message)
if err != nil {
errorLog.Printf("Failed to write to client: %v", err)
return
}
}
}()
go func() {
for {
messageType, message, err := clientConn.ReadMessage()
if err != nil {
errorLog.Printf("Failed to read from client: %v", err)
return
}
err = targetConn.WriteMessage(messageType, message)
if err != nil {
errorLog.Printf("Failed to write to target: %v", err)
return
}
}
}()

// The goroutines will run until one of the connections is closed
select {}
}

func compressBrotli(buffer *bytes.Buffer) *bytes.Buffer {
var compressedData bytes.Buffer
writer := brotli.NewWriterLevel(&compressedData, brotli.DefaultCompression)
Expand Down

0 comments on commit c1d76aa

Please sign in to comment.