Skip to content

Commit

Permalink
Restore postgres dialect (#10)
Browse files Browse the repository at this point in the history
In #1 we removed the Postgres dialect for simplicity. Per #7396 we'd
like to be able to readily experiment with Postgres, so this restores
it.
  • Loading branch information
jsha authored Mar 29, 2024
1 parent 52168b2 commit 02fd711
Show file tree
Hide file tree
Showing 12 changed files with 450 additions and 10 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ jobs:
runs-on: ubuntu-latest
container: golang:${{ matrix.containerGoVer }}
services:
postgres:
image: postgres
env:
POSTGRES_DB: gorptest
POSTGRES_USER: gorptest
POSTGRES_PASSWORD: gorptest
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 10
mysql:
image: mariadb:10.5
env:
Expand Down
45 changes: 42 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ func main() {

// fetch one row - note use of "post_id" instead of "Id" since column is aliased
//
// Postgres users should use $1 instead of ? placeholders
// See 'Known Issues' below

err = dbmap.SelectOne(&p2, "select * from posts where post_id=?", p2.Id)
checkErr(err, "SelectOne failed")
log.Println("p2 row:", p2)
Expand Down Expand Up @@ -368,7 +371,7 @@ if reflect.DeepEqual(list[0], expected) {
Borp provides a few convenience methods for selecting a single string or int64.

```go
// select single int64 from db
// select single int64 from db (use $1 instead of ? for postgresql)
i64, err := dbmap.SelectInt("select count(*) from foo where blah=?", blahVal)

// select single string from db:
Expand Down Expand Up @@ -579,6 +582,7 @@ interface that should be implemented per database vendor. Dialects
are provided for:

* MySQL
* PostgreSQL
* sqlite3

Each of these three databases pass the test suite. See `borp_test.go`
Expand Down Expand Up @@ -612,6 +616,41 @@ func customDriver() (*sql.DB, error) {

## Known Issues

### SQL placeholder portability

Different databases use different strings to indicate variable
placeholders in prepared SQL statements. Unlike some database
abstraction layers (such as JDBC), Go's `database/sql` does not
standardize this.

SQL generated by borp in the `Insert`, `Update`, `Delete`, and `Get`
methods delegates to a Dialect implementation for each database, and
will generate portable SQL.

Raw SQL strings passed to `Exec`, `Select`, `SelectOne`, `SelectInt`,
etc will not be parsed. Consequently you may have portability issues
if you write a query like this:

```go
// works on MySQL and Sqlite3, but not with Postgresql err :=
dbmap.SelectOne(&val, "select * from foo where id = ?", 30)
```

In `Select` and `SelectOne` you can use named parameters to work
around this. The following is portable:

```go
err := dbmap.SelectOne(&val, "select * from foo where id = :id",
map[string]interface{} { "id": 30})
```

Additionally, when using Postgres as your database, you should utilize
`$1` instead of `?` placeholders as utilizing `?` placeholders when
querying Postgres will result in `pq: operator does not exist`
errors. Alternatively, use `dbMap.Dialect.BindVar(varIdx)` to get the
proper variable binding for your dialect.


### time.Time and time zones

borp will pass `time.Time` fields through to the `database/sql`
Expand All @@ -630,7 +669,7 @@ To avoid any potential issues with timezone/DST, consider:

## Running the tests

The included tests may be run against MySQL or sqlite3.
The included tests may be run against MySQL, Postgres, or sqlite3.
You must set two environment variables so the test code knows which
driver to use, and how to connect to your database.

Expand All @@ -647,7 +686,7 @@ go test -bench="Bench" -benchtime 10
```

Valid `GORP_TEST_DIALECT` values are: "mysql"(for mymysql),
"gomysql"(for go-sql-driver), or "sqlite" See the
"gomysql"(for go-sql-driver), "postgres", or "sqlite" See the
`test_all.sh` script for examples of all 3 databases. This is the
script I run locally to test the library.

Expand Down
5 changes: 5 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ func TestWithCanceledContext(t *testing.T) {
}

switch driver {
case "postgres":
// pq doesn't return standard deadline exceeded error
if err.Error() != "pq: canceling statement due to user request" {
t.Errorf("expected context.DeadlineExceeded, got %v", err)
}
default:
if err != context.DeadlineExceeded {
t.Errorf("expected context.DeadlineExceeded, got %v", err)
Expand Down
3 changes: 3 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ func (m *DbMap) createIndexImpl(ctx context.Context, dialect reflect.Type,
}
s.WriteString(" index")
s.WriteString(fmt.Sprintf(" %s on %s", index.IndexName, table.TableName))
if dname := dialect.Name(); dname == "PostgresDialect" && index.IndexType != "" {
s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType))
}
s.WriteString(" (")
for x, col := range index.columns {
if x > 0 {
Expand Down
2 changes: 1 addition & 1 deletion dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type Dialect interface {
TruncateClause() string

// Bind variable string to use when forming SQL statements
// in many dbs it is "?".
// in many dbs it is "?", but Postgres appears to use $1
//
// i is a zero based index of the bind variable in this statement
//
Expand Down
150 changes: 150 additions & 0 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright 2012 James Cooper. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package borp

import (
"context"
"fmt"
"reflect"
"strings"
"time"
)

type PostgresDialect struct {
suffix string
LowercaseFields bool
}

func (d PostgresDialect) QuerySuffix() string { return ";" }

func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
if isAutoIncr {
return "serial"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if isAutoIncr {
return "bigserial"
}
return "bigint"
case reflect.Float64:
return "double precision"
case reflect.Float32:
return "real"
case reflect.Slice:
if val.Elem().Kind() == reflect.Uint8 {
return "bytea"
}
}

switch val.Name() {
case "NullInt64":
return "bigint"
case "NullFloat64":
return "double precision"
case "NullBool":
return "boolean"
case "Time", "NullTime":
return "timestamp with time zone"
}

if maxsize > 0 {
return fmt.Sprintf("varchar(%d)", maxsize)
} else {
return "text"
}

}

// Returns empty string
func (d PostgresDialect) AutoIncrStr() string {
return ""
}

func (d PostgresDialect) AutoIncrBindValue() string {
return "default"
}

func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return " returning " + d.QuoteField(col.ColumnName)
}

// Returns suffix
func (d PostgresDialect) CreateTableSuffix() string {
return d.suffix
}

func (d PostgresDialect) CreateIndexSuffix() string {
return "using"
}

func (d PostgresDialect) DropIndexSuffix() string {
return ""
}

func (d PostgresDialect) TruncateClause() string {
return "truncate"
}

func (d PostgresDialect) SleepClause(s time.Duration) string {
return fmt.Sprintf("pg_sleep(%f)", s.Seconds())
}

// Returns "$(i+1)"
func (d PostgresDialect) BindVar(i int) string {
return fmt.Sprintf("$%d", i+1)
}

func (d PostgresDialect) InsertAutoIncrToTarget(ctx context.Context, exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error {
rows, err := exec.QueryContext(ctx, insertSql, params...)
if err != nil {
return err
}
defer rows.Close()

if !rows.Next() {
return fmt.Errorf("No serial value returned for insert: %s Encountered error: %s", insertSql, rows.Err())
}
if err := rows.Scan(target); err != nil {
return err
}
if rows.Next() {
return fmt.Errorf("more than two serial value returned for insert: %s", insertSql)
}
return rows.Err()
}

func (d PostgresDialect) QuoteField(f string) string {
if d.LowercaseFields {
return `"` + strings.ToLower(f) + `"`
}
return `"` + f + `"`
}

func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string {
if strings.TrimSpace(schema) == "" {
return d.QuoteField(table)
}

return schema + "." + d.QuoteField(table)
}

func (d PostgresDialect) IfSchemaNotExists(command, schema string) string {
return fmt.Sprintf("%s if not exists", command)
}

func (d PostgresDialect) IfTableExists(command, schema, table string) string {
return fmt.Sprintf("%s if exists", command)
}

func (d PostgresDialect) IfTableNotExists(command, schema, table string) string {
return fmt.Sprintf("%s if not exists", command)
}
Loading

0 comments on commit 02fd711

Please sign in to comment.