Skip to content

Commit

Permalink
added wait for signals
Browse files Browse the repository at this point in the history
  • Loading branch information
partyzanex committed Aug 20, 2023
1 parent 6d99fd3 commit 5abc959
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 53 deletions.
57 changes: 51 additions & 6 deletions closure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package shutdown
import (
"context"
"io"
"os"
"os/signal"
"sync"
)

Expand All @@ -20,6 +22,7 @@ type Closure interface {
var (
closure Closure = &Lifo{} // Default implementation of Closure using Lifo (Last In First Out) strategy
mu sync.Mutex // Mutex to ensure thread safety
once sync.Once
)

// SetPackageClosure allows for setting a different Closure implementation.
Expand All @@ -38,14 +41,56 @@ func Append(closer Closer) {

// Close attempts to close all appended resources.
func Close() error {
mu.Lock() // Acquiring the lock
defer mu.Unlock() // Making sure to release the lock after the function exits
return closure.Close() // Close all resources and return any encountered error
return CloseContext(context.Background()) // Close all resources and return any encountered error
}

// CloseContext attempts to close all appended resources with context support.
func CloseContext(ctx context.Context) error {
mu.Lock() // Acquiring the lock
defer mu.Unlock() // Making sure to release the lock after the function exits
return closure.CloseContext(ctx) // Close resources with context support and return any encountered error
mu.Lock() // Acquiring the lock
defer mu.Unlock() // Making sure to release the lock after the function exits

var err error

once.Do(func() {
err = closure.CloseContext(ctx) // Close all resources and return any encountered error
})

return err
}

// Logger is an interface representing logging capabilities. It provides a method to log warning messages.
type Logger interface {
Warnf(format string, args ...interface{})
}

// WaitForSignals blocks until a given signal (or signals) is received.
// Once the signal is caught, it logs a warning message using the provided logger.
func WaitForSignals(logger Logger, sig ...os.Signal) {
// Create a channel to listen for signals.
c := make(chan os.Signal, 1)

// Register the given signals to the channel.
signal.Notify(c, sig...)

// Ensure that we stop the signal notifications to the channel when the function returns.
defer signal.Stop(c)

// Log a warning when a signal is received.
logger.Warnf("Received signal: %s", <-c)
}

// WaitForSignalsContext is similar to WaitForSignals but with support for context.
// It blocks until a given signal (or signals) is received or the context is done.
func WaitForSignalsContext(ctx context.Context, logger Logger, sig ...os.Signal) {
// Create a context that will be done when the given signals are caught or the parent context is done.
sigCtx, cancel := signal.NotifyContext(ctx, sig...)

// Ensure resources are released when the function returns.
defer cancel()

// Wait until the signal context is done (either from a caught signal or the parent context).
<-sigCtx.Done()

// Log a warning indicating which signal or context-related error occurred.
logger.Warnf("Received signal: %s", sigCtx.Err())
}
105 changes: 59 additions & 46 deletions closure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,74 +3,87 @@ package shutdown
import (
"context"
"errors"
"fmt"
"os"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"time"
)

type pkgCloser struct {
closeCalled bool
err error
mu sync.Mutex
isClose bool
err error
}

func (mc *pkgCloser) Close() error {
mc.closeCalled = true
mc.mu.Lock()
defer mc.mu.Unlock()

mc.isClose = true
return mc.err
}

type mockClosure struct {
appendCalled bool
closeCalled bool
ctxCalled bool
closers []Closer
func TestAppendAndClose(t *testing.T) {
SetPackageClosure(&Lifo{})
once = sync.Once{}
mCloser := &pkgCloser{}
Append(mCloser)
if err := Close(); err != nil || !mCloser.isClose {
t.Fatalf("Expected closer to be closed without errors, got: %v", err)
}
}

func (mc *mockClosure) Append(closer Closer) {
mc.appendCalled = true
mc.closers = append(mc.closers, closer)
func TestAppendAndCloseWithError(t *testing.T) {
SetPackageClosure(&Fifo{})
once = sync.Once{}
expectedErr := errors.New("close error")
mCloser := &pkgCloser{err: expectedErr}
Append(mCloser)
if err := Close(); err == nil || err.Error() != expectedErr.Error() {
t.Fatalf("Expected error: %v, got: %v", expectedErr, err)
}
}

func (mc *mockClosure) Close() error {
mc.closeCalled = true
if len(mc.closers) > 0 {
return mc.closers[0].Close()
}
return nil
type mockLogger struct {
messages []string
mu sync.Mutex
}

func (mc *mockClosure) CloseContext(ctx context.Context) error {
mc.ctxCalled = true
if len(mc.closers) > 0 {
return mc.closers[0].Close()
}
return nil
func (ml *mockLogger) Warnf(format string, args ...interface{}) {
ml.mu.Lock()
defer ml.mu.Unlock()

ml.messages = append(ml.messages, fmt.Sprintf(format, args...))
}

func TestPackageFunctions(t *testing.T) {
// Reset package state after the test
defer func() {
SetPackageClosure(&Lifo{})
func TestWaitForSignals(t *testing.T) {
ml := &mockLogger{}
signals := []os.Signal{os.Interrupt}

go func() {
// Simulate a signal after a short delay
time.Sleep(100 * time.Millisecond)
process, _ := os.FindProcess(os.Getpid())
_ = process.Signal(os.Interrupt)
}()

// Mock closure for testing
mc := &mockClosure{}
SetPackageClosure(mc)
WaitForSignals(ml, signals...)

// Test Append
closer := &pkgCloser{}
Append(closer)
assert.True(t, mc.appendCalled, "Append was not called on the mock closure")
assert.Contains(t, mc.closers, closer, "Closer was not added to the mock closure")
if len(ml.messages) == 0 || ml.messages[0] != "Received signal: interrupt" {
t.Errorf("Expected log message about received signal, got: %v", ml.messages)
}
}

// Test Close and CloseContext
errClose := errors.New("test close error")
closer.err = errClose
func TestWaitForSignalsContext(t *testing.T) {
ml := &mockLogger{}
signals := []os.Signal{os.Interrupt}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

err := Close()
assert.True(t, mc.closeCalled, "Close was not called on the mock closure")
assert.Equal(t, errClose, err, "Unexpected error returned from Close")
WaitForSignalsContext(ctx, ml, signals...)

err = CloseContext(context.Background())
assert.True(t, mc.ctxCalled, "CloseContext was not called on the mock closure")
assert.Equal(t, errClose, err, "Unexpected error returned from CloseContext")
if len(ml.messages) == 0 || ml.messages[0] != "Received signal: context deadline exceeded" {
t.Errorf("Expected log message about context deadline, got: %v", ml.messages)
}
}
2 changes: 1 addition & 1 deletion lifo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestLifoCloseContext(t *testing.T) {

lifo.Append(timeoutCloser)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err = lifo.CloseContext(ctx)
assert.NotNil(t, err)
Expand Down

0 comments on commit 5abc959

Please sign in to comment.