diff --git a/clockwork.go b/clockwork.go index 29341f2..1018051 100644 --- a/clockwork.go +++ b/clockwork.go @@ -152,7 +152,7 @@ func (fc *fakeClock) NewTicker(d time.Duration) Ticker { clock: fc, period: d, } - go ft.tick() + ft.runTickThread() return ft } diff --git a/ticker.go b/ticker.go index 819a355..32b5d01 100644 --- a/ticker.go +++ b/ticker.go @@ -34,33 +34,39 @@ func (ft *fakeTicker) Stop() { ft.stop <- true } -// tick sends the tick time to the ticker channel after every period. -// Tick events are discarded if the underlying ticker channel does -// not have enough capacity. -func (ft *fakeTicker) tick() { - tick := ft.clock.Now() - for { - tick = tick.Add(ft.period) - remaining := tick.Sub(ft.clock.Now()) - if remaining <= 0 { - // The tick should have already happened. This can happen when - // Advance() is called on the fake clock with a duration larger - // than this ticker's period. +// runTickThread initializes a background goroutine to send the tick time to the ticker channel +// after every period. Tick events are discarded if the underlying ticker channel does not have +// enough capacity. +func (ft *fakeTicker) runTickThread() { + nextTick := ft.clock.Now().Add(ft.period) + next := ft.clock.After(ft.period) + go func() { + for { select { - case ft.c <- tick: - default: + case <-ft.stop: + return + case <-next: + // We send the time that the tick was supposed to occur at. + tick := nextTick + // Before sending the tick, we'll compute the next tick time and star the clock.After call. + now := ft.clock.Now() + // First, figure out how many periods there have been between "now" and the time we were + // supposed to have trigged, then advance over all of those. + skipTicks := (now.Sub(tick) + ft.period - 1) / ft.period + nextTick = nextTick.Add(skipTicks * ft.period) + // Now, keep advancing until we are past now. This should happen at most once. + for !nextTick.After(now) { + nextTick = nextTick.Add(ft.period) + } + // Figure out how long between now and the next scheduled tick, then wait that long. + remaining := nextTick.Sub(now) + next = ft.clock.After(remaining) + // Finally, we can actually send the tick. + select { + case ft.c <- tick: + default: + } } - continue } - - select { - case <-ft.stop: - return - case <-ft.clock.After(remaining): - select { - case ft.c <- tick: - default: - } - } - } + }() } diff --git a/ticker_test.go b/ticker_test.go index 5b11e8c..1f34036 100644 --- a/ticker_test.go +++ b/ticker_test.go @@ -2,6 +2,7 @@ package clockwork import ( "testing" + "time" ) func TestFakeTickerStop(t *testing.T) { @@ -38,7 +39,7 @@ func TestFakeTickerTick(t *testing.T) { if tick != first { t.Errorf("wrong tick time, got: %v, want: %v", tick, first) } - default: + case <-time.After(time.Millisecond): t.Errorf("expected tick!") } @@ -51,8 +52,38 @@ func TestFakeTickerTick(t *testing.T) { if tick != second { t.Errorf("wrong tick time, got: %v, want: %v", tick, second) } - default: + case <-time.After(time.Millisecond): t.Errorf("expected tick!") } ft.Stop() } + +func TestFakeTicker_Race(t *testing.T) { + fc := NewFakeClock() + + tickTime := 1 * time.Millisecond + ticker := fc.NewTicker(tickTime) + defer ticker.Stop() + + fc.Advance(tickTime) + + timeout := time.NewTimer(500 * time.Millisecond) + defer timeout.Stop() + + select { + case <-ticker.Chan(): + // Pass + case <-timeout.C: + t.Fatalf("Ticker didn't detect the clock advance!") + } +} + +func TestFakeTicker_Race2(t *testing.T) { + fc := NewFakeClock() + ft := fc.NewTicker(5 * time.Second) + for i := 0; i < 100; i++ { + fc.Advance(5 * time.Second) + <-ft.Chan() + } + ft.Stop() +}