From ae120dde71a3f076c3ccc0b0831862acea807e3b Mon Sep 17 00:00:00 2001 From: Dmitry Makushin <24922494+dmakushin@users.noreply.github.com> Date: Fri, 29 Nov 2024 16:18:01 +0400 Subject: [PATCH] Add RunInTx method for DB --- go.sum | 2 -- stdlib.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/go.sum b/go.sum index 5f2a47bb..c49adb66 100644 --- a/go.sum +++ b/go.sum @@ -109,8 +109,6 @@ github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w= github.com/spf13/cast v1.5.0/go.mod h1:SpXXQ5YoyJw6s3/6cMTQuxvgRl3PCJiyaX9p6b155UU= github.com/stephenafamo/fakedb v0.0.0-20221230081958-0b86f816ed97 h1:XItoZNmhOih06TC02jK7l3wlpZ0XT/sPQYutDcGOQjg= github.com/stephenafamo/fakedb v0.0.0-20221230081958-0b86f816ed97/go.mod h1:bM3Vmw1IakoaXocHmMIGgJFYob0vuK+CFWiJHQvz0jQ= -github.com/stephenafamo/scan v0.6.0 h1:N0joyP/wriC9VvP6w9SDxHIuQGatW4c2YW7Z5L4m45s= -github.com/stephenafamo/scan v0.6.0/go.mod h1:FhIUJ8pLNyex36xGFiazDJJ5Xry0UkAi+RkWRrEcRMg= github.com/stephenafamo/scan v0.6.1 h1:nXokGCQwYazMuyvdNAoK0T8Z76FWcpMvDdtengpz6PU= github.com/stephenafamo/scan v0.6.1/go.mod h1:FhIUJ8pLNyex36xGFiazDJJ5Xry0UkAi+RkWRrEcRMg= github.com/stephenafamo/sqlparser v0.0.0-20241111104950-b04fa8a26c9c h1:JFga++XBnZG2xlnvQyHJkeBWZ9G9mGdtgvLeSRbp/BA= diff --git a/stdlib.go b/stdlib.go index 0ad4e1bc..0459e2b5 100644 --- a/stdlib.go +++ b/stdlib.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" + "fmt" "github.com/stephenafamo/scan" "github.com/stephenafamo/scan/stdscan" @@ -96,6 +98,32 @@ func (d DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { return NewTx(tx), nil } +// RunInTx runs the provided function in a transaction. +// If the function returns an error, the transaction is rolled back. +// Otherwise, the transaction is committed. +func (d DB) RunInTx(ctx context.Context, txOptions *sql.TxOptions, fn func(context.Context, Tx) error) error { + tx, err := d.BeginTx(ctx, txOptions) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + + if err := fn(ctx, tx); err != nil { + err = fmt.Errorf("call method in transaction: %w", err) + + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return errors.Join(err, rollbackErr) + } + + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit transaction: %w", err) + } + + return nil +} + // NewTx wraps an [*sql.Tx] and returns a type that implements [Queryer] but still // retains the expected methods used by *sql.Tx // This is useful when an existing *sql.Tx is used in other places in the codebase