Skip to content

Commit

Permalink
reimplement the condition IN
Browse files Browse the repository at this point in the history
  • Loading branch information
xgfone committed May 21, 2024
1 parent 65bad94 commit b12434c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 47 deletions.
65 changes: 50 additions & 15 deletions op_condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package sqlx

import (
"fmt"
"reflect"
"strings"

"github.com/xgfone/go-op"
Expand Down Expand Up @@ -99,32 +100,66 @@ func newCondLike(format string) OpBuilder {

func newCondIn(format string) OpBuilder {
return OpBuilderFunc(func(ab *ArgsBuilder, op op.Op) string {
vs := op.Val.([]interface{})

switch len(vs) {
case 0:
switch vs := op.Val.(type) {
case nil:
return "1=0"

case 1:
if _vs, ok := vs[0].([]interface{}); ok {
vs = _vs
case []interface{}:
return fmtcondin(format, ab, op, vs)

case []int:
return fmtcondin(format, ab, op, vs)

case []uint:
return fmtcondin(format, ab, op, vs)

case []int64:
return fmtcondin(format, ab, op, vs)

case []uint64:
return fmtcondin(format, ab, op, vs)

case []string:
return fmtcondin(format, ab, op, vs)

default:
vf := reflect.ValueOf(op.Val)
switch vf.Kind() {
case reflect.Array, reflect.Slice:
default:
panic(fmt.Errorf("sqlx: condition IN not support type %T", op.Val))
}

switch len(vs) {
case 0:
_len := vf.Len()
if _len == 0 {
return "1=0"
}

case 1:
return fmt.Sprintf(format, ab.Quote(getOpKey(op)), ab.Add(vs[0]))
ss := make([]string, _len)
for i := 0; i < _len; i++ {
ss[i] = ab.Add(vf.Index(i).Interface())
}

return fmt.Sprintf(format, ab.Quote(getOpKey(op)), strings.Join(ss, ", "))
}
})
}

ss := make([]string, 0, len(vs))
for _, v := range vs {
ss = append(ss, ab.Add(v))
func fmtcondin[T any](format string, ab *ArgsBuilder, op op.Op, vs []T) string {
switch _len := len(vs); _len {
case 0:
return "1=0"

case 1:
return fmt.Sprintf(format, ab.Quote(getOpKey(op)), ab.Add(vs[0]))

default:
ss := make([]string, _len)
for i := 0; i < _len; i++ {
ss[i] = ab.Add(vs[i])
}
return fmt.Sprintf(format, ab.Quote(getOpKey(op)), strings.Join(ss, ", "))
})
}
}

func newCondBetween(format string) OpBuilder {
Expand Down
32 changes: 0 additions & 32 deletions op_condition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,35 +101,3 @@ func TestCondInForOne(t *testing.T) {
t.Errorf("expect args %v, but got %v", expectargs, args)
}
}

func TestCondInForOneSliceOne(t *testing.T) {
ab := NewArgsBuilder(MySQL)
sql := BuildOper(ab, op.In("field", []any{"value"}))
args := ab.Args()

expectsql := "`field` IN (?)"
expectargs := []any{"value"}

if sql != expectsql {
t.Errorf("expect sql '%s', but got '%s'", expectsql, sql)
}
if !reflect.DeepEqual(args, expectargs) {
t.Errorf("expect args %v, but got %v", expectargs, args)
}
}

func TestCondInForOneSliceTwo(t *testing.T) {
ab := NewArgsBuilder(MySQL)
sql := BuildOper(ab, op.In("field", []any{"value1", "value2"}))
args := ab.Args()

expectsql := "`field` IN (?, ?)"
expectargs := []any{"value1", "value2"}

if sql != expectsql {
t.Errorf("expect sql '%s', but got '%s'", expectsql, sql)
}
if !reflect.DeepEqual(args, expectargs) {
t.Errorf("expect args %v, but got %v", expectargs, args)
}
}

0 comments on commit b12434c

Please sign in to comment.