diff --git a/pkg/loader/util.go b/pkg/loader/util.go index dea714466..71d226db2 100644 --- a/pkg/loader/util.go +++ b/pkg/loader/util.go @@ -18,10 +18,12 @@ import ( gosql "database/sql" "fmt" "hash/crc32" + "math/rand" "net/url" "strconv" "strings" "sync/atomic" + "time" "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" @@ -146,7 +148,18 @@ func createDBWitSessions(dsn string, params map[string]string) (db *gosql.DB, er // CreateDBWithSQLMode return sql.DB func CreateDBWithSQLMode(user string, password string, host string, port int, tlsConfig *tls.Config, sqlMode *string, params map[string]string) (db *gosql.DB, err error) { - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4,utf8&interpolateParams=true&readTimeout=1m&multiStatements=true", user, password, host, port) + hosts := strings.Split(host, ",") + + if len(hosts) < 1 { + return nil, errors.Annotate(err, "You must provide at least one mysql address") + } + + random := rand.New(rand.NewSource(time.Now().UnixNano())) + + index := random.Intn(len(hosts)) + h := hosts[index] + + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4,utf8&interpolateParams=true&readTimeout=1m&multiStatements=true", user, password, h, port) if sqlMode != nil { // same as "set sql_mode = ''" dsn += "&sql_mode='" + url.QueryEscape(*sqlMode) + "'" diff --git a/tests/util/db.go b/tests/util/db.go index deca62bf8..41d57ed19 100644 --- a/tests/util/db.go +++ b/tests/util/db.go @@ -18,7 +18,9 @@ import ( "database/sql" "fmt" "log" + "math/rand" "net/url" + "strings" "time" "github.com/pingcap/errors" @@ -47,14 +49,25 @@ func (c *DBConfig) String() string { } // CreateDB create a mysql fd -func CreateDB(cfg DBConfig) (*sql.DB, error) { +func CreateDB(cfg DBConfig) (db *sql.DB, err error) { // just set to the same timezone so the timestamp field of mysql will return the same value // timestamp field will be display as the time zone of the Local time of drainer when write to kafka, so we set it to local time to pass CI now _, offset := time.Now().Zone() zone := fmt.Sprintf("'+%02d:00'", offset/3600) - dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8&interpolateParams=true&multiStatements=true&time_zone=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name, url.QueryEscape(zone)) - db, err := sql.Open("mysql", dbDSN) + hosts := strings.Split(cfg.Host, ",") + + if len(hosts) < 1 { + return nil, errors.Annotate(err, "You must provide at least one mysql address") + } + + random := rand.New(rand.NewSource(time.Now().UnixNano())) + + index := random.Intn(len(hosts)) + h := hosts[index] + + dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8&interpolateParams=true&multiStatements=true&time_zone=%s", cfg.User, cfg.Password, h, cfg.Port, cfg.Name, url.QueryEscape(zone)) + db, err = sql.Open("mysql", dbDSN) if err != nil { return nil, errors.Trace(err) }