Skip to content

Commit

Permalink
fixup! feat: sqlconnect library
Browse files Browse the repository at this point in the history
  • Loading branch information
atzoum committed Feb 19, 2024
1 parent 55718a9 commit eccbbff
Show file tree
Hide file tree
Showing 49 changed files with 1,081 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ jobs:
run: |
go install github.com/wadey/gocovmerge@latest
gocovmerge */profile.out > profile.out
- uses: codecov/codecov-action@v3
- uses: codecov/codecov-action@v4
with:
fail_ci_if_error: true
files: ./profile.out
Expand Down
35 changes: 9 additions & 26 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: help default test test-run test-teardown generate lint fmt
.PHONY: help default test test-run generate lint fmt

GO=go
LDFLAGS?=-s -w
Expand All @@ -9,47 +9,30 @@ default: lint
generate: install-tools
$(GO) generate ./...

test: install-tools test-run test-teardown
test: install-tools test-run

test-run: ## Run all unit tests
ifeq ($(filter 1,$(debug) $(RUNNER_DEBUG)),)
$(eval TEST_CMD = SLOW=0 gotestsum --format pkgname-and-test-fails --)
$(eval TEST_OPTIONS = -p=1 -v -failfast -shuffle=on -coverprofile=profile.out -covermode=count -coverpkg=./... -vet=all --timeout=15m)
$(eval TEST_CMD = gotestsum --format pkgname-and-test-fails --)
$(eval TEST_OPTIONS = -p=1 -v -failfast -shuffle=on -coverprofile=profile.out -covermode=atomic -coverpkg=./... -vet=all --timeout=30m)
else
$(eval TEST_CMD = SLOW=0 go test)
$(eval TEST_OPTIONS = -p=1 -v -failfast -shuffle=on -coverprofile=profile.out -covermode=count -coverpkg=./... -vet=all --timeout=15m)
$(eval TEST_OPTIONS = -p=1 -v -failfast -shuffle=on -coverprofile=profile.out -covermode=atomic -coverpkg=./... -vet=all --timeout=30m)
endif
ifdef package
ifdef exclude
$(eval FILES = `go list ./$(package)/... | egrep -iv '$(exclude)'`)
$(TEST_CMD) -count=1 $(TEST_OPTIONS) $(FILES) && touch $(TESTFILE) || true
$(TEST_CMD) -count=1 $(TEST_OPTIONS) $(FILES) && touch $(TESTFILE)
else
$(TEST_CMD) $(TEST_OPTIONS) ./$(package)/... && touch $(TESTFILE) || true
$(TEST_CMD) $(TEST_OPTIONS) ./$(package)/... && touch $(TESTFILE)
endif
else ifdef exclude
$(eval FILES = `go list ./... | egrep -iv '$(exclude)'`)
$(TEST_CMD) -count=1 $(TEST_OPTIONS) $(FILES) && touch $(TESTFILE) || true
$(TEST_CMD) -count=1 $(TEST_OPTIONS) $(FILES) && touch $(TESTFILE)
else
$(TEST_CMD) -count=1 $(TEST_OPTIONS) ./... && touch $(TESTFILE) || true
$(TEST_CMD) -count=1 $(TEST_OPTIONS) ./... && touch $(TESTFILE)
endif

test-teardown:
@if [ -f "$(TESTFILE)" ]; then \
echo "Tests passed, tearing down..." ;\
rm -f $(TESTFILE) ;\
echo "mode: atomic" > coverage.txt ;\
find . -name "profile.out" | while read file; do grep -v 'mode: atomic' $${file} >> coverage.txt; rm -f $${file}; done ;\
else \
rm -f coverage.txt coverage.html ; find . -name "profile.out" | xargs rm -f ;\
echo "Tests failed :-(" ;\
exit 1 ;\
fi

coverage:
go tool cover -html=coverage.txt -o coverage.html

test-with-coverage: test coverage

help: ## Show the available commands
@grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' ./Makefile | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'

Expand Down
1 change: 1 addition & 0 deletions sqlconnect/async.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func QueryAsync[T any](ctx context.Context, db DB, mapper RowMapper[T], query st
s := &async.Sender[ValueOrError[T]]{}
ctx, ch, leave = s.Begin(ctx)
go func() {
defer s.Close()
rows, err := db.QueryContext(ctx, query, params...)
if err != nil {
s.Send(ValueOrError[T]{Err: fmt.Errorf("executing query: %w", err)})
Expand Down
File renamed without changes.
2 changes: 2 additions & 0 deletions sqlconnect/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,5 @@ type Dialect interface {
// FormatTableName formats a table name, typically by lower or upper casing it, depending on the database
FormatTableName(name string) string
}

// var ErrNotSupported = errors.New("sqlconnect: operation not supported")
35 changes: 26 additions & 9 deletions sqlconnect/internal/base/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ func NewDB(db *sql.DB, rudderSchema string, opts ...Option) *DB {
d := &DB{
DB: db,
Dialect: dialect{},
columnTypeMapper: func(databaseTypeName string) string {
return databaseTypeName
columnTypeMapper: func(c ColumnType) string {
return c.DatabaseTypeName()
},
jsonRowMapper: func(databaseTypeName string, value any) any {
return value
Expand All @@ -26,15 +26,15 @@ func NewDB(db *sql.DB, rudderSchema string, opts ...Option) *DB {
return "SELECT schema_name FROM information_schema.schemata", "schema_name"
},
SchemaExists: func(schema string) string {
return fmt.Sprintf("SELECT EXISTS (SELECT schema_name FROM information_schema.schemata where schema_name = '%[1]s')", schema)
return fmt.Sprintf("SELECT schema_name FROM information_schema.schemata where schema_name = '%[1]s'", schema)
},
DropSchema: func(schema string) string { return fmt.Sprintf("DROP SCHEMA %[1]s CASCADE", schema) },
CreateTestTable: func(table string) string {
return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %[1]s (c1 INT, c2 VARCHAR(255))", table)
},
ListTables: func(schema string) []lo.Tuple2[string, string] {
return []lo.Tuple2[string, string]{
{A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = %[1]s", schema), B: "table_name"},
{A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%[1]s'", schema), B: "table_name"},
}
},
ListTablesWithPrefix: func(schema, prefix string) []lo.Tuple2[string, string] {
Expand All @@ -43,16 +43,16 @@ func NewDB(db *sql.DB, rudderSchema string, opts ...Option) *DB {
}
},
TableExists: func(schema, table string) string {
return fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' and table_name = '%[1]s'", schema, table)
return fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' and table_name = '%[2]s'", schema, table)
},
ListColumns: func(schema, table string) (string, string, string) {
return fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", schema, table), "column_name", "data_type"
},
CountTableRows: func(table string) string { return fmt.Sprintf("SELECT COUNT(*) FROM %[1]s", table) },
DropTable: func(table string) string { return fmt.Sprintf("DROP TABLE IF EXISTS %[1]s", table) },
TruncateTable: func(table string) string { return fmt.Sprintf("TRUNCATE TABLE %[1]s", table) },
RenameTable: func(oldName, newName string) string {
return fmt.Sprintf("ALTER TABLE %[1]s RENAME TO %[2]s", oldName, newName)
RenameTable: func(schema, oldName, newName string) string {
return fmt.Sprintf("ALTER TABLE %[1]s.%[2]s RENAME TO %[3]s", schema, oldName, newName)
},
},
}
Expand All @@ -67,11 +67,28 @@ type DB struct {
sqlconnect.Dialect

rudderSchema string
columnTypeMapper func(string) string // map from database type to rudder type
columnTypeMapper func(ColumnType) string // map from database type to rudder type
jsonRowMapper func(databaseTypeName string, value any) any
sqlCommands SQLCommands
}

type ColumnType interface {
DatabaseTypeName() string
DecimalSize() (precision, scale int64, ok bool)
}

type colRefTypeAdapter struct {
sqlconnect.ColumnRef
}

func (c colRefTypeAdapter) DatabaseTypeName() string {
return c.Type
}

func (c colRefTypeAdapter) DecimalSize() (precision, scale int64, ok bool) {
return 0, 0, false
}

// SqlDB returns the underlying *sql.DB
func (db *DB) SqlDB() *sql.DB {
return db.DB
Expand Down Expand Up @@ -103,5 +120,5 @@ type SQLCommands struct {
// Provides the SQL command to truncate a table
TruncateTable func(table string) string
// Provides the SQL command to rename a table
RenameTable func(oldName, newName string) string
RenameTable func(schema, oldName, newName string) string
}
17 changes: 12 additions & 5 deletions sqlconnect/internal/base/dbopts.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
package base

import "github.com/rudderlabs/sqlconnect-go/sqlconnect"
import (
"strings"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
)

type Option func(*DB)

// WithColumnTypeMappings sets the column type mappings for the client
func WithColumnTypeMappings(columnTypeMappings map[string]string) Option {
return func(db *DB) {
db.columnTypeMapper = func(dbType string) string {
if mappedType, ok := columnTypeMappings[dbType]; ok {
db.columnTypeMapper = func(c ColumnType) string {
if mappedType, ok := columnTypeMappings[strings.ToLower(c.DatabaseTypeName())]; ok {
return mappedType
}
if mappedType, ok := columnTypeMappings[strings.ToUpper(c.DatabaseTypeName())]; ok {
return mappedType
}
return dbType
return c.DatabaseTypeName()
}
}
}

// WithColumnTypeMapper sets the column type mapper for the client
func WithColumnTypeMapper(columnTypeMapper func(string) string) Option {
func WithColumnTypeMapper(columnTypeMapper func(ColumnType) string) Option {
return func(db *DB) {
db.columnTypeMapper = columnTypeMapper
}
Expand Down
30 changes: 30 additions & 0 deletions sqlconnect/internal/base/dialect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package base

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
)

func TestDialect(t *testing.T) {
var d dialect
t.Run("format table", func(t *testing.T) {
formatted := d.FormatTableName("TaBle")
require.Equal(t, "table", formatted, "table name should be lowercased")
})

t.Run("quote identifier", func(t *testing.T) {
quoted := d.QuoteIdentifier("column")
require.Equal(t, `"column"`, quoted, "column name should be quoted with double quotes")
})

t.Run("quote table", func(t *testing.T) {
quoted := d.QuoteTable(sqlconnect.NewRelationRef("table"))
require.Equal(t, `"table"`, quoted, "table name should be quoted with double quotes")

quoted = d.QuoteTable(sqlconnect.NewRelationRef("table", sqlconnect.WithSchema("schema")))
require.Equal(t, `"schema"."table"`, quoted, "schema and table name should be quoted with double quotes")
})
}
6 changes: 5 additions & 1 deletion sqlconnect/internal/base/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ func (db *DB) JSONRowMapper() sqlconnect.RowMapper[json.RawMessage] {
o := map[string]any{}
for i := range values {
v := values[i].(*NilAny)
var val any
if v != nil {
val = v.Value
}
col := cols[i]
o[col.Name()] = db.jsonRowMapper(col.DatabaseTypeName(), v)
o[col.Name()] = db.jsonRowMapper(col.DatabaseTypeName(), val)
}
b, err := json.Marshal(o)
if err != nil {
Expand Down
16 changes: 12 additions & 4 deletions sqlconnect/internal/base/schemaadmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package base
import (
"context"
"fmt"
"strings"

"github.com/samber/lo"

Expand Down Expand Up @@ -36,20 +37,22 @@ func (db *DB) ListSchemas(ctx context.Context) ([]sqlconnect.SchemaRef, error) {
if err != nil {
return nil, fmt.Errorf("getting columns in list schemas: %w", err)
}
cols = lo.Map(cols, func(col string, _ int) string { return strings.ToLower(col) })
var schema sqlconnect.SchemaRef
scanValues := make([]any, len(cols))
if len(cols) == 1 {
scanValues[0] = &schema.Name
} else {
tableNameColIdx := lo.IndexOf(cols, colName)
tableNameColIdx := lo.IndexOf(cols, strings.ToLower(colName))
if tableNameColIdx == -1 {
return nil, fmt.Errorf("column %s not found in result set: %+v", colName, cols)
}
var otherCol NilAny
for i := 0; i < len(cols); i++ {
if i == tableNameColIdx {
scanValues[i] = &schema.Name
} else {
scanValues[i] = new(NilAny)
scanValues[i] = &otherCol
}
}
}
Expand All @@ -68,10 +71,15 @@ func (db *DB) ListSchemas(ctx context.Context) ([]sqlconnect.SchemaRef, error) {

// SchemaExists returns true if the schema exists
func (db *DB) SchemaExists(ctx context.Context, schemaRef sqlconnect.SchemaRef) (bool, error) {
var exists bool
if err := db.QueryRowContext(ctx, db.sqlCommands.SchemaExists(schemaRef.Name)).Scan(&exists); err != nil {
rows, err := db.QueryContext(ctx, db.sqlCommands.SchemaExists(schemaRef.Name))
if err != nil {
return false, fmt.Errorf("querying schema exists: %w", err)
}
defer func() { _ = rows.Close() }()
exists := rows.Next()
if err := rows.Err(); err != nil {
return false, fmt.Errorf("iterating schema exists: %w", err)
}
return exists, nil
}

Expand Down
Loading

0 comments on commit eccbbff

Please sign in to comment.