Skip to content

Commit

Permalink
use Executor instead of *sql.DB
Browse files Browse the repository at this point in the history
  • Loading branch information
xgfone committed Aug 30, 2024
1 parent ee5e942 commit 1d42cc9
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
24 changes: 14 additions & 10 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ func ConnMaxIdleTime(d time.Duration) Config {

// DB is the wrapper of the sql.DB.
type DB struct {
*sql.DB
Dialect
Executor
Interceptor
}

Expand All @@ -109,30 +109,34 @@ func Open(driverName, dataSourceName string, configs ...Config) (*DB, error) {
return nil, err
}

xdb := &DB{Dialect: dialect, DB: db}
if configs == nil {
configs = DefaultConfigs
}
for _, c := range configs {
c(xdb.DB)
c(db)
}

xdb := &DB{Dialect: dialect, Executor: db}
return xdb, nil
}

func getDB(db *DB) *DB {
func getDB(db Executor) Executor {
if db != nil {
return db
}
return DefaultDB
}

// Set resets the current db to other.
func (db *DB) Set(other *DB) {
func (db *DB) Reset(other *DB) {
if other == nil {
*db = DB{}
db.Dialect = nil
db.Executor = nil
db.Interceptor = nil
} else {
*db = *other
db.Dialect = other.Dialect
db.Executor = other.Executor
db.Interceptor = other.Interceptor
}
}

Expand Down Expand Up @@ -174,15 +178,15 @@ func (db *DB) QueryRow(query string, args ...any) *sql.Row {
// ExecContext executes the sql statement.
func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (r sql.Result, err error) {
if query, args, err = db.Intercept(query, args); err == nil {
r, err = db.DB.ExecContext(ctx, query, args...)
r, err = db.Executor.ExecContext(ctx, query, args...)
}
return
}

// QueryContext executes the query sql statement.
func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) {
if query, args, err = db.Intercept(query, args); err == nil {
rows, err = db.DB.QueryContext(ctx, query, args...)
rows, err = db.Executor.QueryContext(ctx, query, args...)
}
return
}
Expand All @@ -193,5 +197,5 @@ func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *s
if err != nil {
panic(err)
}
return db.DB.QueryRowContext(ctx, query, args...)
return db.Executor.QueryRowContext(ctx, query, args...)
}
32 changes: 32 additions & 0 deletions executor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2024 xgfone
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sqlx

import (
"context"
"database/sql"
)

var (
_ Executor = new(DB)
_ Executor = new(sql.DB)
)

// Executor is used to execute the sql statement.
type Executor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}

0 comments on commit 1d42cc9

Please sign in to comment.