Skip to content

Commit

Permalink
Add template sets
Browse files Browse the repository at this point in the history
  • Loading branch information
marrow16 committed Aug 19, 2023
1 parent ac34603 commit 545879e
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 16 deletions.
7 changes: 5 additions & 2 deletions named_template_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

func (n *namedTemplate) buildArgs() error {
if err := n.replaceTokens(); err != nil {
if err := n.replaceTokens(true); err != nil {
return err
}
var builder strings.Builder
Expand Down Expand Up @@ -65,7 +65,7 @@ func (n *namedTemplate) buildArgs() error {

var tokenRegexp = regexp.MustCompile(`\{\{([^}]*)}}`)

func (n *namedTemplate) replaceTokens() error {
func (n *namedTemplate) replaceTokens(first bool) error {
errs := make([]string, 0)
n.originalStatement = tokenRegexp.ReplaceAllStringFunc(n.originalStatement, func(s string) string {
token := s[2 : len(s)-2]
Expand All @@ -82,6 +82,9 @@ func (n *namedTemplate) replaceTokens() error {
} else if len(errs) > 0 {
return fmt.Errorf("unknown tokens: %s", strings.Join(errs, ", "))
}
if first && strings.Contains(n.originalStatement, "{{") && strings.Contains(n.originalStatement, "}}") {
return n.replaceTokens(false)
}
return nil
}

Expand Down
41 changes: 27 additions & 14 deletions named_template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,24 @@ func TestNamedTemplate(t *testing.T) {
statement: `INSERT INTO {{tableName}} ({{cols}}) VALUES(:{{argA}},:{{argB}},:{{argC}})`,
expectStatement: `INSERT INTO foo (col_a,col_b,col_c) VALUES($1,$2,$3)`,
expectOriginal: `INSERT INTO foo (col_a,col_b,col_c) VALUES(:a,:b,:c)`,
options: []any{PostgresOption, &testTokenOption{}},
options: []any{PostgresOption, testTokenOption},
expectArgsCount: 3,
expectArgNamesCount: 3,
expectArgNames: []string{"a", "b", "c"},
inArgs: []any{
map[string]any{
"a": "a value",
"b": "b value",
"c": "c value",
},
},
expectOutArgs: []any{"a value", "b value", "c value"},
},
{
statement: `INSERT INTO {{nested}} ({{cols}}) VALUES(:{{argA}},:{{argB}},:{{argC}})`,
expectStatement: `INSERT INTO foo (col_a,col_b,col_c) VALUES($1,$2,$3)`,
expectOriginal: `INSERT INTO foo (col_a,col_b,col_c) VALUES(:a,:b,:c)`,
options: []any{PostgresOption, testTokenOption},
expectArgsCount: 3,
expectArgNamesCount: 3,
expectArgNames: []string{"a", "b", "c"},
Expand All @@ -421,13 +438,13 @@ func TestNamedTemplate(t *testing.T) {
},
{
statement: `INSERT INTO {{unknownToken}} ({{cols}}) VALUES({{argA}},{{argB}},{{argC}})`,
options: []any{PostgresOption, &testTokenOption{}},
options: []any{PostgresOption, testTokenOption},
expectError: true,
expectErrorMessage: "unknown token: unknownToken",
},
{
statement: `INSERT INTO {{unknown token}} ({{another unknown}}) VALUES({{argA}},{{argB}},{{argC}})`,
options: []any{PostgresOption, &testTokenOption{}},
options: []any{PostgresOption, testTokenOption},
expectError: true,
expectErrorMessage: "unknown tokens: unknown token, another unknown",
},
Expand Down Expand Up @@ -539,17 +556,13 @@ func TestNamedTemplate(t *testing.T) {
}
}

type testTokenOption struct{}

func (t *testTokenOption) Replace(token string) (string, bool) {
ok, r := (map[string]string{
"tableName": "foo",
"cols": "col_a,col_b,col_c",
"argA": "a",
"argB": "b",
"argC": "c",
})[token]
return ok, r
var testTokenOption = TokenOptionMap{
"tableName": "foo",
"cols": "col_a,col_b,col_c",
"argA": "a",
"argB": "b",
"argC": "c",
"nested": "{{tableName}}",
}

type unmarshalable struct{}
Expand Down
7 changes: 7 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ type TokenOption interface {
Replace(token string) (string, bool)
}

type TokenOptionMap map[string]string

func (m TokenOptionMap) Replace(token string) (string, bool) {
s, ok := m[token]
return s, ok
}

var (
MySqlOption Option = _MySqlOption // option to produce final args like ?, ?, ? (e.g. for https://github.com/go-sql-driver/mysql)
PostgresOption Option = _PostgresOption // option to produce final args like $1, $2, $3 (e.g. for https://github.com/lib/pq or https://github.com/jackc/pgx)
Expand Down
88 changes: 88 additions & 0 deletions template_set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package sqlnt

import (
"errors"
"fmt"
"reflect"
"strings"
)

const sqlTag = "sql"

// NewTemplateSet builds a set of templates for the given struct type T
//
// Fields of type sqlnt.NamedTemplate are created and set from the field tag 'sql'
//
// Example:
//
// type MyTemplateSet struct {
// Select sqlnt.NamedTemplate `sql:"SELECT * FROM foo WHERE col_a = :a"`
// Insert sqlnt.NamedTemplate `sql:"INSERT INTO foo (col_a, col_b, col_c) VALUES(:a, :b, :c)"`
// Delete sqlnt.NamedTemplate `sql:"DELETE FROM foo WHERE col_a = :a"`
// }
// set, err := sqlnt.NewTemplateSet[MyTemplateSet]()
//
// Note: If the overall field tag does not contain a 'sql' tag nor any other tags (i.e. there are no double-quotes in it)
// then the entire field tag value is used as the template - enabling the use of carriage returns to format the statement
//
// Example:
//
// type MyTemplateSet struct {
// Select sqlnt.NamedTemplate `SELECT *
// FROM foo
// WHERE col_a = :a`
// }
func NewTemplateSet[T any](options ...any) (*T, error) {
var chk T
if reflect.TypeOf(chk).Kind() != reflect.Struct {
return nil, errors.New("not a struct")
}
r := new(T)
if err := setTemplateFields(reflect.ValueOf(r).Elem(), options...); err != nil {
return nil, err
}
return r, nil
}

// MustCreateTemplateSet is the same as NewTemplateSet except that it panics on error
func MustCreateTemplateSet[T any](options ...any) *T {
r, err := NewTemplateSet[T](options...)
if err != nil {
panic(err)
}
return r
}

var ntt = reflect.TypeOf((*NamedTemplate)(nil)).Elem()

func setTemplateFields(rv reflect.Value, options ...any) error {
rvt := rv.Type()
for i := 0; i < rv.NumField(); i++ {
fld := rv.Field(i)
if ft := rvt.Field(i); ft.IsExported() {
if fld.Type() == ntt {
tag, ok := ft.Tag.Lookup(sqlTag)
if !ok {
if ft.Tag != "" && !strings.ContainsRune(string(ft.Tag), '"') {
tag = string(ft.Tag)
} else {
return fmt.Errorf("field '%s' does not have '%s' tag", ft.Name, sqlTag)
}
}
if tmp, err := NewNamedTemplate(tag, options...); err == nil {
fld.Set(reflect.ValueOf(tmp))
} else {
return err
}
} else if fld.Kind() == reflect.Struct {
sub := reflect.New(fld.Type()).Elem()
if err := setTemplateFields(sub, options...); err == nil {
fld.Set(sub)
} else {
return err
}
}
}
}
return nil
}
105 changes: 105 additions & 0 deletions template_set_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package sqlnt

import (
"github.com/stretchr/testify/assert"
"testing"
)

type MySet struct {
Select NamedTemplate `SELECT *
FROM {{tableName}}
WHERE col_a = :a`
Insert NamedTemplate `sql:"INSERT INTO {{tableName}} (col_a,col_b,col_c) VALUES (:a,:b,:c)"`
Delete NamedTemplate `sql:"DELETE FROM {{tableName}} WHERE col_a = :a"`
delete NamedTemplate // unexported fields not used
}

type MySet2 struct {
MySet
}

type MySet3 struct {
MySet2
}

func TestNewTemplateSet(t *testing.T) {
ts, err := NewTemplateSet[MySet](testTokenOption)
assert.NoError(t, err)
assert.NotNil(t, ts.Select)
assert.Equal(t, "SELECT *\nFROM foo\nWHERE col_a = ?", ts.Select.Statement())
assert.NotNil(t, ts.Insert)
assert.Equal(t, "INSERT INTO foo (col_a,col_b,col_c) VALUES (?,?,?)", ts.Insert.Statement())
assert.NotNil(t, ts.Delete)
assert.Equal(t, "DELETE FROM foo WHERE col_a = ?", ts.Delete.Statement())
assert.Nil(t, ts.delete)
}

func TestNewTemplateSet_Anonymous(t *testing.T) {
ts, err := NewTemplateSet[struct {
Select NamedTemplate `SELECT *
FROM {{tableName}}
WHERE col_a = :a`
}](testTokenOption)
assert.NoError(t, err)
assert.NotNil(t, ts.Select)
assert.Equal(t, "SELECT *\nFROM foo\nWHERE col_a = ?", ts.Select.Statement())
}

func TestNewTemplateSet_Nested(t *testing.T) {
ts, err := NewTemplateSet[MySet2](testTokenOption)
assert.NoError(t, err)
assert.NotNil(t, ts.Select)
}

func TestNewTemplateSet_NestedDouble(t *testing.T) {
ts, err := NewTemplateSet[MySet3](testTokenOption)
assert.NoError(t, err)
assert.NotNil(t, ts.Select)
}

func TestNewTemplateSet_Error_NotStruct(t *testing.T) {
_, err := NewTemplateSet[string](testTokenOption)
assert.Error(t, err)
assert.Equal(t, "not a struct", err.Error())
}

type BadSet1 struct {
Select NamedTemplate // doesn't have 'sql' tag
}

func TestNewTemplateSet_Error_NoSqlTag(t *testing.T) {
_, err := NewTemplateSet[BadSet1](testTokenOption)
assert.Error(t, err)
assert.Equal(t, "field 'Select' does not have 'sql' tag", err.Error())
}

type BadSet2 struct {
Select NamedTemplate `sql:"{{unknown_token}}"`
}

func TestNewTemplateSet_Error_BadTemplate(t *testing.T) {
_, err := NewTemplateSet[BadSet2](testTokenOption)
assert.Error(t, err)
assert.Equal(t, "unknown token: unknown_token", err.Error())
}

type BadSet3 struct {
BadSet2
}

func TestNewTemplateSet_Error_NestedBadTemplate(t *testing.T) {
_, err := NewTemplateSet[BadSet3](testTokenOption)
assert.Error(t, err)
assert.Equal(t, "unknown token: unknown_token", err.Error())
}

func TestMustCreateTemplateSet(t *testing.T) {
ts := MustCreateTemplateSet[MySet](testTokenOption)
assert.NotNil(t, ts.Select)
}

func TestMustCreateTemplateSet_Panics(t *testing.T) {
assert.Panics(t, func() {
_ = MustCreateTemplateSet[BadSet1]()
})
}

0 comments on commit 545879e

Please sign in to comment.