From 584198ddd8a08cfa2445a16db8a644714b5fc63f Mon Sep 17 00:00:00 2001 From: Farzad Ghanei <644113+farzadghanei@users.noreply.github.com> Date: Thu, 9 May 2024 07:51:19 +0200 Subject: [PATCH 1/4] feat: allow setting max concurrent requests This is useful for limiting the number of simultaneous connections the web runner handles at any given time, preventing overloading the server --- examples/config.yaml | 1 + internal/conf.go | 89 ++++++++++++++++++++++++++++--------------- internal/conf_test.go | 31 ++++++++------- internal/run_modes.go | 23 ++++++----- 4 files changed, 90 insertions(+), 54 deletions(-) diff --git a/examples/config.yaml b/examples/config.yaml index f0b892b..ecb8570 100644 --- a/examples/config.yaml +++ b/examples/config.yaml @@ -18,6 +18,7 @@ runners: # response_write_timeout: 2s # timeout: 5s # max_header_bytes: 8192 + # max_concurrent_requests: 1 # 0 means no limit check_suites: diff --git a/internal/conf.go b/internal/conf.go index 34adb2d..e68d59b 100644 --- a/internal/conf.go +++ b/internal/conf.go @@ -21,15 +21,17 @@ type Conf struct { // ConfRunner is config for the check runners type ConfRunner struct { - Timeout *time.Duration - ShutdownSignalHeader *string `yaml:"shutdown_signal_header"` - MaxHeaderBytes *int `yaml:"max_header_bytes"` - ListenAddress string `yaml:"listen_address"` - RequestReadTimeout *time.Duration `yaml:"request_read_timeout"` - ResponseWriteTimeout *time.Duration `yaml:"response_write_timeout"` - ResponseOK *string `yaml:"response_ok"` - ResponseFailed *string `yaml:"response_failed"` - ResponseTimeout *string `yaml:"response_timeout"` + Timeout *time.Duration + ShutdownSignalHeader *string `yaml:"shutdown_signal_header"` + MaxHeaderBytes *int `yaml:"max_header_bytes"` + MaxConcurrentRequests *int `yaml:"max_concurrent_requests"` + ListenAddress string `yaml:"listen_address"` + RequestReadTimeout *time.Duration `yaml:"request_read_timeout"` + ResponseWriteTimeout *time.Duration `yaml:"response_write_timeout"` + ResponseOK *string `yaml:"response_ok"` + ResponseFailed *string `yaml:"response_failed"` + ResponseTimeout *string `yaml:"response_timeout"` + ResponseUnavailable *string `yaml:"response_unavailable"` } // ConfCheckSpec is the spec for each check configuration @@ -58,23 +60,33 @@ func ReadConf(path string) (*Conf, error) { return &conf, err } -// GetDefaultConfRunner returns a ConfRunner based on the default configuration -func GetDefaultConfRunner(runners *ConfRunners) ConfRunner { +// GetBaseConfRunner returns a base ConfRunner with default literal values +func GetBaseConfRunner() ConfRunner { var timeout, readTimeout, writeTimout time.Duration = 5 * time.Minute, 30 * time.Second, 30 * time.Second var maxHeaderBytes int = 8 * 1024 + var MaxConcurrentRequests int = 1 var respOK, respFailed, respTimeout string = "OK", "FAILED", "TIMEOUT" + var respUnavailable string = "UNAVAILABLE" baseConf := ConfRunner{ - Timeout: &timeout, - ShutdownSignalHeader: nil, - MaxHeaderBytes: &maxHeaderBytes, - ListenAddress: "127.0.0.1:8880", - RequestReadTimeout: &readTimeout, - ResponseWriteTimeout: &writeTimout, - ResponseOK: &respOK, - ResponseFailed: &respFailed, - ResponseTimeout: &respTimeout, + Timeout: &timeout, + ShutdownSignalHeader: nil, + MaxHeaderBytes: &maxHeaderBytes, + ListenAddress: "127.0.0.1:8880", + RequestReadTimeout: &readTimeout, + ResponseWriteTimeout: &writeTimout, + ResponseOK: &respOK, + ResponseFailed: &respFailed, + ResponseTimeout: &respTimeout, + ResponseUnavailable: &respUnavailable, + MaxConcurrentRequests: &MaxConcurrentRequests, } + return baseConf +} + +// GetDefaultConfRunner returns a ConfRunner based on the default configuration +func GetDefaultConfRunner(runners *ConfRunners) ConfRunner { + baseConf := GetBaseConfRunner() if defaultConf, defaultExists := (*runners)["default"]; defaultExists { baseConf = MergedConfRunners(&baseConf, &defaultConf) @@ -100,17 +112,7 @@ func GetConfRunner(runners *ConfRunners, name string) (ConfRunner, bool) { // MergedConfRunners merges the baseConf with the overrideConf and returns the merged ConfRunner func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner { - mergedConf := ConfRunner{ - Timeout: overrideConf.Timeout, - ShutdownSignalHeader: overrideConf.ShutdownSignalHeader, - ListenAddress: overrideConf.ListenAddress, - RequestReadTimeout: overrideConf.RequestReadTimeout, - ResponseWriteTimeout: overrideConf.ResponseWriteTimeout, - ResponseOK: overrideConf.ResponseOK, - ResponseFailed: overrideConf.ResponseFailed, - ResponseTimeout: overrideConf.ResponseTimeout, - MaxHeaderBytes: overrideConf.MaxHeaderBytes, - } + mergedConf := CopyConfRunner(overrideConf) if mergedConf.Timeout == nil { mergedConf.Timeout = baseConf.Timeout @@ -148,5 +150,30 @@ func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner { mergedConf.ResponseTimeout = baseConf.ResponseTimeout } + if mergedConf.ResponseUnavailable == nil { + mergedConf.ResponseUnavailable = baseConf.ResponseUnavailable + } + + if mergedConf.MaxConcurrentRequests == nil { + mergedConf.MaxConcurrentRequests = baseConf.MaxConcurrentRequests + } + return mergedConf } + +// CopyConfRunner returns a copy of the ConfRunner with the same values +func CopyConfRunner(conf *ConfRunner) ConfRunner { + return ConfRunner{ + Timeout: conf.Timeout, + ShutdownSignalHeader: conf.ShutdownSignalHeader, + ListenAddress: conf.ListenAddress, + RequestReadTimeout: conf.RequestReadTimeout, + ResponseWriteTimeout: conf.ResponseWriteTimeout, + ResponseOK: conf.ResponseOK, + ResponseFailed: conf.ResponseFailed, + ResponseTimeout: conf.ResponseTimeout, + ResponseUnavailable: conf.ResponseUnavailable, + MaxHeaderBytes: conf.MaxHeaderBytes, + MaxConcurrentRequests: conf.MaxConcurrentRequests, + } +} diff --git a/internal/conf_test.go b/internal/conf_test.go index 86b7cf9..fe5e6c6 100644 --- a/internal/conf_test.go +++ b/internal/conf_test.go @@ -74,8 +74,6 @@ func TestGetDefaultConfRunner(t *testing.T) { respNo := "NO" respMaybe := "MAYBE" - wantMaxHeaderBytes := 8 * 1024 - runners := ConfRunners{ "default": ConfRunner{ Timeout: &wantTimeout, @@ -115,25 +113,30 @@ func TestGetDefaultConfRunner(t *testing.T) { } // Test case where default key does not exist + baseRunner := GetBaseConfRunner() runners = ConfRunners{} defaultRunner = GetDefaultConfRunner(&runners) - if *defaultRunner.Timeout != 5*time.Minute { - t.Errorf("Expected Timeout to be 0, got %v", *defaultRunner.Timeout) + if *defaultRunner.Timeout != *baseRunner.Timeout { + t.Errorf("Timeout want %v, got %v", *baseRunner.Timeout, *defaultRunner.Timeout) + } + if defaultRunner.ListenAddress != baseRunner.ListenAddress { + t.Errorf("ListenAddress want %v, got %s", baseRunner.ListenAddress, defaultRunner.ListenAddress) } - if defaultRunner.ListenAddress != "127.0.0.1:8880" { - t.Errorf("Expected ListenAddress to be 127.0.0.1:8080, got %s", defaultRunner.ListenAddress) + if *defaultRunner.ResponseOK != *baseRunner.ResponseOK { + t.Errorf("ResponseOK want %v, got %s", *baseRunner.ResponseOK, *defaultRunner.ResponseOK) } - if *defaultRunner.ResponseOK != "OK" { - t.Errorf("Expected ResponseOK to be OK, got %s", *defaultRunner.ResponseOK) + if *defaultRunner.ResponseFailed != *baseRunner.ResponseFailed { + t.Errorf("ResponseFailed want %v, got %s", *baseRunner.ResponseFailed, *defaultRunner.ResponseFailed) } - if *defaultRunner.ResponseFailed != "FAILED" { - t.Errorf("Expected ResponseFailed to be FAILED, got %s", *defaultRunner.ResponseFailed) + if *defaultRunner.ResponseTimeout != *baseRunner.ResponseTimeout { + t.Errorf("ResponseTimeout want %v, got %s", *baseRunner.ResponseTimeout, *defaultRunner.ResponseTimeout) } - if *defaultRunner.ResponseTimeout != "TIMEOUT" { - t.Errorf("Expected ResponseTimeout to be TIMEOUT, got %s", *defaultRunner.ResponseTimeout) + if *defaultRunner.MaxHeaderBytes != *baseRunner.MaxHeaderBytes { + t.Errorf("MaxHeaderBytes want %v, got %v", *baseRunner.MaxHeaderBytes, *defaultRunner.MaxHeaderBytes) } - if *defaultRunner.MaxHeaderBytes != wantMaxHeaderBytes { - t.Errorf("Expected MaxHeaderBytes to be %v, got %v", wantMaxHeaderBytes, *defaultRunner.MaxHeaderBytes) + if *defaultRunner.MaxConcurrentRequests != *baseRunner.MaxConcurrentRequests { + t.Errorf("MaxConcurrentRequests want %v, got %v", *baseRunner.MaxConcurrentRequests, + *defaultRunner.MaxConcurrentRequests) } } diff --git a/internal/run_modes.go b/internal/run_modes.go index 3caddef..5205f66 100644 --- a/internal/run_modes.go +++ b/internal/run_modes.go @@ -42,19 +42,24 @@ func RunModeHTTP(checkGroups *CheckSuites, conf *ConfRunner, logger *log.Logger) if conf.ShutdownSignalHeader != nil { shutdownSignalHeaderValue = *conf.ShutdownSignalHeader } - listenAddress := conf.ListenAddress - timeout := *conf.Timeout + maxConcurrentRequests := *conf.MaxConcurrentRequests responseOK := *conf.ResponseOK responseFailed := *conf.ResponseFailed responseTimeout := *conf.ResponseTimeout - requestReadTimeout := *conf.RequestReadTimeout - responseWriteTimeout := *conf.ResponseWriteTimeout + responseUnavailable := *conf.ResponseUnavailable - runner := Runner{Log: logger, Timeout: timeout} + runner := Runner{Log: logger, Timeout: *conf.Timeout} + var runningRequests atomic.Int32 var reqHandlerChan = make(chan *http.Request, 1) httpHandler := func(w http.ResponseWriter, r *http.Request) { + runningRequests.Add(1) + if maxConcurrentRequests > 0 && runningRequests.Load() > int32(maxConcurrentRequests) { + w.WriteHeader(http.StatusServiceUnavailable) // 503 + fmt.Fprint(w, responseUnavailable) + } + defer runningRequests.Add(-1) logger.Printf("processing http request: %s", httpRequestAsString(r)) _, failed, timedout := runChecks(&runner, checkGroups, logger) if timedout > 0 { @@ -73,10 +78,10 @@ func RunModeHTTP(checkGroups *CheckSuites, conf *ConfRunner, logger *log.Logger) http.HandleFunc("/", httpHandler) server := &http.Server{ - Addr: listenAddress, + Addr: conf.ListenAddress, Handler: nil, // use http.DefaultServeMux - ReadTimeout: requestReadTimeout, - WriteTimeout: responseWriteTimeout, + ReadTimeout: *conf.RequestReadTimeout, + WriteTimeout: *conf.ResponseWriteTimeout, IdleTimeout: 0 * time.Second, // set to 0 so uses read timeout MaxHeaderBytes: *conf.MaxHeaderBytes, } @@ -101,7 +106,7 @@ func RunModeHTTP(checkGroups *CheckSuites, conf *ConfRunner, logger *log.Logger) } }() - logger.Printf("starting http server listening on %s", listenAddress) + logger.Printf("starting http server listening on %s", conf.ListenAddress) err := server.ListenAndServe() close(reqHandlerChan) if err != nil && err != http.ErrServerClosed { From 786bca168cf265fb8cee4e0b36ee87a6e32d1a69 Mon Sep 17 00:00:00 2001 From: Farzad Ghanei <644113+farzadghanei@users.noreply.github.com> Date: Thu, 9 May 2024 15:24:11 +0200 Subject: [PATCH 2/4] feat(web): Add support for required headers for http runner These changes allow to configure a list of required HTTP headers to limit clients accessing the http runner. a list of headers can be specified using a map --- .github/workflows/tests.yml | 2 + .golangci.yaml | 5 +- .pre-commit-config.yaml | 1 + cmd/chkok_test.go | 2 + examples/config.yaml | 11 +++-- examples/test-http.yaml | 10 +++- internal/conf.go | 92 ++++++++++++++++++++++--------------- internal/conf_test.go | 81 ++++++++++++++++---------------- internal/run_modes.go | 88 +++++++++++++++++++++++------------ 9 files changed, 178 insertions(+), 114 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7d7d03a..1f0a2c9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,3 +1,5 @@ +--- + name: tests on: diff --git a/.golangci.yaml b/.golangci.yaml index 1703e78..ceb8fd0 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,3 +1,4 @@ +--- # https://golangci-lint.run/usage/configuration/ run: @@ -66,7 +67,7 @@ linters-settings: - performance - style disabled-checks: - - dupImport # https://github.com/go-critic/go-critic/issues/845 + - dupImport # https://github.com/go-critic/go-critic/issues/845 - ifElseChain - octalLiteral - whyNoLint @@ -98,7 +99,7 @@ linters-settings: # There are three different modes: `original`, `strict`, and `lax`. # Default: "original" list-mode: original - # List of file globs that will match this list of settings to compare against. + # File globs that will match this list of settings to compare against. # Default: $all files: - "$all" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5168e37..bd8241a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,4 @@ +--- exclude: "^docs/|/.vscode/" default_stages: [commit] diff --git a/cmd/chkok_test.go b/cmd/chkok_test.go index dda7de2..f271d6f 100644 --- a/cmd/chkok_test.go +++ b/cmd/chkok_test.go @@ -47,6 +47,8 @@ func TestRunHttp(t *testing.T) { t.Fatalf("Failed to create HTTP request: %v", err) } req.Header.Set("X-Server-Shutdown", "test-shutdown-signal") // shutdown the server after the request + req.Header.Set("X-Required-Header", "required-value") + req.Header.Set("X-Required-Header2", "anything") // Send the request multiple times, waiting for the server to // start up and respond diff --git a/examples/config.yaml b/examples/config.yaml index ecb8570..f15eafb 100644 --- a/examples/config.yaml +++ b/examples/config.yaml @@ -9,16 +9,21 @@ runners: # response_timeout: "TIMEOUT" cli: {} # override default runner only for CLI mode http: # override default runner only for HTTP mode - # shutdown_signal_header is mainly useful for testing http mode, do not set it in production - # if set, better be treated like a secret, and a secure transport layer should be used. + listen_address: "127.0.0.1:51234" + # shutdown_signal_header is mainly useful for testing http mode, + # do not set it in production + # if set, better be treated like a secret, and a secure transport + # layer should be used. # this is the value set on "X-Shutdown-Signal" header in the http request # shutdown_signal_header: "test-shutdown-signal" - # listen_address: "127.0.0.1:51234" # request_read_timeout: 2s # response_write_timeout: 2s # timeout: 5s # max_header_bytes: 8192 # max_concurrent_requests: 1 # 0 means no limit + # request_required_headers: + # "X-Required-Header": "required-value" + # "X-Required-Header2": "" # header existance is required, not value check_suites: diff --git a/examples/test-http.yaml b/examples/test-http.yaml index 1ed63ad..69cc91c 100644 --- a/examples/test-http.yaml +++ b/examples/test-http.yaml @@ -4,15 +4,21 @@ runners: default: timeout: 1m + request_required_headers: + "X-Required-Header": "required-value" http: listen_address: "127.0.0.1:51234" request_read_timeout: 2s response_write_timeout: 2s - # shutdown_signal_header is mainly useful for testing http mode, do not set it in production - # if set, better be treated like a secret, and a secure transport layer should be used. + # shutdown_signal_header is mainly useful for testing http mode, + # do not set it in production + # if set, better be treated like a secret, and a secure transport + # layer should be used. # this is the value set on "X-Shutdown-Signal" header in the http request shutdown_signal_header: "test-shutdown-signal" timeout: 5s + request_required_headers: + "X-Required-Header2": "" # header existance is required check_suites: diff --git a/internal/conf.go b/internal/conf.go index e68d59b..ea5962e 100644 --- a/internal/conf.go +++ b/internal/conf.go @@ -1,6 +1,7 @@ package chkok import ( + "maps" "os" "time" @@ -21,17 +22,19 @@ type Conf struct { // ConfRunner is config for the check runners type ConfRunner struct { - Timeout *time.Duration - ShutdownSignalHeader *string `yaml:"shutdown_signal_header"` - MaxHeaderBytes *int `yaml:"max_header_bytes"` - MaxConcurrentRequests *int `yaml:"max_concurrent_requests"` - ListenAddress string `yaml:"listen_address"` - RequestReadTimeout *time.Duration `yaml:"request_read_timeout"` - ResponseWriteTimeout *time.Duration `yaml:"response_write_timeout"` - ResponseOK *string `yaml:"response_ok"` - ResponseFailed *string `yaml:"response_failed"` - ResponseTimeout *string `yaml:"response_timeout"` - ResponseUnavailable *string `yaml:"response_unavailable"` + Timeout *time.Duration + ShutdownSignalHeader *string `yaml:"shutdown_signal_header"` + MaxHeaderBytes *int `yaml:"max_header_bytes"` + MaxConcurrentRequests *int `yaml:"max_concurrent_requests"` + ListenAddress string `yaml:"listen_address"` + RequestReadTimeout *time.Duration `yaml:"request_read_timeout"` + RequestRequiredHeaders map[string]string `yaml:"request_required_headers"` + ResponseWriteTimeout *time.Duration `yaml:"response_write_timeout"` + ResponseOK *string `yaml:"response_ok"` + ResponseFailed *string `yaml:"response_failed"` + ResponseTimeout *string `yaml:"response_timeout"` + ResponseUnavailable *string `yaml:"response_unavailable"` + ResponseInvalidRequest *string `yaml:"response_invalid_request"` } // ConfCheckSpec is the spec for each check configuration @@ -66,20 +69,22 @@ func GetBaseConfRunner() ConfRunner { var maxHeaderBytes int = 8 * 1024 var MaxConcurrentRequests int = 1 var respOK, respFailed, respTimeout string = "OK", "FAILED", "TIMEOUT" - var respUnavailable string = "UNAVAILABLE" + var respUnavailable, respInvalidRequest string = "UNAVAILABLE", "INVALID REQUEST" baseConf := ConfRunner{ - Timeout: &timeout, - ShutdownSignalHeader: nil, - MaxHeaderBytes: &maxHeaderBytes, - ListenAddress: "127.0.0.1:8880", - RequestReadTimeout: &readTimeout, - ResponseWriteTimeout: &writeTimout, - ResponseOK: &respOK, - ResponseFailed: &respFailed, - ResponseTimeout: &respTimeout, - ResponseUnavailable: &respUnavailable, - MaxConcurrentRequests: &MaxConcurrentRequests, + Timeout: &timeout, + ShutdownSignalHeader: nil, + MaxHeaderBytes: &maxHeaderBytes, + ListenAddress: "127.0.0.1:8880", + RequestReadTimeout: &readTimeout, + RequestRequiredHeaders: map[string]string{}, + ResponseWriteTimeout: &writeTimout, + ResponseOK: &respOK, + ResponseFailed: &respFailed, + ResponseTimeout: &respTimeout, + ResponseInvalidRequest: &respInvalidRequest, + ResponseUnavailable: &respUnavailable, + MaxConcurrentRequests: &MaxConcurrentRequests, } return baseConf } @@ -134,6 +139,13 @@ func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner { mergedConf.RequestReadTimeout = baseConf.RequestReadTimeout } + // Merge the request required headers map with the baseConf + for key, value := range baseConf.RequestRequiredHeaders { + if _, exists := mergedConf.RequestRequiredHeaders[key]; !exists { + mergedConf.RequestRequiredHeaders[key] = value + } + } + if mergedConf.ResponseWriteTimeout == nil { mergedConf.ResponseWriteTimeout = baseConf.ResponseWriteTimeout } @@ -158,22 +170,30 @@ func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner { mergedConf.MaxConcurrentRequests = baseConf.MaxConcurrentRequests } + if mergedConf.ResponseInvalidRequest == nil { + mergedConf.ResponseInvalidRequest = baseConf.ResponseInvalidRequest + } + return mergedConf } // CopyConfRunner returns a copy of the ConfRunner with the same values func CopyConfRunner(conf *ConfRunner) ConfRunner { - return ConfRunner{ - Timeout: conf.Timeout, - ShutdownSignalHeader: conf.ShutdownSignalHeader, - ListenAddress: conf.ListenAddress, - RequestReadTimeout: conf.RequestReadTimeout, - ResponseWriteTimeout: conf.ResponseWriteTimeout, - ResponseOK: conf.ResponseOK, - ResponseFailed: conf.ResponseFailed, - ResponseTimeout: conf.ResponseTimeout, - ResponseUnavailable: conf.ResponseUnavailable, - MaxHeaderBytes: conf.MaxHeaderBytes, - MaxConcurrentRequests: conf.MaxConcurrentRequests, - } + newConfRunner := ConfRunner{ + Timeout: conf.Timeout, + ShutdownSignalHeader: conf.ShutdownSignalHeader, + ListenAddress: conf.ListenAddress, + RequestReadTimeout: conf.RequestReadTimeout, + RequestRequiredHeaders: map[string]string{}, + ResponseWriteTimeout: conf.ResponseWriteTimeout, + ResponseOK: conf.ResponseOK, + ResponseFailed: conf.ResponseFailed, + ResponseTimeout: conf.ResponseTimeout, + ResponseUnavailable: conf.ResponseUnavailable, + ResponseInvalidRequest: conf.ResponseInvalidRequest, + MaxHeaderBytes: conf.MaxHeaderBytes, + MaxConcurrentRequests: conf.MaxConcurrentRequests, + } + maps.Copy(newConfRunner.RequestRequiredHeaders, conf.RequestRequiredHeaders) + return newConfRunner } diff --git a/internal/conf_test.go b/internal/conf_test.go index fe5e6c6..38ce68d 100644 --- a/internal/conf_test.go +++ b/internal/conf_test.go @@ -1,10 +1,12 @@ package chkok import ( + "maps" "testing" "time" ) +// TestReadConfErrors tests the ReadConf function for error handling func TestReadConfErrors(t *testing.T) { var conf *Conf var err error @@ -147,17 +149,19 @@ func TestGetConfRunner(t *testing.T) { runners := ConfRunners{ "default": ConfRunner{ - Timeout: &fiveSecond, - ListenAddress: "localhost:8080", - ResponseWriteTimeout: &tenSecond, + Timeout: &fiveSecond, + ListenAddress: "localhost:8080", + ResponseWriteTimeout: &tenSecond, + RequestRequiredHeaders: map[string]string{"X-Test-Default": "test"}, }, "testMinimalHttpRunner": ConfRunner{}, "testHttpRunner": ConfRunner{ - Timeout: &tenSecond, - ShutdownSignalHeader: &shutdownSignalHeader, - ListenAddress: "localhost:9090", - RequestReadTimeout: &fiveSecond, - ResponseWriteTimeout: &fiveSecond, + Timeout: &tenSecond, + ShutdownSignalHeader: &shutdownSignalHeader, + ListenAddress: "localhost:9090", + RequestReadTimeout: &fiveSecond, + ResponseWriteTimeout: &fiveSecond, + RequestRequiredHeaders: map[string]string{"X-Test-2": "http-test"}, }, } @@ -173,14 +177,15 @@ func TestGetConfRunner(t *testing.T) { name: "Existing runner", runnerName: "testHttpRunner", expectedRunner: ConfRunner{ - Timeout: &tenSecond, - ShutdownSignalHeader: &shutdownSignalHeader, - ListenAddress: "localhost:9090", - RequestReadTimeout: &fiveSecond, - ResponseWriteTimeout: &fiveSecond, - ResponseOK: &ok, - ResponseFailed: &failed, - ResponseTimeout: &timeout, + Timeout: &tenSecond, + ShutdownSignalHeader: &shutdownSignalHeader, + ListenAddress: "localhost:9090", + RequestReadTimeout: &fiveSecond, + RequestRequiredHeaders: map[string]string{"X-Test-2": "http-test", "X-Test-Default": "test"}, + ResponseWriteTimeout: &fiveSecond, + ResponseOK: &ok, + ResponseFailed: &failed, + ResponseTimeout: &timeout, }, expectedExists: true, }, @@ -198,42 +203,36 @@ func TestGetConfRunner(t *testing.T) { }, } - var wantTimeout, wantReadTimeout, wantWriteTimeout time.Duration = 0, 0, 0 - var wantResponseOK, wantResponseFailed, wantResponseTimeout, wantListenAddr string = "", "", "", "" - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { runner, exists := GetConfRunner(&runners, tt.runnerName) if exists != tt.expectedExists { - t.Errorf("expected runner exists to be %v, got %v", tt.expectedExists, exists) + t.Errorf("exists want %v got %v", tt.expectedExists, exists) + } + expRunner := tt.expectedRunner + if *runner.Timeout != *expRunner.Timeout { + t.Errorf("timout want %v got %+v", *expRunner.Timeout, runner.Timeout) } - wantTimeout = *tt.expectedRunner.Timeout - wantReadTimeout = *tt.expectedRunner.RequestReadTimeout - wantWriteTimeout = *tt.expectedRunner.ResponseWriteTimeout - wantListenAddr = tt.expectedRunner.ListenAddress - wantResponseOK = *tt.expectedRunner.ResponseOK - wantResponseFailed = *tt.expectedRunner.ResponseFailed - wantResponseTimeout = *tt.expectedRunner.ResponseTimeout - if *runner.Timeout != wantTimeout { - t.Errorf("expected runner timeout to be %+v, got %+v", wantTimeout, runner.Timeout) + if *runner.RequestReadTimeout != *expRunner.RequestReadTimeout { + t.Errorf("read timeout want %+v got %+v", *expRunner.Timeout, runner.RequestReadTimeout) } - if *runner.RequestReadTimeout != wantReadTimeout { - t.Errorf("expected runner read timeout to be %+v, got %+v", wantReadTimeout, runner.RequestReadTimeout) + if *runner.ResponseWriteTimeout != *expRunner.ResponseWriteTimeout { + t.Errorf("write timeout want %+v got %+v", *expRunner.ResponseWriteTimeout, runner.ResponseWriteTimeout) } - if *runner.ResponseWriteTimeout != wantWriteTimeout { - t.Errorf("expected runner write timeout to be %+v, got %+v", wantWriteTimeout, runner.ResponseWriteTimeout) + if runner.ListenAddress != expRunner.ListenAddress { + t.Errorf("listen address want %s got %s", expRunner.ListenAddress, runner.ListenAddress) } - if runner.ListenAddress != wantListenAddr { - t.Errorf("expected runner listen address to be %s, got %s", wantListenAddr, runner.ListenAddress) + if *runner.ResponseOK != *expRunner.ResponseOK { + t.Errorf("response ok want %s got %s", *expRunner.ResponseOK, *runner.ResponseOK) } - if *runner.ResponseOK != wantResponseOK { - t.Errorf("expected runner response ok to be %s, got %s", wantResponseOK, *runner.ResponseOK) + if *runner.ResponseFailed != *expRunner.ResponseFailed { + t.Errorf("response failed want %s got %s", *expRunner.ResponseFailed, *runner.ResponseFailed) } - if *runner.ResponseFailed != wantResponseFailed { - t.Errorf("expected runner response failed to be %s, got %s", wantResponseFailed, *runner.ResponseFailed) + if *runner.ResponseTimeout != *expRunner.ResponseTimeout { + t.Errorf("response timeout want %s got %s", *expRunner.ResponseTimeout, *runner.ResponseTimeout) } - if *runner.ResponseTimeout != wantResponseTimeout { - t.Errorf("expected runner response timeout to be %s, got %s", wantResponseTimeout, *runner.ResponseTimeout) + if !maps.Equal(expRunner.RequestRequiredHeaders, runner.RequestRequiredHeaders) { + t.Errorf("request headers want %+v got %+v", expRunner.RequestRequiredHeaders, runner.RequestRequiredHeaders) } }) } diff --git a/internal/run_modes.go b/internal/run_modes.go index 5205f66..949d128 100644 --- a/internal/run_modes.go +++ b/internal/run_modes.go @@ -42,39 +42,10 @@ func RunModeHTTP(checkGroups *CheckSuites, conf *ConfRunner, logger *log.Logger) if conf.ShutdownSignalHeader != nil { shutdownSignalHeaderValue = *conf.ShutdownSignalHeader } - maxConcurrentRequests := *conf.MaxConcurrentRequests - responseOK := *conf.ResponseOK - responseFailed := *conf.ResponseFailed - responseTimeout := *conf.ResponseTimeout - responseUnavailable := *conf.ResponseUnavailable - - runner := Runner{Log: logger, Timeout: *conf.Timeout} - var runningRequests atomic.Int32 var reqHandlerChan = make(chan *http.Request, 1) - httpHandler := func(w http.ResponseWriter, r *http.Request) { - runningRequests.Add(1) - if maxConcurrentRequests > 0 && runningRequests.Load() > int32(maxConcurrentRequests) { - w.WriteHeader(http.StatusServiceUnavailable) // 503 - fmt.Fprint(w, responseUnavailable) - } - defer runningRequests.Add(-1) - logger.Printf("processing http request: %s", httpRequestAsString(r)) - _, failed, timedout := runChecks(&runner, checkGroups, logger) - if timedout > 0 { - w.WriteHeader(http.StatusGatewayTimeout) // 504 - fmt.Fprint(w, responseTimeout) - } else if failed > 0 { - w.WriteHeader(http.StatusInternalServerError) // 500 - fmt.Fprint(w, responseFailed) - } else { - w.WriteHeader(http.StatusOK) - fmt.Fprint(w, responseOK) - } - reqHandlerChan <- r - } - + httpHandler := makeHTTPRequestHandler(reqHandlerChan, conf, checkGroups, logger) http.HandleFunc("/", httpHandler) server := &http.Server{ @@ -117,6 +88,63 @@ func RunModeHTTP(checkGroups *CheckSuites, conf *ConfRunner, logger *log.Logger) return ExOK } +// makeHTTPRequestHandler creates a http request handler function used by RunModeHTTP +func makeHTTPRequestHandler(reqHandlerChan chan *http.Request, + conf *ConfRunner, checkGroups *CheckSuites, logger *log.Logger) func(http.ResponseWriter, *http.Request) { + maxConcurrentRequests := int32(*conf.MaxConcurrentRequests) + responseOK, responseFailed := *conf.ResponseOK, *conf.ResponseFailed + responseTimeout := *conf.ResponseTimeout + responseUnavailable, responseInvalidRequest := *conf.ResponseUnavailable, *conf.ResponseInvalidRequest + requieredHeaders := conf.RequestRequiredHeaders + shouldCheckHeaders := len(requieredHeaders) > 0 + + runner := Runner{Log: logger, Timeout: *conf.Timeout} + + var runningRequests atomic.Int32 + httpRequestHandler := func(w http.ResponseWriter, r *http.Request) { + runningRequests.Add(1) + if maxConcurrentRequests > 0 && runningRequests.Load() > maxConcurrentRequests { + logger.Printf("runner reached max conccurent requests. rejecting request: %s", httpRequestAsString(r)) + w.WriteHeader(http.StatusServiceUnavailable) // 503 + fmt.Fprint(w, responseUnavailable) + return + } + defer runningRequests.Add(-1) + if shouldCheckHeaders { + for header, value := range requieredHeaders { + reqHeader, ok := r.Header[header] + if !ok { + logger.Printf("http request missing required header %s: %s", header, httpRequestAsString(r)) + w.WriteHeader(http.StatusBadRequest) // 400 + fmt.Print(w, responseInvalidRequest) + return + } + if value != "" && reqHeader[0] != value { + logger.Printf("http request doesn't match required header %s: %s", header, httpRequestAsString(r)) + w.WriteHeader(http.StatusBadRequest) // 400 + fmt.Print(w, responseInvalidRequest) + return + } + } + } + + logger.Printf("processing http request: %s", httpRequestAsString(r)) + _, failed, timedout := runChecks(&runner, checkGroups, logger) + if timedout > 0 { + w.WriteHeader(http.StatusGatewayTimeout) // 504 + fmt.Fprint(w, responseTimeout) + } else if failed > 0 { + w.WriteHeader(http.StatusInternalServerError) // 500 + fmt.Fprint(w, responseFailed) + } else { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, responseOK) + } + reqHandlerChan <- r + } + return httpRequestHandler +} + // runChecks runs checks with logs, and returns number of passed, failed and timedout checks func runChecks(runner *Runner, checkGroups *CheckSuites, logger *log.Logger) (passed, failed, timedout int) { checks := runner.RunChecks(*checkGroups) From efba21a43cd78c6ea47e65188b3dd4ab69f8ffa1 Mon Sep 17 00:00:00 2001 From: Farzad Ghanei <644113+farzadghanei@users.noreply.github.com> Date: Sat, 11 May 2024 09:44:33 +0200 Subject: [PATCH 3/4] refactor: improve merging conf runners Split merging function into two separate functions for timeouts and responses --- internal/conf.go | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/internal/conf.go b/internal/conf.go index ea5962e..a3874d7 100644 --- a/internal/conf.go +++ b/internal/conf.go @@ -119,10 +119,6 @@ func GetConfRunner(runners *ConfRunners, name string) (ConfRunner, bool) { func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner { mergedConf := CopyConfRunner(overrideConf) - if mergedConf.Timeout == nil { - mergedConf.Timeout = baseConf.Timeout - } - if mergedConf.ShutdownSignalHeader == nil { mergedConf.ShutdownSignalHeader = baseConf.ShutdownSignalHeader } @@ -134,11 +130,12 @@ func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner { if mergedConf.ListenAddress == "" { mergedConf.ListenAddress = baseConf.ListenAddress } - - if mergedConf.RequestReadTimeout == nil { - mergedConf.RequestReadTimeout = baseConf.RequestReadTimeout + if mergedConf.MaxConcurrentRequests == nil { + mergedConf.MaxConcurrentRequests = baseConf.MaxConcurrentRequests } + mergeConfRunnerTimeouts(&mergedConf, baseConf) + // Merge the request required headers map with the baseConf for key, value := range baseConf.RequestRequiredHeaders { if _, exists := mergedConf.RequestRequiredHeaders[key]; !exists { @@ -146,10 +143,26 @@ func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner { } } + mergeConfRunnerResponses(&mergedConf, baseConf) + + return mergedConf +} + +// mergeConfRunnerTimeouts merges the timeout fields of the mergedConf with the baseConf in place +func mergeConfRunnerTimeouts(mergedConf, baseConf *ConfRunner) { + if mergedConf.Timeout == nil { + mergedConf.Timeout = baseConf.Timeout + } + if mergedConf.RequestReadTimeout == nil { + mergedConf.RequestReadTimeout = baseConf.RequestReadTimeout + } if mergedConf.ResponseWriteTimeout == nil { mergedConf.ResponseWriteTimeout = baseConf.ResponseWriteTimeout } +} +// mergeConfRunnerResponses merges the response fields of the mergedConf with the baseConf in place +func mergeConfRunnerResponses(mergedConf, baseConf *ConfRunner) { if mergedConf.ResponseOK == nil { mergedConf.ResponseOK = baseConf.ResponseOK } @@ -166,15 +179,9 @@ func MergedConfRunners(baseConf, overrideConf *ConfRunner) ConfRunner { mergedConf.ResponseUnavailable = baseConf.ResponseUnavailable } - if mergedConf.MaxConcurrentRequests == nil { - mergedConf.MaxConcurrentRequests = baseConf.MaxConcurrentRequests - } - if mergedConf.ResponseInvalidRequest == nil { mergedConf.ResponseInvalidRequest = baseConf.ResponseInvalidRequest } - - return mergedConf } // CopyConfRunner returns a copy of the ConfRunner with the same values From 00c492406728d6f53ff7a3638507fbdce0a05ebc Mon Sep 17 00:00:00 2001 From: Farzad Ghanei <644113+farzadghanei@users.noreply.github.com> Date: Sat, 11 May 2024 09:50:23 +0200 Subject: [PATCH 4/4] Update Chkok app version - Update the Chkok application version from 0.2.0 to 0.3.0 --- cmd/chkok.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/chkok.go b/cmd/chkok.go index 49bdf39..0b50295 100644 --- a/cmd/chkok.go +++ b/cmd/chkok.go @@ -10,7 +10,8 @@ import ( chkok "github.com/farzadghanei/chkok/internal" ) -const Version string = "0.2.0" +// Version of the app +const Version string = "0.3.0" // ModeHTTP run checks in http server mode const ModeHTTP string = "http"