Skip to content

Commit

Permalink
Added config option to extract raw request Header (#35)
Browse files Browse the repository at this point in the history
Replaced custom header extractor config option with raw header extractor
Simplify code
  • Loading branch information
sada-sigsci authored Jan 10, 2023
1 parent 9c9764c commit 6ffded8
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 173 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Golang Module Release Notes

## 1.12.0 2023-01-10

* Replaced internal custom header extractor function with raw header extractor function

## 1.11.0 2022-01-18

* Improved `Content-Type` header inspection
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.11.0
1.12.0
28 changes: 14 additions & 14 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ var (
DefaultServerFlavor = ""
)

// HeaderExtractorFunc is a header extraction function
type HeaderExtractorFunc func(*http.Request) (http.Header, error)
// RawHeaderExtractorFunc is a header extraction function
type RawHeaderExtractorFunc func(*http.Request) [][2]string

// ModuleConfig is a configuration object for a Module
type ModuleConfig struct {
Expand All @@ -48,7 +48,7 @@ type ModuleConfig struct {
anomalySize int64
expectedContentTypes []string
debug bool
headerExtractor HeaderExtractorFunc
rawHeaderExtractor RawHeaderExtractorFunc
inspector Inspector
inspInit InspectorInitFunc
inspFini InspectorFiniFunc
Expand All @@ -69,7 +69,6 @@ func NewModuleConfig(options ...ModuleConfigOption) (*ModuleConfig, error) {
anomalySize: DefaultAnomalySize,
expectedContentTypes: make([]string, 0),
debug: DefaultDebug,
headerExtractor: nil,
inspector: DefaultInspector,
inspInit: nil,
inspFini: nil,
Expand Down Expand Up @@ -158,9 +157,9 @@ func (c *ModuleConfig) Debug() bool {
return c.debug
}

// HeaderExtractor returns the configuration value
func (c *ModuleConfig) HeaderExtractor() func(r *http.Request) (http.Header, error) {
return c.headerExtractor
// RawHeaderExtractor returns the configuration value
func (c *ModuleConfig) RawHeaderExtractor() RawHeaderExtractorFunc {
return c.rawHeaderExtractor
}

// Inspector returns the inspector
Expand Down Expand Up @@ -228,10 +227,11 @@ type ModuleConfigOption func(*ModuleConfig) error
// to read the body when the content length is not specified.
//
// NOTE: This can be dangerous (fill RAM) if set when the max content
// length is not limited by the server itself. This is intended
// for use with gRPC where the max message receive length is limited.
// Do NOT enable this if there is no limit set on the request
// content length!
//
// length is not limited by the server itself. This is intended
// for use with gRPC where the max message receive length is limited.
// Do NOT enable this if there is no limit set on the request
// content length!
func AllowUnknownContentLength(allow bool) ModuleConfigOption {
return func(c *ModuleConfig) error {
c.allowUnknownContentLength = allow
Expand Down Expand Up @@ -290,12 +290,12 @@ func CustomInspector(insp Inspector, init InspectorInitFunc, fini InspectorFiniF
}
}

// CustomHeaderExtractor is a function argument that sets a function to extract
// RawHeaderExtractor is a function argument that sets a function to extract
// an alternative header object from the request. It is primarily intended only
// for internal use.
func CustomHeaderExtractor(fn func(r *http.Request) (http.Header, error)) ModuleConfigOption {
func RawHeaderExtractor(fn RawHeaderExtractorFunc) ModuleConfigOption {
return func(c *ModuleConfig) error {
c.headerExtractor = fn
c.rawHeaderExtractor = fn
return nil
}
}
Expand Down
16 changes: 8 additions & 8 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ func TestDefaultModuleConfig(t *testing.T) {
if c.Debug() != DefaultDebug {
t.Errorf("Unexpected Debug: %v", c.Debug())
}
if c.HeaderExtractor() != nil {
t.Errorf("Unexpected HeaderExtractor: %p", c.HeaderExtractor())
if c.RawHeaderExtractor() != nil {
t.Errorf("Unexpected RawHeaderExtractor: %p", c.RawHeaderExtractor())
}
if c.Inspector() != DefaultInspector {
t.Errorf("Unexpected Inspector: %v", c.Inspector())
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestConfiguredModuleConfig(t *testing.T) {
AnomalyDuration(10*time.Second),
AnomalySize(8192),
CustomInspector(&RPCInspector{}, func(_ *http.Request) bool { return true }, func(_ *http.Request) {}),
CustomHeaderExtractor(func(_ *http.Request) (http.Header, error) { return nil, nil }),
RawHeaderExtractor(func(_ *http.Request) [][2]string { return nil }),
ExpectedContentType("application/foobar"),
ExpectedContentType("application/fizzbuzz"),
Debug(true),
Expand Down Expand Up @@ -119,8 +119,8 @@ func TestConfiguredModuleConfig(t *testing.T) {
if c.Debug() != true {
t.Errorf("Unexpected Debug: %v", c.Debug())
}
if c.HeaderExtractor() == nil {
t.Errorf("Unexpected HeaderExtractor: %p", c.HeaderExtractor())
if c.RawHeaderExtractor() == nil {
t.Errorf("Unexpected HeaderExtractor: %p", c.RawHeaderExtractor())
}
if c.Inspector() == DefaultInspector {
t.Errorf("Unexpected Inspector: %v", c.Inspector())
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestFromModuleConfig(t *testing.T) {
ExpectedContentType("application/foobar"),
ExpectedContentType("application/fizzbuzz"),
CustomInspector(&RPCInspector{}, func(_ *http.Request) bool { return true }, func(_ *http.Request) {}),
CustomHeaderExtractor(func(_ *http.Request) (http.Header, error) { return nil, nil }),
RawHeaderExtractor(func(_ *http.Request) [][2]string { return nil }),
Debug(true),
MaxContentLength(500000),
Socket("tcp", "0.0.0.0:1234"),
Expand Down Expand Up @@ -216,8 +216,8 @@ func TestFromModuleConfig(t *testing.T) {
if c.Debug() != true {
t.Errorf("Unexpected Debug: %v", c.Debug())
}
if c.HeaderExtractor() == nil {
t.Errorf("Unexpected HeaderExtractor: %p", c.HeaderExtractor())
if c.RawHeaderExtractor() == nil {
t.Errorf("Unexpected HeaderExtractor: %p", c.RawHeaderExtractor())
}
if c.Inspector() == DefaultInspector {
t.Errorf("Unexpected Inspector: %v", c.Inspector())
Expand Down
143 changes: 45 additions & 98 deletions module.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ import (
// data collection and sends it to the Signal Sciences Agent for
// inspection.
type Module struct {
config *ModuleConfig
handler http.Handler
inspector Inspector
inspInit InspectorInitFunc
inspFini InspectorFiniFunc
headerExtractor HeaderExtractorFunc
config *ModuleConfig
handler http.Handler
inspector Inspector
inspInit InspectorInitFunc
inspFini InspectorFiniFunc
}

// NewModule wraps an existing http.Handler with one that extracts data and
Expand All @@ -39,12 +38,11 @@ func NewModule(h http.Handler, options ...ModuleConfigOption) (*Module, error) {

// The following are the defaults, overridden by passing in functional options
m := Module{
handler: h,
config: config,
inspector: config.Inspector(),
inspInit: config.InspectorInit(),
inspFini: config.InspectorFini(),
headerExtractor: config.HeaderExtractor(),
handler: h,
config: config,
inspector: config.Inspector(),
inspInit: config.InspectorInit(),
inspFini: config.InspectorFini(),
}

// By default, use an RPC based inspector if not configured externally
Expand Down Expand Up @@ -169,8 +167,7 @@ func (m *Module) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if m.config.Debug() {
log.Printf("DEBUG: calling 'RPC.PostRequest' due to anomaly: method=%s host=%s url=%s code=%d size=%d duration=%s", req.Method, req.Host, req.URL, code, size, duration)
}
inspin := NewRPCMsgIn(req, nil, code, size, duration, m.config.ModuleIdentifier(), m.config.ServerIdentifier())
m.extractHeaders(req, inspin)
inspin := NewRPCMsgIn(m.config, req, nil, code, size, duration)
inspin.WAFResponse = wafresponse
inspin.HeadersOut = convertHeaders(rw.Header())

Expand Down Expand Up @@ -225,8 +222,7 @@ func (m *Module) inspectorPreRequest(req *http.Request) (inspin2 RPCMsgIn2, out
req.Body = ioutil.NopCloser(bytes.NewBuffer(reqbody))
}

inspin := NewRPCMsgInWithModuleConfig(m.config, req, reqbody)
m.extractHeaders(req, inspin)
inspin := NewRPCMsgIn(m.config, req, reqbody, -1, -1, 0)

if m.config.Debug() {
log.Printf("DEBUG: Making PreRequest call to inspector: %s %s", inspin.Method, inspin.URI)
Expand Down Expand Up @@ -276,20 +272,6 @@ func (m *Module) inspectorPreRequest(req *http.Request) (inspin2 RPCMsgIn2, out
return
}

func (m *Module) extractHeaders(req *http.Request, inspin *RPCMsgIn) {
// If the user supplied a custom header extractor, use it to unpack the
// headers. If there no custom header extractor or it returns an error,
// fallback to the native headers on the request.
if m.headerExtractor != nil {
hin, err := m.headerExtractor(req)
if err == nil {
inspin.HeadersIn = convertHeaders(hin)
} else if m.config.Debug() {
log.Printf("DEBUG: Error extracting custom headers, using native headers: %s", err)
}
}
}

// inspectorPostRequest makes a postrequest call to the inspector
func (m *Module) inspectorPostRequest(inspin *RPCMsgIn) error {
// Create message to agent from the input request
Expand Down Expand Up @@ -329,93 +311,42 @@ func (m *Module) inspectorUpdateRequest(inspin RPCMsgIn2) error {
// NewRPCMsgIn creates a message from a go http.Request object
// End-users of the golang module never need to use this
// directly and it is only exposed for performance testing
func NewRPCMsgIn(r *http.Request, postbody []byte, code int, size int64, dur time.Duration, module, server string) *RPCMsgIn {
func NewRPCMsgIn(mcfg *ModuleConfig, r *http.Request, postbody []byte, code int, size int64, dur time.Duration) *RPCMsgIn {
now := time.Now()

// assemble a message to send to inspector
tlsProtocol := ""
tlsCipher := ""
scheme := "http"
if r.TLS != nil {
// convert golang/spec integers into something human readable
scheme = "https"
tlsProtocol = tlstext.Version(r.TLS.Version)
tlsCipher = tlstext.CipherSuite(r.TLS.CipherSuite)
}

// golang removes Host header from req.Header map and
// promotes it to r.Host field. Add it back as the first header.
hin := convertHeaders(r.Header)
if len(r.Host) > 0 {
hin = append([][2]string{{"Host", r.Host}}, hin...)
}

return &RPCMsgIn{
ModuleVersion: module,
ServerVersion: server,
msgIn := RPCMsgIn{
ModuleVersion: mcfg.ModuleIdentifier(),
ServerVersion: mcfg.ServerIdentifier(),
ServerFlavor: mcfg.ServerFlavor(),
ServerName: r.Host,
Timestamp: now.Unix(),
NowMillis: now.UnixNano() / 1e6,
NowMillis: now.UnixMilli(),
RemoteAddr: stripPort(r.RemoteAddr),
Method: r.Method,
Scheme: scheme,
URI: r.RequestURI,
Protocol: r.Proto,
TLSProtocol: tlsProtocol,
TLSCipher: tlsCipher,
ResponseCode: int32(code),
ResponseMillis: int64(dur / time.Millisecond),
ResponseMillis: dur.Milliseconds(),
ResponseSize: size,
PostBody: string(postbody),
HeadersIn: hin,
}
}

// NewRPCMsgInWithModuleConfig creates a message from a ModuleConfig object
// End-users of the golang module never need to use this
// directly and it is only exposed for performance testing
func NewRPCMsgInWithModuleConfig(mcfg *ModuleConfig, r *http.Request, postbody []byte) *RPCMsgIn {

now := time.Now()

// assemble a message to send to inspector
tlsProtocol := ""
tlsCipher := ""
scheme := "http"
if r.TLS != nil {
// convert golang/spec integers into something human readable
scheme = "https"
tlsProtocol = tlstext.Version(r.TLS.Version)
tlsCipher = tlstext.CipherSuite(r.TLS.CipherSuite)
msgIn.Scheme = "https"
msgIn.TLSProtocol = tlstext.Version(r.TLS.Version)
msgIn.TLSCipher = tlstext.CipherSuite(r.TLS.CipherSuite)
} else {
msgIn.Scheme = "http"
}

// golang removes Host header from req.Header map and
// promotes it to r.Host field. Add it back as the first header.
hin := convertHeaders(r.Header)
if len(r.Host) > 0 {
hin = append([][2]string{{"Host", r.Host}}, hin...)
if hdrs := mcfg.RawHeaderExtractor(); hdrs != nil {
msgIn.HeadersIn = hdrs(r)
}

return &RPCMsgIn{
ModuleVersion: mcfg.ModuleIdentifier(),
ServerVersion: mcfg.ServerIdentifier(),
ServerFlavor: mcfg.ServerFlavor(),
ServerName: r.Host,
Timestamp: now.Unix(),
NowMillis: now.UnixNano() / 1e6,
RemoteAddr: stripPort(r.RemoteAddr),
Method: r.Method,
Scheme: scheme,
URI: r.RequestURI,
Protocol: r.Proto,
TLSProtocol: tlsProtocol,
TLSCipher: tlsCipher,
ResponseCode: -1,
ResponseMillis: 0,
ResponseSize: -1,
PostBody: string(postbody),
HeadersIn: hin,
if msgIn.HeadersIn == nil {
msgIn.HeadersIn = requestHeader(r)
}
return &msgIn
}

// stripPort removes any port from an address (e.g., the client port from the RemoteAddr)
Expand Down Expand Up @@ -505,6 +436,22 @@ func inspectableContentType(s string) bool {
return false
}

// requestHeader returns request headers with host header
func requestHeader(r *http.Request) [][2]string {
out := make([][2]string, 0, len(r.Header)+1)
// golang removes Host header from req.Header map and
// promotes it to r.Host field. Add it back as the first header.
if len(r.Host) > 0 {
out = append(out, [2]string{"Host", r.Host})
}
for key, values := range r.Header {
for _, value := range values {
out = append(out, [2]string{key, value})
}
}
return out
}

// converts a http.Header map to a [][2]string
func convertHeaders(h http.Header) [][2]string {
// get headers
Expand Down
Loading

0 comments on commit 6ffded8

Please sign in to comment.