diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 55e09eb..470a1bb 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -21,6 +21,18 @@ jobs: runs-on: ubuntu-latest container: golang:${{ matrix.containerGoVer }} services: + postgres: + image: postgres + env: + POSTGRES_DB: gorptest + POSTGRES_USER: gorptest + POSTGRES_PASSWORD: gorptest + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 10 + mysql: image: mariadb:10.5 env: diff --git a/README.md b/README.md index f140696..34115fd 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,9 @@ func main() { // fetch one row - note use of "post_id" instead of "Id" since column is aliased // + // Postgres users should use $1 instead of ? placeholders + // See 'Known Issues' below + err = dbmap.SelectOne(&p2, "select * from posts where post_id=?", p2.Id) checkErr(err, "SelectOne failed") log.Println("p2 row:", p2) @@ -368,7 +371,7 @@ if reflect.DeepEqual(list[0], expected) { Borp provides a few convenience methods for selecting a single string or int64. ```go -// select single int64 from db +// select single int64 from db (use $1 instead of ? for postgresql) i64, err := dbmap.SelectInt("select count(*) from foo where blah=?", blahVal) // select single string from db: @@ -579,6 +582,7 @@ interface that should be implemented per database vendor. Dialects are provided for: * MySQL +* PostgreSQL * sqlite3 Each of these three databases pass the test suite. See `borp_test.go` @@ -612,6 +616,41 @@ func customDriver() (*sql.DB, error) { ## Known Issues +### SQL placeholder portability + +Different databases use different strings to indicate variable +placeholders in prepared SQL statements. Unlike some database +abstraction layers (such as JDBC), Go's `database/sql` does not +standardize this. + +SQL generated by borp in the `Insert`, `Update`, `Delete`, and `Get` +methods delegates to a Dialect implementation for each database, and +will generate portable SQL. + +Raw SQL strings passed to `Exec`, `Select`, `SelectOne`, `SelectInt`, +etc will not be parsed. Consequently you may have portability issues +if you write a query like this: + +```go +// works on MySQL and Sqlite3, but not with Postgresql err := +dbmap.SelectOne(&val, "select * from foo where id = ?", 30) +``` + +In `Select` and `SelectOne` you can use named parameters to work +around this. The following is portable: + +```go +err := dbmap.SelectOne(&val, "select * from foo where id = :id", +map[string]interface{} { "id": 30}) +``` + +Additionally, when using Postgres as your database, you should utilize +`$1` instead of `?` placeholders as utilizing `?` placeholders when +querying Postgres will result in `pq: operator does not exist` +errors. Alternatively, use `dbMap.Dialect.BindVar(varIdx)` to get the +proper variable binding for your dialect. + + ### time.Time and time zones borp will pass `time.Time` fields through to the `database/sql` @@ -630,7 +669,7 @@ To avoid any potential issues with timezone/DST, consider: ## Running the tests -The included tests may be run against MySQL or sqlite3. +The included tests may be run against MySQL, Postgres, or sqlite3. You must set two environment variables so the test code knows which driver to use, and how to connect to your database. @@ -647,7 +686,7 @@ go test -bench="Bench" -benchtime 10 ``` Valid `GORP_TEST_DIALECT` values are: "mysql"(for mymysql), -"gomysql"(for go-sql-driver), or "sqlite" See the +"gomysql"(for go-sql-driver), "postgres", or "sqlite" See the `test_all.sh` script for examples of all 3 databases. This is the script I run locally to test the library. diff --git a/context_test.go b/context_test.go index d39a52e..a044651 100644 --- a/context_test.go +++ b/context_test.go @@ -62,6 +62,11 @@ func TestWithCanceledContext(t *testing.T) { } switch driver { + case "postgres": + // pq doesn't return standard deadline exceeded error + if err.Error() != "pq: canceling statement due to user request" { + t.Errorf("expected context.DeadlineExceeded, got %v", err) + } default: if err != context.DeadlineExceeded { t.Errorf("expected context.DeadlineExceeded, got %v", err) diff --git a/db.go b/db.go index 5399414..d6e50a8 100644 --- a/db.go +++ b/db.go @@ -102,6 +102,9 @@ func (m *DbMap) createIndexImpl(ctx context.Context, dialect reflect.Type, } s.WriteString(" index") s.WriteString(fmt.Sprintf(" %s on %s", index.IndexName, table.TableName)) + if dname := dialect.Name(); dname == "PostgresDialect" && index.IndexType != "" { + s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType)) + } s.WriteString(" (") for x, col := range index.columns { if x > 0 { diff --git a/dialect.go b/dialect.go index 2d48e06..a2591ab 100644 --- a/dialect.go +++ b/dialect.go @@ -45,7 +45,7 @@ type Dialect interface { TruncateClause() string // Bind variable string to use when forming SQL statements - // in many dbs it is "?". + // in many dbs it is "?", but Postgres appears to use $1 // // i is a zero based index of the bind variable in this statement // diff --git a/dialect_postgres.go b/dialect_postgres.go new file mode 100644 index 0000000..937f81e --- /dev/null +++ b/dialect_postgres.go @@ -0,0 +1,150 @@ +// Copyright 2012 James Cooper. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package borp + +import ( + "context" + "fmt" + "reflect" + "strings" + "time" +) + +type PostgresDialect struct { + suffix string + LowercaseFields bool +} + +func (d PostgresDialect) QuerySuffix() string { return ";" } + +func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + if isAutoIncr { + return "serial" + } + return "integer" + case reflect.Int64, reflect.Uint64: + if isAutoIncr { + return "bigserial" + } + return "bigint" + case reflect.Float64: + return "double precision" + case reflect.Float32: + return "real" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "bytea" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "double precision" + case "NullBool": + return "boolean" + case "Time", "NullTime": + return "timestamp with time zone" + } + + if maxsize > 0 { + return fmt.Sprintf("varchar(%d)", maxsize) + } else { + return "text" + } + +} + +// Returns empty string +func (d PostgresDialect) AutoIncrStr() string { + return "" +} + +func (d PostgresDialect) AutoIncrBindValue() string { + return "default" +} + +func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return " returning " + d.QuoteField(col.ColumnName) +} + +// Returns suffix +func (d PostgresDialect) CreateTableSuffix() string { + return d.suffix +} + +func (d PostgresDialect) CreateIndexSuffix() string { + return "using" +} + +func (d PostgresDialect) DropIndexSuffix() string { + return "" +} + +func (d PostgresDialect) TruncateClause() string { + return "truncate" +} + +func (d PostgresDialect) SleepClause(s time.Duration) string { + return fmt.Sprintf("pg_sleep(%f)", s.Seconds()) +} + +// Returns "$(i+1)" +func (d PostgresDialect) BindVar(i int) string { + return fmt.Sprintf("$%d", i+1) +} + +func (d PostgresDialect) InsertAutoIncrToTarget(ctx context.Context, exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error { + rows, err := exec.QueryContext(ctx, insertSql, params...) + if err != nil { + return err + } + defer rows.Close() + + if !rows.Next() { + return fmt.Errorf("No serial value returned for insert: %s Encountered error: %s", insertSql, rows.Err()) + } + if err := rows.Scan(target); err != nil { + return err + } + if rows.Next() { + return fmt.Errorf("more than two serial value returned for insert: %s", insertSql) + } + return rows.Err() +} + +func (d PostgresDialect) QuoteField(f string) string { + if d.LowercaseFields { + return `"` + strings.ToLower(f) + `"` + } + return `"` + f + `"` +} + +func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return d.QuoteField(table) + } + + return schema + "." + d.QuoteField(table) +} + +func (d PostgresDialect) IfSchemaNotExists(command, schema string) string { + return fmt.Sprintf("%s if not exists", command) +} + +func (d PostgresDialect) IfTableExists(command, schema, table string) string { + return fmt.Sprintf("%s if exists", command) +} + +func (d PostgresDialect) IfTableNotExists(command, schema, table string) string { + return fmt.Sprintf("%s if not exists", command) +} diff --git a/dialect_postgres_test.go b/dialect_postgres_test.go new file mode 100644 index 0000000..4a2f674 --- /dev/null +++ b/dialect_postgres_test.go @@ -0,0 +1,158 @@ +// Copyright 2012 James Cooper. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +//go:build !integration +// +build !integration + +package borp_test + +import ( + "database/sql" + "reflect" + "testing" + "time" + + "github.com/letsencrypt/borp" + "github.com/poy/onpar" + "github.com/poy/onpar/expect" + "github.com/poy/onpar/matchers" +) + +type postgresTestContext struct { + expect expect.Expectation + dialect borp.PostgresDialect +} + +func TestPostgresDialect(t *testing.T) { + o := onpar.BeforeEach(onpar.New(t), func(t *testing.T) postgresTestContext { + return postgresTestContext{ + expect.New(t), + borp.PostgresDialect{ + LowercaseFields: false, + }, + } + }) + + defer o.Run() + + o.Group("ToSqlType", func() { + tests := []struct { + name string + value interface{} + maxSize int + autoIncr bool + expected string + }{ + {"bool", true, 0, false, "boolean"}, + {"int8", int8(1), 0, false, "integer"}, + {"uint8", uint8(1), 0, false, "integer"}, + {"int16", int16(1), 0, false, "integer"}, + {"uint16", uint16(1), 0, false, "integer"}, + {"int32", int32(1), 0, false, "integer"}, + {"int (treated as int32)", int(1), 0, false, "integer"}, + {"uint32", uint32(1), 0, false, "integer"}, + {"uint (treated as uint32)", uint(1), 0, false, "integer"}, + {"int64", int64(1), 0, false, "bigint"}, + {"uint64", uint64(1), 0, false, "bigint"}, + {"float32", float32(1), 0, false, "real"}, + {"float64", float64(1), 0, false, "double precision"}, + {"[]uint8", []uint8{1}, 0, false, "bytea"}, + {"NullInt64", sql.NullInt64{}, 0, false, "bigint"}, + {"NullFloat64", sql.NullFloat64{}, 0, false, "double precision"}, + {"NullBool", sql.NullBool{}, 0, false, "boolean"}, + {"Time", time.Time{}, 0, false, "timestamp with time zone"}, + {"default-size string", "", 0, false, "text"}, + {"sized string", "", 50, false, "varchar(50)"}, + {"large string", "", 1024, false, "varchar(1024)"}, + } + for _, t := range tests { + o.Spec(t.name, func(tcx postgresTestContext) { + typ := reflect.TypeOf(t.value) + sqlType := tcx.dialect.ToSqlType(typ, t.maxSize, t.autoIncr) + tcx.expect(sqlType).To(matchers.Equal(t.expected)) + }) + } + }) + + o.Spec("AutoIncrStr", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.AutoIncrStr()).To(matchers.Equal("")) + }) + + o.Spec("AutoIncrBindValue", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.AutoIncrBindValue()).To(matchers.Equal("default")) + }) + + o.Spec("AutoIncrInsertSuffix", func(tcx postgresTestContext) { + cm := borp.ColumnMap{ + ColumnName: "foo", + } + tcx.expect(tcx.dialect.AutoIncrInsertSuffix(&cm)).To(matchers.Equal(` returning "foo"`)) + }) + + o.Spec("CreateTableSuffix", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.CreateTableSuffix()).To(matchers.Equal("")) + }) + + o.Spec("CreateIndexSuffix", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.CreateIndexSuffix()).To(matchers.Equal("using")) + }) + + o.Spec("DropIndexSuffix", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.DropIndexSuffix()).To(matchers.Equal("")) + }) + + o.Spec("TruncateClause", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.TruncateClause()).To(matchers.Equal("truncate")) + }) + + o.Spec("SleepClause", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.SleepClause(1 * time.Second)).To(matchers.Equal("pg_sleep(1.000000)")) + tcx.expect(tcx.dialect.SleepClause(100 * time.Millisecond)).To(matchers.Equal("pg_sleep(0.100000)")) + }) + + o.Spec("BindVar", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.BindVar(0)).To(matchers.Equal("$1")) + tcx.expect(tcx.dialect.BindVar(4)).To(matchers.Equal("$5")) + }) + + o.Group("QuoteField", func() { + o.Spec("By default, case is preserved", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`)) + tcx.expect(tcx.dialect.QuoteField("bar")).To(matchers.Equal(`"bar"`)) + }) + + o.Group("With LowercaseFields set to true", func() { + o := onpar.BeforeEach(o, func(tcx postgresTestContext) postgresTestContext { + tcx.dialect.LowercaseFields = true + return postgresTestContext{tcx.expect, tcx.dialect} + }) + + o.Spec("fields are lowercased", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.QuoteField("Foo")).To(matchers.Equal(`"foo"`)) + }) + }) + }) + + o.Group("QuotedTableForQuery", func() { + o.Spec("using the default schema", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.QuotedTableForQuery("", "foo")).To(matchers.Equal(`"foo"`)) + }) + + o.Spec("with a supplied schema", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`foo."bar"`)) + }) + }) + + o.Spec("IfSchemaNotExists", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.IfSchemaNotExists("foo", "bar")).To(matchers.Equal("foo if not exists")) + }) + + o.Spec("IfTableExists", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.IfTableExists("foo", "bar", "baz")).To(matchers.Equal("foo if exists")) + }) + + o.Spec("IfTableNotExists", func(tcx postgresTestContext) { + tcx.expect(tcx.dialect.IfTableNotExists("foo", "bar", "baz")).To(matchers.Equal("foo if not exists")) + }) +} diff --git a/go.mod b/go.mod index 930002d..f2a2c0b 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index fb6c220..d5d5763 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/gorp.go b/gorp.go index a78caa7..c5c32d8 100644 --- a/gorp.go +++ b/gorp.go @@ -281,8 +281,8 @@ func fieldByName(val reflect.Value, fieldName string) *reflect.Value { return &f } - // try to find by case insensitive match in the case where columns are - // aliased in the sql + // try to find by case insensitive match - only the Postgres driver + // seems to require this fieldNameL := strings.ToLower(fieldName) fieldCount := val.NumField() t := val.Type() diff --git a/gorp_test.go b/gorp_test.go index 9d67632..668ab51 100644 --- a/gorp_test.go +++ b/gorp_test.go @@ -27,6 +27,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/letsencrypt/borp" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) @@ -34,6 +35,7 @@ var ( // verify interface compliance _ = []borp.Dialect{ borp.SqliteDialect{}, + borp.PostgresDialect{}, borp.MySQLDialect{}, } @@ -1426,6 +1428,9 @@ func TestTransaction(t *testing.T) { } func TestTransactionExecNamed(t *testing.T) { + if os.Getenv("GORP_TEST_DIALECT") == "postgres" { + return + } dbmap := initDBMap(t) defer dropAndClose(dbmap) trans, err := dbmap.BeginTx(context.Background()) @@ -1480,6 +1485,56 @@ func TestTransactionExecNamed(t *testing.T) { } } +func TestTransactionExecNamedPostgres(t *testing.T) { + if os.Getenv("GORP_TEST_DIALECT") != "postgres" { + return + } + ctx := context.Background() + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + trans, err := dbmap.BeginTx(ctx) + if err != nil { + panic(err) + } + // exec should support named params + args := map[string]interface{}{ + "created": 100, + "updated": 200, + "memo": "zzTest", + "personID": 0, + "isPaid": false, + } + _, err = trans.ExecContext(ctx, `INSERT INTO invoice_test ("Created", "Updated", "Memo", "PersonId", "IsPaid") Values(:created, :updated, :memo, :personID, :isPaid)`, args) + if err != nil { + panic(err) + } + var checkMemo = func(want string) { + args := map[string]interface{}{ + "memo": want, + } + memo, err := trans.SelectStr(ctx, `select "Memo" from invoice_test where "Memo" = :memo`, args) + if err != nil { + panic(err) + } + if memo != want { + t.Errorf("%q != %q", want, memo) + } + } + checkMemo("zzTest") + + // exec should still work with ? params + _, err = trans.ExecContext(ctx, `INSERT INTO invoice_test ("Created", "Updated", "Memo", "PersonId", "IsPaid") Values($1, $2, $3, $4, $5)`, 10, 15, "yyTest", 0, true) + + if err != nil { + panic(err) + } + checkMemo("yyTest") + err = trans.Commit() + if err != nil { + panic(err) + } +} + func TestSavepoint(t *testing.T) { dbmap := initDBMap(t) defer dropAndClose(dbmap) @@ -2480,10 +2535,18 @@ func BenchmarkNativeCrud(b *testing.B) { columnPersonId := columnName(dbmap, Invoice{}, "PersonId") b.StartTimer() - insert := "insert into invoice_test (" + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + ") values (?, ?, ?, ?)" - sel := "select " + columnId + ", " + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + " from invoice_test where " + columnId + "=?" - update := "update invoice_test set " + columnCreated + "=?, " + columnUpdated + "=?, " + columnMemo + "=?, " + columnPersonId + "=? where " + columnId + "=?" - delete := "delete from invoice_test where " + columnId + "=?" + var insert, sel, update, delete string + if os.Getenv("GORP_TEST_DIALECT") != "postgres" { + insert = "insert into invoice_test (" + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + ") values (?, ?, ?, ?)" + sel = "select " + columnId + ", " + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + " from invoice_test where " + columnId + "=?" + update = "update invoice_test set " + columnCreated + "=?, " + columnUpdated + "=?, " + columnMemo + "=?, " + columnPersonId + "=? where " + columnId + "=?" + delete = "delete from invoice_test where " + columnId + "=?" + } else { + insert = "insert into invoice_test (" + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + ") values ($1, $2, $3, $4)" + sel = "select " + columnId + ", " + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + " from invoice_test where " + columnId + "=$1" + update = "update invoice_test set " + columnCreated + "=$1, " + columnUpdated + "=$2, " + columnMemo + "=$3, " + columnPersonId + "=$4 where " + columnId + "=$5" + delete = "delete from invoice_test where " + columnId + "=$1" + } inv := &Invoice{0, 100, 200, "my memo", 0, false} @@ -2674,6 +2737,8 @@ func dialectAndDriver() (borp.Dialect, string) { // seems mostly unmaintained recently. We've dropped it from tests, at least for // now. return borp.MySQLDialect{"InnoDB", "UTF8"}, "mysql" + case "postgres": + return borp.PostgresDialect{}, "postgres" case "sqlite": return borp.SqliteDialect{}, "sqlite3" } diff --git a/test_all.sh b/test_all.sh index c3dd424..91007d6 100755 --- a/test_all.sh +++ b/test_all.sh @@ -6,6 +6,11 @@ echo "Running unit tests" go test -race +echo "Testing against postgres" +export GORP_TEST_DSN="host=postgres user=gorptest password=gorptest dbname=gorptest sslmode=disable" +export GORP_TEST_DIALECT=postgres +go test -tags integration $GOBUILDFLAG $@ . + echo "Testing against sqlite" export GORP_TEST_DSN=/tmp/gorptest.bin export GORP_TEST_DIALECT=sqlite