Skip to content
This repository has been archived by the owner on Oct 23, 2024. It is now read-only.

Commit

Permalink
conn: fix racy access to result fields when connection is closing
Browse files Browse the repository at this point in the history
  • Loading branch information
James DeFelice authored and jdef committed Aug 22, 2018
1 parent 449c45a commit 5daf934
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion zk/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,17 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv

func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) {
r := <-c.queueRequest(opcode, req, res, recvFunc)
return r.zxid, r.err
select {
case <-c.shouldQuit:
// queueRequest() can be racy, double-check for the race here and avoid
// a potential data-race. otherwise the client of this func may try to
// access `res` fields concurrently w/ the async response processor.
// NOTE: callers of this func should check for (at least) ErrConnectionClosed
// and avoid accessing fields of the response object if such error is present.
return -1, ErrConnectionClosed
default:
return r.zxid, r.err
}
}

func (c *Conn) AddAuth(scheme string, auth []byte) error {
Expand Down Expand Up @@ -1045,6 +1055,9 @@ func (c *Conn) Children(path string) ([]string, *Stat, error) {

res := &getChildren2Response{}
_, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil)
if err == ErrConnectionClosed {
return nil, nil, err
}
return res.Children, &res.Stat, err
}

Expand Down Expand Up @@ -1073,6 +1086,9 @@ func (c *Conn) Get(path string) ([]byte, *Stat, error) {

res := &getDataResponse{}
_, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil)
if err == ErrConnectionClosed {
return nil, nil, err
}
return res.Data, &res.Stat, err
}

Expand Down Expand Up @@ -1102,6 +1118,9 @@ func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) {

res := &setDataResponse{}
_, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil)
if err == ErrConnectionClosed {
return nil, err
}
return &res.Stat, err
}

Expand All @@ -1112,6 +1131,9 @@ func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string,

res := &createResponse{}
_, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil)
if err == ErrConnectionClosed {
return "", err
}
return res.Path, err
}

Expand Down Expand Up @@ -1180,6 +1202,9 @@ func (c *Conn) Exists(path string) (bool, *Stat, error) {

res := &existsResponse{}
_, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil)
if err == ErrConnectionClosed {
return false, nil, err
}
exists := true
if err == ErrNoNode {
exists = false
Expand Down Expand Up @@ -1220,6 +1245,9 @@ func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) {

res := &getAclResponse{}
_, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil)
if err == ErrConnectionClosed {
return nil, nil, err
}
return res.Acl, &res.Stat, err
}
func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) {
Expand All @@ -1229,6 +1257,9 @@ func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) {

res := &setAclResponse{}
_, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil)
if err == ErrConnectionClosed {
return nil, err
}
return &res.Stat, err
}

Expand All @@ -1239,6 +1270,9 @@ func (c *Conn) Sync(path string) (string, error) {

res := &syncResponse{}
_, err := c.request(opSync, &syncRequest{Path: path}, res, nil)
if err == ErrConnectionClosed {
return "", err
}
return res.Path, err
}

Expand Down Expand Up @@ -1274,6 +1308,9 @@ func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) {
}
res := &multiResponse{}
_, err := c.request(opMulti, req, res, nil)
if err == ErrConnectionClosed {
return nil, err
}
mr := make([]MultiResponse, len(res.Ops))
for i, op := range res.Ops {
mr[i] = MultiResponse{Stat: op.Stat, String: op.String, Error: op.Err.toError()}
Expand Down

0 comments on commit 5daf934

Please sign in to comment.