From b12434c8249b7c6700654dc72e76c9699d2c17c0 Mon Sep 17 00:00:00 2001 From: xgfone Date: Tue, 21 May 2024 18:44:47 +0800 Subject: [PATCH] reimplement the condition IN --- op_condition.go | 65 ++++++++++++++++++++++++++++++++++---------- op_condition_test.go | 32 ---------------------- 2 files changed, 50 insertions(+), 47 deletions(-) diff --git a/op_condition.go b/op_condition.go index b3000f5..abf607b 100644 --- a/op_condition.go +++ b/op_condition.go @@ -16,6 +16,7 @@ package sqlx import ( "fmt" + "reflect" "strings" "github.com/xgfone/go-op" @@ -99,32 +100,66 @@ func newCondLike(format string) OpBuilder { func newCondIn(format string) OpBuilder { return OpBuilderFunc(func(ab *ArgsBuilder, op op.Op) string { - vs := op.Val.([]interface{}) - - switch len(vs) { - case 0: + switch vs := op.Val.(type) { + case nil: return "1=0" - case 1: - if _vs, ok := vs[0].([]interface{}); ok { - vs = _vs + case []interface{}: + return fmtcondin(format, ab, op, vs) + + case []int: + return fmtcondin(format, ab, op, vs) + + case []uint: + return fmtcondin(format, ab, op, vs) + + case []int64: + return fmtcondin(format, ab, op, vs) + + case []uint64: + return fmtcondin(format, ab, op, vs) + + case []string: + return fmtcondin(format, ab, op, vs) + + default: + vf := reflect.ValueOf(op.Val) + switch vf.Kind() { + case reflect.Array, reflect.Slice: + default: + panic(fmt.Errorf("sqlx: condition IN not support type %T", op.Val)) } - switch len(vs) { - case 0: + _len := vf.Len() + if _len == 0 { return "1=0" + } - case 1: - return fmt.Sprintf(format, ab.Quote(getOpKey(op)), ab.Add(vs[0])) + ss := make([]string, _len) + for i := 0; i < _len; i++ { + ss[i] = ab.Add(vf.Index(i).Interface()) } + + return fmt.Sprintf(format, ab.Quote(getOpKey(op)), strings.Join(ss, ", ")) } + }) +} - ss := make([]string, 0, len(vs)) - for _, v := range vs { - ss = append(ss, ab.Add(v)) +func fmtcondin[T any](format string, ab *ArgsBuilder, op op.Op, vs []T) string { + switch _len := len(vs); _len { + case 0: + return "1=0" + + case 1: + return fmt.Sprintf(format, ab.Quote(getOpKey(op)), ab.Add(vs[0])) + + default: + ss := make([]string, _len) + for i := 0; i < _len; i++ { + ss[i] = ab.Add(vs[i]) } return fmt.Sprintf(format, ab.Quote(getOpKey(op)), strings.Join(ss, ", ")) - }) + } } func newCondBetween(format string) OpBuilder { diff --git a/op_condition_test.go b/op_condition_test.go index 18b1736..8797190 100644 --- a/op_condition_test.go +++ b/op_condition_test.go @@ -101,35 +101,3 @@ func TestCondInForOne(t *testing.T) { t.Errorf("expect args %v, but got %v", expectargs, args) } } - -func TestCondInForOneSliceOne(t *testing.T) { - ab := NewArgsBuilder(MySQL) - sql := BuildOper(ab, op.In("field", []any{"value"})) - args := ab.Args() - - expectsql := "`field` IN (?)" - expectargs := []any{"value"} - - if sql != expectsql { - t.Errorf("expect sql '%s', but got '%s'", expectsql, sql) - } - if !reflect.DeepEqual(args, expectargs) { - t.Errorf("expect args %v, but got %v", expectargs, args) - } -} - -func TestCondInForOneSliceTwo(t *testing.T) { - ab := NewArgsBuilder(MySQL) - sql := BuildOper(ab, op.In("field", []any{"value1", "value2"})) - args := ab.Args() - - expectsql := "`field` IN (?, ?)" - expectargs := []any{"value1", "value2"} - - if sql != expectsql { - t.Errorf("expect sql '%s', but got '%s'", expectsql, sql) - } - if !reflect.DeepEqual(args, expectargs) { - t.Errorf("expect args %v, but got %v", expectargs, args) - } -}