diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a92b54..5f3092e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Golang Module Release Notes +## 1.12.0 2023-01-10 + +* Replaced internal custom header extractor function with raw header extractor function + ## 1.11.0 2022-01-18 * Improved `Content-Type` header inspection diff --git a/VERSION b/VERSION index 1cac385..0eed1a2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.11.0 +1.12.0 diff --git a/config.go b/config.go index 0981ceb..2cf82d4 100644 --- a/config.go +++ b/config.go @@ -38,8 +38,8 @@ var ( DefaultServerFlavor = "" ) -// HeaderExtractorFunc is a header extraction function -type HeaderExtractorFunc func(*http.Request) (http.Header, error) +// RawHeaderExtractorFunc is a header extraction function +type RawHeaderExtractorFunc func(*http.Request) [][2]string // ModuleConfig is a configuration object for a Module type ModuleConfig struct { @@ -48,7 +48,7 @@ type ModuleConfig struct { anomalySize int64 expectedContentTypes []string debug bool - headerExtractor HeaderExtractorFunc + rawHeaderExtractor RawHeaderExtractorFunc inspector Inspector inspInit InspectorInitFunc inspFini InspectorFiniFunc @@ -69,7 +69,6 @@ func NewModuleConfig(options ...ModuleConfigOption) (*ModuleConfig, error) { anomalySize: DefaultAnomalySize, expectedContentTypes: make([]string, 0), debug: DefaultDebug, - headerExtractor: nil, inspector: DefaultInspector, inspInit: nil, inspFini: nil, @@ -158,9 +157,9 @@ func (c *ModuleConfig) Debug() bool { return c.debug } -// HeaderExtractor returns the configuration value -func (c *ModuleConfig) HeaderExtractor() func(r *http.Request) (http.Header, error) { - return c.headerExtractor +// RawHeaderExtractor returns the configuration value +func (c *ModuleConfig) RawHeaderExtractor() RawHeaderExtractorFunc { + return c.rawHeaderExtractor } // Inspector returns the inspector @@ -228,10 +227,11 @@ type ModuleConfigOption func(*ModuleConfig) error // to read the body when the content length is not specified. // // NOTE: This can be dangerous (fill RAM) if set when the max content -// length is not limited by the server itself. This is intended -// for use with gRPC where the max message receive length is limited. -// Do NOT enable this if there is no limit set on the request -// content length! +// +// length is not limited by the server itself. This is intended +// for use with gRPC where the max message receive length is limited. +// Do NOT enable this if there is no limit set on the request +// content length! func AllowUnknownContentLength(allow bool) ModuleConfigOption { return func(c *ModuleConfig) error { c.allowUnknownContentLength = allow @@ -290,12 +290,12 @@ func CustomInspector(insp Inspector, init InspectorInitFunc, fini InspectorFiniF } } -// CustomHeaderExtractor is a function argument that sets a function to extract +// RawHeaderExtractor is a function argument that sets a function to extract // an alternative header object from the request. It is primarily intended only // for internal use. -func CustomHeaderExtractor(fn func(r *http.Request) (http.Header, error)) ModuleConfigOption { +func RawHeaderExtractor(fn RawHeaderExtractorFunc) ModuleConfigOption { return func(c *ModuleConfig) error { - c.headerExtractor = fn + c.rawHeaderExtractor = fn return nil } } diff --git a/config_test.go b/config_test.go index 7d70d88..d88aff4 100644 --- a/config_test.go +++ b/config_test.go @@ -29,8 +29,8 @@ func TestDefaultModuleConfig(t *testing.T) { if c.Debug() != DefaultDebug { t.Errorf("Unexpected Debug: %v", c.Debug()) } - if c.HeaderExtractor() != nil { - t.Errorf("Unexpected HeaderExtractor: %p", c.HeaderExtractor()) + if c.RawHeaderExtractor() != nil { + t.Errorf("Unexpected RawHeaderExtractor: %p", c.RawHeaderExtractor()) } if c.Inspector() != DefaultInspector { t.Errorf("Unexpected Inspector: %v", c.Inspector()) @@ -90,7 +90,7 @@ func TestConfiguredModuleConfig(t *testing.T) { AnomalyDuration(10*time.Second), AnomalySize(8192), CustomInspector(&RPCInspector{}, func(_ *http.Request) bool { return true }, func(_ *http.Request) {}), - CustomHeaderExtractor(func(_ *http.Request) (http.Header, error) { return nil, nil }), + RawHeaderExtractor(func(_ *http.Request) [][2]string { return nil }), ExpectedContentType("application/foobar"), ExpectedContentType("application/fizzbuzz"), Debug(true), @@ -119,8 +119,8 @@ func TestConfiguredModuleConfig(t *testing.T) { if c.Debug() != true { t.Errorf("Unexpected Debug: %v", c.Debug()) } - if c.HeaderExtractor() == nil { - t.Errorf("Unexpected HeaderExtractor: %p", c.HeaderExtractor()) + if c.RawHeaderExtractor() == nil { + t.Errorf("Unexpected HeaderExtractor: %p", c.RawHeaderExtractor()) } if c.Inspector() == DefaultInspector { t.Errorf("Unexpected Inspector: %v", c.Inspector()) @@ -182,7 +182,7 @@ func TestFromModuleConfig(t *testing.T) { ExpectedContentType("application/foobar"), ExpectedContentType("application/fizzbuzz"), CustomInspector(&RPCInspector{}, func(_ *http.Request) bool { return true }, func(_ *http.Request) {}), - CustomHeaderExtractor(func(_ *http.Request) (http.Header, error) { return nil, nil }), + RawHeaderExtractor(func(_ *http.Request) [][2]string { return nil }), Debug(true), MaxContentLength(500000), Socket("tcp", "0.0.0.0:1234"), @@ -216,8 +216,8 @@ func TestFromModuleConfig(t *testing.T) { if c.Debug() != true { t.Errorf("Unexpected Debug: %v", c.Debug()) } - if c.HeaderExtractor() == nil { - t.Errorf("Unexpected HeaderExtractor: %p", c.HeaderExtractor()) + if c.RawHeaderExtractor() == nil { + t.Errorf("Unexpected HeaderExtractor: %p", c.RawHeaderExtractor()) } if c.Inspector() == DefaultInspector { t.Errorf("Unexpected Inspector: %v", c.Inspector()) diff --git a/module.go b/module.go index 0e09655..e8b4457 100644 --- a/module.go +++ b/module.go @@ -19,12 +19,11 @@ import ( // data collection and sends it to the Signal Sciences Agent for // inspection. type Module struct { - config *ModuleConfig - handler http.Handler - inspector Inspector - inspInit InspectorInitFunc - inspFini InspectorFiniFunc - headerExtractor HeaderExtractorFunc + config *ModuleConfig + handler http.Handler + inspector Inspector + inspInit InspectorInitFunc + inspFini InspectorFiniFunc } // NewModule wraps an existing http.Handler with one that extracts data and @@ -39,12 +38,11 @@ func NewModule(h http.Handler, options ...ModuleConfigOption) (*Module, error) { // The following are the defaults, overridden by passing in functional options m := Module{ - handler: h, - config: config, - inspector: config.Inspector(), - inspInit: config.InspectorInit(), - inspFini: config.InspectorFini(), - headerExtractor: config.HeaderExtractor(), + handler: h, + config: config, + inspector: config.Inspector(), + inspInit: config.InspectorInit(), + inspFini: config.InspectorFini(), } // By default, use an RPC based inspector if not configured externally @@ -169,8 +167,7 @@ func (m *Module) ServeHTTP(w http.ResponseWriter, req *http.Request) { if m.config.Debug() { log.Printf("DEBUG: calling 'RPC.PostRequest' due to anomaly: method=%s host=%s url=%s code=%d size=%d duration=%s", req.Method, req.Host, req.URL, code, size, duration) } - inspin := NewRPCMsgIn(req, nil, code, size, duration, m.config.ModuleIdentifier(), m.config.ServerIdentifier()) - m.extractHeaders(req, inspin) + inspin := NewRPCMsgIn(m.config, req, nil, code, size, duration) inspin.WAFResponse = wafresponse inspin.HeadersOut = convertHeaders(rw.Header()) @@ -225,8 +222,7 @@ func (m *Module) inspectorPreRequest(req *http.Request) (inspin2 RPCMsgIn2, out req.Body = ioutil.NopCloser(bytes.NewBuffer(reqbody)) } - inspin := NewRPCMsgInWithModuleConfig(m.config, req, reqbody) - m.extractHeaders(req, inspin) + inspin := NewRPCMsgIn(m.config, req, reqbody, -1, -1, 0) if m.config.Debug() { log.Printf("DEBUG: Making PreRequest call to inspector: %s %s", inspin.Method, inspin.URI) @@ -276,20 +272,6 @@ func (m *Module) inspectorPreRequest(req *http.Request) (inspin2 RPCMsgIn2, out return } -func (m *Module) extractHeaders(req *http.Request, inspin *RPCMsgIn) { - // If the user supplied a custom header extractor, use it to unpack the - // headers. If there no custom header extractor or it returns an error, - // fallback to the native headers on the request. - if m.headerExtractor != nil { - hin, err := m.headerExtractor(req) - if err == nil { - inspin.HeadersIn = convertHeaders(hin) - } else if m.config.Debug() { - log.Printf("DEBUG: Error extracting custom headers, using native headers: %s", err) - } - } -} - // inspectorPostRequest makes a postrequest call to the inspector func (m *Module) inspectorPostRequest(inspin *RPCMsgIn) error { // Create message to agent from the input request @@ -329,93 +311,42 @@ func (m *Module) inspectorUpdateRequest(inspin RPCMsgIn2) error { // NewRPCMsgIn creates a message from a go http.Request object // End-users of the golang module never need to use this // directly and it is only exposed for performance testing -func NewRPCMsgIn(r *http.Request, postbody []byte, code int, size int64, dur time.Duration, module, server string) *RPCMsgIn { +func NewRPCMsgIn(mcfg *ModuleConfig, r *http.Request, postbody []byte, code int, size int64, dur time.Duration) *RPCMsgIn { now := time.Now() - // assemble a message to send to inspector - tlsProtocol := "" - tlsCipher := "" - scheme := "http" - if r.TLS != nil { - // convert golang/spec integers into something human readable - scheme = "https" - tlsProtocol = tlstext.Version(r.TLS.Version) - tlsCipher = tlstext.CipherSuite(r.TLS.CipherSuite) - } - - // golang removes Host header from req.Header map and - // promotes it to r.Host field. Add it back as the first header. - hin := convertHeaders(r.Header) - if len(r.Host) > 0 { - hin = append([][2]string{{"Host", r.Host}}, hin...) - } - - return &RPCMsgIn{ - ModuleVersion: module, - ServerVersion: server, + msgIn := RPCMsgIn{ + ModuleVersion: mcfg.ModuleIdentifier(), + ServerVersion: mcfg.ServerIdentifier(), + ServerFlavor: mcfg.ServerFlavor(), ServerName: r.Host, Timestamp: now.Unix(), - NowMillis: now.UnixNano() / 1e6, + NowMillis: now.UnixMilli(), RemoteAddr: stripPort(r.RemoteAddr), Method: r.Method, - Scheme: scheme, URI: r.RequestURI, Protocol: r.Proto, - TLSProtocol: tlsProtocol, - TLSCipher: tlsCipher, ResponseCode: int32(code), - ResponseMillis: int64(dur / time.Millisecond), + ResponseMillis: dur.Milliseconds(), ResponseSize: size, PostBody: string(postbody), - HeadersIn: hin, } -} - -// NewRPCMsgInWithModuleConfig creates a message from a ModuleConfig object -// End-users of the golang module never need to use this -// directly and it is only exposed for performance testing -func NewRPCMsgInWithModuleConfig(mcfg *ModuleConfig, r *http.Request, postbody []byte) *RPCMsgIn { - - now := time.Now() - // assemble a message to send to inspector - tlsProtocol := "" - tlsCipher := "" - scheme := "http" if r.TLS != nil { // convert golang/spec integers into something human readable - scheme = "https" - tlsProtocol = tlstext.Version(r.TLS.Version) - tlsCipher = tlstext.CipherSuite(r.TLS.CipherSuite) + msgIn.Scheme = "https" + msgIn.TLSProtocol = tlstext.Version(r.TLS.Version) + msgIn.TLSCipher = tlstext.CipherSuite(r.TLS.CipherSuite) + } else { + msgIn.Scheme = "http" } - // golang removes Host header from req.Header map and - // promotes it to r.Host field. Add it back as the first header. - hin := convertHeaders(r.Header) - if len(r.Host) > 0 { - hin = append([][2]string{{"Host", r.Host}}, hin...) + if hdrs := mcfg.RawHeaderExtractor(); hdrs != nil { + msgIn.HeadersIn = hdrs(r) } - - return &RPCMsgIn{ - ModuleVersion: mcfg.ModuleIdentifier(), - ServerVersion: mcfg.ServerIdentifier(), - ServerFlavor: mcfg.ServerFlavor(), - ServerName: r.Host, - Timestamp: now.Unix(), - NowMillis: now.UnixNano() / 1e6, - RemoteAddr: stripPort(r.RemoteAddr), - Method: r.Method, - Scheme: scheme, - URI: r.RequestURI, - Protocol: r.Proto, - TLSProtocol: tlsProtocol, - TLSCipher: tlsCipher, - ResponseCode: -1, - ResponseMillis: 0, - ResponseSize: -1, - PostBody: string(postbody), - HeadersIn: hin, + if msgIn.HeadersIn == nil { + msgIn.HeadersIn = requestHeader(r) } + return &msgIn } // stripPort removes any port from an address (e.g., the client port from the RemoteAddr) @@ -505,6 +436,22 @@ func inspectableContentType(s string) bool { return false } +// requestHeader returns request headers with host header +func requestHeader(r *http.Request) [][2]string { + out := make([][2]string, 0, len(r.Header)+1) + // golang removes Host header from req.Header map and + // promotes it to r.Host field. Add it back as the first header. + if len(r.Host) > 0 { + out = append(out, [2]string{"Host", r.Host}) + } + for key, values := range r.Header { + for _, value := range values { + out = append(out, [2]string{key, value}) + } + } + return out +} + // converts a http.Header map to a [][2]string func convertHeaders(h http.Header) [][2]string { // get headers diff --git a/module_test.go b/module_test.go index e9b20ff..1d0715b 100644 --- a/module_test.go +++ b/module_test.go @@ -24,7 +24,7 @@ func TestNewRPCMsgInWithModuleConfigFromRequest(t *testing.T) { AnomalyDuration(10*time.Second), AnomalySize(8192), CustomInspector(&RPCInspector{}, func(_ *http.Request) bool { return true }, func(_ *http.Request) {}), - CustomHeaderExtractor(func(_ *http.Request) (http.Header, error) { return nil, nil }), + RawHeaderExtractor(func(r *http.Request) [][2]string { return nil }), Debug(true), MaxContentLength(500000), Socket("tcp", "0.0.0.0:1234"), @@ -78,62 +78,12 @@ func TestNewRPCMsgInWithModuleConfigFromRequest(t *testing.T) { } } - got := NewRPCMsgInWithModuleConfig(c, r, nil) + got := NewRPCMsgIn(c, r, nil, -1, -1, 0) if ne, equal := eq(*got, want); !equal { t.Errorf("NewRPCMsgInWithModuleConfig: incorrect %q", ne) } } -func TestNewRPCMsgFromRequest(t *testing.T) { - b := bytes.Buffer{} - b.WriteString("test") - r, err := http.NewRequest("GET", "http://localhost/", &b) - if err != nil { - t.Fatal(err) - } - r.RemoteAddr = "127.0.0.1" - r.Header.Add("If-None-Match", `W/"wyzzy"`) - r.RequestURI = "http://localhost/" - r.TLS = &tls.ConnectionState{} - - want := RPCMsgIn{ - ServerName: "localhost", - Method: "GET", - Scheme: "https", - URI: "http://localhost/", - Protocol: "HTTP/1.1", - RemoteAddr: "127.0.0.1", - HeadersIn: [][2]string{{"Host", "localhost"}, {"If-None-Match", `W/"wyzzy"`}}, - } - eq := func(got, want RPCMsgIn) (ne string, equal bool) { - switch { - case got.ServerName != want.ServerName: - return "ServerHostname", false - case got.Method != want.Method: - return "Method", false - case got.Scheme != want.Scheme: - return "Scheme", false - case got.URI != want.URI: - return "URI", false - case got.Protocol != want.Protocol: - return "Protocol", false - case got.RemoteAddr != want.RemoteAddr: - return "RemoteAddr", false - case !reflect.DeepEqual(got.HeadersIn, want.HeadersIn): - return "HeadersIn", false - default: - return "", true - } - } - - got := NewRPCMsgIn(r, nil, -1, -1, -1, "", "") - if ne, equal := eq(*got, want); !equal { - t.Errorf("NewRPCMsgIn: incorrect %q", ne) - } -} - -// helper functions - func TestStripPort(t *testing.T) { cases := []struct { want string @@ -204,6 +154,18 @@ func TestShouldReadBody(t *testing.T) { } } +func TestRequestHeader(t *testing.T) { + r := &http.Request{ + Host: "example.com", + Header: http.Header{"ContentType": {"text/plain"}}, + } + got := requestHeader(r) + expected := [][2]string{{"Host", "example.com"}, {"ContentType", "text/plain"}} + if !reflect.DeepEqual(expected, got) { + t.Errorf("expected %#v, got %#v", expected, got) + } +} + func TestConvertHeaders(t *testing.T) { cases := []struct { want [][2]string // Only the order of like keys matters