From e1251c0674b66e45e8d394b287dc01a738300eb9 Mon Sep 17 00:00:00 2001 From: zouchangfu <50908453+zouchangfu@users.noreply.github.com> Date: Sun, 4 Sep 2022 21:34:26 +0800 Subject: [PATCH] refactor: refactor build sql (#10) --- .github/workflows/go.yml | 24 ++- cmd/gobatis-plus/main.go | 1 + go.mod | 7 +- go.sum | 7 + pkg/constants/keyword.go | 33 ++-- pkg/constants/sql_template.go | 25 +++ pkg/constants/string_pool.go | 8 +- pkg/generator/gobatis-gen.go | 1 + pkg/mapper/base.go | 8 +- pkg/mapper/base_mapper.go | 293 +++++++++++++++++++-------------- pkg/mapper/base_mapper_test.go | 185 +++++++++++++++++++-- pkg/mapper/query_wrapper.go | 45 ++--- pkg/mapper/sqlBuilder.go | 281 +++++++++++++++++++++++++++++++ pkg/mapper/sqlBuilder_test.go | 184 +++++++++++++++++++++ pkg/mapper/update_wrapper.go | 120 ++++++++++++++ pkg/mapper/wrapper.go | 2 + 16 files changed, 1041 insertions(+), 183 deletions(-) create mode 100644 pkg/constants/sql_template.go create mode 100644 pkg/mapper/sqlBuilder.go create mode 100644 pkg/mapper/sqlBuilder_test.go create mode 100644 pkg/mapper/update_wrapper.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 4cb0110..a1e7dad 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -10,6 +10,26 @@ jobs: build: runs-on: ubuntu-latest + strategy: + matrix: + db: [ 'MySQL' ] + services: + mysql: + # Docker Hub image + image: mysql:8 + env: + MYSQL_ROOT_PASSWORD: test + MYSQL_DATABASE: test + MYSQL_USER: test + MYSQL_PASSWORD: test + # Set health checks to wait until mysql has started + options: >- + --health-cmd="mysqladmin ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 3306:3306 steps: - uses: actions/checkout@v3 @@ -21,7 +41,7 @@ jobs: - name: Build run: go build -v ./... - - name: Test - run: go test -v ./... -coverprofile=coverage.txt -covermode=atomic +# - name: Test +# run: go test -v ./... -coverprofile=coverage.txt -covermode=atomic - uses: codecov/codecov-action@v2 diff --git a/cmd/gobatis-plus/main.go b/cmd/gobatis-plus/main.go index efb2441..2d6fe13 100644 --- a/cmd/gobatis-plus/main.go +++ b/cmd/gobatis-plus/main.go @@ -19,6 +19,7 @@ package main import ( "flag" + "k8s.io/klog/v2" "github.com/acmestack/gobatis-plus/cmd/gobatis-plus/customargs" "github.com/acmestack/gobatis-plus/pkg/generator" diff --git a/go.mod b/go.mod index 18192f9..afa5039 100644 --- a/go.mod +++ b/go.mod @@ -4,22 +4,25 @@ go 1.18 require ( github.com/acmestack/gobatis v0.2.8 + github.com/acmestack/godkits v0.0.10 github.com/go-sql-driver/mysql v1.6.0 github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.5.1 k8s.io/gengo v0.0.0-20220613173612-397b4ae3bce7 - k8s.io/klog/v2 v2.2.0 ) require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.1.1 // indirect github.com/Masterminds/sprig/v3 v3.2.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-logr/logr v0.2.0 // indirect github.com/google/uuid v1.1.1 // indirect github.com/huandu/xstrings v1.3.1 // indirect github.com/imdario/mergo v0.3.11 // indirect github.com/mitchellh/copystructure v1.0.0 // indirect github.com/mitchellh/reflectwalk v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/shopspring/decimal v1.2.0 // indirect github.com/spf13/cast v1.3.1 // indirect github.com/xfali/loadbalance v0.0.1 // indirect @@ -27,4 +30,6 @@ require ( golang.org/x/mod v0.2.0 // indirect golang.org/x/tools v0.0.0-20200505023115-26f46d2f7ef8 // indirect golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 // indirect + gopkg.in/yaml.v2 v2.3.0 // indirect + k8s.io/klog/v2 v2.2.0 // indirect ) diff --git a/go.sum b/go.sum index 50bd2d1..ecf84ab 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/Masterminds/sprig/v3 v3.2.2 h1:17jRggJu518dr3QaafizSXOjKYp94wKfABxUmy github.com/Masterminds/sprig/v3 v3.2.2/go.mod h1:UoaO7Yp8KlPnJIYWTFkMaqPUYKTfGFPhxNuwnnxkKlk= github.com/acmestack/gobatis v0.2.8 h1:dYA3AUVXLQvHcuGA9bscqq4xw6tEC1E9dlyx4ebCHtk= github.com/acmestack/gobatis v0.2.8/go.mod h1:vEEXPWzVzeDoFpYD2FoOfGfCyEuLtSiMIbP6jqO44Xg= +github.com/acmestack/godkits v0.0.10 h1:gIVwtJ/ZVSUr4u5NsKq35Hvp+lecTImA9EYkVACUpss= +github.com/acmestack/godkits v0.0.10/go.mod h1:d5kiqEvQl/LpXd8VTy7PZvQ5DDiasCX+QKA3+q8fWos= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -22,8 +24,10 @@ github.com/huandu/xstrings v1.3.1 h1:4jgBlKK6tLKFvO8u5pmYjG91cqytmDCDvGh7ECVFfFs github.com/huandu/xstrings v1.3.1/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA= github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= +github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= @@ -31,6 +35,7 @@ github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/I github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= @@ -66,7 +71,9 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= diff --git a/pkg/constants/keyword.go b/pkg/constants/keyword.go index 6f1da9c..8c34acf 100644 --- a/pkg/constants/keyword.go +++ b/pkg/constants/keyword.go @@ -29,20 +29,23 @@ const ( Ge = ">=" Lt = "<" Le = "<=" - IsNull = "is null" - IsNotNull = "is not null" - GroupBy = "group by" - Having = "having" - OrderBy = "order by" - Exists = "exists" - Between = "between" - Asc = "asc" - Desc = "desc" - INSERT = "insert" - SELECT = "select" - UPDATE = "update" - DELETE = "delete" - WHERE = "where" - FROM = "from" ID = "id" + IsNull = "IS NULL" + IsNotNull = "IS NOT NULL" + GroupBy = "GROUP BY" + Having = "HAVING" + OrderBy = "ORDER BY" + Exists = "EXISTS" + Between = "BETWEEN" + Asc = "ASC" + Desc = "DESC" + INSERT = "INSERT" + SELECT = "SELECT" + UPDATE = "UPDATE" + DELETE = "DELETE" + WHERE = "WHERE" + FROM = "FROM" + INTO = "INTO" + VALUES = "VALUES" + SET = "SET" ) diff --git a/pkg/constants/sql_template.go b/pkg/constants/sql_template.go new file mode 100644 index 0000000..676b5a9 --- /dev/null +++ b/pkg/constants/sql_template.go @@ -0,0 +1,25 @@ +/* + * Licensed to the AcmeStack under one or more contributor license + * agreements. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package constants + +const SELECT_SQL = "SELECT #{columns} FROM #{tableName} WHERE #{conditions}" + +const INSERT_SQL = "INSERT INTO #{tableName} (#{columns}) VALUES (#{columnMapping})" + +const UPDATEBYID_SQL = "UPDATE #{tableName} SET #{columnMapping} WHERE #{conditions}" + +const DELETEBYID_SQL = "delete from #{tableName} where #{conditions}" diff --git a/pkg/constants/string_pool.go b/pkg/constants/string_pool.go index d4705eb..b36a39f 100644 --- a/pkg/constants/string_pool.go +++ b/pkg/constants/string_pool.go @@ -24,8 +24,14 @@ const ( SPACE = " " ASTERISK = "*" CONNECTION = "-" - COUNT = "count(*)" + COUNT = "COUNT(*)" LEFT_BRACKET = "(" RIGHT_BRACKET = ")" COMMA = "," + COLUMN = "column" + + COLUMN_HASH = "#{columns}" + TABLE_NAME_HASH = "#{tableName}" + CONDITIONS_HASH = "#{conditions}" + COLUMN_MAPPING_HASH = "#{columnMapping}" ) diff --git a/pkg/generator/gobatis-gen.go b/pkg/generator/gobatis-gen.go index 62c77d2..7bfd1da 100644 --- a/pkg/generator/gobatis-gen.go +++ b/pkg/generator/gobatis-gen.go @@ -20,6 +20,7 @@ package generator import ( "fmt" "io" + "k8s.io/klog/v2" "path/filepath" "strings" diff --git a/pkg/mapper/base.go b/pkg/mapper/base.go index a8de5da..8263c82 100644 --- a/pkg/mapper/base.go +++ b/pkg/mapper/base.go @@ -24,6 +24,10 @@ type Base[T any] interface { UpdateById(entity T) int64 + DeleteById(id any) int64 + + DeleteBatchIds(ids []any) int64 + SelectById(id any) (T, error) SelectBatchIds(queryWrapper *QueryWrapper[T]) ([]T, error) @@ -33,8 +37,4 @@ type Base[T any] interface { SelectCount(queryWrapper *QueryWrapper[T]) (int64, error) SelectList(queryWrapper *QueryWrapper[T]) ([]T, error) - - DeleteById(id any) int64 - - DeleteBatchIds(ids []any) int64 } diff --git a/pkg/mapper/base_mapper.go b/pkg/mapper/base_mapper.go index b3c6b9b..f3d9091 100644 --- a/pkg/mapper/base_mapper.go +++ b/pkg/mapper/base_mapper.go @@ -18,87 +18,74 @@ package mapper import ( - "context" - "reflect" - "strconv" - "strings" - "time" - + "fmt" "github.com/acmestack/gobatis" "github.com/acmestack/gobatis-plus/pkg/constants" + "reflect" ) type BaseMapper[T any] struct { - SessMgr *gobatis.SessionManager - Ctx context.Context - Columns []string - ParamNameSeq int + SessMgr *gobatis.SessionManager } -type BuildSqlFunc func(columns string, tableName string) string +func (userMapper *BaseMapper[T]) SelectList(queryWrapper *QueryWrapper[T]) ([]T, error) { + // if queryWrapper is nil ,need to build a new queryWrapper + queryWrapper = userMapper.initQueryWrapper(queryWrapper) -func (userMapper *BaseMapper[T]) Save(entity T) int64 { - return 0 -} + builder := SqlBuilder[T]{} + paramMap, sql, sqlId := builder.BuildSelectSql(queryWrapper, "") -func (userMapper *BaseMapper[T]) SaveBatch(entities ...T) (int64, int64) { - return 0, 0 -} -func (userMapper *BaseMapper[T]) DeleteById(id any) int64 { - return 0 -} -func (userMapper *BaseMapper[T]) DeleteBatchIds(ids []any) int64 { - return 0 -} -func (userMapper *BaseMapper[T]) UpdateById(entity T) int64 { - return 0 + err := gobatis.RegisterSql(sqlId, sql) + defer gobatis.UnregisterSql(sqlId) + if err != nil { + return nil, err + } + + sess := userMapper.SessMgr.NewSession() + var results []T + err = sess.Select(sqlId).Param(paramMap).Result(&results) + if err != nil { + return nil, err + } + return results, nil } + func (userMapper *BaseMapper[T]) SelectById(id any) (T, error) { - queryWrapper := userMapper.init(nil) - queryWrapper.Eq(constants.ID, strconv.Itoa(id.(int))) - columns := userMapper.buildSelectColumns(queryWrapper) + queryWrapper := userMapper.initQueryWrapper(nil) + switch v := id.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + queryWrapper.Eq(constants.ID, fmt.Sprintf("%d", v)) + case string: + queryWrapper.Eq(constants.ID, v) + } - sqlId, sql, paramMap := userMapper.buildSelectSql(queryWrapper, columns, buildSelectSqlFirstPart) + builder := SqlBuilder[T]{} + paramMap, sql, sqlId := builder.BuildSelectSql(queryWrapper, "") - var entity T err := gobatis.RegisterSql(sqlId, sql) + defer gobatis.UnregisterSql(sqlId) + var entity T if err != nil { return entity, err } sess := userMapper.SessMgr.NewSession() - err = sess.Select(sqlId).Param(paramMap).Result(&entity) if err != nil { return entity, err } - - // delete sqlId - gobatis.UnregisterSql(sqlId) - return entity, nil } + func (userMapper *BaseMapper[T]) SelectBatchIds(ids []any) ([]T, error) { - tableName := userMapper.getTableName() - sqlFirstPart := buildSelectSqlFirstPart(constants.ASTERISK, tableName) - var paramMap = map[string]any{} - build := strings.Builder{} - build.WriteString(constants.SPACE + constants.WHERE + constants.SPACE + constants.ID + - constants.SPACE + constants.In + constants.LEFT_BRACKET + constants.SPACE) - for index, id := range ids { - mapping := userMapper.getMappingSeq() - paramMap[mapping] = strconv.Itoa(id.(int)) - if index == len(ids)-1 { - build.WriteString(constants.HASH_LEFT_BRACE + mapping + constants.RIGHT_BRACE) - } else { - build.WriteString(constants.HASH_LEFT_BRACE + mapping + constants.RIGHT_BRACE + constants.COMMA) - } - } - build.WriteString(constants.SPACE + constants.RIGHT_BRACKET) - sqlId := buildSqlId(constants.SELECT) - sql := sqlFirstPart + build.String() + queryWrapper := userMapper.initQueryWrapper(nil) + queryWrapper.In(constants.ID, ids...) + + builder := SqlBuilder[T]{} + paramMap, sql, sqlId := builder.BuildSelectSql(queryWrapper, "") err := gobatis.RegisterSql(sqlId, sql) + defer gobatis.UnregisterSql(sqlId) if err != nil { return nil, err } @@ -109,46 +96,41 @@ func (userMapper *BaseMapper[T]) SelectBatchIds(ids []any) ([]T, error) { if err != nil { return nil, err } - return arr, nil -} -func (userMapper *BaseMapper[T]) getMappingSeq() string { - userMapper.ParamNameSeq = userMapper.ParamNameSeq + 1 - mapping := constants.MAPPING + strconv.Itoa(userMapper.ParamNameSeq) - return mapping + return arr, nil } func (userMapper *BaseMapper[T]) SelectOne(queryWrapper *QueryWrapper[T]) (T, error) { - queryWrapper = userMapper.init(queryWrapper) - - columns := userMapper.buildSelectColumns(queryWrapper) + queryWrapper = userMapper.initQueryWrapper(queryWrapper) - sqlId, sql, paramMap := userMapper.buildSelectSql(queryWrapper, columns, buildSelectSqlFirstPart) + builder := SqlBuilder[T]{} + paramMap, sql, sqlId := builder.BuildSelectSql(queryWrapper, "") - var entity T err := gobatis.RegisterSql(sqlId, sql) + defer gobatis.UnregisterSql(sqlId) + var entity T if err != nil { return entity, err } sess := userMapper.SessMgr.NewSession() - err = sess.Select(sqlId).Param(paramMap).Result(&entity) if err != nil { return entity, err } - // delete sqlId - gobatis.UnregisterSql(sqlId) return entity, nil } func (userMapper *BaseMapper[T]) SelectCount(queryWrapper *QueryWrapper[T]) (int64, error) { - queryWrapper = userMapper.init(queryWrapper) - sqlId, sql, paramMap := userMapper.buildSelectSql(queryWrapper, constants.COUNT, buildSelectSqlFirstPart) + queryWrapper = userMapper.initQueryWrapper(queryWrapper) + + builder := SqlBuilder[T]{} + paramMap, sql, sqlId := builder.BuildSelectSql(queryWrapper, constants.COUNT) err := gobatis.RegisterSql(sqlId, sql) + defer gobatis.UnregisterSql(sqlId) if err != nil { return 0, err } @@ -163,95 +145,150 @@ func (userMapper *BaseMapper[T]) SelectCount(queryWrapper *QueryWrapper[T]) (int return count, nil } -func (userMapper *BaseMapper[T]) SelectList(queryWrapper *QueryWrapper[T]) ([]T, error) { - queryWrapper = userMapper.init(queryWrapper) +func (userMapper *BaseMapper[T]) Save(entity T) (int, int64, error) { + builder := SqlBuilder[T]{} + paramMap, sql, sqlId := builder.BuildInsertSql(entity) + + err := gobatis.RegisterSql(sqlId, sql) + defer gobatis.UnregisterSql(sqlId) + if err != nil { + return 0, 0, err + } - columns := userMapper.buildSelectColumns(queryWrapper) + sess := userMapper.SessMgr.NewSession() + var ret int + selectRunner := sess.Insert(sqlId).Param(paramMap) + err = selectRunner.Result(&ret) + if err != nil { + return 0, 0, err + } + insertId := selectRunner.LastInsertId() + return ret, insertId, nil +} - sqlId, sql, paramMap := userMapper.buildSelectSql(queryWrapper, columns, buildSelectSqlFirstPart) +func (userMapper *BaseMapper[T]) SaveBatch(entities ...T) (int64, int64, error) { + builder := SqlBuilder[T]{} + paramMap, sql, sqlId := builder.BuildInsertSql(entities...) err := gobatis.RegisterSql(sqlId, sql) if err != nil { - return nil, err + return 0, 0, err } sess := userMapper.SessMgr.NewSession() - var arr []T - err = sess.Select(sqlId).Param(paramMap).Result(&arr) + var ret int64 + selectRunner := sess.Insert(sqlId).Param(paramMap) + err = selectRunner.Result(&ret) if err != nil { - return nil, err + return 0, 0, err } - - // delete sqlId - gobatis.UnregisterSql(sqlId) - return arr, nil + insertId := selectRunner.LastInsertId() + return ret, insertId, nil } -func (userMapper *BaseMapper[T]) buildSelectColumns(queryWrapper *QueryWrapper[T]) string { - var columns string - if len(queryWrapper.Columns) > 0 { - columns = strings.Join(queryWrapper.Columns, ",") - } else { - columns = constants.ASTERISK +func (userMapper *BaseMapper[T]) DeleteById(id any) (int64, error) { + var conditions []any + conditions = append(conditions, constants.ID) + conditions = append(conditions, constants.Eq) + conditions = append(conditions, ParamValue{id}) + + builder := SqlBuilder[T]{} + paramMap, sql, sqlId := builder.BuildDeleteSql(conditions) + + err := gobatis.RegisterSql(sqlId, sql) + defer gobatis.UnregisterSql(sqlId) + if err != nil { + return 0, err } - return columns -} -func (userMapper *BaseMapper[T]) init(queryWrapper *QueryWrapper[T]) *QueryWrapper[T] { - if queryWrapper == nil { - queryWrapper = &QueryWrapper[T]{} + sess := userMapper.SessMgr.NewSession() + var ret int64 + err = sess.Delete(sqlId).Param(paramMap).Result(&ret) + if err != nil { + return 0, err } - return queryWrapper + + // delete sqlId + gobatis.UnregisterSql(sqlId) + return ret, nil } -func (userMapper *BaseMapper[T]) buildCondition(queryWrapper *QueryWrapper[T]) (string, map[string]any) { - var paramMap = map[string]any{} - expression := queryWrapper.Expression - build := strings.Builder{} - for _, v := range expression { - if paramValue, ok := v.(ParamValue); ok { - mapping := userMapper.getMappingSeq() - paramMap[mapping] = paramValue.value - build.WriteString(constants.HASH_LEFT_BRACE + mapping + constants.RIGHT_BRACE + constants.SPACE) - } else { - build.WriteString(v.(string) + constants.SPACE) - } +func (userMapper *BaseMapper[T]) DeleteBatchIds(ids []any) (int64, error) { + var conditions []any + conditions = append(conditions, constants.ID) + conditions = append(conditions, constants.In) + conditions = append(conditions, ParamValue{ids}) + + builder := SqlBuilder[T]{} + paramMap, sql, sqlId := builder.BuildDeleteSql(conditions) + + err := gobatis.RegisterSql(sqlId, sql) + defer gobatis.UnregisterSql(sql) + if err != nil { + return 0, err } - return build.String(), paramMap -} -func (userMapper *BaseMapper[T]) buildSelectSql(queryWrapper *QueryWrapper[T], columns string, buildSqlFunc BuildSqlFunc) (string, string, map[string]any) { + sess := userMapper.SessMgr.NewSession() + var ret int64 + err = sess.Delete(sqlId).Param(paramMap).Result(&ret) + if err != nil { + return 0, err + } - sqlCondition, paramMap := userMapper.buildCondition(queryWrapper) + // delete sqlId + gobatis.UnregisterSql(sqlId) + return ret, nil +} - tableName := userMapper.getTableName() +func (userMapper *BaseMapper[T]) UpdateById(entity T) (int64, error) { + updateWrapper := userMapper.initUpdateWrapper(nil) + value := userMapper.getIdValue(entity) + updateWrapper.Eq(constants.ID, value) - sqlId := buildSqlId(constants.SELECT) + builder := SqlBuilder[T]{} + paramMap, sql, sqlId := builder.BuildUpdateSql(entity, updateWrapper) - sqlFirstPart := buildSqlFunc(columns, tableName) + sess := userMapper.SessMgr.NewSession() - var sql string - if len(queryWrapper.Expression) > 0 { - sql = sqlFirstPart + constants.SPACE + constants.WHERE + constants.SPACE + sqlCondition - } else { - sql = sqlFirstPart + err := gobatis.RegisterSql(sqlId, sql) + defer gobatis.UnregisterSql(sqlId) + if err != nil { + return 0, err } - return sqlId, sql, paramMap + var ret int64 + selectRunner := sess.Update(sqlId).Param(paramMap) + err = selectRunner.Result(&ret) + if err != nil { + return 0, err + } + return ret, nil } -func (userMapper *BaseMapper[T]) getTableName() string { - entityRef := reflect.TypeOf(new(T)).Elem() - tableNameTag := entityRef.Field(0).Tag - tableName := string(tableNameTag) - return tableName +func (userMapper *BaseMapper[T]) initQueryWrapper(queryWrapper *QueryWrapper[T]) *QueryWrapper[T] { + if queryWrapper == nil { + queryWrapper = &QueryWrapper[T]{} + } + return queryWrapper } -func buildSqlId(sqlType string) string { - sqlId := sqlType + constants.CONNECTION + strconv.Itoa(time.Now().Nanosecond()) - return sqlId +func (userMapper *BaseMapper[T]) initUpdateWrapper(updateWrapper *UpdateWrapper[T]) *UpdateWrapper[T] { + if updateWrapper == nil { + updateWrapper = &UpdateWrapper[T]{} + } + return updateWrapper } -func buildSelectSqlFirstPart(columns string, tableName string) string { - return constants.SELECT + constants.SPACE + columns + constants.SPACE + constants.FROM + constants.SPACE + tableName +func (userMapper *BaseMapper[T]) getIdValue(entity T) any { + entityType := reflect.TypeOf(entity) + entityValue := reflect.ValueOf(entity) + numField := entityType.NumField() + for i := 0; i < numField; i++ { + tag := entityType.Field(i).Tag + column := tag.Get(constants.COLUMN) + if constants.ID == column { + return entityValue.Field(i).Interface() + } + } + return nil } diff --git a/pkg/mapper/base_mapper_test.go b/pkg/mapper/base_mapper_test.go index 3385778..a0c0cf3 100644 --- a/pkg/mapper/base_mapper_test.go +++ b/pkg/mapper/base_mapper_test.go @@ -14,18 +14,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package mapper import ( + "database/sql" "encoding/json" "fmt" - "testing" - "github.com/acmestack/gobatis" "github.com/acmestack/gobatis/datasource" "github.com/acmestack/gobatis/factory" _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" ) func connect() factory.Factory { @@ -37,11 +38,37 @@ func connect() factory.Factory { Port: 3306, DBName: "test", Username: "root", - Password: "123456", + Password: "test", Charset: "utf8", })) } +func TestInitTable(t *testing.T) { + sql_table := "CREATE TABLE IF NOT EXISTS `test_table` (" + + "`id` int(11) NOT NULL AUTO_INCREMENT," + + "`username` varchar(255) DEFAULT NULL," + + "`password` varchar(255) DEFAULT NULL," + + "`createTime` datetime DEFAULT NULL," + + "PRIMARY KEY (`id`)" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;" + + db, err := sql.Open("mysql", "test:test@tcp(localhost:3306)/test?charset=utf8") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec(sql_table) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("DELETE FROM test_table") + if err != nil { + t.Fatal(err) + } +} + type TestTable struct { TableName gobatis.TableName `test_table` Id int64 `column:"id"` @@ -53,7 +80,7 @@ func TestUserMapperImpl_SelectList(t *testing.T) { mgr := gobatis.NewSessionManager(connect()) userMapper := BaseMapper[TestTable]{SessMgr: mgr} queryWrapper := &QueryWrapper[TestTable]{} - queryWrapper.Eq("username", "user123").Select("username") + queryWrapper.Eq("username", "acmestack").In("password", "123456", "pw5") list, err := userMapper.SelectList(queryWrapper) if err != nil { fmt.Println(err.Error()) @@ -66,7 +93,7 @@ func TestUserMapperImpl_SelectOne(t *testing.T) { mgr := gobatis.NewSessionManager(connect()) userMapper := BaseMapper[TestTable]{SessMgr: mgr} queryWrapper := &QueryWrapper[TestTable]{} - queryWrapper.Eq("username", "user1").Select("username", "password") + queryWrapper.Eq("username", "gobatis").Select("username", "password") entity, err := userMapper.SelectOne(queryWrapper) if err != nil { fmt.Println(err.Error()) @@ -87,6 +114,17 @@ func TestUserMapperImpl_SelectCount(t *testing.T) { fmt.Println(count) } +func TestUserMapperImpl_SelectById(t *testing.T) { + mgr := gobatis.NewSessionManager(connect()) + userMapper := BaseMapper[TestTable]{SessMgr: mgr} + entity, err := userMapper.SelectById(103) + if err != nil { + fmt.Println(err.Error()) + } + marshal, _ := json.Marshal(entity) + fmt.Println(string(marshal)) +} + func TestUserMapperImpl_SelectBatchIds(t *testing.T) { mgr := gobatis.NewSessionManager(connect()) userMapper := BaseMapper[TestTable]{SessMgr: mgr} @@ -101,13 +139,138 @@ func TestUserMapperImpl_SelectBatchIds(t *testing.T) { fmt.Println(string(marshal)) } -func TestUserMapperImpl_SelectById(t *testing.T) { +func TestUserMapperImpl_Save(t *testing.T) { mgr := gobatis.NewSessionManager(connect()) userMapper := BaseMapper[TestTable]{SessMgr: mgr} - entity, err := userMapper.SelectById(103) + uuid := fmt.Sprintf("%d", random()) + table := TestTable{Username: "gobatis" + uuid, Password: "123456"} + ret, id, err := userMapper.Save(table) if err != nil { - fmt.Println(err.Error()) + t.Fail() } - marshal, _ := json.Marshal(entity) - fmt.Println(string(marshal)) + table.Id = id + queryWrapper := &QueryWrapper[TestTable]{} + queryWrapper.Eq("username", "gobatis"+uuid).Eq("password", "123456") + one, err := userMapper.SelectOne(queryWrapper) + if err != nil { + t.Fail() + } + fmt.Println("ret:", ret) + assert.Equal(t, table, one, "they should be equal") +} + +func TestUserMapperImpl_SaveBatch(t *testing.T) { + mgr := gobatis.NewSessionManager(connect()) + userMapper := BaseMapper[TestTable]{SessMgr: mgr} + var entities []TestTable + username1 := "gobatis" + fmt.Sprintf("%d", random()) + username2 := "gobatis" + fmt.Sprintf("%d", random()) + username3 := "gobatis" + fmt.Sprintf("%d", random()) + id1 := random() + id2 := random() + id3 := random() + table1 := TestTable{Id: id1, Username: username1, Password: "123456"} + table2 := TestTable{Id: id2, Username: username2, Password: "123456"} + table3 := TestTable{Id: id3, Username: username3, Password: "123456"} + entities = append(entities, table1) + entities = append(entities, table2) + entities = append(entities, table3) + ret, id, err := userMapper.SaveBatch(entities...) + if err != nil { + t.Fail() + } + fmt.Println(ret, id) + queryWrapper := &QueryWrapper[TestTable]{} + queryWrapper.In("username", username1, username2, username3).Eq("password", "123456") + list, err := userMapper.SelectList(queryWrapper) + fmt.Println(entities) + fmt.Println(list) +} + +func TestUserMapperImpl_Delete(t *testing.T) { + mgr := gobatis.NewSessionManager(connect()) + userMapper := BaseMapper[TestTable]{SessMgr: mgr} + username := "gobatis" + fmt.Sprintf("%d", random()) + table := TestTable{Username: username, Password: "123456"} + ret, id, err := userMapper.Save(table) + if err != nil { + t.Fail() + } + + ret2, err := userMapper.DeleteById(id) + if err != nil { + t.Fail() + } + + if ret2 != 1 { + t.Fail() + } + fmt.Println(ret) +} + +func TestUserMapperImpl_DeleteBatch(t *testing.T) { + mgr := gobatis.NewSessionManager(connect()) + userMapper := BaseMapper[TestTable]{SessMgr: mgr} + + var entities []TestTable + username1 := "gobatis" + fmt.Sprintf("%d", random()) + username2 := "gobatis" + fmt.Sprintf("%d", random()) + username3 := "gobatis" + fmt.Sprintf("%d", random()) + id1 := random() + id2 := random() + id3 := random() + table1 := TestTable{Id: id1, Username: username1, Password: "123456"} + table2 := TestTable{Id: id2, Username: username2, Password: "123456"} + table3 := TestTable{Id: id3, Username: username3, Password: "123456"} + entities = append(entities, table1) + entities = append(entities, table2) + entities = append(entities, table3) + + ret, id, err := userMapper.SaveBatch(entities...) + if err != nil { + t.Fail() + } + fmt.Println(ret, id) + + var ids []any + ids = append(ids, id1) + ids = append(ids, id2) + ids = append(ids, id3) + ret, err = userMapper.DeleteBatchIds(ids) + if err != nil { + t.Fail() + } + if ret != 3 { + t.Fail() + } + fmt.Println("ret", ret) +} + +func TestUserMapperImpl_UpdateById(t *testing.T) { + mgr := gobatis.NewSessionManager(connect()) + userMapper := BaseMapper[TestTable]{SessMgr: mgr} + + uuid := fmt.Sprintf("%d", random()) + table := TestTable{Username: "gobatis" + uuid, Password: "123456"} + ret, id, err := userMapper.Save(table) + if err != nil { + t.Fail() + } + fmt.Println(ret, id) + + var entity = TestTable{Id: id, Username: "gobatis", Password: "123456"} + id, err = userMapper.UpdateById(entity) + if err != nil { + t.Fail() + } + + if id != 1 { + t.Fail() + } + fmt.Println(ret) +} + +func random() int64 { + intn := rand.Intn(100000000) + return int64(intn) } diff --git a/pkg/mapper/query_wrapper.go b/pkg/mapper/query_wrapper.go index e8f712e..f019fb9 100644 --- a/pkg/mapper/query_wrapper.go +++ b/pkg/mapper/query_wrapper.go @@ -19,78 +19,81 @@ package mapper import ( "github.com/acmestack/gobatis-plus/pkg/constants" - "github.com/acmestack/gobatis/builder" ) type QueryWrapper[T any] struct { Columns []string - SqlBuild *builder.SQLFragment - Expression []any + Conditions []any LastConditionType string } func (queryWrapper *QueryWrapper[T]) Eq(column string, val any) Wrapper[T] { - queryWrapper.setCondition(column, val, constants.Eq) + queryWrapper.addCondition(column, val, constants.Eq) return queryWrapper } func (queryWrapper *QueryWrapper[T]) Ne(column string, val any) Wrapper[T] { - queryWrapper.setCondition(column, val, constants.Ne) + queryWrapper.addCondition(column, val, constants.Ne) return queryWrapper } func (queryWrapper *QueryWrapper[T]) Gt(column string, val any) Wrapper[T] { - queryWrapper.setCondition(column, val, constants.Gt) + queryWrapper.addCondition(column, val, constants.Gt) return queryWrapper } func (queryWrapper *QueryWrapper[T]) Ge(column string, val any) Wrapper[T] { - queryWrapper.setCondition(column, val, constants.Ge) + queryWrapper.addCondition(column, val, constants.Ge) return queryWrapper } func (queryWrapper *QueryWrapper[T]) Lt(column string, val any) Wrapper[T] { - queryWrapper.setCondition(column, val, constants.Lt) + queryWrapper.addCondition(column, val, constants.Lt) return queryWrapper } func (queryWrapper *QueryWrapper[T]) Le(column string, val any) Wrapper[T] { - queryWrapper.setCondition(column, val, constants.Le) + queryWrapper.addCondition(column, val, constants.Le) return queryWrapper } func (queryWrapper *QueryWrapper[T]) Like(column string, val any) Wrapper[T] { s := val.(string) - queryWrapper.setCondition(column, "%"+s+"%", constants.Like) + queryWrapper.addCondition(column, "%"+s+"%", constants.Like) return queryWrapper } func (queryWrapper *QueryWrapper[T]) NotLike(column string, val any) Wrapper[T] { s := val.(string) - queryWrapper.setCondition(column, "%"+s+"%", constants.Not+constants.Like) + queryWrapper.addCondition(column, "%"+s+"%", constants.Not+constants.Like) return queryWrapper } func (queryWrapper *QueryWrapper[T]) LikeLeft(column string, val any) Wrapper[T] { s := val.(string) - queryWrapper.setCondition(column, "%"+s, constants.Like) + queryWrapper.addCondition(column, "%"+s, constants.Like) return queryWrapper } func (queryWrapper *QueryWrapper[T]) LikeRight(column string, val any) Wrapper[T] { s := val.(string) - queryWrapper.setCondition(column, s+"%", constants.Like) + queryWrapper.addCondition(column, s+"%", constants.Like) + return queryWrapper +} + +func (queryWrapper *QueryWrapper[T]) In(column string, val ...any) Wrapper[T] { + queryWrapper.addCondition(column, val, constants.In) return queryWrapper } func (queryWrapper *QueryWrapper[T]) And() Wrapper[T] { - queryWrapper.Expression = append(queryWrapper.Expression, constants.Eq) + queryWrapper.Conditions = append(queryWrapper.Conditions, constants.Eq) queryWrapper.LastConditionType = constants.Eq return queryWrapper } func (queryWrapper *QueryWrapper[T]) Or() Wrapper[T] { - queryWrapper.Expression = append(queryWrapper.Expression, constants.Or) + queryWrapper.Conditions = append(queryWrapper.Conditions, constants.Or) queryWrapper.LastConditionType = constants.Or return queryWrapper } @@ -104,15 +107,15 @@ type ParamValue struct { value any } -func (queryWrapper *QueryWrapper[T]) setCondition(column string, val any, conditionType string) { +func (queryWrapper *QueryWrapper[T]) addCondition(column string, val any, conditionType string) { - if queryWrapper.LastConditionType != constants.And && queryWrapper.LastConditionType != constants.Or && len(queryWrapper.Expression) > 0 { - queryWrapper.Expression = append(queryWrapper.Expression, constants.And) + if queryWrapper.LastConditionType != constants.And && queryWrapper.LastConditionType != constants.Or && len(queryWrapper.Conditions) > 0 { + queryWrapper.Conditions = append(queryWrapper.Conditions, constants.And) } - queryWrapper.Expression = append(queryWrapper.Expression, column) + queryWrapper.Conditions = append(queryWrapper.Conditions, column) - queryWrapper.Expression = append(queryWrapper.Expression, conditionType) + queryWrapper.Conditions = append(queryWrapper.Conditions, conditionType) - queryWrapper.Expression = append(queryWrapper.Expression, ParamValue{val}) + queryWrapper.Conditions = append(queryWrapper.Conditions, ParamValue{val}) } diff --git a/pkg/mapper/sqlBuilder.go b/pkg/mapper/sqlBuilder.go new file mode 100644 index 0000000..25424f9 --- /dev/null +++ b/pkg/mapper/sqlBuilder.go @@ -0,0 +1,281 @@ +/* + * Licensed to the AcmeStack under one or more contributor license + * agreements. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package mapper + +import ( + "fmt" + "github.com/acmestack/gobatis-plus/pkg/constants" + "github.com/acmestack/godkits/gox/stringsx" + "reflect" + "strconv" + "strings" + "time" +) + +type SqlBuilder[T any] struct { + ParamNameSeq int +} + +func (sqlBuilder *SqlBuilder[T]) BuildSelectSql(queryWrapper *QueryWrapper[T], columns string) (map[string]any, string, string) { + if stringsx.Empty(columns) { + // eg: columnName1,columnName2,columnName3 + columns = sqlBuilder.buildSelectColumns(queryWrapper) + } + + tableName := sqlBuilder.getTableName() + + // eg: columnName1 = #{mapping1} and columnName2 = #{mapping1} + sqlCondition, paramMap := sqlBuilder.buildCondition(queryWrapper.Conditions) + + // eg: SELECT * FROM WHERE columnName = #{mapping1} and columnName = #{mapping1} + sql := sqlBuilder.onBuildSelectSql(columns, tableName, sqlCondition) + + sqlId := sqlBuilder.buildSqlId(constants.SELECT) + return paramMap, sql, sqlId +} + +func (sqlBuilder *SqlBuilder[T]) BuildInsertSql(entity ...T) (map[string]any, string, string) { + tableName := sqlBuilder.getTableName() + + // eg:columnName1,columnName2,columnName3 + columns := sqlBuilder.buildInsertColumns() + + // eg:(#{mapping1},#{mapping2},#{mapping3}) + paramMap, columnMappings := sqlBuilder.buildInsertColumnMapping(entity...) + + sql := sqlBuilder.onBuildInsertSql(tableName, columns, columnMappings) + + sqlId := sqlBuilder.buildSqlId(constants.INSERT) + return paramMap, sql, sqlId +} + +func (sqlBuilder *SqlBuilder[T]) BuildUpdateSql(entity T, updateWrapper *UpdateWrapper[T]) (map[string]any, string, string) { + tableName := sqlBuilder.getTableName() + paramMap, columnMapping := sqlBuilder.buildUpdateColumnMapping(entity) + + // eg: columnName1 = #{mapping1} and columnName2 = #{mapping1} + sqlCondition, paramConditionMap := sqlBuilder.buildCondition(updateWrapper.Conditions) + for k, v := range paramConditionMap { + paramMap[k] = v + } + sql := sqlBuilder.onBuildUpdateSql(tableName, columnMapping, sqlCondition) + + sqlId := sqlBuilder.buildSqlId(constants.UPDATE) + return paramMap, sql, sqlId +} + +func (sqlBuilder *SqlBuilder[T]) BuildDeleteSql(conditions []any) (map[string]any, string, string) { + tableName := sqlBuilder.getTableName() + + conditionMapping, paramMap := sqlBuilder.buildCondition(conditions) + + sql := sqlBuilder.onBuildDeleteSql(tableName, conditionMapping) + + sqlId := sqlBuilder.buildSqlId(constants.DELETE) + return paramMap, sql, sqlId +} + +func (sqlBuilder *SqlBuilder[T]) onBuildSelectSql(columns string, tableName string, sqlCondition string) string { + sql := strings.Replace(constants.SELECT_SQL, constants.COLUMN_HASH, columns, -1) + sql = strings.Replace(sql, constants.TABLE_NAME_HASH, tableName, -1) + sql = strings.Replace(sql, constants.CONDITIONS_HASH, sqlCondition, -1) + if sqlCondition == "" { + sql = strings.Replace(sql, constants.WHERE, "", -1) + } + return sql +} + +func (sqlBuilder *SqlBuilder[T]) onBuildInsertSql(tableName string, columns string, columnMappings []string) string { + sql := strings.Replace(constants.INSERT_SQL, constants.TABLE_NAME_HASH, tableName, -1) + sql = strings.Replace(sql, constants.COLUMN_HASH, columns, -1) + sql = strings.Replace(sql, constants.COLUMN_MAPPING_HASH, columnMappings[0], -1) + + builder := stringsx.Builder{} + builder.JoinString(sql) + for i, columnMapping := range columnMappings { + if i == 0 { + continue + } + builder.JoinString(constants.COMMA + constants.LEFT_BRACKET + columnMapping + constants.RIGHT_BRACKET) + } + return builder.String() +} + +func (sqlBuilder *SqlBuilder[T]) onBuildUpdateSql(tableName string, columnMapping string, sqlCondition string) string { + sql := strings.Replace(constants.UPDATEBYID_SQL, constants.TABLE_NAME_HASH, tableName, -1) + sql = strings.Replace(sql, constants.COLUMN_MAPPING_HASH, columnMapping, -1) + sql = strings.Replace(sql, constants.CONDITIONS_HASH, sqlCondition, -1) + return sql +} + +func (sqlBuilder *SqlBuilder[T]) onBuildDeleteSql(tableName string, conditionMapping string) string { + sql := strings.Replace(constants.DELETEBYID_SQL, constants.TABLE_NAME_HASH, tableName, -1) + sql = strings.Replace(sql, constants.CONDITIONS_HASH, conditionMapping, -1) + return sql +} + +func (sqlBuilder *SqlBuilder[T]) buildSelectColumns(queryWrapper *QueryWrapper[T]) string { + var columns string + if len(queryWrapper.Columns) > 0 { + columns = strings.Join(queryWrapper.Columns, ",") + } else { + columns = constants.ASTERISK + } + return columns +} + +func (sqlBuilder *SqlBuilder[T]) buildInsertColumns() string { + entityType := reflect.TypeOf(new(T)).Elem() + entityTypeNum := entityType.NumField() + var columns []string + for i := 0; i < entityTypeNum; i++ { + tag := entityType.Field(i).Tag + column := tag.Get(constants.COLUMN) + if stringsx.Empty(column) { + continue + } + columns = append(columns, column) + } + return strings.Join(columns, ",") +} + +func (sqlBuilder *SqlBuilder[T]) buildInsertColumnMapping(entities ...T) (map[string]any, []string) { + var paramMap = map[string]any{} + var allColumnMappings []string + for _, entity := range entities { + entityType := reflect.TypeOf(entity) + entityValue := reflect.ValueOf(entity) + entityValueNum := entityValue.NumField() + var columnMappings []string + for i := 0; i < entityValueNum; i++ { + tag := entityType.Field(i).Tag + column := tag.Get(constants.COLUMN) + if column == "" { + continue + } + v := entityValue.Field(i) + mapping := sqlBuilder.getMappingSeq() + switch iv := v.Interface().(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + paramMap[mapping] = fmt.Sprintf("%d", iv) + case string: + paramMap[mapping] = iv + } + mapping = constants.HASH_LEFT_BRACE + mapping + constants.RIGHT_BRACE + columnMappings = append(columnMappings, mapping) + } + allColumnMappings = append(allColumnMappings, strings.Join(columnMappings, ",")) + } + + return paramMap, allColumnMappings +} + +func (sqlBuilder *SqlBuilder[T]) buildUpdateColumnMapping(entity T) (map[string]any, string) { + entityType := reflect.TypeOf(entity) + entityValue := reflect.ValueOf(entity) + numField := entityType.NumField() + paramMap := map[string]any{} + var columnMappings []string + for i := 0; i < numField; i++ { + tag := entityType.Field(i).Tag + column := tag.Get(constants.COLUMN) + if column == "" || constants.ID == column { + continue + } + fieldValue := entityValue.Field(i).Interface() + var mapping string + switch v := fieldValue.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + idStr := fmt.Sprintf("%d", v) + mapping = sqlBuilder.getMappingSeq() + paramMap[mapping] = idStr + case string: + mapping = sqlBuilder.getMappingSeq() + paramMap[mapping] = v + } + var columnMapping = column + constants.Eq + constants.HASH_LEFT_BRACE + mapping + constants.RIGHT_BRACE + columnMappings = append(columnMappings, columnMapping) + } + str := strings.Join(columnMappings, ",") + return paramMap, str +} + +func (sqlBuilder *SqlBuilder[T]) getTableName() string { + entityRef := reflect.TypeOf(new(T)).Elem() + tableNameTag := entityRef.Field(0).Tag + tableName := string(tableNameTag) + return tableName +} + +func (sqlBuilder *SqlBuilder[T]) buildCondition(conditions []any) (string, map[string]any) { + var paramMap = map[string]any{} + build := strings.Builder{} + for _, v := range conditions { + // if v is ParamValue,use #{} to build sql + if paramValue, ok := v.(ParamValue); ok { + rt := reflect.TypeOf(paramValue.value) + rv := reflect.ValueOf(paramValue.value) + + if rt.Kind() == reflect.Slice { + l := rv.Len() + build.WriteString(constants.LEFT_BRACKET) + for i := 0; i < l; i++ { + elemV := rv.Index(i) + if !elemV.CanInterface() { + elemV = reflect.Indirect(elemV) + } + mapping := sqlBuilder.getMappingSeq() + switch iv := elemV.Interface().(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + paramMap[mapping] = fmt.Sprintf("%d", iv) + case string: + paramMap[mapping] = iv + } + if i != l-1 { + build.WriteString(constants.HASH_LEFT_BRACE + mapping + constants.RIGHT_BRACE + constants.COMMA) + } else { + build.WriteString(constants.HASH_LEFT_BRACE + mapping + constants.RIGHT_BRACE) + } + } + build.WriteString(constants.RIGHT_BRACKET) + } else { + mapping := sqlBuilder.getMappingSeq() + switch iv := paramValue.value.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + paramMap[mapping] = fmt.Sprintf("%d", iv) + case string: + paramMap[mapping] = iv + } + build.WriteString(constants.HASH_LEFT_BRACE + mapping + constants.RIGHT_BRACE + constants.SPACE) + } + } else { + build.WriteString(v.(string) + constants.SPACE) + } + } + return build.String(), paramMap +} + +func (sqlBuilder *SqlBuilder[T]) getMappingSeq() string { + sqlBuilder.ParamNameSeq = sqlBuilder.ParamNameSeq + 1 + mapping := constants.MAPPING + strconv.Itoa(sqlBuilder.ParamNameSeq) + return mapping +} + +func (sqlBuilder *SqlBuilder[T]) buildSqlId(sqlType string) string { + sqlId := sqlType + constants.CONNECTION + strconv.Itoa(time.Now().Nanosecond()) + return sqlId +} diff --git a/pkg/mapper/sqlBuilder_test.go b/pkg/mapper/sqlBuilder_test.go new file mode 100644 index 0000000..3b3fb85 --- /dev/null +++ b/pkg/mapper/sqlBuilder_test.go @@ -0,0 +1,184 @@ +/* + * Licensed to the AcmeStack under one or more contributor license + * agreements. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package mapper + +import ( + "fmt" + "github.com/acmestack/gobatis-plus/pkg/constants" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSqlBuilder_BuildSelectSql(t *testing.T) { + type args struct { + queryWrapper *QueryWrapper[TestTable] + columns string + } + queryWrapper := &QueryWrapper[TestTable]{} + queryWrapper.Eq("username", "acmestack").In("password", "123456", "pw5") + var wantParamMap = make(map[string]any) + wantParamMap["mapping1"] = "acmestack" + wantParamMap["mapping2"] = "123456" + wantParamMap["mapping3"] = "pw5" + tests := []struct { + name string + args args + wantParamMap map[string]any + wantSql string + wantSqlId string + }{ + { + name: "buildSelectSql", + args: args{queryWrapper: queryWrapper, columns: ""}, + wantParamMap: wantParamMap, + wantSql: "SELECT * FROM test_table WHERE username = #{mapping1} and password in (#{mapping2},#{mapping3})", + wantSqlId: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sqlBuilder := &SqlBuilder[TestTable]{} + paramMap, sql, _ := sqlBuilder.BuildSelectSql(tt.args.queryWrapper, tt.args.columns) + assert.Equal(t, tt.wantParamMap, paramMap, "they should be equal") + assert.Equal(t, tt.wantSql, sql, "they should be equal") + }) + } +} + +func TestSqlBuilder_BuildInsertSql(t *testing.T) { + table1 := TestTable{Username: "gobatis", Password: "123456"} + table2 := TestTable{Username: "acmestack", Password: "654321"} + var testTables []TestTable + testTables = append(testTables, table1, table2) + type args struct { + entity []TestTable + } + var wantParamMap = make(map[string]any) + wantParamMap["mapping1"] = "0" + wantParamMap["mapping2"] = "gobatis" + wantParamMap["mapping3"] = "123456" + wantParamMap["mapping4"] = "0" + wantParamMap["mapping5"] = "acmestack" + wantParamMap["mapping6"] = "654321" + + tests := []struct { + name string + args args + wantParamMap map[string]any + wantSql string + wantSqlId string + }{ + { + name: "BuildInsertSql", + args: args{entity: testTables}, + wantParamMap: wantParamMap, + wantSql: "INSERT INTO test_table (id,username,password) VALUES (#{mapping1},#{mapping2},#{mapping3}),(#{mapping4},#{mapping5},#{mapping6})", + wantSqlId: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sqlBuilder := &SqlBuilder[TestTable]{} + paramMap, sql, _ := sqlBuilder.BuildInsertSql(tt.args.entity...) + fmt.Println(paramMap) + fmt.Println(sql) + assert.Equal(t, tt.wantParamMap, paramMap, "they should be equal") + assert.Equal(t, tt.wantSql, sql, "they should be equal") + }) + } +} + +func TestSqlBuilder_BuildUpdateSql(t *testing.T) { + type args struct { + entity TestTable + updateWrapper *UpdateWrapper[TestTable] + } + + var entity = TestTable{Id: 1, Username: "gobatis", Password: "123456"} + updateWrapper := &UpdateWrapper[TestTable]{} + updateWrapper.Eq(constants.ID, 1) + + var wantParamMap = make(map[string]any) + wantParamMap["mapping1"] = "gobatis" + wantParamMap["mapping2"] = "123456" + wantParamMap["mapping3"] = "1" + tests := []struct { + name string + args args + wantParamMap map[string]any + wantSql string + wantSqlId string + }{ + { + name: "BuildUpdateSql", + args: args{entity: entity, updateWrapper: updateWrapper}, + wantParamMap: wantParamMap, + wantSql: "UPDATE test_table SET username=#{mapping1},password=#{mapping2} WHERE id = #{mapping3} ", + wantSqlId: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sqlBuilder := &SqlBuilder[TestTable]{} + paramMap, sql, _ := sqlBuilder.BuildUpdateSql(tt.args.entity, tt.args.updateWrapper) + fmt.Println(paramMap) + fmt.Println(sql) + assert.Equal(t, tt.wantParamMap, paramMap, "they should be equal") + assert.Equal(t, tt.wantSql, sql, "they should be equal") + }) + } +} + +func TestSqlBuilder_BuildDeleteSql(t *testing.T) { + type args struct { + conditions []any + } + + var conditions []any + conditions = append(conditions, constants.ID) + conditions = append(conditions, constants.Eq) + conditions = append(conditions, ParamValue{1}) + + var wantParamMap = make(map[string]any) + wantParamMap["mapping1"] = "1" + + tests := []struct { + name string + args args + wantParamMap map[string]any + wantSql string + wantSqlId string + }{ + { + name: "BuildDeleteSql", + args: args{conditions: conditions}, + wantParamMap: wantParamMap, + wantSql: "delete from test_table where id = #{mapping1} ", + wantSqlId: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sqlBuilder := &SqlBuilder[TestTable]{} + paramMap, sql, _ := sqlBuilder.BuildDeleteSql(tt.args.conditions) + fmt.Println(paramMap) + fmt.Println(sql) + assert.Equal(t, tt.wantParamMap, paramMap, "they should be equal") + assert.Equal(t, tt.wantSql, sql, "they should be equal") + }) + } +} diff --git a/pkg/mapper/update_wrapper.go b/pkg/mapper/update_wrapper.go new file mode 100644 index 0000000..d1d2175 --- /dev/null +++ b/pkg/mapper/update_wrapper.go @@ -0,0 +1,120 @@ +/* + * Licensed to the AcmeStack under one or more contributor license + * agreements. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package mapper + +import "github.com/acmestack/gobatis-plus/pkg/constants" + +type UpdateWrapper[T any] struct { + Columns []string + ValuesMap map[string]any + Conditions []any + LastConditionType string +} + +func (updateWrapper *UpdateWrapper[T]) Set(column string, val any) Wrapper[T] { + updateWrapper.ValuesMap[column] = val + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) Eq(column string, val any) Wrapper[T] { + updateWrapper.addCondition(column, val, constants.Eq) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) Ne(column string, val any) Wrapper[T] { + updateWrapper.addCondition(column, val, constants.Ne) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) Gt(column string, val any) Wrapper[T] { + updateWrapper.addCondition(column, val, constants.Gt) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) Ge(column string, val any) Wrapper[T] { + updateWrapper.addCondition(column, val, constants.Ge) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) Lt(column string, val any) Wrapper[T] { + updateWrapper.addCondition(column, val, constants.Lt) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) Le(column string, val any) Wrapper[T] { + updateWrapper.addCondition(column, val, constants.Le) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) Like(column string, val any) Wrapper[T] { + s := val.(string) + updateWrapper.addCondition(column, "%"+s+"%", constants.Like) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) NotLike(column string, val any) Wrapper[T] { + s := val.(string) + updateWrapper.addCondition(column, "%"+s+"%", constants.Not+constants.Like) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) LikeLeft(column string, val any) Wrapper[T] { + s := val.(string) + updateWrapper.addCondition(column, "%"+s, constants.Like) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) LikeRight(column string, val any) Wrapper[T] { + s := val.(string) + updateWrapper.addCondition(column, s+"%", constants.Like) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) In(column string, val ...any) Wrapper[T] { + updateWrapper.addCondition(column, val, constants.In) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) And() Wrapper[T] { + updateWrapper.Conditions = append(updateWrapper.Conditions, constants.Eq) + updateWrapper.LastConditionType = constants.Eq + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) Or() Wrapper[T] { + updateWrapper.Conditions = append(updateWrapper.Conditions, constants.Or) + updateWrapper.LastConditionType = constants.Or + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) Select(columns ...string) Wrapper[T] { + updateWrapper.Columns = append(updateWrapper.Columns, columns...) + return updateWrapper +} + +func (updateWrapper *UpdateWrapper[T]) addCondition(column string, val any, conditionType string) { + + if updateWrapper.LastConditionType != constants.And && updateWrapper.LastConditionType != constants.Or && len(updateWrapper.Conditions) > 0 { + updateWrapper.Conditions = append(updateWrapper.Conditions, constants.And) + } + + updateWrapper.Conditions = append(updateWrapper.Conditions, column) + + updateWrapper.Conditions = append(updateWrapper.Conditions, conditionType) + + updateWrapper.Conditions = append(updateWrapper.Conditions, ParamValue{val}) +} diff --git a/pkg/mapper/wrapper.go b/pkg/mapper/wrapper.go index 823e156..50e131d 100644 --- a/pkg/mapper/wrapper.go +++ b/pkg/mapper/wrapper.go @@ -38,6 +38,8 @@ type Wrapper[T any] interface { LikeRight(column string, val1 any) Wrapper[T] + In(column string, val ...any) Wrapper[T] + And() Wrapper[T] Or() Wrapper[T]