Skip to content

Commit

Permalink
database: Close query rows before returning
Browse files Browse the repository at this point in the history
  • Loading branch information
beautifulentropy committed Sep 12, 2024
1 parent 990ad07 commit fc10303
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 19 deletions.
24 changes: 18 additions & 6 deletions cmd/contact-auditor/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,24 @@ func (c contactAuditor) writeResults(result string) {
// run retrieves a cursor from `beginAuditQuery` and then audits the
// `contact` column of all returned rows for abnormalities or policy
// violations.
func (c contactAuditor) run(ctx context.Context, resChan chan *result) error {
func (c contactAuditor) run(ctx context.Context, resChan chan *result) (err error) {
c.logger.Infof("Beginning database query")
rows, err := c.beginAuditQuery(ctx)
if err != nil {
return err
}
defer func() {
// Close the row reader when we exit. Use the named error return to combine
// any error from normal execution with any error from closing.
closeErr := rows.Close()
if closeErr != nil && err != nil {
err = fmt.Errorf("%w; also while closing the row reader: %w", err, closeErr)
} else if closeErr != nil {
err = closeErr
}
// If closeErr is nil, then just leaving the existing named return alone
// will do the right thing.
}()

for rows.Next() {
var id int64
Expand All @@ -130,12 +142,12 @@ func (c contactAuditor) run(ctx context.Context, resChan chan *result) error {
resChan <- &result{id, contacts, createdAt}
}
}
// Ensure the query wasn't interrupted before it could complete.
err = rows.Close()

err = rows.Err()
if err != nil {
return err
} else {
c.logger.Info("Query completed successfully")
// It's okay to return here, an abnormal termination automatically calls
// rows.Close(): http://go-database-sql.org/errors.html
return fmt.Errorf("querying db: %w", err)
}

// Only used for testing.
Expand Down
37 changes: 25 additions & 12 deletions db/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,35 +104,48 @@ func (mi *MultiInserter) query() (string, []interface{}) {
// Insert inserts all the collected rows into the database represented by
// `queryer`. If a non-empty returningColumn was provided, then it returns
// the list of values from that column returned by the query.
func (mi *MultiInserter) Insert(ctx context.Context, queryer Queryer) ([]int64, error) {
func (mi *MultiInserter) Insert(ctx context.Context, queryer Queryer) (ids []int64, err error) {
query, queryArgs := mi.query()
rows, err := queryer.QueryContext(ctx, query, queryArgs...)
if err != nil {
return nil, err
}
defer func() {
// Hack: sometimes in unittests we make a mock Queryer that returns a nil
// `*sql.Rows`. A nil `*sql.Rows` is not actually valid— calling `Close()`
// on it will panic— but here we choose to treat it like an empty list,
// and skip calling `Close()` to avoid the panic.
if rows != nil {
// Close the row reader when we exit. Use the named error return to combine
// any error from normal execution with any error from closing.
closeErr := rows.Close()
if closeErr != nil && err != nil {
err = fmt.Errorf("%w; also while closing the row reader: %w", err, closeErr)
} else if closeErr != nil {
err = closeErr
}
// If closeErr is nil, then just leaving the existing named return alone
// will do the right thing.
}
}()

ids := make([]int64, 0, len(mi.values))
ids = make([]int64, 0, len(mi.values))
if mi.returningColumn != "" {
for rows.Next() {
var id int64
err = rows.Scan(&id)
if err != nil {
rows.Close()
return nil, err
}
ids = append(ids, id)
}
}

// Hack: sometimes in unittests we make a mock Queryer that returns a nil
// `*sql.Rows`. A nil `*sql.Rows` is not actually valid— calling `Close()`
// on it will panic— but here we choose to treat it like an empty list,
// and skip calling `Close()` to avoid the panic.
if rows != nil {
err = rows.Close()
if err != nil {
return nil, err
}
err = rows.Err()
if err != nil {
// It's okay to return here, an abnormal termination automatically calls
// rows.Close(): http://go-database-sql.org/errors.html
return nil, fmt.Errorf("querying db: %w", err)
}

return ids, nil
Expand Down
9 changes: 8 additions & 1 deletion test/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ func allTableNamesInDB(ctx context.Context, db CleanUpDB) ([]string, error) {
if err != nil {
return nil, err
}
defer r.Close()
var ts []string
for r.Next() {
tableName := ""
Expand All @@ -122,5 +123,11 @@ func allTableNamesInDB(ctx context.Context, db CleanUpDB) ([]string, error) {
}
ts = append(ts, tableName)
}
return ts, r.Err()
err = r.Err()
if err != nil {
// It's okay to return here, an abnormal termination automatically calls
// rows.Close(): http://go-database-sql.org/errors.html
return nil, err
}
return ts, nil
}

0 comments on commit fc10303

Please sign in to comment.