From 5675b29a52e066a1e86aa80d54274df81ff7cf08 Mon Sep 17 00:00:00 2001 From: HEXINGZE <2046084122@qq.com> Date: Sun, 22 Dec 2024 01:59:42 +0800 Subject: [PATCH] pr725 bugfix --- ...sertonduplicate_update_undo_log_builder.go | 115 ++++++++++-------- ...nduplicate_update_undo_log_builder_test.go | 33 +++++ 2 files changed, 97 insertions(+), 51 deletions(-) diff --git a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go index ff7bcde89..2758c9814 100644 --- a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go @@ -97,10 +97,7 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil { return "", nil, err } - - // Reset primary keys map u.BeforeImageSqlPrimaryKeys = make(map[string]bool, len(metaData.Indexs)) - pkIndexMap := u.getPkIndex(insertStmt, metaData) var pkIndexArray []int for _, val := range pkIndexMap { @@ -120,74 +117,71 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a hasPK := false for _, index := range metaData.Indexs { if strings.EqualFold("PRIMARY", index.Name) { - hasPK = true + allPKColumnsHaveValue := true + for _, col := range index.Columns { + if params, ok := paramMap[col.ColumnName]; !ok || len(params) == 0 || params[0] == nil { + allPKColumnsHaveValue = false + break + } + } + hasPK = allPKColumnsHaveValue break } } if !hasPK { - return "", nil, nil + hasValidUniqueIndex := false + for _, index := range metaData.Indexs { + if !index.NonUnique && !strings.EqualFold("PRIMARY", index.Name) { + if _, _, valid := validateIndexPrefix(index, paramMap, 0); valid { + hasValidUniqueIndex = true + break + } + } + } + if !hasValidUniqueIndex { + return "", nil, nil + } } var sql strings.Builder sql.WriteString("SELECT * FROM " + metaData.TableName + " ") - var selectArgs []driver.Value isContainWhere := false hasConditions := false for i := 0; i < len(insertRows); i++ { - var rowConditions = make([]string, 0, cap(insertRows[i])) - var rowArgs = make([]driver.Value, 0, cap(insertRows[i])) - usedParams := make(map[string]bool, len(paramMap)) + var rowConditions []string + var rowArgs []driver.Value + usedParams := make(map[string]bool) + // First try unique indexes for _, index := range metaData.Indexs { if index.NonUnique || strings.EqualFold("PRIMARY", index.Name) { continue } - if !isIndexValueNotNull(index, paramMap, i) { - continue - } - var indexConditions []string - var indexArgs []driver.Value - allColumnsPresent := true - for _, colMeta := range index.Columns { - columnName := colMeta.ColumnName - if params, ok := paramMap[columnName]; ok && len(params) > i && params[i] != nil { - indexConditions = append(indexConditions, columnName+" = ? ") - indexArgs = append(indexArgs, params[i]) - usedParams[columnName] = true - } else if colMeta.ColumnDef != nil { - indexConditions = append(indexConditions, columnName+" = DEFAULT("+columnName+")") - } else { - allColumnsPresent = false - break - } - } - if allColumnsPresent && len(indexConditions) > 0 { - rowConditions = append(rowConditions, "("+strings.Join(indexConditions, " and ")+")") - rowArgs = append(rowArgs, indexArgs...) + if conditions, args, valid := validateIndexPrefix(index, paramMap, i); valid { + rowConditions = append(rowConditions, "("+strings.Join(conditions, " and ")+")") + rowArgs = append(rowArgs, args...) hasConditions = true + for _, colMeta := range index.Columns { + usedParams[colMeta.ColumnName] = true + } } } + // Then try primary key for _, index := range metaData.Indexs { if !strings.EqualFold("PRIMARY", index.Name) { continue } - var pkConditions []string - var pkArgs []driver.Value - for _, colMeta := range index.Columns { - columnName := colMeta.ColumnName - u.BeforeImageSqlPrimaryKeys[columnName] = true - if params, ok := paramMap[columnName]; ok && len(params) > i && params[i] != nil && !usedParams[columnName] { - pkConditions = append(pkConditions, columnName+" = ? ") - pkArgs = append(pkArgs, params[i]) - } - } - if len(pkConditions) > 0 { - rowConditions = append(rowConditions, "("+strings.Join(pkConditions, " and ")+")") - rowArgs = append(rowArgs, pkArgs...) + if conditions, args, valid := validateIndexPrefix(index, paramMap, i); valid { + rowConditions = append(rowConditions, "("+strings.Join(conditions, " and ")+")") + rowArgs = append(rowArgs, args...) hasConditions = true + for _, colMeta := range index.Columns { + usedParams[colMeta.ColumnName] = true + } } } + if len(rowConditions) > 0 { if !isContainWhere { sql.WriteString("WHERE ") @@ -195,12 +189,7 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a } else { sql.WriteString(" OR ") } - for j, condition := range rowConditions { - if j > 0 { - sql.WriteString(" OR ") - } - sql.WriteString(condition + " ") - } + sql.WriteString(strings.Join(rowConditions, " OR ") + " ") selectArgs = append(selectArgs, rowArgs...) } } @@ -210,7 +199,6 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a sqlStr := sql.String() log.Infof("build select sql by insert on duplicate sourceQuery, sql: %s", sqlStr) return sqlStr, selectArgs, nil - } func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { @@ -371,3 +359,28 @@ func isIndexValueNotNull(indexMeta types.IndexMeta, imageParameterMap map[string } return true } + +func validateIndexPrefix(index types.IndexMeta, paramMap map[string][]driver.Value, rowIndex int) ([]string, []driver.Value, bool) { + var indexConditions []string + var indexArgs []driver.Value + if len(index.Columns) > 1 { + for _, colMeta := range index.Columns { + params, ok := paramMap[colMeta.ColumnName] + if !ok || len(params) <= rowIndex || params[rowIndex] == nil { + return nil, nil, false + } + } + } + for _, colMeta := range index.Columns { + columnName := colMeta.ColumnName + params, ok := paramMap[columnName] + if ok && len(params) > rowIndex && params[rowIndex] != nil { + indexConditions = append(indexConditions, columnName+" = ? ") + indexArgs = append(indexArgs, params[rowIndex]) + } + } + if len(indexConditions) != len(index.Columns) { + return nil, nil, false + } + return indexConditions, indexArgs, true +} diff --git a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go index 07f5e87bd..4e831f4cd 100644 --- a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go +++ b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go @@ -173,6 +173,39 @@ func TestInsertOnDuplicateBuildBeforeImageSQL(t *testing.T) { expectQuery1: "", expectQueryArgs1: nil, }, + // Test case for composite index with all columns + { + name: "composite_index_full", + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(?,?,?) on duplicate key update other = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{1, "Jack", 25, "other"}, + expectQuery1: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (id = ? ) ", + expectQueryArgs1: []driver.Value{"Jack", 25, 1}, + }, + // Test case for composite index with null value + { + name: "composite_index_with_null", + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(?,?,?) on duplicate key update other = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{1, "Jack", nil, "other"}, + expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) ", + expectQueryArgs1: []driver.Value{1}, + }, + // Test case for composite index with leftmost prefix only + { + name: "composite_index_leftmost_prefix", + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name) values(?,?) on duplicate key update other = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{1, "Jack", "other"}, + expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) ", + expectQueryArgs1: []driver.Value{1}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {