From 98572ac9c7165601350ca9ebe5b954c00a237f59 Mon Sep 17 00:00:00 2001 From: eahydra <616941303@qq.com> Date: Fri, 23 Feb 2018 13:27:47 +0800 Subject: [PATCH] fix: operations hang if attempted after connection closed Operations hang forever if attempted after client connection is closed Fixes: #125 --- zk/conn.go | 22 +++++++++++++++++----- zk/zk_test.go | 16 ++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/zk/conn.go b/zk/conn.go index f79a51b3..28cb4441 100644 --- a/zk/conn.go +++ b/zk/conn.go @@ -310,10 +310,14 @@ func WithMaxConnBufferSize(maxBufferSize int) connOption { } func (c *Conn) Close() { + rc, err := c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil) + if err != nil { + return + } close(c.shouldQuit) select { - case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil): + case <-rc: case <-time.After(time.Second): } } @@ -933,7 +937,7 @@ func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event { return ch } -func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response { +func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (<-chan response, error) { rq := &request{ xid: c.nextXid(), opcode: opcode, @@ -942,12 +946,20 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv recvChan: make(chan response, 1), recvFunc: recvFunc, } - c.sendChan <- rq - return rq.recvChan + select { + case c.sendChan <- rq: + return rq.recvChan, nil + case <-c.shouldQuit: + return nil, ErrClosing + } } func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { - r := <-c.queueRequest(opcode, req, res, recvFunc) + rc, err := c.queueRequest(opcode, req, res, recvFunc) + if err != nil { + return 0, err + } + r := <-rc return r.zxid, r.err } diff --git a/zk/zk_test.go b/zk/zk_test.go index c81ef9fb..81c4c8ff 100644 --- a/zk/zk_test.go +++ b/zk/zk_test.go @@ -666,6 +666,22 @@ func TestRequestFail(t *testing.T) { } } +func TestRequestFailAfterClosed(t *testing.T) { + ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) + if err != nil { + t.Fatal(err) + } + defer ts.Stop() + zk, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + zk.Close() + _, _, err = zk.Get("/blah") + if err != ErrClosing { + t.Fatalf("unexpected err: %+v", err) + } +} func TestSlowServer(t *testing.T) { ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) if err != nil {