From 82172bebd08023dc1dc87f23b92adbb3c3d12f84 Mon Sep 17 00:00:00 2001 From: Tom Hayward Date: Thu, 19 Sep 2024 12:56:11 -0700 Subject: [PATCH] remove connection after close --- conn.go | 27 +++++++++++++++++++-------- conn_test.go | 2 +- driver.go | 13 ++++++++++++- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index cea8935..fbc2687 100644 --- a/conn.go +++ b/conn.go @@ -17,6 +17,9 @@ type managedConn struct { killed bool mu sync.RWMutex + // callback function to be called after the connection is closed + afterClose func(*managedConn) + execStmtsCounter int // count the number of exec calls in a transaction queryStmtsCounter int // count the number of query calls in a transaction } @@ -31,7 +34,7 @@ type managedConn struct { func (c *managedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { select { case <-c.ctx.Done(): - c.conn.Close() + c.close() return nil, driver.ErrBadConn default: } @@ -67,10 +70,11 @@ func (c *managedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (drive return tx, err } -func newManagedConn(ctx context.Context, conn driver.Conn) *managedConn { +func newManagedConn(ctx context.Context, conn driver.Conn, afterClose func(*managedConn)) *managedConn { return &managedConn{ - ctx: ctx, - conn: conn, + ctx: ctx, + conn: conn, + afterClose: afterClose, } } @@ -121,7 +125,7 @@ func (c *managedConn) QueryContext(ctx context.Context, query string, args []dri func (c *managedConn) Prepare(query string) (driver.Stmt, error) { select { case <-c.ctx.Done(): - c.conn.Close() + c.close() return nil, driver.ErrBadConn default: } @@ -133,7 +137,7 @@ func (c *managedConn) Prepare(query string) (driver.Stmt, error) { func (c *managedConn) Begin() (driver.Tx, error) { select { case <-c.ctx.Done(): - c.conn.Close() + c.close() return nil, driver.ErrBadConn default: } @@ -143,7 +147,7 @@ func (c *managedConn) Begin() (driver.Tx, error) { func (c *managedConn) IsValid() bool { select { case <-c.ctx.Done(): - c.conn.Close() + c.close() return false default: } @@ -170,7 +174,7 @@ func (c *managedConn) ResetSession(ctx context.Context) error { func (c *managedConn) Close() error { c.mu.Lock() defer c.mu.Unlock() - err := c.conn.Close() + err := c.close() if err == nil { c.killed = true @@ -179,6 +183,13 @@ func (c *managedConn) Close() error { return err } +func (c *managedConn) close() error { + if c.afterClose != nil { + defer c.afterClose(c) + } + return c.conn.Close() +} + func (c *managedConn) GetReset() bool { c.mu.RLock() defer c.mu.RUnlock() diff --git a/conn_test.go b/conn_test.go index 39463ab..789e08a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -128,7 +128,7 @@ var _ = Describe("PrometheusMetrics", func() { ` It("Should emit the correct metrics", func() { - mc := newManagedConn(context.Background(), mockDriverConn{}) + mc := newManagedConn(context.Background(), mockDriverConn{}, nil) ctx := ContextWithExecLabels(context.Background(), map[string]string{"grpc_method": "method_1", "grpc_service": "service_1"}) diff --git a/driver.go b/driver.go index d22d4d6..3def12c 100644 --- a/driver.go +++ b/driver.go @@ -272,12 +272,23 @@ func (cg *chanGroup) Open() (driver.Conn, error) { return conn, err } - manConn := newManagedConn(cg.ctx, conn) + manConn := newManagedConn(cg.ctx, conn, cg.remove) cg.conns = append(cg.conns, manConn) return manConn, nil } +func (cg *chanGroup) remove(conn *managedConn) { + cg.mu.Lock() + defer cg.mu.Unlock() + for i, c := range cg.conns { + if c == conn { + cg.conns = append(cg.conns[:i], cg.conns[i+1:]...) + return + } + } +} + func (cg *chanGroup) parseValues(vs url.Values) { cg.mu.Lock() defer cg.mu.Unlock()