Skip to content

Commit

Permalink
graceful shutdown (#1)
Browse files Browse the repository at this point in the history
* Update redcon.

* Add CloseAllActiveConnections.

* Update redcon.

* Close connection in defer function.

* Update redcon.

* Update redcon.

* Update logger messages.

* Add logger.

* Move logger into Server.

* Update logger.

* Add more logs.

* Add more logs.

* Avoid potential deadlock.

* Remove debug logs.
  • Loading branch information
zhoufeng1989 authored Dec 6, 2022
1 parent 90a6f64 commit efc9434
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 68 deletions.
179 changes: 142 additions & 37 deletions redcon.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"errors"
"fmt"
"io"
"log"
"net"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/tidwall/btree"
Expand All @@ -23,6 +25,8 @@ var (
errDetached = errors.New("detached")
errIncompleteCommand = errors.New("incomplete command")
errTooMuchData = errors.New("too much data")

ErrServerIsNotClosed = errors.New("server is not closed yet")
)

type errProtocol struct {
Expand All @@ -38,7 +42,7 @@ type Conn interface {
// RemoteAddr returns the remote address of the client connection.
RemoteAddr() string
// Close closes the connection.
Close() error
Close(force bool) (bool, error)
// WriteError writes an error to the client.
WriteError(msg string)
// WriteString writes a string to the client.
Expand Down Expand Up @@ -111,6 +115,12 @@ type Conn interface {
PeekPipeline() []Command
// NetConn returns the base net.Conn connection
NetConn() net.Conn
// Closed returns the connection is closed or not
Closed() bool
// InTx returns the connection in transaction or not
InTx() bool
// SetTxStatus set redis transaction status of the connection
SetTxStatus(bool)
}

// NewServer returns a new Redcon server configured on "tcp" network net.
Expand Down Expand Up @@ -182,9 +192,28 @@ func NewServerNetworkTLS(
return tls
}

// Close stops listening on the TCP address.
// Already Accepted connections will be closed.
func (s *Server) Close() error {
// Close will close server and all connections.
// Active connections will be closed after `waitDuration` duration.
// this function will be blocked for `waitDuration` duration if active connection exists
func (s *Server) Close(waitDuration time.Duration) (CloseConnsResult, error) {
err := s.close()
if err != nil {
return CloseConnsResult{}, err
}
if s.OpenConnectionCount() == 0 {
return CloseConnsResult{}, nil
}
time.Sleep(waitDuration)
return s.closeAllActiveConnections()
}

func (s *Server) IsServerClosing() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.done
}

func (s *Server) close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.ln == nil {
Expand All @@ -194,6 +223,31 @@ func (s *Server) Close() error {
return s.ln.Close()
}

type CloseConnsResult struct {
Errs []error
Count int
}

func (s *Server) closeAllActiveConnections() (CloseConnsResult, error) {
errs := make([]error, 0)
closedCount := 0
s.mu.Lock()
defer s.mu.Unlock()
if !s.done {
return CloseConnsResult{}, ErrServerIsNotClosed
}
for c := range s.conns {
closed, err := c.Close(true)
if err != nil {
errs = append(errs, err)
} else if closed {
closedCount += 1
delete(s.conns, c)
}
}
return CloseConnsResult{Errs: errs, Count: closedCount}, nil
}

// ListenAndServe serves incoming connections.
func (s *Server) ListenAndServe() error {
return s.ListenServeAndSignal(nil)
Expand All @@ -204,6 +258,12 @@ func (s *Server) Addr() net.Addr {
return s.ln.Addr()
}

func (s *Server) OpenConnectionCount() int {
s.mu.Lock()
defer s.mu.Unlock()
return len(s.conns)
}

// Close stops listening on the TCP address.
// Already Accepted connections will be closed.
func (s *TLSServer) Close() error {
Expand Down Expand Up @@ -326,28 +386,24 @@ func (s *TLSServer) ListenServeAndSignal(signal chan error) error {

func serve(s *Server) error {
defer func() {
s.ln.Close()
func() {
s.mu.Lock()
defer s.mu.Unlock()
for c := range s.conns {
c.Close()
}
s.conns = nil
}()
err := s.ln.Close()
if err != nil {
message := fmt.Sprintf("close listener in serve defer error %s", err)
s.LogMessage(message)
}
}()
for {
lnconn, err := s.ln.Accept()
if err != nil {
if s.AcceptError != nil {
s.AcceptError(err)
}
s.mu.Lock()
done := s.done
s.mu.Unlock()
if done {
return nil
}
if s.AcceptError != nil {
s.AcceptError(err)
}
continue
}
c := &conn{
Expand All @@ -364,7 +420,7 @@ func serve(s *Server) error {
s.mu.Lock()
delete(s.conns, c)
s.mu.Unlock()
c.Close()
c.Close(true)
continue
}
go handle(s, c)
Expand All @@ -376,21 +432,22 @@ func handle(s *Server, c *conn) {
var err error
defer func() {
if err != errDetached {
// do not close the connection when a detach is detected.
c.conn.Close()
_, closeErr := c.Close(true)
if closeErr != nil {
message := fmt.Sprintf("close connection error %+v %s", c, err)
s.LogMessage(message)
}
}
func() {
// remove the conn from the server
s.mu.Lock()
defer s.mu.Unlock()
delete(s.conns, c)
if s.closed != nil {
if err == io.EOF {
err = nil
}
s.closed(c, err)
// remove the conn from the server
s.mu.Lock()
delete(s.conns, c)
s.mu.Unlock()
if s.closed != nil {
if err == io.EOF {
err = nil
}
}()
s.closed(c, err)
}
}()

err = func() error {
Expand All @@ -417,12 +474,15 @@ func handle(s *Server, c *conn) {
// client has been detached
return errDetached
}
if c.closed {
if c.Closed() {
return nil
}
if err := c.wr.Flush(); err != nil {
return err
}
if s.IsServerClosing() && !c.InTx() {
return nil
}
}
}()
}
Expand All @@ -435,15 +495,38 @@ type conn struct {
addr string
ctx interface{}
detached bool
closed bool
closed int32 // atmoic access
inTx int32 // atomic access
cmds []Command
idleClose time.Duration
}

func (c *conn) Close() error {
func (c *conn) InTx() bool {
return atomic.LoadInt32(&c.inTx) == 1
}

func (c *conn) SetTxStatus(inTx bool) {
if inTx {
atomic.StoreInt32(&c.inTx, 1)
} else {
atomic.StoreInt32(&c.inTx, 0)
}
}

func (c *conn) Close(force bool) (bool, error) {
if c.InTx() && !force {
return false, nil
}
if c.Closed() {
return true, nil
}
c.wr.Flush()
c.closed = true
return c.conn.Close()
err := c.conn.Close()
if err != nil {
return false, err
}
atomic.StoreInt32(&c.closed, 1)
return true, nil
}
func (c *conn) Context() interface{} { return c.ctx }
func (c *conn) SetContext(v interface{}) { c.ctx = v }
Expand Down Expand Up @@ -472,6 +555,10 @@ func (c *conn) NetConn() net.Conn {
return c.conn
}

func (c *conn) Closed() bool {
return atomic.LoadInt32(&c.closed) == 1
}

// BaseWriter returns the underlying connection writer, if any
func BaseWriter(c Conn) *Writer {
if c, ok := c.(*conn); ok {
Expand Down Expand Up @@ -551,6 +638,8 @@ type Server struct {
done bool
idleClose time.Duration

logger *log.Logger

// AcceptError is an optional function used to handle Accept errors.
AcceptError func(err error)
}
Expand Down Expand Up @@ -1126,7 +1215,7 @@ func (sconn *pubSubConn) bgrunner(ps *PubSub) {
delete(ps.conns, sconn.conn)
sconn.mu.Lock()
defer sconn.mu.Unlock()
sconn.dconn.Close()
sconn.dconn.Close(true)
}()
for {
cmd, err := sconn.dconn.ReadCommand()
Expand Down Expand Up @@ -1172,7 +1261,7 @@ func (sconn *pubSubConn) bgrunner(ps *PubSub) {
defer sconn.mu.Unlock()
sconn.dconn.WriteString("OK")
sconn.dconn.Flush()
sconn.dconn.Close()
sconn.dconn.Close(true)
}()
return
case "ping":
Expand Down Expand Up @@ -1370,3 +1459,19 @@ func (s *Server) SetIdleClose(dur time.Duration) {
s.idleClose = dur
s.mu.Unlock()
}

func (s *Server) SetLogger(logger *log.Logger) {
s.logger = logger
}

func (s *Server) Logger() *log.Logger {
return s.logger
}

func (s *Server) LogMessage(message string) {
if s.logger != nil {
s.logger.Printf("%s\n", message)
} else {
fmt.Printf("%s\n", message)
}
}
Loading

0 comments on commit efc9434

Please sign in to comment.