diff --git a/cmd/collector/collector.go b/cmd/collector/collector.go index 4bd9fb93..e3bb3cc5 100644 --- a/cmd/collector/collector.go +++ b/cmd/collector/collector.go @@ -40,6 +40,8 @@ import ( const ( logToStdErrFlag = "logtostderr" + // An arbitrary limit, to prevent unbounded memory usage. + maxFlowRecords = 4096 ) var ( @@ -48,8 +50,14 @@ var ( IPFIXTransport string flowRecords []string mutex sync.Mutex + + flowTextSeparator = bytes.Repeat([]byte("="), 80) ) +type jsonResponse struct { + FlowRecords []string `json:"flowRecords"` +} + func initLoggingToFile(fs *pflag.FlagSet) { var err error @@ -134,6 +142,11 @@ func addIPFIXMessage(msg *entities.Message) { } mutex.Lock() defer mutex.Unlock() + if len(flowRecords) >= maxFlowRecords { + // Zero and remove first element. + flowRecords[0] = "" // Ensure first element can be garbage-collected. + flowRecords = flowRecords[1:] + } flowRecords = append(flowRecords, buf.String()) } @@ -240,21 +253,54 @@ func newCollectorCommand() *cobra.Command { func flowRecordHandler(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { + countP := r.URL.Query().Get("count") + var count int + if countP != "" { + var err error + if count, err = strconv.Atoi(countP); err != nil || count < 0 { + http.Error(w, "Invalid count query parameter", http.StatusBadRequest) + return + } + } else { + count = -1 + } + format := r.URL.Query().Get("format") + if format == "" { + format = "json" + } + if format != "text" && format != "json" { + http.Error(w, "Invalid format query parameter", http.StatusBadRequest) + return + } mutex.Lock() defer mutex.Unlock() + if count < 0 || count > len(flowRecords) { + count = len(flowRecords) + } // Retrieve data - klog.InfoS("Return flow records", "length", len(flowRecords)) - // Convert data to JSON - responseData := map[string]interface{}{"flowRecords": flowRecords} - jsonData, err := json.Marshal(responseData) - if err != nil { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return + klog.InfoS("Return flow records", "length", count) + records := flowRecords[len(flowRecords)-count:] + if format == "json" { + // Convert data to JSON + responseData := &jsonResponse{ + FlowRecords: records, + } + jsonData, err := json.Marshal(responseData) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + // Write JSON response + w.Write(jsonData) + } else { + w.Header().Set("Content-Type", "text/plain") + for idx := range records { + w.Write([]byte(records[idx])) + // Write a separator + w.Write(flowTextSeparator) + } } - // Set response headers - w.Header().Set("Content-Type", "application/json") - // Write JSON response - w.Write(jsonData) } else { http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) } diff --git a/cmd/collector/collector_test.go b/cmd/collector/collector_test.go new file mode 100644 index 00000000..a64cb3a1 --- /dev/null +++ b/cmd/collector/collector_test.go @@ -0,0 +1,154 @@ +// Copyright 2024 VMware, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/vmware/go-ipfix/pkg/entities" +) + +var testFlowRecords = []string{"flow1", "flow2", "flow3"} + +func TestFlowRecordHandler(t *testing.T) { + flowRecords = testFlowRecords + defer func() { + flowRecords = nil + }() + + testCases := []struct { + name string + countParam string + formatParam string + expectedStatus int + expectedFlows []string + }{ + { + name: "default", + expectedFlows: []string{"flow1", "flow2", "flow3"}, + }, + { + name: "last flow", + countParam: "1", + expectedFlows: []string{"flow3"}, + }, + { + name: "text format", + formatParam: "text", + expectedFlows: []string{"flow1", "flow2", "flow3"}, + }, + { + name: "json format", + formatParam: "json", + expectedFlows: []string{"flow1", "flow2", "flow3"}, + }, + { + name: "large count", + countParam: "100", + expectedFlows: []string{"flow1", "flow2", "flow3"}, + }, + { + name: "invalid count", + countParam: "-1", + expectedStatus: http.StatusBadRequest, + }, + { + name: "invalid format", + formatParam: "foobar", + expectedStatus: http.StatusBadRequest, + }, + { + name: "both params set", + countParam: "2", + formatParam: "text", + expectedFlows: []string{"flow1", "flow2"}, + }, + } + + for _, tc := range testCases { + expectedStatus := tc.expectedStatus + if expectedStatus == 0 { + expectedStatus = http.StatusOK + } + format := tc.formatParam + if format == "" { + format = "json" + } + rr := httptest.NewRecorder() + u := url.URL{ + Path: "/records", + } + q := u.Query() + if tc.countParam != "" { + q.Set("count", tc.countParam) + } + if tc.formatParam != "" { + q.Set("format", tc.formatParam) + } + u.RawQuery = q.Encode() + req, err := http.NewRequest("GET", u.String(), nil) + require.NoError(t, err) + flowRecordHandler(rr, req) + resp := rr.Result() + defer resp.Body.Close() + require.Equal(t, expectedStatus, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + if expectedStatus != http.StatusOK { + return + } + + contentType := resp.Header.Get("Content-type") + if format == "json" { + require.Equal(t, "application/json", contentType) + var data jsonResponse + err := json.Unmarshal(body, &data) + require.NoError(t, err, "Invalid JSON response") + assert.Equal(t, tc.expectedFlows, data.FlowRecords) + } else if format == "text" { + require.Equal(t, "text/plain", contentType) + flows := strings.Split(string(body), string(flowTextSeparator)) + // Ignore the last empty fragment. + assert.Equal(t, tc.expectedFlows, flows[:len(flows)-1]) + } else { + require.FailNow(t, "Invalid format specified for test case") + } + } +} + +func TestAddIPFIXMessage(t *testing.T) { + defer func() { + flowRecords = nil + }() + set := entities.NewSet(false) + msg := entities.NewMessage(false) + msg.AddSet(set) + for i := 0; i < maxFlowRecords; i++ { + addIPFIXMessage(msg) + require.Len(t, flowRecords, i+1) + } + addIPFIXMessage(msg) + assert.Len(t, flowRecords, maxFlowRecords) +}