From 1e89d6ade6cb5177a658cb9f3e041dd1c8a83965 Mon Sep 17 00:00:00 2001 From: Pavel Rybintsev <2716513+prybintsev@users.noreply.github.com> Date: Tue, 5 Jan 2021 09:29:53 -0800 Subject: [PATCH] propagating context to gorm transaction (#207) --- gorm/transaction.go | 16 ++++++++++++---- gorm/transaction_test.go | 20 ++++++++++++++++++-- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/gorm/transaction.go b/gorm/transaction.go index faa83ff1..0d0e35af 100644 --- a/gorm/transaction.go +++ b/gorm/transaction.go @@ -69,7 +69,7 @@ func BeginFromContext(ctx context.Context) (*gorm.DB, error) { if txn.parent == nil { return nil, ErrCtxTxnNoDB } - db := txn.Begin() + db := txn.beginWithContext(ctx) if db.Error != nil { return nil, db.Error } @@ -89,7 +89,7 @@ func BeginWithOptionsFromContext(ctx context.Context, opts *sql.TxOptions) (*gor if txn.parent == nil { return nil, ErrCtxTxnNoDB } - db := txn.BeginWithOptions(opts) + db := txn.beginWithContextAndOptions(ctx, opts) if db.Error != nil { return nil, db.Error } @@ -99,11 +99,15 @@ func BeginWithOptionsFromContext(ctx context.Context, opts *sql.TxOptions) (*gor // Begin starts new transaction by calling `*gorm.DB.Begin()` // Returns new instance of `*gorm.DB` (error can be checked by `*gorm.DB.Error`) func (t *Transaction) Begin() *gorm.DB { + return t.beginWithContext(context.Background()) +} + +func (t *Transaction) beginWithContext(ctx context.Context) *gorm.DB { t.mu.Lock() defer t.mu.Unlock() if t.current == nil { - t.current = t.parent.Begin() + t.current = t.parent.BeginTx(ctx, nil) } return t.current @@ -112,11 +116,15 @@ func (t *Transaction) Begin() *gorm.DB { // BeginWithOptions starts new transaction by calling `*gorm.DB.BeginTx()` // Returns new instance of `*gorm.DB` (error can be checked by `*gorm.DB.Error`) func (t *Transaction) BeginWithOptions(opts *sql.TxOptions) *gorm.DB { + return t.beginWithContextAndOptions(context.Background(), opts) +} + +func (t *Transaction) beginWithContextAndOptions(ctx context.Context, opts *sql.TxOptions) *gorm.DB { t.mu.Lock() defer t.mu.Unlock() if t.current == nil { - t.current = t.parent.BeginTx(context.Background(), opts) + t.current = t.parent.BeginTx(ctx, opts) } return t.current diff --git a/gorm/transaction_test.go b/gorm/transaction_test.go index 8625a92d..e5461b02 100644 --- a/gorm/transaction_test.go +++ b/gorm/transaction_test.go @@ -367,8 +367,9 @@ func TestBeginFromContext_Good(t *testing.T) { func TestBeginFromContext_Bad(t *testing.T) { tests := []struct { - desc string - withOpts bool + desc string + withOpts bool + contextCanceled bool }{ { desc: "begin without options", @@ -378,11 +379,26 @@ func TestBeginFromContext_Bad(t *testing.T) { desc: "begin with options", withOpts: true, }, + { + desc: "canceled context without context", + withOpts: true, + contextCanceled: true, + }, + { + desc: "canceled context with options", + withOpts: false, + contextCanceled: true, + }, } for _, test := range tests { test := test t.Run(test.desc, func(t *testing.T) { ctx := context.Background() + if test.contextCanceled { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + cancel() + } // Case: Transaction missing from context txn1, err := beginFromContext(ctx, test.withOpts)