From ce7c97f99201b2d0c08cd1c77170cc8b1cef67ac Mon Sep 17 00:00:00 2001 From: Peyman Mortazavi Date: Mon, 22 Jan 2024 20:45:31 -0500 Subject: [PATCH] Feature: Stop pipelines after context deadline (#3) * Stop executing handlers after deadline is reached * rename shutdown to completion --- pipeline.go | 98 ++++++++++++++++++++++------------------------------- 1 file changed, 40 insertions(+), 58 deletions(-) diff --git a/pipeline.go b/pipeline.go index 9864def..346e287 100644 --- a/pipeline.go +++ b/pipeline.go @@ -15,11 +15,11 @@ import ( // steps when a shutdown is triggered. In order to start the shutdown procedure upon receiving a kernel Interrupt // signal, use WaitForInterrupt() blocking method. type Manager struct { - steps []shutdownStep - timeout time.Duration - shutdownFunc func() - logger Logger - lock sync.Mutex + steps []shutdownStep + timeout time.Duration + completionFuncnc func() + logger Logger + lock sync.Mutex } // New creates a new shutdown pipeline. @@ -49,7 +49,7 @@ func (m *Manager) SetCompletionFunc(f func()) { m.lock.Lock() defer m.lock.Unlock() - m.shutdownFunc = f + m.completionFuncnc = f } // SetLogger sets the shutdown logger. If set to nil, no logs will be written. @@ -133,76 +133,60 @@ func (m *Manager) Trigger(ctx context.Context) { errorCount := 0 resultChannel := make(chan handlerResult) - go func() { - for _, step := range m.steps { - waitGroup := sync.WaitGroup{} - waitGroup.Add(len(step.handlers)) +mainLoop: + for _, step := range m.steps { + remainingHandlers := len(step.handlers) + go func() { for _, handler := range step.handlers { if step.parallel { go func(h NamedHandler) { m.executeHandler(ctx, h, resultChannel) - waitGroup.Done() }(handler) } else { m.executeHandler(ctx, handler, resultChannel) - waitGroup.Done() } } + }() - waitGroup.Wait() + for remainingHandlers > 0 { + select { + case result := <-resultChannel: + if result.Err != nil { + errorCount++ + m.err(result.HandlerName + " shutdown failed: " + result.Err.Error()) + } else { + m.info(result.HandlerName + " shutdown completed") + } + remainingHandlers-- + case <-ctx.Done(): + m.err("context canceled") + errorCount++ + break mainLoop + } } - close(resultChannel) - }() - - // blocks until the result channel is closed. - errorCount = m.processResultChannel(ctx, resultChannel) + } - if m.shutdownFunc != nil { - m.shutdownFunc() + if m.completionFuncnc != nil { + m.completionFuncnc() } if errorCount > 0 { - if m.logger != nil { - m.logger.Error(fmt.Sprintf("shutdown pipeline completed with %d errors", errorCount)) - } + m.err(fmt.Sprintf("shutdown pipeline completed with %d errors", errorCount)) } else { - if m.logger != nil { - m.logger.Info("shutdown pipeline completed with no errors") - } + m.info("shutdown pipeline completed with no errors") } } -// processResultChannel receives from the result channel until the channel is closed. This method is blocking. In the -// end it will return the count of errors. In the event of a context cancellation, this method returns immediately and -// increases the count by 1. -func (m *Manager) processResultChannel(ctx context.Context, resultChannel <-chan handlerResult) int { - errorCount := 0 +func (m *Manager) info(text string) { + if m.logger != nil { + m.logger.Info(text) + } +} - for { - select { - case result, ok := <-resultChannel: - if ok { - if result.Err != nil { - errorCount++ - if m.logger != nil { - m.logger.Error(result.HandlerName + " shutdown failed: " + result.Err.Error()) - } - } else { - if m.logger != nil { - m.logger.Info(result.HandlerName + " shutdown completed") - } - } - } else { - // channel is closed, that means there is no more handler result to wait for. - return errorCount - } - case <-ctx.Done(): - if m.logger != nil { - m.logger.Error("context canceled") - } - return errorCount + 1 - } +func (m *Manager) err(text string) { + if m.logger != nil { + m.logger.Error(text) } } @@ -226,9 +210,7 @@ func (m *Manager) WaitForInterrupt() { signal.Notify(exit, os.Interrupt, syscall.SIGTERM) <-exit - if m.logger != nil { - m.logger.Info("received interrupt signal, starting shutdown procedures...") - } + m.info("received interrupt signal, starting shutdown procedures...") m.Trigger(context.Background()) }