diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index aa55d383444..93de958a4e6 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -212,6 +212,7 @@ https://github.com/elastic/beats/compare/v8.8.1\...main[Check the HEAD diff] - Update CEL mito extensions to v1.12.2. {pull}39755[39755] - Add ability to remove request trace logs from http_endpoint input. {pull}40005[40005] - Add ability to remove request trace logs from entityanalytics input. {pull}40004[40004] +- Added OAuth2 support with auto token refresh for websocket streaming input. {issue}41989[41989] {pull}42212[42212] - Added infinite & blanket retry options to websockets and improved logging and retry logic. {pull}42225[42225] *Auditbeat* diff --git a/x-pack/filebeat/docs/inputs/input-streaming.asciidoc b/x-pack/filebeat/docs/inputs/input-streaming.asciidoc index 85a7c02467a..1ee343e4a9b 100644 --- a/x-pack/filebeat/docs/inputs/input-streaming.asciidoc +++ b/x-pack/filebeat/docs/inputs/input-streaming.asciidoc @@ -20,6 +20,7 @@ The websocket streaming input supports: ** Basic ** Bearer ** Custom +** OAuth2.0 NOTE: The `streaming` input websocket handler does not currently support XML messages. Auto-reconnects are also not supported at the moment so reconnection will occur on input restart. @@ -113,7 +114,7 @@ This will include any sensitive or secret information kept in the `state` object ==== Authentication -The websocket streaming input supports authentication via Basic token authentication, Bearer token authentication and authentication via a custom auth config. Unlike REST inputs Basic Authentication contains a basic auth token, Bearer Authentication contains a bearer token and custom auth contains any combination of custom header and value. These token/key values are are added to the request headers and are not exposed to the `state` object. The custom auth configuration is useful for constructing requests that require custom headers and values for authentication. The basic and bearer token configurations will always use the `Authorization` header and prepend the token with `Basic` or `Bearer` respectively. +The websocket streaming input supports authentication via Basic token authentication, Bearer token authentication, authentication via a custom auth config and OAuth2 based authentication. Unlike REST inputs Basic Authentication contains a basic auth token, Bearer Authentication contains a bearer token and custom auth contains any combination of custom header and value. These token/key values are are added to the request headers and are not exposed to the `state` object. The custom auth configuration is useful for constructing requests that require custom headers and values for authentication. The basic and bearer token configurations will always use the `Authorization` header and prepend the token with `Basic` or `Bearer` respectively. Example configurations with authentication: @@ -166,6 +167,48 @@ filebeat.inputs: token_url: https://api.crowdstrike.com/oauth2/token ---- +==== Websocket OAuth2.0 + +The `websocket` streaming input supports OAuth2.0 authentication. The `auth` configuration field is used to specify the OAuth2.0 configuration. These values are not exposed to the `state` object. + +The `auth` configuration field has the following subfields: + + - `client_id`: The client ID to use for OAuth2.0 authentication. + - `client_secret`: The client secret to use for OAuth2.0 authentication. + - `token_url`: The token URL to use for OAuth2.0 authentication. + - `scopes`: The scopes to use for OAuth2.0 authentication. + - `endpoint_params`: The endpoint parameters to use for OAuth2.0 authentication. + - `auth_style`: The authentication style to use for OAuth2.0 authentication. If left unset, the style will be automatically detected. + - `token_expiry_buffer`: Minimum valid time remaining before attempting an OAuth2 token renewal. The default value is `2m`. + +**Explanations for `auth_style` and `token_expiry_buffer`:** + +- `auth_style`: The authentication style to use for OAuth2.0 authentication which determines how the values of sensitive information like `client_id` and `client_secret` are sent in the token request. The default style value is automatically inferred and used appropriately if no value is provided. The `auth_style` configuration field is optional and can be used to specify the authentication style to use for OAuth2.0 authentication. The `auth_style` configuration field supports the following configurable values: + + * `in_header`: The `client_id` and `client_secret` is sent in the header as a base64 encoded `Authorization` header. + * `in_params`: The `client_id` and `client_secret` is sent in the request body along with the other OAuth2 parameters. + +- `token_expiry_buffer`: The token expiry buffer to use for OAuth2.0 authentication. The `token_expiry_buffer` is used as a safety net to ensure that the token does not expire before the input can refresh it. The `token_expiry_buffer` configuration field is optional. If the `token_expiry_buffer` configuration field is not set, the default value of `2m` is used. + +NOTE: We recommend leaving the `auth_style` configuration field unset (automatically inferred internally) for most scenarios, except where manual intervention is required. + +["source","yaml",subs="attributes"] +---- +filebeat.inputs: +- type: streaming + auth: + client_id: a23fcea2643868ef1a41565a1a8a1c7c + client_secret: c3VwZXJzZWNyZXRfY2xpZW50X3NlY3JldF9zaGhoaGgK + token_url: https://api.sample-url.com/oauth2/token + scopes: ["read", "write"] + endpoint_params: + param1: value1 + param2: value2 + auth_style: in_params + token_expiry_buffer: 5m + url: wss://localhost:443/_stream +---- + [[input-state-streaming]] ==== Input state diff --git a/x-pack/filebeat/input/streaming/config.go b/x-pack/filebeat/input/streaming/config.go index df557d553de..753c36febf3 100644 --- a/x-pack/filebeat/input/streaming/config.go +++ b/x-pack/filebeat/input/streaming/config.go @@ -12,10 +12,17 @@ import ( "regexp" "time" + "golang.org/x/oauth2" + "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-libs/transport/httpcommon" ) +const ( + authStyleInHeader = "in_header" + authStyleInParams = "in_params" +) + type config struct { // Type is the type of the stream being followed. The // zero value indicates websocket. @@ -85,11 +92,30 @@ type customAuthConfig struct { type oAuth2Config struct { // common oauth fields - ClientID string `config:"client_id"` - ClientSecret string `config:"client_secret"` - EndpointParams map[string][]string `config:"endpoint_params"` - Scopes []string `config:"scopes"` - TokenURL string `config:"token_url"` + AuthStyle string `config:"auth_style"` + ClientID string `config:"client_id"` + ClientSecret string `config:"client_secret"` + EndpointParams url.Values `config:"endpoint_params"` + Scopes []string `config:"scopes"` + TokenExpiryBuffer time.Duration `config:"token_expiry_buffer" validate:"min=0"` + TokenURL string `config:"token_url"` + // accessToken is only used internally to set the initial headers via formHeader() if oauth2 is enabled + accessToken string +} + +func (o oAuth2Config) isEnabled() bool { + return o.ClientID != "" && o.ClientSecret != "" && o.TokenURL != "" +} + +func (o oAuth2Config) getAuthStyle() oauth2.AuthStyle { + switch o.AuthStyle { + case authStyleInHeader: + return oauth2.AuthStyleInHeader + case authStyleInParams: + return oauth2.AuthStyleInParams + default: + return oauth2.AuthStyleAutoDetect + } } type urlConfig struct { @@ -144,6 +170,12 @@ func (c config) Validate() error { return errors.New("wait_min must be less than or equal to wait_max") } } + + if c.Auth.OAuth2.isEnabled() { + if c.Auth.OAuth2.AuthStyle != authStyleInHeader && c.Auth.OAuth2.AuthStyle != authStyleInParams && c.Auth.OAuth2.AuthStyle != "" { + return fmt.Errorf("unsupported auth style: %s", c.Auth.OAuth2.AuthStyle) + } + } return nil } @@ -173,6 +205,11 @@ func defaultConfig() config { Transport: httpcommon.HTTPTransportSettings{ Timeout: 180 * time.Second, }, + Auth: authConfig{ + OAuth2: oAuth2Config{ + TokenExpiryBuffer: 2 * time.Minute, + }, + }, Retry: &retry{ MaxAttempts: 5, WaitMin: 1 * time.Second, diff --git a/x-pack/filebeat/input/streaming/config_test.go b/x-pack/filebeat/input/streaming/config_test.go index 437267bc7b7..99b3cc80559 100644 --- a/x-pack/filebeat/input/streaming/config_test.go +++ b/x-pack/filebeat/input/streaming/config_test.go @@ -142,6 +142,79 @@ var configTests = []struct { "url": "wss://localhost:443/v1/stream", }, }, + { + name: "valid_authStyle_default", + config: map[string]interface{}{ + "auth": map[string]interface{}{ + "client_id": "a_client_id", + "client_secret": "a_client_secret", + "token_url": "https://localhost:443/token", + }, + "url": "wss://localhost:443/v1/stream", + }, + }, + { + name: "valid_authStyle_in_params", + config: map[string]interface{}{ + "auth": map[string]interface{}{ + "auth_style": "in_params", + "client_id": "a_client_id", + "client_secret": "a_client_secret", + "token_url": "https://localhost:443/token", + }, + "url": "wss://localhost:443/v1/stream", + }, + }, + { + name: "valid_authStyle_in_header", + config: map[string]interface{}{ + "auth": map[string]interface{}{ + "auth_style": "in_header", + "client_id": "a_client_id", + "client_secret": "a_client_secret", + "token_url": "https://localhost:443/token", + }, + "url": "wss://localhost:443/v1/stream", + }, + }, + { + name: "invalid_authStyle", + config: map[string]interface{}{ + "auth": map[string]interface{}{ + "auth_style": "in_query", + "client_id": "a_client_id", + "client_secret": "a_client_secret", + "token_url": "https://localhost:443/token", + }, + "url": "wss://localhost:443/v1/stream", + }, + wantErr: fmt.Errorf("unsupported auth style: in_query accessing config"), + }, + { + name: "valid_tokenExpiryBuffer", + config: map[string]interface{}{ + "auth": map[string]interface{}{ + "client_id": "a_client_id", + "client_secret": "a_client_secret", + "token_url": "https://localhost:443/token", + "token_expiry_buffer": "5m", + }, + "url": "wss://localhost:443/v1/stream", + }, + }, + { + name: "invalid_tokenExpiryBuffer", + config: map[string]interface{}{ + "auth": map[string]interface{}{ + "client_id": "a_client_id", + "client_secret": "a_client_secret", + "token_url": "https://localhost:443/token", + "token_expiry_buffer": "-1s", + }, + "url": "wss://localhost:443/v1/stream", + }, + wantErr: fmt.Errorf("requires duration >= 0 accessing 'auth.token_expiry_buffer'"), + }, } func TestConfig(t *testing.T) { diff --git a/x-pack/filebeat/input/streaming/input.go b/x-pack/filebeat/input/streaming/input.go index 12a362625bf..3df6b384554 100644 --- a/x-pack/filebeat/input/streaming/input.go +++ b/x-pack/filebeat/input/streaming/input.go @@ -378,12 +378,14 @@ func errorMessage(msg string) map[string]interface{} { func formHeader(cfg config) map[string][]string { header := make(map[string][]string) switch { - case cfg.Auth.CustomAuth != nil: - header[cfg.Auth.CustomAuth.Header] = []string{cfg.Auth.CustomAuth.Value} + case cfg.Auth.OAuth2.accessToken != "": + header["Authorization"] = []string{"Bearer " + cfg.Auth.OAuth2.accessToken} case cfg.Auth.BearerToken != "": header["Authorization"] = []string{"Bearer " + cfg.Auth.BearerToken} case cfg.Auth.BasicToken != "": header["Authorization"] = []string{"Basic " + cfg.Auth.BasicToken} + case cfg.Auth.CustomAuth != nil: + header[cfg.Auth.CustomAuth.Header] = []string{cfg.Auth.CustomAuth.Value} } return header } diff --git a/x-pack/filebeat/input/streaming/input_manager.go b/x-pack/filebeat/input/streaming/input_manager.go index c685452c34f..6a1bd8bc5a4 100644 --- a/x-pack/filebeat/input/streaming/input_manager.go +++ b/x-pack/filebeat/input/streaming/input_manager.go @@ -38,7 +38,6 @@ func cursorConfigure(cfg *conf.C) ([]inputcursor.Source, inputcursor.Input, erro if err := cfg.Unpack(&src.cfg); err != nil { return nil, nil, err } - if src.cfg.Program == "" { // set default program src.cfg.Program = ` diff --git a/x-pack/filebeat/input/streaming/input_test.go b/x-pack/filebeat/input/streaming/input_test.go index e4a8eac1d41..df9b406e17c 100644 --- a/x-pack/filebeat/input/streaming/input_test.go +++ b/x-pack/filebeat/input/streaming/input_test.go @@ -43,7 +43,9 @@ var inputTests = []struct { name string server func(*testing.T, WebSocketHandler, map[string]interface{}, []string) proxyServer func(*testing.T, WebSocketHandler, map[string]interface{}, []string) *httptest.Server + oauth2Server func(*testing.T, http.HandlerFunc, map[string]interface{}) handler WebSocketHandler + oauth2Handler http.HandlerFunc config map[string]interface{} response []string time func() time.Time @@ -417,13 +419,13 @@ var inputTests = []struct { }, }, response: []string{` - { - "pps": { - "agent": "example.proofpoint.com", - "cid": "mmeng_uivm071" - }, - "ts": 1502908200 - }`, + { + "pps": { + "agent": "example.proofpoint.com", + "cid": "mmeng_uivm071" + }, + "ts": 1502908200 + }`, }, want: []map[string]interface{}{ { @@ -586,6 +588,171 @@ var inputTests = []struct { }, }, }, + { + name: "oauth2_blank_auth_style", + oauth2Server: func(t *testing.T, h http.HandlerFunc, config map[string]interface{}) { + s := httptest.NewServer(h) + config["auth.token_url"] = s.URL + "/token" + config["url"] = "ws://placeholder" + t.Cleanup(s.Close) + }, + oauth2Handler: oauth2TokenHandler, + server: webSocketTestServerWithAuth(httptest.NewServer), + handler: defaultHandler, + config: map[string]interface{}{ + "auth": map[string]interface{}{ + "client_id": "a_client_id", + "client_secret": "a_client_secret", + "scopes": []string{ + "scope1", + "scope2", + }, + "endpoint_params": map[string]string{ + "param1": "v1", + }, + }, + "program": ` + bytes(state.response).decode_json().as(inner_body,{ + "events": [inner_body], + })`, + }, + response: []string{` + { + "pps": { + "agent": "example.proofpoint.com", + "cid": "mmeng_uivm071" + }, + "ts": "2017-08-17T14:54:12.949180-07:00", + "data": "2017-08-17T14:54:12.949180-07:00 example sendmail[30641]:v7HLqYbx029423: to=/dev/null, ctladdr= (8/0),delay=00:00:00, xdelay=00:00:00, mailer=*file*, tls_verify=NONE, pri=35342,dsn=2.0.0, stat=Sent", + "sm": { + "tls": { + "verify": "NONE" + }, + "stat": "Sent", + "qid": "v7HLqYbx029423", + "dsn": "2.0.0", + "mailer": "*file*", + "to": [ + "/dev/null" + ], + "ctladdr": " (8/0)", + "delay": "00:00:00", + "xdelay": "00:00:00", + "pri": 35342 + }, + "id": "ZeYGULpZmL5N0151HN1OyA" + }`}, + want: []map[string]interface{}{ + { + "pps": map[string]interface{}{ + "agent": "example.proofpoint.com", + "cid": "mmeng_uivm071", + }, + "ts": "2017-08-17T14:54:12.949180-07:00", + "data": "2017-08-17T14:54:12.949180-07:00 example sendmail[30641]:v7HLqYbx029423: to=/dev/null, ctladdr= (8/0),delay=00:00:00, xdelay=00:00:00, mailer=*file*, tls_verify=NONE, pri=35342,dsn=2.0.0, stat=Sent", + "sm": map[string]interface{}{ + "tls": map[string]interface{}{ + "verify": "NONE", + }, + "stat": "Sent", + "qid": "v7HLqYbx029423", + "dsn": "2.0.0", + "mailer": "*file*", + "to": []interface{}{ + "/dev/null", + }, + "ctladdr": " (8/0)", + "delay": "00:00:00", + "xdelay": "00:00:00", + "pri": float64(35342), + }, + "id": "ZeYGULpZmL5N0151HN1OyA", + }, + }, + }, + { + name: "oauth2_in_params_auth_style", + oauth2Server: func(t *testing.T, h http.HandlerFunc, config map[string]interface{}) { + s := httptest.NewServer(h) + config["auth.token_url"] = s.URL + "/token" + config["url"] = "ws://placeholder" + t.Cleanup(s.Close) + }, + oauth2Handler: oauth2TokenHandler, + server: webSocketTestServerWithAuth(httptest.NewServer), + handler: defaultHandler, + config: map[string]interface{}{ + "auth": map[string]interface{}{ + "auth_style": "in_params", + "client_id": "a_client_id", + "client_secret": "a_client_secret", + "scopes": []string{ + "scope1", + "scope2", + }, + "endpoint_params": map[string]string{ + "param1": "v1", + }, + }, + "program": ` + bytes(state.response).decode_json().as(inner_body,{ + "events": [inner_body], + })`, + }, + response: []string{` + { + "pps": { + "agent": "example.proofpoint.com", + "cid": "mmeng_uivm071" + }, + "ts": "2017-08-17T14:54:12.949180-07:00", + "data": "2017-08-17T14:54:12.949180-07:00 example sendmail[30641]:v7HLqYbx029423: to=/dev/null, ctladdr= (8/0),delay=00:00:00, xdelay=00:00:00, mailer=*file*, tls_verify=NONE, pri=35342,dsn=2.0.0, stat=Sent", + "sm": { + "tls": { + "verify": "NONE" + }, + "stat": "Sent", + "qid": "v7HLqYbx029423", + "dsn": "2.0.0", + "mailer": "*file*", + "to": [ + "/dev/null" + ], + "ctladdr": " (8/0)", + "delay": "00:00:00", + "xdelay": "00:00:00", + "pri": 35342 + }, + "id": "ZeYGULpZmL5N0151HN1OyA" + }`}, + want: []map[string]interface{}{ + { + "pps": map[string]interface{}{ + "agent": "example.proofpoint.com", + "cid": "mmeng_uivm071", + }, + "ts": "2017-08-17T14:54:12.949180-07:00", + "data": "2017-08-17T14:54:12.949180-07:00 example sendmail[30641]:v7HLqYbx029423: to=/dev/null, ctladdr= (8/0),delay=00:00:00, xdelay=00:00:00, mailer=*file*, tls_verify=NONE, pri=35342,dsn=2.0.0, stat=Sent", + "sm": map[string]interface{}{ + "tls": map[string]interface{}{ + "verify": "NONE", + }, + "stat": "Sent", + "qid": "v7HLqYbx029423", + "dsn": "2.0.0", + "mailer": "*file*", + "to": []interface{}{ + "/dev/null", + }, + "ctladdr": " (8/0)", + "delay": "00:00:00", + "xdelay": "00:00:00", + "pri": float64(35342), + }, + "id": "ZeYGULpZmL5N0151HN1OyA", + }, + }, + }, } var urlEvalTests = []struct { @@ -693,6 +860,9 @@ func TestInput(t *testing.T) { logp.TestingSetup() for _, test := range inputTests { t.Run(test.name, func(t *testing.T) { + if test.oauth2Server != nil { + test.oauth2Server(t, test.oauth2Handler, test.config) + } if test.server != nil { test.server(t, test.handler, test.config, test.response) } @@ -870,7 +1040,7 @@ func webSocketTestServerWithAuth(serve func(http.Handler) *httptest.Server) func handler(t, conn, response) })) // only set the resource URL if it is not already set - if config["url"] == nil { + if config["url"] == nil || config["url"] == "ws://placeholder" { config["url"] = "ws" + server.URL[4:] } t.Cleanup(server.Close) @@ -1029,3 +1199,34 @@ func newWebSocketProxyTestServer(t *testing.T, handler WebSocketHandler, config config["proxy_url"] = "ws" + backendServer.URL[4:] return httptest.NewServer(webSocketProxyHandler(config["url"].(string))) } + +//nolint:errcheck // no point checking errors in test server. +func oauth2TokenHandler(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/token" { + return + } + w.Header().Set("content-type", "application/json") + r.ParseForm() + switch { + case r.Method != http.MethodPost: + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"wrong method"}`)) + case r.FormValue("grant_type") != "client_credentials": + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"wrong grant_type"}`)) + case r.FormValue("client_id") != "a_client_id": + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"wrong client_id"}`)) + case r.FormValue("client_secret") != "a_client_secret": + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"wrong client_secret"}`)) + case r.FormValue("scope") != "scope1 scope2": + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"wrong scope"}`)) + case r.FormValue("param1") != "v1": + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"wrong param1"}`)) + default: + w.Write([]byte(`{"token_type": "Bearer", "expires_in": "3600", "access_token": "` + bearerToken + `"}`)) + } +} diff --git a/x-pack/filebeat/input/streaming/websocket.go b/x-pack/filebeat/input/streaming/websocket.go index 584852aabcc..eeb89ad5c9b 100644 --- a/x-pack/filebeat/input/streaming/websocket.go +++ b/x-pack/filebeat/input/streaming/websocket.go @@ -21,6 +21,8 @@ import ( "github.com/gorilla/websocket" "go.uber.org/zap/zapcore" + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" inputcursor "github.com/elastic/beats/v7/filebeat/input/v2/input-cursor" "github.com/elastic/elastic-agent-libs/logp" @@ -31,11 +33,29 @@ import ( type websocketStream struct { processor - id string - cfg config - cursor map[string]any + id string + cfg config + cursor map[string]any + tokenSource oauth2.TokenSource + tokenExpiry <-chan time.Time + time func() time.Time +} + +type loggingRoundTripper struct { + rt http.RoundTripper + log *logp.Logger +} - time func() time.Time +func (l *loggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := l.rt.RoundTrip(req) + // avoided logging request and and response body as it may contain sensitive information and can be huge + if l.log.Core().Enabled(zapcore.DebugLevel) { + l.log.Debugf("request: %v %v\nHeaders: %v\n", req.Method, req.URL, req.Header) + if err == nil { + l.log.Debugf("response: %v\nHeaders: %v\n", resp.Status, resp.Header) + } + } + return resp, err } // NewWebsocketFollower performs environment construction including CEL @@ -53,9 +73,40 @@ func NewWebsocketFollower(ctx context.Context, id string, cfg config, cursor map redact: cfg.Redact, metrics: newInputMetrics(id), }, + // the token expiry handler will never trigger unless a valid expiry time is assigned + tokenExpiry: nil, } s.metrics.url.Set(cfg.URL.String()) s.metrics.errorsTotal.Set(0) + // initialize the oauth2 token source if oauth2 is enabled and set access token in the config + if cfg.Auth.OAuth2.isEnabled() { + config := &clientcredentials.Config{ + AuthStyle: cfg.Auth.OAuth2.getAuthStyle(), + ClientID: cfg.Auth.OAuth2.ClientID, + ClientSecret: cfg.Auth.OAuth2.ClientSecret, + TokenURL: cfg.Auth.OAuth2.TokenURL, + Scopes: cfg.Auth.OAuth2.Scopes, + EndpointParams: cfg.Auth.OAuth2.EndpointParams, + } + // injecting a custom http client with loggingRoundTripper to debug-log request and response attributes for oauth2 token + client := &http.Client{ + Transport: &loggingRoundTripper{http.DefaultTransport, log}, + } + oauth2Ctx := context.WithValue(ctx, oauth2.HTTPClient, client) + s.tokenSource = config.TokenSource(oauth2Ctx) + // get the initial token + token, err := s.tokenSource.Token() + if err != nil { + s.metrics.errorsTotal.Inc() + s.Close() + return nil, fmt.Errorf("failed to obtain oauth2 token: %w", err) + } + // set the initial token in the config if oauth2 is enabled + // this allows seamless header creation in formHeader() for the initial connection + s.cfg.Auth.OAuth2.accessToken = token.AccessToken + // set the initial token expiry channel with buffer of 2 mins + s.tokenExpiry = time.After(time.Until(token.Expiry) - cfg.Auth.OAuth2.TokenExpiryBuffer) + } patterns, err := regexpsFromConfig(cfg) if err != nil { @@ -114,6 +165,32 @@ func (s *websocketStream) FollowStream(ctx context.Context) error { case <-ctx.Done(): s.log.Debugw("context cancelled, closing websocket connection") return ctx.Err() + // s.tokenExpiry channel will only trigger if oauth2 is enabled and the token is about to expire + case <-s.tokenExpiry: + // get the new token + token, err := s.tokenSource.Token() + if err != nil { + s.metrics.errorsTotal.Inc() + s.log.Errorw("failed to obtain oauth2 token during token refresh", "error", err) + return err + } + // gracefully close current connection + if err := c.Close(); err != nil { + s.metrics.errorsTotal.Inc() + s.log.Errorw("encountered an error while closing the existing websocket connection during token refresh", "error", err) + } + // set the new token in the config + s.cfg.Auth.OAuth2.accessToken = token.AccessToken + // set the new token expiry channel with 2 mins buffer + s.tokenExpiry = time.After(time.Until(token.Expiry) - s.cfg.Auth.OAuth2.TokenExpiryBuffer) + // establish a new connection with the new token + c, resp, err = connectWebSocket(ctx, s.cfg, url, s.log) + handleConnectionResponse(resp, s.metrics, s.log) + if err != nil { + s.metrics.errorsTotal.Inc() + s.log.Errorw("failed to establish a new websocket connection on token refresh", "error", err) + return err + } default: _, message, err := c.ReadMessage() if err != nil {