diff --git a/progress.go b/progress.go index 759a1a1..1caacd1 100644 --- a/progress.go +++ b/progress.go @@ -41,14 +41,18 @@ type Progress struct { // New returns a new progress bar with defaults func New() *Progress { + lw := uilive.New() + lw.Out = Out + return &Progress{ Width: Width, Out: Out, Bars: make([]*Bar, 0), RefreshInterval: RefreshInterval, - lw: uilive.New(), - mtx: &sync.RWMutex{}, + tdone: make(chan bool), + lw: uilive.New(), + mtx: &sync.RWMutex{}, } } @@ -72,6 +76,20 @@ func Listen() { defaultProgress.Listen() } +func (p *Progress) SetOut(o io.Writer) { + p.mtx.Lock() + defer p.mtx.Unlock() + + p.Out = o + p.lw.Out = o +} + +func (p *Progress) SetRefreshInterval(interval time.Duration) { + p.mtx.Lock() + defer p.mtx.Unlock() + p.RefreshInterval = interval +} + // AddBar creates a new progress bar and adds to the container func (p *Progress) AddBar(total int) *Bar { p.mtx.Lock() @@ -85,55 +103,41 @@ func (p *Progress) AddBar(total int) *Bar { // Listen listens for updates and renders the progress bars func (p *Progress) Listen() { - var tickChan = p.ticker.C - p.lw.Out = p.Out - for { - select { - case <-tickChan: - p.mtx.RLock() - if p.ticker != nil { - p.print() - p.lw.Flush() - } + p.mtx.Lock() + interval := p.RefreshInterval + p.mtx.Unlock() - p.mtx.RUnlock() + select { + case <-time.After(interval): + p.print() case <-p.tdone: - if p.ticker != nil { - p.ticker.Stop() - p.ticker = nil - return - } + p.print() + close(p.tdone) + return } } } func (p *Progress) print() { + p.mtx.Lock() + defer p.mtx.Unlock() for _, bar := range p.Bars { fmt.Fprintln(p.lw, bar.String()) } + p.lw.Flush() } // Start starts the rendering the progress of progress bars. It listens for updates using `bar.Set(n)` and new bars when added using `AddBar` func (p *Progress) Start() { - p.mtx.Lock() - if p.ticker == nil { - p.ticker = time.NewTicker(RefreshInterval) - p.tdone = make(chan bool, 1) - } - p.mtx.Unlock() - go p.Listen() } // Stop stops listening func (p *Progress) Stop() { - p.mtx.Lock() - close(p.tdone) - p.print() - p.lw.Flush() - p.mtx.Unlock() + p.tdone <- true + <-p.tdone } // Bypass returns a writer which allows non-buffered data to be written to the underlying output diff --git a/progress_test.go b/progress_test.go index d6a1b42..e631dd7 100644 --- a/progress_test.go +++ b/progress_test.go @@ -11,9 +11,11 @@ import ( func TestStoppingPrintout(t *testing.T) { progress := New() - progress.RefreshInterval = time.Millisecond * 10 + progress.SetRefreshInterval(time.Millisecond * 10) + var buffer = &bytes.Buffer{} - progress.Out = buffer + progress.SetOut(buffer) + bar := progress.AddBar(100) progress.Start()