diff --git a/sqlconnect/db_factory_test.go b/sqlconnect/db_factory_test.go new file mode 100644 index 0000000..3148514 --- /dev/null +++ b/sqlconnect/db_factory_test.go @@ -0,0 +1,14 @@ +package sqlconnect_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +func TestNewDB(t *testing.T) { + _, err := sqlconnect.NewDB("invalid", []byte{}) + require.Error(t, err, "should return error for invalid db name") +} diff --git a/sqlconnect/internal/bigquery/config.go b/sqlconnect/internal/bigquery/config.go index b665038..061aa96 100644 --- a/sqlconnect/internal/bigquery/config.go +++ b/sqlconnect/internal/bigquery/config.go @@ -2,7 +2,6 @@ package bigquery import ( "encoding/json" - "fmt" ) type Config struct { @@ -14,10 +13,6 @@ type Config struct { UseLegacyMappings bool `json:"useLegacyMappings"` } -func (c Config) ConnectionString() (string, error) { - return "", fmt.Errorf("not implemented") -} - func ParseConfig(configJSON json.RawMessage) (config Config, err error) { err = json.Unmarshal(configJSON, &config) return diff --git a/sqlconnect/internal/bigquery/db.go b/sqlconnect/internal/bigquery/db.go index 7daa44a..23ff60d 100644 --- a/sqlconnect/internal/bigquery/db.go +++ b/sqlconnect/internal/bigquery/db.go @@ -73,6 +73,7 @@ type DB struct { *base.DB } +// WithBigqueryClient runs the provided function by providing access to a native bigquery client, the underlying client that is used by the bigquery driver func (db *DB) WithBigqueryClient(ctx context.Context, f func(*bigquery.Client) error) error { sqlconn, err := db.Conn(ctx) if err != nil { diff --git a/sqlconnect/internal/bigquery/driver/connection.go b/sqlconnect/internal/bigquery/driver/connection.go index 83ca3de..0340135 100644 --- a/sqlconnect/internal/bigquery/driver/connection.go +++ b/sqlconnect/internal/bigquery/driver/connection.go @@ -68,10 +68,10 @@ func (connection *bigQueryConnection) ExecContext(ctx context.Context, query str return statement.ExecContext(ctx, args) } -func (connection *bigQueryConnection) Exec(query string, args []driver.Value) (driver.Result, error) { - statement := &bigQueryStatement{connection, query} - return statement.Exec(args) -} +// func (connection *bigQueryConnection) Exec(query string, args []driver.Value) (driver.Result, error) { +// statement := &bigQueryStatement{connection, query} +// return statement.Exec(args) +// } func (bigQueryConnection) CheckNamedValue(*driver.NamedValue) error { // TODO: Revise in the future diff --git a/sqlconnect/internal/bigquery/driver/driver.go b/sqlconnect/internal/bigquery/driver/driver.go index d2aedb3..2c31b83 100644 --- a/sqlconnect/internal/bigquery/driver/driver.go +++ b/sqlconnect/internal/bigquery/driver/driver.go @@ -23,10 +23,6 @@ type bigQueryConfig struct { } func (b bigQueryDriver) Open(uri string) (driver.Conn, error) { - if uri == "scanner" { - return &scannerConnection{}, nil - } - config, err := configFromUri(uri) if err != nil { return nil, err diff --git a/sqlconnect/internal/bigquery/driver/driver_test.go b/sqlconnect/internal/bigquery/driver/driver_test.go new file mode 100644 index 0000000..2a1dc58 --- /dev/null +++ b/sqlconnect/internal/bigquery/driver/driver_test.go @@ -0,0 +1,217 @@ +package driver_test + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/api/option" + + "github.com/rudderlabs/rudder-go-kit/testhelper/rand" + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/bigquery/driver" +) + +func TestBigqueryDriver(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + configJSON, ok := os.LookupEnv("BIGQUERY_TEST_ENVIRONMENT_CREDENTIALS") + if !ok { + t.Skip("skipping bigquery driver test due to lack of a test environment") + } + var c config + require.NoError(t, json.Unmarshal([]byte(configJSON), &c)) + + db := sql.OpenDB(driver.NewConnector(c.ProjectID, option.WithCredentialsJSON([]byte(c.CredentialsJSON)))) + t.Cleanup(func() { + require.NoError(t, db.Close(), "it should be able to close the database connection") + }) + + schema := GenerateTestSchema() + + t.Run("Ping", func(t *testing.T) { + require.NoError(t, db.Ping(), "it should be able to ping the database") + require.NoError(t, db.PingContext(ctx), "it should be able to ping the database using a context") + }) + + t.Run("Transaction unsupported", func(t *testing.T) { + t.Run("Begin", func(t *testing.T) { + _, err := db.Begin() + require.Error(t, err, "it should not be able to begin a transaction") + }) + + t.Run("BeginTx", func(t *testing.T) { + _, err := db.BeginTx(ctx, nil) + require.Error(t, err, "it should not be able to begin a transaction") + }) + }) + t.Run("Exec", func(t *testing.T) { + _, err := db.Exec(fmt.Sprintf("CREATE SCHEMA `%s`", schema)) + require.NoError(t, err, "it should be able to create a schema") + }) + + t.Run("ExecContext", func(t *testing.T) { + _, err := db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE `%s`.`test_table` (C1 INT, C2 ARRAY)", schema)) + require.NoError(t, err, "it should be able to create a table") + }) + + t.Run("prepared statement", func(t *testing.T) { + t.Run("QueryRow", func(t *testing.T) { + stmt, err := db.Prepare(fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`test_table`", schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + + var count int + err = stmt.QueryRow().Scan(&count) + require.NoError(t, err, "it should be able to execute a prepared statement") + }) + + t.Run("Exec", func(t *testing.T) { + stmt, err := db.Prepare(fmt.Sprintf("INSERT INTO `%s`.`test_table` (C1) VALUES (?)", schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + result, err := stmt.Exec(1) + require.NoError(t, err, "it should be able to execute a prepared statement") + + _, err = result.LastInsertId() + require.Error(t, err, "last insert id not supported") + + rowsAffected, err := result.RowsAffected() + require.NoError(t, err, "it should be able to get rows affected") + require.EqualValues(t, 0, rowsAffected, "rows affected should be 0 (not supported)") + }) + + t.Run("Query", func(t *testing.T) { + stmt, err := db.Prepare(fmt.Sprintf("SELECT C1 FROM `%s`.`test_table` WHERE C1 = ?", schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + rows, err := stmt.Query(1) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + err = rows.Scan(&c1) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 1, c1, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + + t.Run("Query with named parameters", func(t *testing.T) { + stmt, err := db.PrepareContext(ctx, fmt.Sprintf("SELECT C1, C2 FROM `%s`.`test_table` WHERE C1 = @c1_value", schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + rows, err := stmt.QueryContext(ctx, sql.Named("c1_value", 1)) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + + cols, err := rows.Columns() + require.NoError(t, err, "it should be able to get the columns") + require.EqualValues(t, []string{"C1", "C2"}, cols, "it should be able to get the correct columns") + + colTypes, err := rows.ColumnTypes() + require.NoError(t, err, "it should be able to get the column types") + require.Len(t, colTypes, 2, "it should be able to get the correct number of column types") + require.EqualValues(t, "INTEGER", colTypes[0].DatabaseTypeName(), "it should be able to get the correct column type") + require.EqualValues(t, "ARRAY", colTypes[1].DatabaseTypeName(), "it should be able to get the correct column type") + + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + var c2 any + err = rows.Scan(&c1, &c2) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 1, c1, "it should be able to get the correct value") + require.Nil(t, c2, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + }) + + t.Run("query", func(t *testing.T) { + t.Run("QueryRow", func(t *testing.T) { + var count int + err := db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`test_table`", schema)).Scan(&count) + require.NoError(t, err, "it should be able to execute a prepared statement") + require.Equal(t, 1, count, "it should be able to get the correct value") + }) + + t.Run("Exec", func(t *testing.T) { + result, err := db.Exec(fmt.Sprintf("INSERT INTO `%s`.`test_table` (C1) VALUES (?)", schema), 2) + require.NoError(t, err, "it should be able to execute a prepared statement") + rowsAffected, err := result.RowsAffected() + require.NoError(t, err, "it should be able to get rows affected") + require.EqualValues(t, 0, rowsAffected, "rows affected should be 0 (not supported)") + }) + + t.Run("Query", func(t *testing.T) { + rows, err := db.Query(fmt.Sprintf("SELECT C1 FROM `%s`.`test_table` WHERE C1 = ?", schema), 2) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + err = rows.Scan(&c1) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 2, c1, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + + t.Run("Query with named parameters", func(t *testing.T) { + rows, err := db.QueryContext(ctx, fmt.Sprintf("SELECT C1 FROM `%s`.`test_table` WHERE C1 = @c1_value", schema), sql.Named("c1_value", 2)) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + + cols, err := rows.Columns() + require.NoError(t, err, "it should be able to get the columns") + require.EqualValues(t, []string{"C1"}, cols, "it should be able to get the correct columns") + + colTypes, err := rows.ColumnTypes() + require.NoError(t, err, "it should be able to get the column types") + require.Len(t, colTypes, 1, "it should be able to get the correct number of column types") + require.EqualValues(t, "INTEGER", colTypes[0].DatabaseTypeName(), "it should be able to get the correct column type") + + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + err = rows.Scan(&c1) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 2, c1, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + }) +} + +type config struct { + ProjectID string `json:"project"` + CredentialsJSON string `json:"credentials"` +} + +func GenerateTestSchema() string { + return strings.ToLower(fmt.Sprintf("tbqdrv_%s_%d", rand.String(12), time.Now().Unix())) +} diff --git a/sqlconnect/internal/bigquery/driver/scanner.go b/sqlconnect/internal/bigquery/driver/scanner.go deleted file mode 100644 index ae623ce..0000000 --- a/sqlconnect/internal/bigquery/driver/scanner.go +++ /dev/null @@ -1,60 +0,0 @@ -package driver - -import ( - "context" - "database/sql/driver" - "errors" -) - -type scannerConnection struct{} - -func (scannerConnection) Prepare(query string) (driver.Stmt, error) { - return &scannerStatement{}, nil -} - -func (scannerConnection) Close() error { - return nil -} - -func (scannerConnection) Begin() (driver.Tx, error) { - return nil, nil -} - -func (scannerConnection) Ping(ctx context.Context) error { - return nil -} - -func (scannerConnection) CheckNamedValue(*driver.NamedValue) error { - return nil -} - -type scannerStatement struct{} - -func (scannerStatement) CheckNamedValue(*driver.NamedValue) error { - return nil -} - -func (s scannerStatement) Close() error { - return nil -} - -func (s scannerStatement) NumInput() int { - return 1 -} - -func (s scannerStatement) Exec(args []driver.Value) (driver.Result, error) { - return nil, errors.New("execution is not supported") -} - -func (s scannerStatement) Query(args []driver.Value) (driver.Rows, error) { - if len(args) < 1 { - return nil, errors.New("scanner arguments should have an argument with rows") - } - - rows, ok := args[0].(driver.Rows) - if !ok { - return nil, errors.New("scanner arguments should have an argument with rows") - } - - return rows, nil -} diff --git a/sqlconnect/internal/bigquery/driver/statement.go b/sqlconnect/internal/bigquery/driver/statement.go index 433a006..26b0e34 100644 --- a/sqlconnect/internal/bigquery/driver/statement.go +++ b/sqlconnect/internal/bigquery/driver/statement.go @@ -3,11 +3,16 @@ package driver import ( "context" "database/sql/driver" + "regexp" + "strings" "cloud.google.com/go/bigquery" + "github.com/samber/lo" "github.com/sirupsen/logrus" ) +var namedParamsRegexp = regexp.MustCompile(`@[\w]+`) + type bigQueryStatement struct { connection *bigQueryConnection query string @@ -18,7 +23,12 @@ func (statement bigQueryStatement) Close() error { } func (statement bigQueryStatement) NumInput() int { - return 0 + params := strings.Count(statement.query, "?") + if params > 0 { + return params + } + uniqueMatches := lo.Uniq(namedParamsRegexp.FindAllString(statement.query, -1)) + return len(uniqueMatches) } func (bigQueryStatement) CheckNamedValue(*driver.NamedValue) error { @@ -26,14 +36,6 @@ func (bigQueryStatement) CheckNamedValue(*driver.NamedValue) error { } func (statement *bigQueryStatement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - // logrus.Debugf("exec:%s", statement.query) - - // if logrus.IsLevelEnabled(logrus.DebugLevel) { - // for _, arg := range args { - // logrus.Debugf("- param:%s", convertParameterToValue(arg)) - // } - // } - query, err := statement.buildQuery(convertParameters(args)) if err != nil { return nil, err @@ -48,14 +50,6 @@ func (statement *bigQueryStatement) ExecContext(ctx context.Context, args []driv } func (statement *bigQueryStatement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - // logrus.Debugf("query:%s", statement.query) - - // if logrus.IsLevelEnabled(logrus.DebugLevel) { - // for _, arg := range args { - // logrus.Debugf("- param:%s", convertParameterToValue(arg)) - // } - // } - query, err := statement.buildQuery(convertParameters(args)) if err != nil { return nil, err @@ -72,14 +66,6 @@ func (statement *bigQueryStatement) QueryContext(ctx context.Context, args []dri } func (statement bigQueryStatement) Exec(args []driver.Value) (driver.Result, error) { - // logrus.Debugf("exec:%s", statement.query) - - // if logrus.IsLevelEnabled(logrus.DebugLevel) { - // for _, arg := range args { - // logrus.Debugf("- param:%s", convertParameterToValue(arg)) - // } - // } - query, err := statement.buildQuery(args) if err != nil { return nil, err @@ -94,13 +80,6 @@ func (statement bigQueryStatement) Exec(args []driver.Value) (driver.Result, err } func (statement bigQueryStatement) Query(args []driver.Value) (driver.Rows, error) { - // logrus.Debugf("query:%s", statement.query) - // if logrus.IsLevelEnabled(logrus.DebugLevel) { - // for _, arg := range args { - // logrus.Debugf("- param:%s", convertParameterToValue(arg)) - // } - // } - query, err := statement.buildQuery(args) if err != nil { return nil, err @@ -153,8 +132,6 @@ func buildParameter(arg driver.Value, parameters []bigquery.QueryParameter) []bi } func buildParameterFromNamedValue(namedValue driver.NamedValue, parameters []bigquery.QueryParameter) []bigquery.QueryParameter { - // logrus.Debugf("-param:%s=%s", namedValue.Name, namedValue.Value) - if namedValue.Name == "" { return append(parameters, bigquery.QueryParameter{ Value: namedValue.Value, diff --git a/sqlconnect/internal/integration_test/db_integration_test_scenario.go b/sqlconnect/internal/integration_test/db_integration_test_scenario.go index ca78c9a..7d20fe9 100644 --- a/sqlconnect/internal/integration_test/db_integration_test_scenario.go +++ b/sqlconnect/internal/integration_test/db_integration_test_scenario.go @@ -3,10 +3,12 @@ package integrationtest import ( "context" "encoding/json" + "errors" "fmt" "os" "regexp" "strings" + "sync" "testing" "text/template" "time" @@ -31,6 +33,11 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe cancelledCtx, cancel := context.WithCancel(context.Background()) cancel() + t.Run("using invalid configuration", func(t *testing.T) { + _, err := sqlconnect.NewDB(warehouse, []byte("invalid")) + require.Error(t, err, "it should return error for invalid configuration") + }) + t.Run("ping", func(t *testing.T) { t.Run("with context cancelled", func(t *testing.T) { err := db.PingContext(cancelledCtx) @@ -349,6 +356,45 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe require.JSONEq(t, string(expectedRowsJSON), string(actualRowsJSON), "it should return the correct rows: "+string(actualRowsJSON)) }) + + t.Run("async query", func(t *testing.T) { + t.Run("QueryJSONMapAsync without error", func(t *testing.T) { + ch, leave := sqlconnect.QueryJSONMapAsync(ctx, db, selectSQL) + defer leave() + for row := range ch { + require.NoError(t, row.Err, "it should be able to scan a row") + } + }) + + t.Run("QueryJSONMapAsync with context cancelled", func(t *testing.T) { + ch, leave := sqlconnect.QueryJSONMapAsync(cancelledCtx, db, selectSQL) + defer leave() + var iterations int + for row := range ch { + iterations++ + require.Error(t, row.Err) + require.True(t, errors.Is(row.Err, context.Canceled)) + } + require.Equal(t, 1, iterations, "it should only iterate once") + }) + + t.Run("QueryJSONMapAsync with leave", func(t *testing.T) { + ch, leave := sqlconnect.QueryJSONMapAsync(cancelledCtx, db, selectSQL) + leave() + time.Sleep(10 * time.Millisecond) + var wg sync.WaitGroup + var iterations int + wg.Add(1) + go func() { + for range ch { + iterations++ + } + wg.Done() + }() + wg.Wait() + require.Equal(t, 0, iterations, "it shouldn't iterate after leaving the channel") + }) + }) }) } diff --git a/sqlconnect/internal/mysql/config_test.go b/sqlconnect/internal/mysql/config_test.go new file mode 100644 index 0000000..34163a4 --- /dev/null +++ b/sqlconnect/internal/mysql/config_test.go @@ -0,0 +1,50 @@ +package mysql_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/mysql" +) + +func TestConfig(t *testing.T) { + t.Run("host validation", func(t *testing.T) { + _, err := mysql.ParseConfig([]byte(`{"host": "localhost"}`)) + require.Error(t, err, "should not allow localhost") + + _, err = mysql.ParseConfig([]byte(`{"host": "127.0.0.1"}`)) + require.Error(t, err, "should not allow 127.0.0.1") + + _, err = mysql.ParseConfig([]byte(`{"host": "0.0.0.0"}`)) + require.Error(t, err, "should not allow 0.0.0.0") + }) + + t.Run("tls", func(t *testing.T) { + t.Run("empty ssl mode", func(t *testing.T) { + c := mysql.Config{SSLMode: ""} + tls, err := c.TLS() + require.NoError(t, err, "should allow empty tls") + require.Equal(t, "false", tls, "should return false") + }) + + t.Run("skip-verify ssl mode", func(t *testing.T) { + c := mysql.Config{SSLMode: "skip-verify"} + tls, err := c.TLS() + require.NoError(t, err, "should allow skip-verify tls") + require.Equal(t, "skip-verify", tls, "should return skip-verify") + }) + t.Run("false ssl mode", func(t *testing.T) { + c := mysql.Config{SSLMode: "false"} + tls, err := c.TLS() + require.NoError(t, err, "should allow false tls") + require.Equal(t, "false", tls, "should return false") + }) + + t.Run("other ssl mode", func(t *testing.T) { + c := mysql.Config{SSLMode: "other"} + _, err := c.TLS() + require.Error(t, err, "should not allow other tls") + }) + }) +} diff --git a/sqlconnect/internal/trino/config.go b/sqlconnect/internal/trino/config.go index deace77..be735af 100644 --- a/sqlconnect/internal/trino/config.go +++ b/sqlconnect/internal/trino/config.go @@ -22,7 +22,7 @@ type Config struct { UseLegacyMappings bool `json:"useLegacyMappings"` } -func (c Config) ConnectionString() string { +func (c Config) ConnectionString() (string, error) { uri := func() string { hostport := c.Host if c.Port != 0 { @@ -39,13 +39,11 @@ func (c Config) ConnectionString() string { ServerURI: uri, Catalog: c.Catalog, } - dsn, err := config.FormatDSN() if err != nil { - _ = fmt.Errorf("error formatting dsn %v", err) - return "nil" + return "", fmt.Errorf("formatting dsn: %w", err) } - return dsn + return dsn, nil } func ParseConfig(input json.RawMessage) (config Config, err error) { diff --git a/sqlconnect/internal/trino/db.go b/sqlconnect/internal/trino/db.go index 37adba8..01d7056 100644 --- a/sqlconnect/internal/trino/db.go +++ b/sqlconnect/internal/trino/db.go @@ -24,7 +24,11 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { return nil, err } - db, err := sql.Open(DatabaseType, config.ConnectionString()) + dsn, err := config.ConnectionString() + if err != nil { + return nil, err + } + db, err := sql.Open(DatabaseType, dsn) if err != nil { return nil, err }