diff --git a/named_template_build.go b/named_template_build.go index 2e8851b..bf3dc39 100644 --- a/named_template_build.go +++ b/named_template_build.go @@ -8,7 +8,7 @@ import ( ) func (n *namedTemplate) buildArgs() error { - if err := n.replaceTokens(); err != nil { + if err := n.replaceTokens(true); err != nil { return err } var builder strings.Builder @@ -65,7 +65,7 @@ func (n *namedTemplate) buildArgs() error { var tokenRegexp = regexp.MustCompile(`\{\{([^}]*)}}`) -func (n *namedTemplate) replaceTokens() error { +func (n *namedTemplate) replaceTokens(first bool) error { errs := make([]string, 0) n.originalStatement = tokenRegexp.ReplaceAllStringFunc(n.originalStatement, func(s string) string { token := s[2 : len(s)-2] @@ -82,6 +82,9 @@ func (n *namedTemplate) replaceTokens() error { } else if len(errs) > 0 { return fmt.Errorf("unknown tokens: %s", strings.Join(errs, ", ")) } + if first && strings.Contains(n.originalStatement, "{{") && strings.Contains(n.originalStatement, "}}") { + return n.replaceTokens(false) + } return nil } diff --git a/named_template_test.go b/named_template_test.go index 598c2aa..85c9fce 100644 --- a/named_template_test.go +++ b/named_template_test.go @@ -406,7 +406,24 @@ func TestNamedTemplate(t *testing.T) { statement: `INSERT INTO {{tableName}} ({{cols}}) VALUES(:{{argA}},:{{argB}},:{{argC}})`, expectStatement: `INSERT INTO foo (col_a,col_b,col_c) VALUES($1,$2,$3)`, expectOriginal: `INSERT INTO foo (col_a,col_b,col_c) VALUES(:a,:b,:c)`, - options: []any{PostgresOption, &testTokenOption{}}, + options: []any{PostgresOption, testTokenOption}, + expectArgsCount: 3, + expectArgNamesCount: 3, + expectArgNames: []string{"a", "b", "c"}, + inArgs: []any{ + map[string]any{ + "a": "a value", + "b": "b value", + "c": "c value", + }, + }, + expectOutArgs: []any{"a value", "b value", "c value"}, + }, + { + statement: `INSERT INTO {{nested}} ({{cols}}) VALUES(:{{argA}},:{{argB}},:{{argC}})`, + expectStatement: `INSERT INTO foo (col_a,col_b,col_c) VALUES($1,$2,$3)`, + expectOriginal: `INSERT INTO foo (col_a,col_b,col_c) VALUES(:a,:b,:c)`, + options: []any{PostgresOption, testTokenOption}, expectArgsCount: 3, expectArgNamesCount: 3, expectArgNames: []string{"a", "b", "c"}, @@ -421,13 +438,13 @@ func TestNamedTemplate(t *testing.T) { }, { statement: `INSERT INTO {{unknownToken}} ({{cols}}) VALUES({{argA}},{{argB}},{{argC}})`, - options: []any{PostgresOption, &testTokenOption{}}, + options: []any{PostgresOption, testTokenOption}, expectError: true, expectErrorMessage: "unknown token: unknownToken", }, { statement: `INSERT INTO {{unknown token}} ({{another unknown}}) VALUES({{argA}},{{argB}},{{argC}})`, - options: []any{PostgresOption, &testTokenOption{}}, + options: []any{PostgresOption, testTokenOption}, expectError: true, expectErrorMessage: "unknown tokens: unknown token, another unknown", }, @@ -539,17 +556,13 @@ func TestNamedTemplate(t *testing.T) { } } -type testTokenOption struct{} - -func (t *testTokenOption) Replace(token string) (string, bool) { - ok, r := (map[string]string{ - "tableName": "foo", - "cols": "col_a,col_b,col_c", - "argA": "a", - "argB": "b", - "argC": "c", - })[token] - return ok, r +var testTokenOption = TokenOptionMap{ + "tableName": "foo", + "cols": "col_a,col_b,col_c", + "argA": "a", + "argB": "b", + "argC": "c", + "nested": "{{tableName}}", } type unmarshalable struct{} diff --git a/options.go b/options.go index fce1d96..6642717 100644 --- a/options.go +++ b/options.go @@ -28,6 +28,13 @@ type TokenOption interface { Replace(token string) (string, bool) } +type TokenOptionMap map[string]string + +func (m TokenOptionMap) Replace(token string) (string, bool) { + s, ok := m[token] + return s, ok +} + var ( MySqlOption Option = _MySqlOption // option to produce final args like ?, ?, ? (e.g. for https://github.com/go-sql-driver/mysql) PostgresOption Option = _PostgresOption // option to produce final args like $1, $2, $3 (e.g. for https://github.com/lib/pq or https://github.com/jackc/pgx) diff --git a/template_set.go b/template_set.go new file mode 100644 index 0000000..4655f57 --- /dev/null +++ b/template_set.go @@ -0,0 +1,88 @@ +package sqlnt + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +const sqlTag = "sql" + +// NewTemplateSet builds a set of templates for the given struct type T +// +// Fields of type sqlnt.NamedTemplate are created and set from the field tag 'sql' +// +// Example: +// +// type MyTemplateSet struct { +// Select sqlnt.NamedTemplate `sql:"SELECT * FROM foo WHERE col_a = :a"` +// Insert sqlnt.NamedTemplate `sql:"INSERT INTO foo (col_a, col_b, col_c) VALUES(:a, :b, :c)"` +// Delete sqlnt.NamedTemplate `sql:"DELETE FROM foo WHERE col_a = :a"` +// } +// set, err := sqlnt.NewTemplateSet[MyTemplateSet]() +// +// Note: If the overall field tag does not contain a 'sql' tag nor any other tags (i.e. there are no double-quotes in it) +// then the entire field tag value is used as the template - enabling the use of carriage returns to format the statement +// +// Example: +// +// type MyTemplateSet struct { +// Select sqlnt.NamedTemplate `SELECT * +// FROM foo +// WHERE col_a = :a` +// } +func NewTemplateSet[T any](options ...any) (*T, error) { + var chk T + if reflect.TypeOf(chk).Kind() != reflect.Struct { + return nil, errors.New("not a struct") + } + r := new(T) + if err := setTemplateFields(reflect.ValueOf(r).Elem(), options...); err != nil { + return nil, err + } + return r, nil +} + +// MustCreateTemplateSet is the same as NewTemplateSet except that it panics on error +func MustCreateTemplateSet[T any](options ...any) *T { + r, err := NewTemplateSet[T](options...) + if err != nil { + panic(err) + } + return r +} + +var ntt = reflect.TypeOf((*NamedTemplate)(nil)).Elem() + +func setTemplateFields(rv reflect.Value, options ...any) error { + rvt := rv.Type() + for i := 0; i < rv.NumField(); i++ { + fld := rv.Field(i) + if ft := rvt.Field(i); ft.IsExported() { + if fld.Type() == ntt { + tag, ok := ft.Tag.Lookup(sqlTag) + if !ok { + if ft.Tag != "" && !strings.ContainsRune(string(ft.Tag), '"') { + tag = string(ft.Tag) + } else { + return fmt.Errorf("field '%s' does not have '%s' tag", ft.Name, sqlTag) + } + } + if tmp, err := NewNamedTemplate(tag, options...); err == nil { + fld.Set(reflect.ValueOf(tmp)) + } else { + return err + } + } else if fld.Kind() == reflect.Struct { + sub := reflect.New(fld.Type()).Elem() + if err := setTemplateFields(sub, options...); err == nil { + fld.Set(sub) + } else { + return err + } + } + } + } + return nil +} diff --git a/template_set_test.go b/template_set_test.go new file mode 100644 index 0000000..938382d --- /dev/null +++ b/template_set_test.go @@ -0,0 +1,105 @@ +package sqlnt + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +type MySet struct { + Select NamedTemplate `SELECT * +FROM {{tableName}} +WHERE col_a = :a` + Insert NamedTemplate `sql:"INSERT INTO {{tableName}} (col_a,col_b,col_c) VALUES (:a,:b,:c)"` + Delete NamedTemplate `sql:"DELETE FROM {{tableName}} WHERE col_a = :a"` + delete NamedTemplate // unexported fields not used +} + +type MySet2 struct { + MySet +} + +type MySet3 struct { + MySet2 +} + +func TestNewTemplateSet(t *testing.T) { + ts, err := NewTemplateSet[MySet](testTokenOption) + assert.NoError(t, err) + assert.NotNil(t, ts.Select) + assert.Equal(t, "SELECT *\nFROM foo\nWHERE col_a = ?", ts.Select.Statement()) + assert.NotNil(t, ts.Insert) + assert.Equal(t, "INSERT INTO foo (col_a,col_b,col_c) VALUES (?,?,?)", ts.Insert.Statement()) + assert.NotNil(t, ts.Delete) + assert.Equal(t, "DELETE FROM foo WHERE col_a = ?", ts.Delete.Statement()) + assert.Nil(t, ts.delete) +} + +func TestNewTemplateSet_Anonymous(t *testing.T) { + ts, err := NewTemplateSet[struct { + Select NamedTemplate `SELECT * +FROM {{tableName}} +WHERE col_a = :a` + }](testTokenOption) + assert.NoError(t, err) + assert.NotNil(t, ts.Select) + assert.Equal(t, "SELECT *\nFROM foo\nWHERE col_a = ?", ts.Select.Statement()) +} + +func TestNewTemplateSet_Nested(t *testing.T) { + ts, err := NewTemplateSet[MySet2](testTokenOption) + assert.NoError(t, err) + assert.NotNil(t, ts.Select) +} + +func TestNewTemplateSet_NestedDouble(t *testing.T) { + ts, err := NewTemplateSet[MySet3](testTokenOption) + assert.NoError(t, err) + assert.NotNil(t, ts.Select) +} + +func TestNewTemplateSet_Error_NotStruct(t *testing.T) { + _, err := NewTemplateSet[string](testTokenOption) + assert.Error(t, err) + assert.Equal(t, "not a struct", err.Error()) +} + +type BadSet1 struct { + Select NamedTemplate // doesn't have 'sql' tag +} + +func TestNewTemplateSet_Error_NoSqlTag(t *testing.T) { + _, err := NewTemplateSet[BadSet1](testTokenOption) + assert.Error(t, err) + assert.Equal(t, "field 'Select' does not have 'sql' tag", err.Error()) +} + +type BadSet2 struct { + Select NamedTemplate `sql:"{{unknown_token}}"` +} + +func TestNewTemplateSet_Error_BadTemplate(t *testing.T) { + _, err := NewTemplateSet[BadSet2](testTokenOption) + assert.Error(t, err) + assert.Equal(t, "unknown token: unknown_token", err.Error()) +} + +type BadSet3 struct { + BadSet2 +} + +func TestNewTemplateSet_Error_NestedBadTemplate(t *testing.T) { + _, err := NewTemplateSet[BadSet3](testTokenOption) + assert.Error(t, err) + assert.Equal(t, "unknown token: unknown_token", err.Error()) +} + +func TestMustCreateTemplateSet(t *testing.T) { + ts := MustCreateTemplateSet[MySet](testTokenOption) + assert.NotNil(t, ts.Select) +} + +func TestMustCreateTemplateSet_Panics(t *testing.T) { + assert.Panics(t, func() { + _ = MustCreateTemplateSet[BadSet1]() + }) +}