diff --git a/zk/conn.go b/zk/conn.go index f79a51b3..e7019d59 100644 --- a/zk/conn.go +++ b/zk/conn.go @@ -82,6 +82,7 @@ type Conn struct { eventChan chan Event eventCallback EventCallback // may be nil shouldQuit chan struct{} + shouldQuitOnce sync.Once pingInterval time.Duration recvTimeout time.Duration connectTimeout time.Duration @@ -310,12 +311,14 @@ func WithMaxConnBufferSize(maxBufferSize int) connOption { } func (c *Conn) Close() { - close(c.shouldQuit) + c.shouldQuitOnce.Do(func() { + close(c.shouldQuit) - select { - case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil): - case <-time.After(time.Second): - } + select { + case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil): + case <-time.After(time.Second): + } + }) } // State returns the current state of the connection. @@ -939,10 +942,30 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv opcode: opcode, pkt: req, recvStruct: res, - recvChan: make(chan response, 1), + recvChan: make(chan response, 2), recvFunc: recvFunc, } - c.sendChan <- rq + + switch opcode { + case opClose: + // always attempt to send close ops. + c.sendChan <- rq + default: + // otherwise avoid deadlocks for dumb clients who aren't aware that + // the ZK connection is closed yet. + select { + case <-c.shouldQuit: + rq.recvChan <- response{-1, ErrConnectionClosed} + case c.sendChan <- rq: + // check for a tie + select { + case <-c.shouldQuit: + // maybe the caller gets this, maybe not- we tried. + rq.recvChan <- response{-1, ErrConnectionClosed} + default: + } + } + } return rq.recvChan } diff --git a/zk/zk_test.go b/zk/zk_test.go index c81ef9fb..07d29927 100644 --- a/zk/zk_test.go +++ b/zk/zk_test.go @@ -93,6 +93,37 @@ func TestCreate(t *testing.T) { } } +func TestOpsAfterCloseDontDeadlock(t *testing.T) { + ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) + if err != nil { + t.Fatal(err) + } + defer ts.Stop() + zk, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + zk.Close() + + path := "/gozk-test" + + ch := make(chan struct{}) + go func() { + defer close(ch) + for range make([]struct{}, 30) { + if _, err := zk.Create(path, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err == nil { + t.Fatal("Create did not return error") + } + } + }() + select { + case <-ch: + // expected + case <-time.After(10 * time.Second): + t.Fatal("ZK connection deadlocked when executing ops after a Close operation") + } +} + func TestMulti(t *testing.T) { ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) if err != nil { @@ -139,6 +170,7 @@ func TestIfAuthdataSurvivesReconnect(t *testing.T) { if err != nil { t.Fatal(err) } + defer ts.Stop() zk, _, err := ts.ConnectAll() if err != nil { @@ -666,6 +698,16 @@ func TestRequestFail(t *testing.T) { } } +func TestIdempotentClose(t *testing.T) { + zk, _, err := Connect([]string{"127.0.0.1:32444"}, time.Second*15) + if err != nil { + t.Fatal(err) + } + // multiple calls to Close() should not panic + zk.Close() + zk.Close() +} + func TestSlowServer(t *testing.T) { ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) if err != nil {