diff --git a/manager/runner_test.go b/manager/runner_test.go index 7ba249a3d..42dbe500f 100644 --- a/manager/runner_test.go +++ b/manager/runner_test.go @@ -5,12 +5,16 @@ package manager import ( "bytes" + "context" "fmt" + "net/http" + "net/http/httptest" "os" "os/exec" "path/filepath" "reflect" "strings" + "sync/atomic" "testing" "time" @@ -1023,6 +1027,7 @@ func TestRunner_Start(t *testing.T) { t.Fatal("expected parse only to stop runner") } }) + } func TestRunner_quiescence(t *testing.T) { @@ -1278,3 +1283,95 @@ func TestRunner_commandPath(t *testing.T) { t.Fatalf("unexpected shell: %#v\n", cmd) } } + +// TestRunner_stoppedWatcher verifies that dependencies can't be added to a +// watcher that's been stopped after the first template rendering +func TestRunner_stoppedWatcher(t *testing.T) { + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + stopCtx, stopNow := context.WithCancel(ctx) + t.Cleanup(stopNow) + + // Test server to simulate Vault cluster responses. + reqNum := atomic.Uint64{} + vaultServer := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + t.Logf("GET %q", req.URL.Path) + if req.URL.Path == "/v1/secret/data/test1" { + if reqNum.Add(1) > 1 { + stopNow() + } + + w.WriteHeader(http.StatusOK) + + // We can't configure the default lease in tests because the + // field is global, so ensure there's a rotation_period and ttl + // that causes the runner to make multiple requests + w.Write([]byte(`{ + "data": { + "data": { + "secret": "foo" + }, + "rotation_period": "10ms", + "ttl": 1 + } +}`)) + } else { + w.WriteHeader(http.StatusOK) + } + + })) + + t.Cleanup(vaultServer.Close) + + out, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(out.Name()) + + c := config.DefaultConfig().Merge(&config.Config{ + Vault: &config.VaultConfig{ + Address: &vaultServer.URL, + }, + Templates: &config.TemplateConfigs{ + &config.TemplateConfig{ + Contents: config.String( + `{{with secret "secret/data/test1"}}{{ if .Data}}{{.Data.data.secret}}{{end}}{{end}}`), + Destination: config.String(out.Name()), + }, + }, + }) + c.Finalize() + + r, err := NewRunner(c, false) + if err != nil { + t.Fatal(err) + } + t.Cleanup(r.Stop) + + go r.Start() + + // the mock vault server controls the stopCtx so that we can be sure we've + // rendered the empty template, set up the watcher, and have started the + // next poll + select { + case <-stopCtx.Done(): + r.Stop() + case <-ctx.Done(): + t.Fatal("test timed out") + } + + // there's no way for us to control whether the Start goroutine's select + // takes the watcher's DataCh or the runner's doneCh. So once the runner has + // been stopped, we fake that we've received on the DataCh so that we hit + // the code path that calls Run. This should never add a dependency to the + // watcher. + r.Run() + + if r.watcher.Size() != 0 { + t.Fatal("watcher had dependencies added after stop") + } +} diff --git a/watch/watcher.go b/watch/watcher.go index 571e43fcf..ea362e70d 100644 --- a/watch/watcher.go +++ b/watch/watcher.go @@ -51,6 +51,9 @@ type Watcher struct { // one time instead of polling infinitely. once bool + // stopped signals if this watcher should stop adding any new dependencies + stopped bool + // retryFuncs specifies the different ways to retry based on the upstream. retryFuncConsul RetryFunc retryFuncDefault RetryFunc @@ -138,13 +141,18 @@ func (w *Watcher) ServerErrCh() <-chan error { // and start the associated view. If the dependency already exists, no action is // taken. // -// If the Dependency already existed, it this function will return false. If the -// view was successfully created, it will return true. If an error occurs while -// creating the view, it will be returned here (but future errors returned by -// the view will happen on the channel). +// If the Dependency already existed or the watcher was concurrently stopped, it +// this function will return false. If the view was successfully created, it +// will return true. If an error occurs while creating the view, it will be +// returned here (but future errors returned by the view will happen on the +// channel). func (w *Watcher) Add(d dep.Dependency) (bool, error) { w.Lock() defer w.Unlock() + if w.stopped { + log.Printf("[TRACE] (watcher) did not add %s because watcher is stopped", d) + return false, nil + } log.Printf("[DEBUG] (watcher) adding %s", d) @@ -261,4 +269,5 @@ func (w *Watcher) Stop() { // Close any idle TCP connections w.clients.Stop() + w.stopped = true } diff --git a/watch/watcher_test.go b/watch/watcher_test.go index c0fee4219..bda76c443 100644 --- a/watch/watcher_test.go +++ b/watch/watcher_test.go @@ -46,6 +46,29 @@ func TestAdd_exists(t *testing.T) { } } +func TestAdd_stopped(t *testing.T) { + w := NewWatcher(&NewWatcherInput{ + Clients: dep.NewClientSet(), + Once: true, + }) + + w.Stop() + + d := &TestDep{} + added, err := w.Add(d) + if err != nil { + t.Fatal(err) + } + + if added != false { + t.Errorf("expected Add to return false when watcher is stopped") + } + + if w.Watching(d) { + t.Errorf("expected Add not to add dependency when watcher is stopped") + } +} + func TestAdd_startsViewPoll(t *testing.T) { w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(),