Skip to content

Commit

Permalink
remove connection after close
Browse files Browse the repository at this point in the history
  • Loading branch information
kd7lxl committed Sep 19, 2024
1 parent d9a27df commit 82172be
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
27 changes: 19 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ type managedConn struct {
killed bool
mu sync.RWMutex

// callback function to be called after the connection is closed
afterClose func(*managedConn)

execStmtsCounter int // count the number of exec calls in a transaction
queryStmtsCounter int // count the number of query calls in a transaction
}
Expand All @@ -31,7 +34,7 @@ type managedConn struct {
func (c *managedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
select {
case <-c.ctx.Done():
c.conn.Close()
c.close()
return nil, driver.ErrBadConn
default:
}
Expand Down Expand Up @@ -67,10 +70,11 @@ func (c *managedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (drive
return tx, err
}

func newManagedConn(ctx context.Context, conn driver.Conn) *managedConn {
func newManagedConn(ctx context.Context, conn driver.Conn, afterClose func(*managedConn)) *managedConn {
return &managedConn{
ctx: ctx,
conn: conn,
ctx: ctx,
conn: conn,
afterClose: afterClose,
}
}

Expand Down Expand Up @@ -121,7 +125,7 @@ func (c *managedConn) QueryContext(ctx context.Context, query string, args []dri
func (c *managedConn) Prepare(query string) (driver.Stmt, error) {
select {
case <-c.ctx.Done():
c.conn.Close()
c.close()
return nil, driver.ErrBadConn
default:
}
Expand All @@ -133,7 +137,7 @@ func (c *managedConn) Prepare(query string) (driver.Stmt, error) {
func (c *managedConn) Begin() (driver.Tx, error) {
select {
case <-c.ctx.Done():
c.conn.Close()
c.close()
return nil, driver.ErrBadConn
default:
}
Expand All @@ -143,7 +147,7 @@ func (c *managedConn) Begin() (driver.Tx, error) {
func (c *managedConn) IsValid() bool {
select {
case <-c.ctx.Done():
c.conn.Close()
c.close()
return false
default:
}
Expand All @@ -170,7 +174,7 @@ func (c *managedConn) ResetSession(ctx context.Context) error {
func (c *managedConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
err := c.conn.Close()
err := c.close()

if err == nil {
c.killed = true
Expand All @@ -179,6 +183,13 @@ func (c *managedConn) Close() error {
return err
}

func (c *managedConn) close() error {
if c.afterClose != nil {
defer c.afterClose(c)
}
return c.conn.Close()
}

func (c *managedConn) GetReset() bool {
c.mu.RLock()
defer c.mu.RUnlock()
Expand Down
2 changes: 1 addition & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ var _ = Describe("PrometheusMetrics", func() {
`

It("Should emit the correct metrics", func() {
mc := newManagedConn(context.Background(), mockDriverConn{})
mc := newManagedConn(context.Background(), mockDriverConn{}, nil)

ctx := ContextWithExecLabels(context.Background(), map[string]string{"grpc_method": "method_1", "grpc_service": "service_1"})

Expand Down
13 changes: 12 additions & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,23 @@ func (cg *chanGroup) Open() (driver.Conn, error) {
return conn, err
}

manConn := newManagedConn(cg.ctx, conn)
manConn := newManagedConn(cg.ctx, conn, cg.remove)
cg.conns = append(cg.conns, manConn)

return manConn, nil
}

func (cg *chanGroup) remove(conn *managedConn) {
cg.mu.Lock()
defer cg.mu.Unlock()
for i, c := range cg.conns {
if c == conn {
cg.conns = append(cg.conns[:i], cg.conns[i+1:]...)
return
}
}
}

func (cg *chanGroup) parseValues(vs url.Values) {
cg.mu.Lock()
defer cg.mu.Unlock()
Expand Down

0 comments on commit 82172be

Please sign in to comment.