diff --git a/config/base.json b/config/base.json index ec93d48..ba6befc 100644 --- a/config/base.json +++ b/config/base.json @@ -48,8 +48,6 @@ "artifact_private_hostport": "{{ .plugins.artifact_private_hostport }}" }, "sse-streaming": { - "endpoint": "/v1beta/sse/{id}", - "backend_url_pattern": "/sse/{id}", "backend_host": "{{ .plugins.pipeline_public_hostport }}" } } diff --git a/plugins/grpc-proxy/server.go b/plugins/grpc-proxy/server.go index 9cb6eea..22fc09d 100644 --- a/plugins/grpc-proxy/server.go +++ b/plugins/grpc-proxy/server.go @@ -81,9 +81,14 @@ func (r registerer) RegisterHandlers(f func( func (r registerer) registerHandlers(ctx context.Context, extra map[string]interface{}, h http.Handler) (http.Handler, error) { return h2c.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - w = NewResponseHijacker(w) - h.ServeHTTP(w, req) - w.(Trailer).WriteTrailer() + if req.Header.Get("instill-use-sse") == "true" { + // For SSE, we need to skip this plugin. + h.ServeHTTP(w, req) + } else { + w = NewResponseHijacker(w) + h.ServeHTTP(w, req) + w.(Trailer).WriteTrailer() + } }), &http2.Server{}), nil } diff --git a/plugins/multi-auth/main.go b/plugins/multi-auth/main.go index 18db24d..efecc90 100644 --- a/plugins/multi-auth/main.go +++ b/plugins/multi-auth/main.go @@ -3,8 +3,10 @@ package main import ( "context" "encoding/base64" + "encoding/json" "errors" "fmt" + "io" "net/http" "strings" @@ -54,6 +56,8 @@ func (r registerer) registerHandlers(ctx context.Context, extra map[string]inter mgmtClient, _ := InitMgmtPublicServiceClient(context.Background(), config["grpc_server"].(string), "", "") + httpClient := http.Client{Transport: http.DefaultTransport} + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { authorization := req.Header.Get("Authorization") @@ -103,6 +107,51 @@ func (r registerer) registerHandlers(ctx context.Context, extra map[string]inter req.Header.Set("Instill-Visitor-Uid", visitorID.String()) h.ServeHTTP(w, req) + } else if req.Header.Get("instill-use-sse") == "true" { + // Currently, KrakenD doesn’t support event-stream. To make + // authentication work, we send a request to the management API + // first for verification. + r, err := http.NewRequest("GET", "http://localhost:8080/v1beta/user", nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + r.Header = req.Header + r.Header.Del("instill-use-sse") + + resp, err := httpClient.Do(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if resp.StatusCode == 401 { + writeStatusUnauthorized(req, w) + return + } + type user struct { + User struct { + UID string `json:"uid"` + } `json:"user"` + } + respBytes, err := io.ReadAll(resp.Body) + if err != nil { + writeStatusUnauthorized(req, w) + return + } + defer resp.Body.Close() + + u := user{} + err = json.Unmarshal(respBytes, &u) + if err != nil { + writeStatusUnauthorized(req, w) + return + } + + req.Header.Set("Instill-Auth-Type", "user") + req.Header.Set("Instill-User-Uid", u.User.UID) + req.Header.Set("instill-Use-SSE", "true") + h.ServeHTTP(w, req) + } else { req.Header.Set("Instill-Auth-Type", "user") req.URL.Path = "/internal" + req.URL.Path diff --git a/plugins/sse-streaming/go.mod b/plugins/sse-streaming/go.mod index 7dae764..b1058fd 100644 --- a/plugins/sse-streaming/go.mod +++ b/plugins/sse-streaming/go.mod @@ -1,3 +1,7 @@ module sse_streaming_plugin go 1.22.6 + +require golang.org/x/net v0.26.0 + +require golang.org/x/text v0.16.0 // indirect diff --git a/plugins/sse-streaming/go.sum b/plugins/sse-streaming/go.sum new file mode 100644 index 0000000..1292c61 --- /dev/null +++ b/plugins/sse-streaming/go.sum @@ -0,0 +1,4 @@ +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= diff --git a/plugins/sse-streaming/main.go b/plugins/sse-streaming/main.go index b56305a..a40b634 100644 --- a/plugins/sse-streaming/main.go +++ b/plugins/sse-streaming/main.go @@ -18,8 +18,6 @@ the following configuration must be present in the krakend.json file: "plugin/http-server": { "name": ["sse-streaming"], "sse-streaming": { - "endpoint": "/sse/{id}", - "backend_url_pattern": "/events-stream/{id}", "backend_host": "http://localhost:9081" } } @@ -53,22 +51,13 @@ func (r registerer) registerHandlers(ctx context.Context, extra map[string]inter } // Extract configuration values - endpoint, endpointOk := config["endpoint"].(string) - backendURLPattern, backendURLPatternOk := config["backend_url_pattern"].(string) backendHost, backendHostOk := config["backend_host"].(string) // Check if all required configuration values are present - if !endpointOk || !backendURLPatternOk || !backendHostOk { + if !backendHostOk { return h, errors.New("missing required configuration values") } - // Basic sanity checks on the configuration values - if endpoint == "" { - return h, errors.New("endpoint cannot be empty") - } - if backendURLPattern == "" { - return h, errors.New("backend_url_pattern cannot be empty") - } if backendHost == "" { return h, errors.New("backend_host cannot be empty") } @@ -80,35 +69,36 @@ func (r registerer) registerHandlers(ctx context.Context, extra map[string]inter // Return a new HTTP handler that wraps the original handler with custom logic. return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - // TODO: Performance optimize matchStrings; critical, every request to the API gateway uses this. - matchPaths, id := matchStrings(endpoint, req.URL.Path) - if !matchPaths { + httpClient := http.Client{Transport: http.DefaultTransport} + + // This is a quick solution since we only support sse for pipeline trigger endpoint + if req.Header.Get("instill-use-sse") == "true" { + proxyHandler(w, req, httpClient, backendHost) + } else { h.ServeHTTP(w, req) - return } - // Construct serverURL using the extracted ID - serverURL := fmt.Sprintf("http://%s%s", backendHost, strings.Replace(backendURLPattern, "{id}", id, 1)) - // Call proxyHandler if the path matches - proxyHandler(w, req, serverURL) }), nil } // proxyHandler forwards the request to the actual SSE server and streams the response back to the client. -func proxyHandler(w http.ResponseWriter, r *http.Request, serverURL string) { - logger.Debug("server URL", serverURL) - // Forward the request to the actual SSE server - resp, err := http.Get(serverURL) +func proxyHandler(w http.ResponseWriter, r *http.Request, httpClient http.Client, backendHost string) { + + url := string(r.URL.Path) + url = strings.ReplaceAll(url, "/internal", "") + req, err := http.NewRequest("POST", fmt.Sprintf("http://%s%s", backendHost, url), r.Body) + if err != nil { - errM := "failed to connect to downstream SSE server" - logger.Critical(errM) - http.Error(w, errM, http.StatusInternalServerError) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + req.Header = r.Header + resp, err := httpClient.Do(req) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - defer resp.Body.Close() - - // Set headers for the client w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") @@ -134,28 +124,6 @@ func proxyHandler(w http.ResponseWriter, r *http.Request, serverURL string) { } } -// matchStrings checks if the request path matches the pattern and extracts the ID. -func matchStrings(pattern, str string) (bool, string) { - patternParts := strings.Split(pattern, "/") - strParts := strings.Split(str, "/") - - if len(patternParts) != len(strParts) { - return false, "" - } - - var id string - for i := 0; i < len(patternParts); i++ { - if patternParts[i] != strParts[i] && patternParts[i] != "{id}" { - return false, "" - } - if patternParts[i] == "{id}" { - id = strParts[i] - } - } - - return true, id -} - func main() {} // This logger is replaced by the RegisterLogger method to load the one from KrakenD diff --git a/plugins/sse-streaming/main_test.go b/plugins/sse-streaming/main_test.go deleted file mode 100644 index f9e4fbf..0000000 --- a/plugins/sse-streaming/main_test.go +++ /dev/null @@ -1,114 +0,0 @@ -package main - -import ( - "io" - "net/http" - "net/http/httptest" - "testing" -) - -// Mock HTTP server to simulate the backend SSE server -func mockSSEServer() *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(http.StatusOK) - w.Write([]byte("data: event1\n\n")) - w.Write([]byte("data: event2\n\n")) - })) -} - -func TestProxyHandler(t *testing.T) { - tests := []struct { - serverURL string - wantCode int - wantBody string - }{ - // Positive case: valid SSE server response - { - serverURL: "/valid", - wantCode: http.StatusOK, - wantBody: "data: event1\n\ndata: event2\n\n", - }, - // Negative case: invalid SSE server URL - { - serverURL: "/invalid", - wantCode: http.StatusInternalServerError, - wantBody: "Failed to connect to downstream SSE server\n", - }, - } - - // Create a mock SSE server - mockServer := mockSSEServer() - defer mockServer.Close() - - for _, tt := range tests { - t.Run(tt.serverURL, func(t *testing.T) { - req := httptest.NewRequest("GET", "http://example.com"+tt.serverURL, nil) - w := httptest.NewRecorder() - - // Adjust the serverURL for the positive case to point to the mock server - if tt.serverURL == "/valid" { - tt.serverURL = mockServer.URL - } - - proxyHandler(w, req, tt.serverURL) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != tt.wantCode { - t.Errorf("status code = %v; want %v", resp.StatusCode, tt.wantCode) - } - if string(body) != tt.wantBody { - t.Errorf("body = %v; want %v", string(body), tt.wantBody) - } - }) - } -} - -func TestMatchStrings(t *testing.T) { - tests := []struct { - pattern string - str string - want bool - wantID string - }{ - // Positive cases - {pattern: "/api/user/{id}", str: "/api/user/123", want: true, wantID: "123"}, - {pattern: "/product/{id}/details", str: "/product/456/details", want: true, wantID: "456"}, - - // Negative cases - {pattern: "/api/user/{id}", str: "/api/admin/123", want: false, wantID: ""}, - {pattern: "/product/{id}/details", str: "/product/456/info", want: false, wantID: ""}, - } - - for _, tt := range tests { - t.Run(tt.pattern+"_"+tt.str, func(t *testing.T) { - got, gotID := matchStrings(tt.pattern, tt.str) - if got != tt.want || gotID != tt.wantID { - t.Errorf("matchStrings(%q, %q) = %v, %v; want %v, %v", tt.pattern, tt.str, got, gotID, tt.want, tt.wantID) - } - }) - } -} - -func BenchmarkMatchStrings(b *testing.B) { - testCases := []struct { - pattern string - str string - }{ - {"/users/{id}", "/users/123"}, - {"/products/{id}/details", "/products/456/details"}, - {"/orders/{id}/items", "/orders/789/items"}, - {"/categories/{id}/subcategories", "/categories/101/subcategories"}, - {"/users/{id}/posts/{postId}", "/users/202/posts/303"}, - } - - for _, tc := range testCases { - b.Run(tc.pattern+"_"+tc.str, func(b *testing.B) { - for i := 0; i < b.N; i++ { - matchStrings(tc.pattern, tc.str) - } - }) - } -}