From 4c22d004e43625aeabedb0bbffa0be43c38c18de Mon Sep 17 00:00:00 2001 From: Vivek Yadav Date: Fri, 22 Nov 2024 15:26:50 +0530 Subject: [PATCH 1/3] Added the backend changes for check constraints --- internal/convert.go | 1 + internal/helpers.go | 7 +- internal/mapping.go | 13 ++ internal/reports/report_helpers.go | 7 + schema/schema.go | 38 ++-- sources/common/infoschema.go | 23 +-- sources/common/toddl.go | 32 +++- sources/common/toddl_test.go | 30 ++++ sources/mysql/infoschema.go | 100 +++++++++-- sources/mysql/infoschema_test.go | 115 ++++++++++++ spanner/ddl/ast.go | 55 ++++-- spanner/ddl/ast_test.go | 61 +++++-- webv2/api/schema.go | 103 +++++++++++ webv2/api/schema_test.go | 272 +++++++++++++++++++++++++++++ webv2/routes.go | 4 +- webv2/table/review_table_schema.go | 9 +- webv2/table/update_table_schema.go | 10 ++ 17 files changed, 809 insertions(+), 71 deletions(-) diff --git a/internal/convert.go b/internal/convert.go index 6b089053d..81339bda9 100644 --- a/internal/convert.go +++ b/internal/convert.go @@ -128,6 +128,7 @@ const ( SequenceCreated ForeignKeyActionNotSupported NumericPKNotSupported + TypeMismatch ) const ( diff --git a/internal/helpers.go b/internal/helpers.go index 8690dd10c..d4f1de43d 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -26,7 +26,7 @@ import ( type Counter struct { counterMutex sync.Mutex - ObjectId string + ObjectId string } var Cntr Counter @@ -65,6 +65,11 @@ func GenerateForeignkeyId() string { func GenerateIndexesId() string { return GenerateId("i") } + +func GenerateCheckConstrainstId() string { + return GenerateId("ck") +} + func GenerateRuleId() string { return GenerateId("r") } diff --git a/internal/mapping.go b/internal/mapping.go index 0eb9bd055..1b955bdba 100644 --- a/internal/mapping.go +++ b/internal/mapping.go @@ -243,6 +243,19 @@ func ToSpannerIndexName(conv *Conv, srcIndexName string) string { return getSpannerValidName(conv, srcIndexName) } +// Note that the check constraints names in spanner have to be globally unique +// (across the database). But in some source databases, such as MySQL, +// they only have to be unique for a table. Hence we must map each source +// constraint name to a unique spanner constraint name. +func ToSpannerCheckConstraintName(conv *Conv, srcCheckConstraintName string) string { + return getSpannerValidName(conv, srcCheckConstraintName) +} + +func GetSpannerValidExpression(cks []ddl.Checkconstraint) []ddl.Checkconstraint { + // TODO validate the check constraints data with batch verification then send back + return cks +} + // conv.UsedNames tracks Spanner names that have been used for table names, foreign key constraints // and indexes. We use this to ensure we generate unique names when // we map from source dbs to Spanner since Spanner requires all these names to be diff --git a/internal/reports/report_helpers.go b/internal/reports/report_helpers.go index a066f0cde..34cb58511 100644 --- a/internal/reports/report_helpers.go +++ b/internal/reports/report_helpers.go @@ -403,6 +403,13 @@ func buildTableReportBody(conv *internal.Conv, tableId string, issues map[string Description: fmt.Sprintf("UNIQUE constraint on column(s) '%s' replaced with primary key since table '%s' didn't have one. Spanner requires a primary key for every table", strings.Join(uniquePK, ", "), conv.SpSchema[tableId].Name), } l = append(l, toAppend) + case internal.TypeMismatch: + toAppend := Issue{ + Category: IssueDB[i].Category, + Description: fmt.Sprintf("Table '%s': Type mismatch in '%s'column affecting check constraints. Verify data type compatibility with constraint logic", conv.SpSchema[tableId].Name, conv.SpSchema[tableId].ColDefs[colId].Name), + } + l = append(l, toAppend) + default: toAppend := Issue{ Category: IssueDB[i].Category, diff --git a/schema/schema.go b/schema/schema.go index 48b125bba..6bade2a4e 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -35,26 +35,27 @@ import ( // Table represents a database table. type Table struct { - Name string - Schema string - ColIds []string // List of column Ids (for predictable iteration order e.g. printing). - ColDefs map[string]Column // Details of columns. - ColNameIdMap map[string]string `json:"-"` // Computed every time just after conv is generated or after any column renaming - PrimaryKeys []Key - ForeignKeys []ForeignKey - Indexes []Index - Id string + Name string + Schema string + ColIds []string // List of column Ids (for predictable iteration order e.g. printing). + ColDefs map[string]Column // Details of columns. + ColNameIdMap map[string]string `json:"-"` // Computed every time just after conv is generated or after any column renaming + PrimaryKeys []Key + ForeignKeys []ForeignKey + CheckConstraints []CheckConstraints + Indexes []Index + Id string } // Column represents a database column. // TODO: add support for foreign keys. type Column struct { - Name string - Type Type - NotNull bool - Ignored Ignored - Id string - AutoGen ddl.AutoGenCol + Name string + Type Type + NotNull bool + Ignored Ignored + Id string + AutoGen ddl.AutoGenCol } // ForeignKey represents a foreign key. @@ -76,6 +77,13 @@ type ForeignKey struct { Id string } +// CheckConstraints represents a Check Constrainst. +type CheckConstraints struct { + Name string + Expr string + Id string +} + // Key respresents a primary key or index key. type Key struct { ColId string diff --git a/sources/common/infoschema.go b/sources/common/infoschema.go index 5144d2613..da2612895 100644 --- a/sources/common/infoschema.go +++ b/sources/common/infoschema.go @@ -36,7 +36,7 @@ type InfoSchema interface { GetColumns(conv *internal.Conv, table SchemaAndName, constraints map[string][]string, primaryKeys []string) (map[string]schema.Column, []string, error) GetRowsFromTable(conv *internal.Conv, srcTable string) (interface{}, error) GetRowCount(table SchemaAndName) (int64, error) - GetConstraints(conv *internal.Conv, table SchemaAndName) ([]string, map[string][]string, error) + GetConstraints(conv *internal.Conv, table SchemaAndName) ([]string, []schema.CheckConstraints, map[string][]string, error) GetForeignKeys(conv *internal.Conv, table SchemaAndName) (foreignKeys []schema.ForeignKey, err error) GetIndexes(conv *internal.Conv, table SchemaAndName, colNameIdMp map[string]string) ([]schema.Index, error) ProcessData(conv *internal.Conv, tableId string, srcSchema schema.Table, spCols []string, spSchema ddl.CreateTable, additionalAttributes internal.AdditionalDataAttributes) error @@ -185,7 +185,7 @@ func (is *InfoSchemaImpl) processTable(conv *internal.Conv, table SchemaAndName, var t schema.Table fmt.Println("processing schema for table", table) tblId := internal.GenerateTableId() - primaryKeys, constraints, err := infoSchema.GetConstraints(conv, table) + primaryKeys, checkConstraints, constraints, err := infoSchema.GetConstraints(conv, table) if err != nil { return t, fmt.Errorf("couldn't get constraints for table %s.%s: %s", table.Schema, table.Name, err) } @@ -215,15 +215,16 @@ func (is *InfoSchemaImpl) processTable(conv *internal.Conv, table SchemaAndName, schemaPKeys = append(schemaPKeys, schema.Key{ColId: colNameIdMap[k]}) } t = schema.Table{ - Id: tblId, - Name: name, - Schema: table.Schema, - ColIds: colIds, - ColNameIdMap: colNameIdMap, - ColDefs: colDefs, - PrimaryKeys: schemaPKeys, - Indexes: indexes, - ForeignKeys: foreignKeys} + Id: tblId, + Name: name, + Schema: table.Schema, + ColIds: colIds, + ColNameIdMap: colNameIdMap, + ColDefs: colDefs, + PrimaryKeys: schemaPKeys, + CheckConstraints: checkConstraints, + Indexes: indexes, + ForeignKeys: foreignKeys} return t, nil } diff --git a/sources/common/toddl.go b/sources/common/toddl.go index 16f4cf3ce..bf4379288 100644 --- a/sources/common/toddl.go +++ b/sources/common/toddl.go @@ -167,14 +167,15 @@ func (ss *SchemaToSpannerImpl) SchemaToSpannerDDLHelper(conv *internal.Conv, tod } comment := "Spanner schema for source table " + quoteIfNeeded(srcTable.Name) conv.SpSchema[srcTable.Id] = ddl.CreateTable{ - Name: spTableName, - ColIds: spColIds, - ColDefs: spColDef, - PrimaryKeys: cvtPrimaryKeys(srcTable.PrimaryKeys), - ForeignKeys: cvtForeignKeys(conv, spTableName, srcTable.Id, srcTable.ForeignKeys, isRestore), - Indexes: cvtIndexes(conv, srcTable.Id, srcTable.Indexes, spColIds, spColDef), - Comment: comment, - Id: srcTable.Id} + Name: spTableName, + ColIds: spColIds, + ColDefs: spColDef, + PrimaryKeys: cvtPrimaryKeys(srcTable.PrimaryKeys), + ForeignKeys: cvtForeignKeys(conv, spTableName, srcTable.Id, srcTable.ForeignKeys, isRestore), + CheckConstraint: cvtCheckConstraint(conv, srcTable.CheckConstraints), + Indexes: cvtIndexes(conv, srcTable.Id, srcTable.Indexes, spColIds, spColDef), + Comment: comment, + Id: srcTable.Id} return nil } @@ -234,6 +235,21 @@ func cvtForeignKeys(conv *internal.Conv, spTableName string, srcTableId string, return spKeys } +func cvtCheckConstraint(conv *internal.Conv, srcKeys []schema.CheckConstraints) []ddl.Checkconstraint { + var spcks []ddl.Checkconstraint + + for _, cks := range srcKeys { + spcks = append(spcks, ddl.Checkconstraint{ + Id: cks.Id, + Name: internal.ToSpannerCheckConstraintName(conv, cks.Name), + Expr: cks.Expr, + }) + + } + + return internal.GetSpannerValidExpression(spcks) +} + func CvtForeignKeysHelper(conv *internal.Conv, spTableName string, srcTableId string, srcKey schema.ForeignKey, isRestore bool) (ddl.Foreignkey, error) { if len(srcKey.ColIds) != len(srcKey.ReferColumnIds) { conv.Unexpected(fmt.Sprintf("ConvertForeignKeys: ColIds and referColumns don't have the same lengths: len(columns)=%d, len(referColumns)=%d for source tableId: %s, referenced table: %s", len(srcKey.ColIds), len(srcKey.ReferColumnIds), srcTableId, srcKey.ReferTableId)) diff --git a/sources/common/toddl_test.go b/sources/common/toddl_test.go index 63ecad71b..fd1504a84 100644 --- a/sources/common/toddl_test.go +++ b/sources/common/toddl_test.go @@ -428,3 +428,33 @@ func Test_SchemaToSpannerSequenceHelper(t *testing.T) { assert.Equal(t, expectedConv, conv) } } +func Test_cvtCheckContraint(t *testing.T) { + + conv := internal.MakeConv() + srcSchema := []schema.CheckConstraints{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + } + spSchema := []ddl.Checkconstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + } + result := cvtCheckConstraint(conv, srcSchema) + assert.Equal(t, spSchema, result) +} diff --git a/sources/mysql/infoschema.go b/sources/mysql/infoschema.go index 87458188f..7f7a69028 100644 --- a/sources/mysql/infoschema.go +++ b/sources/mysql/infoschema.go @@ -199,7 +199,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // We've already filtered out PRIMARY KEY. switch c { case "CHECK": - ignored.Check = true + case "FOREIGN KEY", "PRIMARY KEY", "UNIQUE": // Nothing to do here -- these are all handled elsewhere. } @@ -237,29 +237,107 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraints, map[string][]string, error) { q := `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA AND t.TABLE_NAME=k.TABLE_NAME WHERE k.TABLE_SCHEMA = ? AND k.TABLE_NAME = ? ORDER BY k.ordinal_position;` - rows, err := isi.Db.Query(q, table.Schema, table.Name) + + q1 := `SELECT + COALESCE(k.COLUMN_NAME, '') AS COLUMN_NAME, + t.CONSTRAINT_NAME, + t.CONSTRAINT_TYPE, + COALESCE(c.CHECK_CLAUSE, '') AS CHECK_CLAUSE + FROM + INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t + LEFT JOIN + INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k + ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME + AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA + AND t.TABLE_NAME = k.TABLE_NAME + LEFT JOIN + INFORMATION_SCHEMA.CHECK_CONSTRAINTS AS c + ON t.CONSTRAINT_NAME = c.CONSTRAINT_NAME + WHERE + t.TABLE_SCHEMA = ? + AND t.TABLE_NAME = ? + ORDER BY k.ORDINAL_POSITION; + ` + checkQuery := `SELECT COUNT(*) + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' + AND TABLE_NAME = 'CHECK_CONSTRAINTS';` + var tableExistsCount int + rows1, err := isi.Db.Query(checkQuery) if err != nil { - return nil, nil, err + return nil, nil, nil, err + } + for rows1.Next() { + err1 := rows1.Scan(&tableExistsCount) + if err1 != nil { + conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) + return nil, nil, nil, err + } + } + + defer rows1.Close() + + tableExists := tableExistsCount > 0 + + var finalQuery string + if tableExists { + finalQuery = q1 + } else { + finalQuery = q + } + + rows, err := isi.Db.Query(finalQuery, table.Schema, table.Name) + + if err != nil { + return nil, nil, nil, err } defer rows.Close() var primaryKeys []string - var col, constraint string + var checkKeys []schema.CheckConstraints + var col, constraintName, constraint, checkClause string m := make(map[string][]string) for rows.Next() { - err := rows.Scan(&col, &constraint) - if err != nil { - conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) - continue + if tableExists { + err := rows.Scan(&col, &constraintName, &constraint, &checkClause) + if err != nil { + conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) + continue + } + } else { + err := rows.Scan(&col, &constraintName, &constraint, &checkClause) + if err != nil { + conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) + continue + } } if col == "" || constraint == "" { - conv.Unexpected(fmt.Sprintf("Got empty col or constraint")) + + if tableExists { + if constraintName == "" || checkClause == "" { + conv.Unexpected(fmt.Sprintf("Got empty constraintName or checkClause")) + continue + } + switch constraint { + case "CHECK": + checkClause = strings.ReplaceAll(checkClause, "_utf8mb4\\", "") + checkClause = strings.ReplaceAll(checkClause, "\\", "") + + checkKeys = append(checkKeys, schema.CheckConstraints{Name: constraintName, Expr: string(checkClause), Id: internal.GenerateCheckConstrainstId()}) + default: + m[col] = append(m[col], constraint) + } + } else { + conv.Unexpected(fmt.Sprintf("Got empty col or constraint")) + } + continue + } switch constraint { case "PRIMARY KEY": @@ -268,7 +346,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, checkKeys, m, nil } // GetForeignKeys return list all the foreign keys constraints. diff --git a/sources/mysql/infoschema_test.go b/sources/mysql/infoschema_test.go index a61c04580..ffc7e460b 100644 --- a/sources/mysql/infoschema_test.go +++ b/sources/mysql/infoschema_test.go @@ -17,6 +17,7 @@ package mysql import ( "database/sql" "database/sql/driver" + "fmt" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -530,3 +531,117 @@ func mkMockDB(t *testing.T, ms []mockSpec) *sql.DB { } return db } +func TestGetConstraints(t *testing.T) { + + case1 := []mockSpec{ + { + query: `SELECT COUNT\(\*\) + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' + AND TABLE_NAME = 'CHECK_CONSTRAINTS'; + `, + cols: []string{"COUNT"}, + rows: [][]driver.Value{{1}}, + }, + { + query: `(?i)SELECT\s+COALESCE\(k.COLUMN_NAME,\s*''\)\s+AS\s+COLUMN_NAME,\s+t\.CONSTRAINT_NAME,\s+t\.CONSTRAINT_TYPE,\s+COALESCE\(c.CHECK_CLAUSE,\s*''\)\s+AS\s+CHECK_CLAUSE\s+FROM\s+INFORMATION_SCHEMA\.TABLE_CONSTRAINTS\s+AS\s+t\s+LEFT\s+JOIN\s+INFORMATION_SCHEMA\.KEY_COLUMN_USAGE\s+AS\s+k\s+ON\s+t\.CONSTRAINT_NAME\s*=\s*k\.CONSTRAINT_NAME\s+AND\s+t\.CONSTRAINT_SCHEMA\s*=\s*k\.CONSTRAINT_SCHEMA\s+AND\s+t\.TABLE_NAME\s*=\s*k\.TABLE_NAME\s+LEFT\s+JOIN\s+INFORMATION_SCHEMA\.CHECK_CONSTRAINTS\s+AS\s+c\s+ON\s+t\.CONSTRAINT_NAME\s*=\s*c\.CONSTRAINT_NAME\s+WHERE\s+t\.TABLE_SCHEMA\s*=\s*\?\s+AND\s+t\.TABLE_NAME\s*=\s*\?\s*ORDER\s+BY\s+k\.ORDINAL_POSITION`, + args: []driver.Value{"test_schema", "test_table"}, + cols: []string{"COLUMN_NAME", "CONSTRAINT_NAME", "CONSTRAINT_TYPE", "CHECK_CLAUSE"}, + rows: [][]driver.Value{{"id", "PRIMARY", "PRIMARY KEY", ""}, {"", "chk_test", "CHECK", "amount > 0"}}, + }, + } + + case2 := []mockSpec{ + { + query: `SELECT COUNT\(\*\) + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' + AND TABLE_NAME = 'CHECK_CONSTRAINTS'; + `, + cols: []string{"COUNT"}, + rows: [][]driver.Value{{0}}, + }, + { + query: `(?i)SELECT\s+k\.COLUMN_NAME,\s+t\.CONSTRAINT_TYPE\s+FROM\s+INFORMATION_SCHEMA\.TABLE_CONSTRAINTS\s+AS\s+t\s+INNER\s+JOIN\s+INFORMATION_SCHEMA\.KEY_COLUMN_USAGE\s+AS\s+k\s+ON\s+t\.CONSTRAINT_NAME\s*=\s*k\.CONSTRAINT_NAME\s+AND\s+t\.CONSTRAINT_SCHEMA\s*=\s*k\.CONSTRAINT_SCHEMA\s+AND\s+t\.TABLE_NAME\s*=\s*k\.TABLE_NAME\s+WHERE\s+k\.TABLE_SCHEMA\s*=\s*\?\s+AND\s+k\.TABLE_NAME\s*=\s*\?\s*ORDER\s+BY\s+k\.ORDINAL_POSITION;`, + args: []driver.Value{"test_schema", "test_table"}, + cols: []string{"COLUMN_NAME", "CONSTRAINT_NAME", "CONSTRAINT_TYPE", "CHECK_CLAUSE"}, + rows: [][]driver.Value{{"id", "PRIMARY", "PRIMARY KEY", ""}}, + }, + } + + cases := []struct { + db []mockSpec + tableExists bool + }{ + { + db: case1, + tableExists: true, + }, + { + db: case2, + tableExists: false, + }, + } + + for _, tc := range cases { + if tc.tableExists { + db := mkMockDB(t, tc.db) + + defer db.Close() + + isi := InfoSchemaImpl{Db: db} + + table := common.SchemaAndName{ + Schema: "test_schema", + Name: "test_table", + } + + conv := new(internal.Conv) + + primaryKeys, checkKeys, constraints, err := isi.GetConstraints(conv, table) + if err != nil { + t.Fatalf("expected no error, but got %v", err) + } + + expectedPrimaryKeys := []string{"id"} + if fmt.Sprintf("%v", primaryKeys) != fmt.Sprintf("%v", expectedPrimaryKeys) { + t.Errorf("expected %v, got %v for primary keys", expectedPrimaryKeys, primaryKeys) + } + + expectedCheckKeys := []schema.CheckConstraints{ + {Name: "chk_test", Expr: "amount > 0", Id: "ck1"}, + } + + assert.Equal(t, expectedCheckKeys, checkKeys) + assert.Equal(t, expectedPrimaryKeys, primaryKeys) + assert.Empty(t, constraints) + } else { + db := mkMockDB(t, tc.db) + + defer db.Close() + + isi := InfoSchemaImpl{Db: db} + + table := common.SchemaAndName{ + Schema: "test_schema", + Name: "test_table", + } + + conv := new(internal.Conv) + + primaryKeys, checkKeys, constraints, err := isi.GetConstraints(conv, table) + if err != nil { + t.Fatalf("expected no error, but got %v", err) + } + + expectedPrimaryKeys := []string{"id"} + if fmt.Sprintf("%v", primaryKeys) != fmt.Sprintf("%v", expectedPrimaryKeys) { + t.Errorf("expected %v, got %v for primary keys", expectedPrimaryKeys, primaryKeys) + } + + assert.Equal(t, expectedPrimaryKeys, primaryKeys) + assert.Empty(t, checkKeys) + assert.Empty(t, constraints) + } + } +} diff --git a/spanner/ddl/ast.go b/spanner/ddl/ast.go index 9490a89e3..d38db2105 100644 --- a/spanner/ddl/ast.go +++ b/spanner/ddl/ast.go @@ -264,6 +264,12 @@ type IndexKey struct { Order int } +type Checkconstraint struct { + Id string + Name string + Expr string +} + // PrintPkOrIndexKey unparses the primary or index keys. func (idx IndexKey) PrintPkOrIndexKey(ct CreateTable, c Config) string { col := c.quote(ct.ColDefs[idx.ColId].Name) @@ -318,16 +324,17 @@ func (k Foreignkey) PrintForeignKey(c Config) string { // // create_table: CREATE TABLE table_name ([column_def, ...] ) primary_key [, cluster] type CreateTable struct { - Name string - ColIds []string // Provides names and order of columns - ShardIdColumn string - ColDefs map[string]ColumnDef // Provides definition of columns (a map for simpler/faster lookup during type processing) - PrimaryKeys []IndexKey - ForeignKeys []Foreignkey - Indexes []CreateIndex - ParentTable InterleavedParent //if not empty, this table will be interleaved - Comment string - Id string + Name string + ColIds []string // Provides names and order of columns + ShardIdColumn string + ColDefs map[string]ColumnDef // Provides definition of columns (a map for simpler/faster lookup during type processing) + PrimaryKeys []IndexKey + ForeignKeys []Foreignkey + Indexes []CreateIndex + ParentTable InterleavedParent //if not empty, this table will be interleaved + CheckConstraint []Checkconstraint + Comment string + Id string } // PrintCreateTable unparses a CREATE TABLE statement. @@ -381,13 +388,20 @@ func (ct CreateTable) PrintCreateTable(spSchema Schema, config Config) string { } } + var checkString string + if len(ct.CheckConstraint) != 0 { + checkString = PrintCheckConstraintTable(ct.CheckConstraint) + } else { + checkString = "" + } + if len(keys) == 0 { - return fmt.Sprintf("%sCREATE TABLE %s (\n%s) %s", tableComment, config.quote(ct.Name), cols, interleave) + return fmt.Sprintf("%sCREATE TABLE %s (\n%s %s) %s", tableComment, config.quote(ct.Name), cols, checkString, interleave) } if config.SpDialect == constants.DIALECT_POSTGRESQL { return fmt.Sprintf("%sCREATE TABLE %s (\n%s\tPRIMARY KEY (%s)\n)%s", tableComment, config.quote(ct.Name), cols, strings.Join(keys, ", "), interleave) } - return fmt.Sprintf("%sCREATE TABLE %s (\n%s) PRIMARY KEY (%s)%s", tableComment, config.quote(ct.Name), cols, strings.Join(keys, ", "), interleave) + return fmt.Sprintf("%sCREATE TABLE %s (\n%s %s) PRIMARY KEY (%s)%s", tableComment, config.quote(ct.Name), cols, checkString, strings.Join(keys, ", "), interleave) } // CreateIndex encodes the following DDL definition: @@ -494,6 +508,23 @@ func (k Foreignkey) PrintForeignKeyAlterTable(spannerSchema Schema, c Config, ta return s } +// PrintCheckConstraintTable unparses the check constraints using CHECK CONSTRAINTS. +func PrintCheckConstraintTable(cks []Checkconstraint) string { + + var s string + s = "" + for index, col := range cks { + if index == len(cks)-1 { + s = s + fmt.Sprintf("\tCONSTRAINT %s CHECK %s\n", col.Name, col.Expr) + } else { + s = s + fmt.Sprintf("\tCONSTRAINT %s CHECK %s,\n", col.Name, col.Expr) + } + + } + + return s +} + // Schema stores a map of table names and Tables. type Schema map[string]CreateTable diff --git a/spanner/ddl/ast_test.go b/spanner/ddl/ast_test.go index d6e53f96b..f55c22ed7 100644 --- a/spanner/ddl/ast_test.go +++ b/spanner/ddl/ast_test.go @@ -143,6 +143,10 @@ func TestPrintCreateTable(t *testing.T) { }, PrimaryKeys: []IndexKey{{ColId: "col1", Desc: true}}, ForeignKeys: nil, + CheckConstraint: []Checkconstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + }, Indexes: nil, ParentTable: InterleavedParent{}, Comment: "", @@ -156,12 +160,13 @@ func TestPrintCreateTable(t *testing.T) { "col4": {Name: "col4", T: Type{Name: Int64}, NotNull: true}, "col5": {Name: "col5", T: Type{Name: String, Len: MaxLength}, NotNull: false}, }, - PrimaryKeys: []IndexKey{{ColId: "col4", Desc: true}}, - ForeignKeys: nil, - Indexes: nil, - ParentTable: InterleavedParent{Id: "t1", OnDelete: constants.FK_CASCADE}, - Comment: "", - Id: "t2", + PrimaryKeys: []IndexKey{{ColId: "col4", Desc: true}}, + ForeignKeys: nil, + Indexes: nil, + CheckConstraint: nil, + ParentTable: InterleavedParent{Id: "t1", OnDelete: constants.FK_CASCADE}, + Comment: "", + Id: "t2", }, "t3": CreateTable{ Name: "table3", @@ -170,12 +175,33 @@ func TestPrintCreateTable(t *testing.T) { ColDefs: map[string]ColumnDef{ "col6": {Name: "col6", T: Type{Name: Int64}, NotNull: true}, }, - PrimaryKeys: []IndexKey{{ColId: "col6", Desc: true}}, + PrimaryKeys: []IndexKey{{ColId: "col6", Desc: true}}, + ForeignKeys: nil, + Indexes: nil, + CheckConstraint: nil, + ParentTable: InterleavedParent{Id: "t1", OnDelete: ""}, + Comment: "", + Id: "t3", + }, + "t4": CreateTable{ + Name: "table1", + ColIds: []string{"col1", "col2", "col3"}, + ShardIdColumn: "", + ColDefs: map[string]ColumnDef{ + "col1": {Name: "col1", T: Type{Name: Int64}, NotNull: true}, + "col2": {Name: "col2", T: Type{Name: String, Len: MaxLength}, NotNull: false}, + "col3": {Name: "col3", T: Type{Name: Bytes, Len: int64(42)}, NotNull: false}, + }, + PrimaryKeys: nil, ForeignKeys: nil, + CheckConstraint: []Checkconstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + }, Indexes: nil, - ParentTable: InterleavedParent{Id: "t1", OnDelete: ""}, + ParentTable: InterleavedParent{}, Comment: "", - Id: "t3", + Id: "t1", }, } tests := []struct { @@ -191,7 +217,8 @@ func TestPrintCreateTable(t *testing.T) { "CREATE TABLE table1 (\n" + " col1 INT64 NOT NULL ,\n" + " col2 STRING(MAX),\n" + - " col3 BYTES(42),\n" + + " col3 BYTES(42),\n " + + "CONSTRAINT check_1 CHECK (age > 18),\nCONSTRAINT check_2 CHECK (age < 99)\n" + ") PRIMARY KEY (col1 DESC)", }, { @@ -201,7 +228,8 @@ func TestPrintCreateTable(t *testing.T) { "CREATE TABLE `table1` (\n" + " `col1` INT64 NOT NULL ,\n" + " `col2` STRING(MAX),\n" + - " `col3` BYTES(42),\n" + + " `col3` BYTES(42),\n " + + "CONSTRAINT check_1 CHECK (age > 18),\nCONSTRAINT check_2 CHECK (age < 99)\n" + ") PRIMARY KEY (`col1` DESC)", }, { @@ -223,6 +251,17 @@ func TestPrintCreateTable(t *testing.T) { ") PRIMARY KEY (col6 DESC),\n" + "INTERLEAVE IN PARENT table1", }, + { + "no quote", + false, + s["t4"], + "CREATE TABLE table1 (\n" + + " col1 INT64 NOT NULL ,\n" + + " col2 STRING(MAX),\n" + + " col3 BYTES(42),\n " + + "CONSTRAINT check_1 CHECK (age > 18),\nCONSTRAINT check_2 CHECK (age < 99)\n" + + ") ", + }, } for _, tc := range tests { assert.Equal(t, tc.expected, tc.ct.PrintCreateTable(s, Config{ProtectIds: tc.protectIds})) diff --git a/webv2/api/schema.go b/webv2/api/schema.go index f8b8ca3d3..c1ea6d2be 100644 --- a/webv2/api/schema.go +++ b/webv2/api/schema.go @@ -488,6 +488,109 @@ func RestoreSecondaryIndex(w http.ResponseWriter, r *http.Request) { } +// UpdateCheckConstraint processes the request to update spanner table check constraints, ensuring session and schema validity, and responds with the updated conversion metadata. +func UpdateCheckConstraint(w http.ResponseWriter, r *http.Request) { + tableId := r.FormValue("table") + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + } + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + + newCKs := []ddl.Checkconstraint{} + if err = json.Unmarshal(reqBody, &newCKs); err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + + sp := sessionState.Conv.SpSchema[tableId] + + sp.CheckConstraint = newCKs + + sessionState.Conv.SpSchema[tableId] = sp + + session.UpdateSessionFile() + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) + +} + +func doesNameExist(spcks []ddl.Checkconstraint, targetName string) bool { + for _, spck := range spcks { + if strings.Contains(spck.Expr, targetName) { + return true + } + } + return false +} + +// ValidateCheckConstraint verifies if the type of a database column has been altered and add an error if a change is detected. +func ValidateCheckConstraint(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + + sp := sessionState.Conv.SpSchema + srcschema := sessionState.Conv.SrcSchema + + flag := true + + schemaissue := []internal.SchemaIssue{} + + for _, src := range srcschema { + + for _, col := range sp[src.Id].ColDefs { + + if len(sp[src.Id].CheckConstraint) != 0 { + spType := col.T.Name + srcType := srcschema[src.Id].ColDefs[col.Id].Type + + actualType := mysqlDefaultTypeMap[srcType.Name] + + if actualType.Name != spType { + + columnName := sp[src.Id].ColDefs[col.Id].Name + spcks := sp[src.Id].CheckConstraint + if doesNameExist(spcks, columnName) { + flag = false + + schemaissue = sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] + + if !utilities.IsSchemaIssuePresent(schemaissue, internal.TypeMismatch) { + schemaissue = append(schemaissue, internal.TypeMismatch) + } + + sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] = schemaissue + + break + + } + } + } + + } + + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(flag) +} + // renameForeignKeys checks the new names for spanner name validity, ensures the new names are already not used by existing tables // secondary indexes or foreign key constraints. If above checks passed then foreignKey renaming reflected in the schema else appropriate // error thrown. diff --git a/webv2/api/schema_test.go b/webv2/api/schema_test.go index 9e7894ea8..c1ca74042 100644 --- a/webv2/api/schema_test.go +++ b/webv2/api/schema_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -2541,3 +2542,274 @@ func TestGetAutoGenMapMySQL(t *testing.T) { } } +func TestUpdateCheckConstraint(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + tableID := "table1" + + expectedCheckConstraint := []ddl.Checkconstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + } + + checkConstraints := []schema.CheckConstraints{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + } + + body, err := json.Marshal(checkConstraints) + assert.NoError(t, err) + + req, err := http.NewRequest("POST", "update/cks", bytes.NewBuffer(body)) + assert.NoError(t, err) + + q := req.URL.Query() + q.Add("table", tableID) + req.URL.RawQuery = q.Encode() + + rr := httptest.NewRecorder() + + handler := http.HandlerFunc(api.UpdateCheckConstraint) + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + updatedSp := sessionState.Conv.SpSchema[tableID] + + assert.Equal(t, expectedCheckConstraint, updatedSp.CheckConstraint) +} + +func TestUpdateCheckConstraint_ParseError(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + invalidJSON := "invalid json body" + + rr := httptest.NewRecorder() + req, err := http.NewRequest("POST", "update/cks", io.NopCloser(strings.NewReader(invalidJSON))) + assert.NoError(t, err) + + handler := http.HandlerFunc(api.UpdateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) + + expectedErrorMessage := "Request Body parse error" + assert.Contains(t, rr.Body.String(), expectedErrorMessage) +} + +func (errReader) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("simulated read error") +} + +func TestUpdateCheckConstraint_ImproperSession(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Conv = nil // Simulate no conversion + + rr := httptest.NewRecorder() + req, err := http.NewRequest("POST", "update/cks", io.NopCloser(errReader{})) + assert.NoError(t, err) + + handler := http.HandlerFunc(api.UpdateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), "Schema is not converted or Driver is not configured properly") + +} + +func TestValidateCheckConstraint_ImproperSession(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Conv = nil // Simulate no conversion + + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/validateCheckConstraint", nil) + + handler := http.HandlerFunc(api.ValidateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNotFound, rr.Code) + assert.Contains(t, rr.Body.String(), "Schema is not converted or Driver is not configured properly") + +} + +func TestValidateCheckConstraint_NoTypeMismatch(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + buildConvMySQL_NoTypeMatch(sessionState.Conv) + rr1 := httptest.NewRecorder() + req1 := httptest.NewRequest("GET", "/spannerDefaultTypeMap", nil) + + handler1 := http.HandlerFunc(api.SpannerDefaultTypeMap) + handler1.ServeHTTP(rr1, req1) + + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/validateCheckConstraint", nil) + + handler := http.HandlerFunc(api.ValidateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + var responseFlag bool + json.NewDecoder(rr.Body).Decode(&responseFlag) + assert.True(t, responseFlag) +} + +func TestValidateCheckConstraint_TypeMismatch(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + buildConvMySQL_TypeMatch(sessionState.Conv) + + rr1 := httptest.NewRecorder() + req1 := httptest.NewRequest("GET", "/spannerDefaultTypeMap", nil) + + handler1 := http.HandlerFunc(api.SpannerDefaultTypeMap) + handler1.ServeHTTP(rr1, req1) + + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/validateCheckConstraint", nil) + + handler := http.HandlerFunc(api.ValidateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + var responseFlag bool + json.NewDecoder(rr.Body).Decode(&responseFlag) + assert.False(t, responseFlag) + issues := sessionState.Conv.SchemaIssues["t1"].ColumnLevelIssues["c2"] + assert.Contains(t, issues, internal.TypeMismatch) +} + +func buildConvMySQL_NoTypeMatch(conv *internal.Conv) { + conv.SrcSchema = map[string]schema.Table{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraints: []schema.CheckConstraints{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]schema.Column{ + "c1": {Name: "a", Id: "c1", Type: schema.Type{Name: "json"}}, + "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "decimal"}}, + "c3": {Name: "c", Id: "c3", Type: schema.Type{Name: "datetime"}}, + }, + PrimaryKeys: []schema.Key{{ColId: "c1"}}}, + } + conv.SpSchema = map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraint: []ddl.Checkconstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.JSON}}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Numeric}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Timestamp}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + }, + } + + conv.SchemaIssues = map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "c1": {internal.Widened}, + "c2": {internal.Time}, + }, + }, + } + conv.SyntheticPKeys["t2"] = internal.SyntheticPKey{"c20", 0} + conv.Audit.MigrationType = migration.MigrationData_SCHEMA_AND_DATA.Enum() +} + +func buildConvMySQL_TypeMatch(conv *internal.Conv) { + conv.SrcSchema = map[string]schema.Table{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraints: []schema.CheckConstraints{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]schema.Column{ + "c1": {Name: "a", Id: "c1", Type: schema.Type{Name: "json"}}, + "c2": {Name: "age", Id: "c2", Type: schema.Type{Name: "decimal"}}, + "c3": {Name: "c", Id: "c3", Type: schema.Type{Name: "datetime"}}, + }, + PrimaryKeys: []schema.Key{{ColId: "c1"}}}, + } + conv.SpSchema = map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3"}, + CheckConstraint: []ddl.Checkconstraint{ + { + Id: "ck1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "ck1", + Name: "check_2", + Expr: "age < 99", + }, + }, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.JSON}}, + "c2": {Name: "age", Id: "c2", T: ddl.Type{Name: ddl.String}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Timestamp}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + }, + } + + conv.SchemaIssues = map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "c1": {internal.Widened}, + "c2": {internal.Time}, + }, + }, + } + conv.SyntheticPKeys["t2"] = internal.SyntheticPKey{"c20", 0} + conv.Audit.MigrationType = migration.MigrationData_SCHEMA_AND_DATA.Enum() +} diff --git a/webv2/routes.go b/webv2/routes.go index bdfb3419c..bcb02d980 100644 --- a/webv2/routes.go +++ b/webv2/routes.go @@ -45,7 +45,7 @@ func getRoutes() *mux.Router { } ctx := context.Background() - spClient, _:= spinstanceadmin.NewInstanceAdminClientImpl(ctx) + spClient, _ := spinstanceadmin.NewInstanceAdminClientImpl(ctx) dsClient, _ := ds.NewDatastreamClientImpl(ctx) storageclient, _ := storageclient.NewStorageClientImpl(ctx) validateResourceImpl := conversion.NewValidateResourcesImpl(&spanneraccessor.SpannerAccessorImpl{}, spClient, &datastream_accessor.DatastreamAccessorImpl{}, @@ -76,6 +76,7 @@ func getRoutes() *mux.Router { router.HandleFunc("/spannerDefaultTypeMap", api.SpannerDefaultTypeMap).Methods("GET") router.HandleFunc("/autoGenMap", api.GetAutoGenMap).Methods("GET") router.HandleFunc("/getSequenceKind", api.GetSequenceKind).Methods("GET") + router.HandleFunc("/validateCheckConstraint", api.ValidateCheckConstraint).Methods("GET") router.HandleFunc("/setparent", api.SetParentTable).Methods("GET") router.HandleFunc("/removeParent", api.RemoveParentTable).Methods("POST") @@ -93,6 +94,7 @@ func getRoutes() *mux.Router { router.HandleFunc("/UpdateSequence", api.UpdateSequence).Methods("POST") router.HandleFunc("/update/fks", api.UpdateForeignKeys).Methods("POST") + router.HandleFunc("/update/cks", api.UpdateCheckConstraint).Methods("POST") router.HandleFunc("/update/indexes", api.UpdateIndexes).Methods("POST") // Session Management diff --git a/webv2/table/review_table_schema.go b/webv2/table/review_table_schema.go index 0ad6fb0ec..5a5d7c31b 100644 --- a/webv2/table/review_table_schema.go +++ b/webv2/table/review_table_schema.go @@ -108,6 +108,13 @@ func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { return } } + oldName := conv.SrcSchema[tableId].ColDefs[colId].Name + + for i := range conv.SpSchema[tableId].CheckConstraint { + originalString := conv.SpSchema[tableId].CheckConstraint[i].Expr + updatedValue := strings.ReplaceAll(originalString, oldName, v.Rename) + conv.SpSchema[tableId].CheckConstraint[i].Expr = updatedValue + } interleaveTableSchema = reviewRenameColumn(v.Rename, tableId, colId, conv, interleaveTableSchema) @@ -148,7 +155,7 @@ func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { } } - if !v.Removed && !v.Add && v.Rename== ""{ + if !v.Removed && !v.Add && v.Rename == "" { sequences := UpdateAutoGenCol(v.AutoGen, tableId, colId, conv) conv.SpSequences = sequences } diff --git a/webv2/table/update_table_schema.go b/webv2/table/update_table_schema.go index 6e03c709c..76eaeca3a 100644 --- a/webv2/table/update_table_schema.go +++ b/webv2/table/update_table_schema.go @@ -19,6 +19,7 @@ import ( "fmt" "io/ioutil" "net/http" + "strings" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" @@ -55,6 +56,7 @@ type updateTable struct { // (3) Rename column. // (4) Add or Remove NotNull constraint. // (5) Update Spanner type. +// (6) Update Check constraints Name. func UpdateTableSchema(w http.ResponseWriter, r *http.Request) { reqBody, err := ioutil.ReadAll(r.Body) @@ -96,6 +98,14 @@ func UpdateTableSchema(w http.ResponseWriter, r *http.Request) { if v.Rename != "" && v.Rename != conv.SpSchema[tableId].ColDefs[colId].Name { + oldName := conv.SrcSchema[tableId].ColDefs[colId].Name + + for i := range conv.SpSchema[tableId].CheckConstraint { + originalString := conv.SpSchema[tableId].CheckConstraint[i].Expr + updatedValue := strings.ReplaceAll(originalString, oldName, v.Rename) + conv.SpSchema[tableId].CheckConstraint[i].Expr = updatedValue + } + renameColumn(v.Rename, tableId, colId, conv) } From 2fe71f42ffd954652e233043a355a68cca18e299 Mon Sep 17 00:00:00 2001 From: Vivek Yadav Date: Fri, 22 Nov 2024 18:15:13 +0530 Subject: [PATCH 2/3] rename the constraint name --- internal/mapping.go | 2 +- sources/common/toddl.go | 6 +++--- sources/common/toddl_test.go | 2 +- spanner/ddl/ast.go | 6 +++--- spanner/ddl/ast_test.go | 4 ++-- webv2/api/schema.go | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/mapping.go b/internal/mapping.go index 1b955bdba..93f9f7f37 100644 --- a/internal/mapping.go +++ b/internal/mapping.go @@ -251,7 +251,7 @@ func ToSpannerCheckConstraintName(conv *Conv, srcCheckConstraintName string) str return getSpannerValidName(conv, srcCheckConstraintName) } -func GetSpannerValidExpression(cks []ddl.Checkconstraint) []ddl.Checkconstraint { +func GetSpannerValidExpression(cks []ddl.CheckConstraint) []ddl.CheckConstraint { // TODO validate the check constraints data with batch verification then send back return cks } diff --git a/sources/common/toddl.go b/sources/common/toddl.go index bf4379288..67b313521 100644 --- a/sources/common/toddl.go +++ b/sources/common/toddl.go @@ -235,11 +235,11 @@ func cvtForeignKeys(conv *internal.Conv, spTableName string, srcTableId string, return spKeys } -func cvtCheckConstraint(conv *internal.Conv, srcKeys []schema.CheckConstraints) []ddl.Checkconstraint { - var spcks []ddl.Checkconstraint +func cvtCheckConstraint(conv *internal.Conv, srcKeys []schema.CheckConstraints) []ddl.CheckConstraint { + var spcks []ddl.CheckConstraint for _, cks := range srcKeys { - spcks = append(spcks, ddl.Checkconstraint{ + spcks = append(spcks, ddl.CheckConstraint{ Id: cks.Id, Name: internal.ToSpannerCheckConstraintName(conv, cks.Name), Expr: cks.Expr, diff --git a/sources/common/toddl_test.go b/sources/common/toddl_test.go index fd1504a84..74ac72202 100644 --- a/sources/common/toddl_test.go +++ b/sources/common/toddl_test.go @@ -443,7 +443,7 @@ func Test_cvtCheckContraint(t *testing.T) { Expr: "age < 99", }, } - spSchema := []ddl.Checkconstraint{ + spSchema := []ddl.CheckConstraint{ { Id: "ck1", Name: "check_1", diff --git a/spanner/ddl/ast.go b/spanner/ddl/ast.go index d38db2105..366222bc4 100644 --- a/spanner/ddl/ast.go +++ b/spanner/ddl/ast.go @@ -264,7 +264,7 @@ type IndexKey struct { Order int } -type Checkconstraint struct { +type CheckConstraint struct { Id string Name string Expr string @@ -332,7 +332,7 @@ type CreateTable struct { ForeignKeys []Foreignkey Indexes []CreateIndex ParentTable InterleavedParent //if not empty, this table will be interleaved - CheckConstraint []Checkconstraint + CheckConstraint []CheckConstraint Comment string Id string } @@ -509,7 +509,7 @@ func (k Foreignkey) PrintForeignKeyAlterTable(spannerSchema Schema, c Config, ta } // PrintCheckConstraintTable unparses the check constraints using CHECK CONSTRAINTS. -func PrintCheckConstraintTable(cks []Checkconstraint) string { +func PrintCheckConstraintTable(cks []CheckConstraint) string { var s string s = "" diff --git a/spanner/ddl/ast_test.go b/spanner/ddl/ast_test.go index f55c22ed7..e424354b4 100644 --- a/spanner/ddl/ast_test.go +++ b/spanner/ddl/ast_test.go @@ -143,7 +143,7 @@ func TestPrintCreateTable(t *testing.T) { }, PrimaryKeys: []IndexKey{{ColId: "col1", Desc: true}}, ForeignKeys: nil, - CheckConstraint: []Checkconstraint{ + CheckConstraint: []CheckConstraint{ {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, }, @@ -194,7 +194,7 @@ func TestPrintCreateTable(t *testing.T) { }, PrimaryKeys: nil, ForeignKeys: nil, - CheckConstraint: []Checkconstraint{ + CheckConstraint: []CheckConstraint{ {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, }, diff --git a/webv2/api/schema.go b/webv2/api/schema.go index c1ea6d2be..08ceca5ec 100644 --- a/webv2/api/schema.go +++ b/webv2/api/schema.go @@ -503,7 +503,7 @@ func UpdateCheckConstraint(w http.ResponseWriter, r *http.Request) { sessionState.Conv.ConvLock.Lock() defer sessionState.Conv.ConvLock.Unlock() - newCKs := []ddl.Checkconstraint{} + newCKs := []ddl.CheckConstraint{} if err = json.Unmarshal(reqBody, &newCKs); err != nil { http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) return @@ -526,7 +526,7 @@ func UpdateCheckConstraint(w http.ResponseWriter, r *http.Request) { } -func doesNameExist(spcks []ddl.Checkconstraint, targetName string) bool { +func doesNameExist(spcks []ddl.CheckConstraint, targetName string) bool { for _, spck := range spcks { if strings.Contains(spck.Expr, targetName) { return true From 32f40ddfa0413fe627e61aae439469bca1288e9f Mon Sep 17 00:00:00 2001 From: taher-cc Date: Wed, 27 Nov 2024 17:52:06 +0530 Subject: [PATCH 3/3] Code cleanup --- schema/schema.go | 6 +- sources/common/infoschema.go | 2 +- sources/common/toddl.go | 20 ++-- sources/common/toddl_test.go | 3 +- sources/dynamodb/schema.go | 6 +- sources/dynamodb/schema_test.go | 4 +- sources/mysql/infoschema.go | 180 ++++++++++++----------------- sources/mysql/infoschema_test.go | 3 +- sources/oracle/infoschema.go | 6 +- sources/postgres/infoschema.go | 6 +- sources/spanner/infoschema.go | 8 +- sources/sqlserver/infoschema.go | 6 +- spanner/ddl/ast.go | 50 ++++---- spanner/ddl/ast_test.go | 80 +++++++++---- webv2/api/schema.go | 34 ++---- webv2/api/schema_test.go | 16 +-- webv2/table/review_table_schema.go | 6 +- webv2/table/update_table_schema.go | 6 +- 18 files changed, 218 insertions(+), 224 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 6bade2a4e..4b18ca777 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -42,7 +42,7 @@ type Table struct { ColNameIdMap map[string]string `json:"-"` // Computed every time just after conv is generated or after any column renaming PrimaryKeys []Key ForeignKeys []ForeignKey - CheckConstraints []CheckConstraints + CheckConstraints []CheckConstraint Indexes []Index Id string } @@ -77,8 +77,8 @@ type ForeignKey struct { Id string } -// CheckConstraints represents a Check Constrainst. -type CheckConstraints struct { +// CheckConstraints represents a check constraint defined in the schema. +type CheckConstraint struct { Name string Expr string Id string diff --git a/sources/common/infoschema.go b/sources/common/infoschema.go index da2612895..998d61e3b 100644 --- a/sources/common/infoschema.go +++ b/sources/common/infoschema.go @@ -36,7 +36,7 @@ type InfoSchema interface { GetColumns(conv *internal.Conv, table SchemaAndName, constraints map[string][]string, primaryKeys []string) (map[string]schema.Column, []string, error) GetRowsFromTable(conv *internal.Conv, srcTable string) (interface{}, error) GetRowCount(table SchemaAndName) (int64, error) - GetConstraints(conv *internal.Conv, table SchemaAndName) ([]string, []schema.CheckConstraints, map[string][]string, error) + GetConstraints(conv *internal.Conv, table SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) GetForeignKeys(conv *internal.Conv, table SchemaAndName) (foreignKeys []schema.ForeignKey, err error) GetIndexes(conv *internal.Conv, table SchemaAndName, colNameIdMp map[string]string) ([]schema.Index, error) ProcessData(conv *internal.Conv, tableId string, srcSchema schema.Table, spCols []string, spSchema ddl.CreateTable, additionalAttributes internal.AdditionalDataAttributes) error diff --git a/sources/common/toddl.go b/sources/common/toddl.go index 67b313521..74e0c7c50 100644 --- a/sources/common/toddl.go +++ b/sources/common/toddl.go @@ -167,15 +167,15 @@ func (ss *SchemaToSpannerImpl) SchemaToSpannerDDLHelper(conv *internal.Conv, tod } comment := "Spanner schema for source table " + quoteIfNeeded(srcTable.Name) conv.SpSchema[srcTable.Id] = ddl.CreateTable{ - Name: spTableName, - ColIds: spColIds, - ColDefs: spColDef, - PrimaryKeys: cvtPrimaryKeys(srcTable.PrimaryKeys), - ForeignKeys: cvtForeignKeys(conv, spTableName, srcTable.Id, srcTable.ForeignKeys, isRestore), - CheckConstraint: cvtCheckConstraint(conv, srcTable.CheckConstraints), - Indexes: cvtIndexes(conv, srcTable.Id, srcTable.Indexes, spColIds, spColDef), - Comment: comment, - Id: srcTable.Id} + Name: spTableName, + ColIds: spColIds, + ColDefs: spColDef, + PrimaryKeys: cvtPrimaryKeys(srcTable.PrimaryKeys), + ForeignKeys: cvtForeignKeys(conv, spTableName, srcTable.Id, srcTable.ForeignKeys, isRestore), + CheckConstraints: cvtCheckConstraint(conv, srcTable.CheckConstraints), + Indexes: cvtIndexes(conv, srcTable.Id, srcTable.Indexes, spColIds, spColDef), + Comment: comment, + Id: srcTable.Id} return nil } @@ -235,7 +235,7 @@ func cvtForeignKeys(conv *internal.Conv, spTableName string, srcTableId string, return spKeys } -func cvtCheckConstraint(conv *internal.Conv, srcKeys []schema.CheckConstraints) []ddl.CheckConstraint { +func cvtCheckConstraint(conv *internal.Conv, srcKeys []schema.CheckConstraint) []ddl.CheckConstraint { var spcks []ddl.CheckConstraint for _, cks := range srcKeys { diff --git a/sources/common/toddl_test.go b/sources/common/toddl_test.go index 74ac72202..fee535c5e 100644 --- a/sources/common/toddl_test.go +++ b/sources/common/toddl_test.go @@ -428,10 +428,11 @@ func Test_SchemaToSpannerSequenceHelper(t *testing.T) { assert.Equal(t, expectedConv, conv) } } + func Test_cvtCheckContraint(t *testing.T) { conv := internal.MakeConv() - srcSchema := []schema.CheckConstraints{ + srcSchema := []schema.CheckConstraint{ { Id: "ck1", Name: "check_1", diff --git a/sources/dynamodb/schema.go b/sources/dynamodb/schema.go index 4bc38f0ea..4af603988 100644 --- a/sources/dynamodb/schema.go +++ b/sources/dynamodb/schema.go @@ -129,20 +129,20 @@ func (isi InfoSchemaImpl) GetRowCount(table common.SchemaAndName) (int64, error) return *result.Table.ItemCount, err } -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) (primaryKeys []string, constraints map[string][]string, err error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) (primaryKeys []string, checkConstraints []schema.CheckConstraint, constraints map[string][]string, err error) { input := &dynamodb.DescribeTableInput{ TableName: aws.String(table.Name), } result, err := isi.DynamoClient.DescribeTable(input) if err != nil { - return primaryKeys, constraints, fmt.Errorf("failed to make a DescribeTable API call for table %v: %v", table.Name, err) + return primaryKeys, checkConstraints, constraints, fmt.Errorf("failed to make a DescribeTable API call for table %v: %v", table.Name, err) } // Primary keys. for _, i := range result.Table.KeySchema { primaryKeys = append(primaryKeys, *i.AttributeName) } - return primaryKeys, constraints, nil + return primaryKeys, checkConstraints, constraints, nil } func (isi InfoSchemaImpl) GetForeignKeys(conv *internal.Conv, table common.SchemaAndName) (foreignKeys []schema.ForeignKey, err error) { diff --git a/sources/dynamodb/schema_test.go b/sources/dynamodb/schema_test.go index b9d10ab3a..9919d3d5a 100644 --- a/sources/dynamodb/schema_test.go +++ b/sources/dynamodb/schema_test.go @@ -633,7 +633,7 @@ func TestInfoSchemaImpl_GetConstraints(t *testing.T) { dySchema := common.SchemaAndName{Name: "test"} conv := internal.MakeConv() isi := InfoSchemaImpl{client, nil, 10} - primaryKeys, constraints, err := isi.GetConstraints(conv, dySchema) + primaryKeys, _, constraints, err := isi.GetConstraints(conv, dySchema) assert.Nil(t, err) pKeys := []string{"a", "b"} @@ -705,7 +705,7 @@ func TestInfoSchemaImpl_GetColumns(t *testing.T) { client := &mockDynamoClient{ scanOutputs: scanOutputs, } - dySchema := common.SchemaAndName{Name: "test", Id: "t1"} + dySchema := common.SchemaAndName{Name: "test", Id: "t1"} isi := InfoSchemaImpl{client, nil, 10} diff --git a/sources/mysql/infoschema.go b/sources/mysql/infoschema.go index 7f7a69028..3a28b186a 100644 --- a/sources/mysql/infoschema.go +++ b/sources/mysql/infoschema.go @@ -194,16 +194,6 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd continue } ignored := schema.Ignored{} - for _, c := range constraints[colName] { - // c can be UNIQUE, PRIMARY KEY, FOREIGN KEY or CHECK - // We've already filtered out PRIMARY KEY. - switch c { - case "CHECK": - - case "FOREIGN KEY", "PRIMARY KEY", "UNIQUE": - // Nothing to do here -- these are all handled elsewhere. - } - } ignored.Default = colDefault.Valid colId := internal.GenerateColumnId() if colExtra.String == "auto_increment" { @@ -237,118 +227,100 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraints, map[string][]string, error) { - q := `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE - FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t - INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k - ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA AND t.TABLE_NAME=k.TABLE_NAME - WHERE k.TABLE_SCHEMA = ? AND k.TABLE_NAME = ? ORDER BY k.ordinal_position;` - - q1 := `SELECT - COALESCE(k.COLUMN_NAME, '') AS COLUMN_NAME, - t.CONSTRAINT_NAME, - t.CONSTRAINT_TYPE, - COALESCE(c.CHECK_CLAUSE, '') AS CHECK_CLAUSE - FROM - INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t - LEFT JOIN - INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k - ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME - AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA - AND t.TABLE_NAME = k.TABLE_NAME - LEFT JOIN - INFORMATION_SCHEMA.CHECK_CONSTRAINTS AS c - ON t.CONSTRAINT_NAME = c.CONSTRAINT_NAME - WHERE - t.TABLE_SCHEMA = ? - AND t.TABLE_NAME = ? - ORDER BY k.ORDINAL_POSITION; - ` - checkQuery := `SELECT COUNT(*) - FROM INFORMATION_SCHEMA.TABLES - WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' - AND TABLE_NAME = 'CHECK_CONSTRAINTS';` - var tableExistsCount int - rows1, err := isi.Db.Query(checkQuery) +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { + tableExists, err := isi.isCheckConstraintsTablePresent() if err != nil { return nil, nil, nil, err } - for rows1.Next() { - err1 := rows1.Scan(&tableExistsCount) - if err1 != nil { - conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) - return nil, nil, nil, err - } - } - - defer rows1.Close() - - tableExists := tableExistsCount > 0 - - var finalQuery string - if tableExists { - finalQuery = q1 - } else { - finalQuery = q - } + finalQuery := isi.getQuery(tableExists) rows, err := isi.Db.Query(finalQuery, table.Schema, table.Name) - if err != nil { return nil, nil, nil, err } defer rows.Close() + var primaryKeys []string - var checkKeys []schema.CheckConstraints - var col, constraintName, constraint, checkClause string + var checkKeys []schema.CheckConstraint m := make(map[string][]string) - for rows.Next() { - if tableExists { - err := rows.Scan(&col, &constraintName, &constraint, &checkClause) - if err != nil { - conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) - continue - } - } else { - err := rows.Scan(&col, &constraintName, &constraint, &checkClause) - if err != nil { - conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) - continue - } - } - if col == "" || constraint == "" { - - if tableExists { - if constraintName == "" || checkClause == "" { - conv.Unexpected(fmt.Sprintf("Got empty constraintName or checkClause")) - continue - } - switch constraint { - case "CHECK": - checkClause = strings.ReplaceAll(checkClause, "_utf8mb4\\", "") - checkClause = strings.ReplaceAll(checkClause, "\\", "") - - checkKeys = append(checkKeys, schema.CheckConstraints{Name: constraintName, Expr: string(checkClause), Id: internal.GenerateCheckConstrainstId()}) - default: - m[col] = append(m[col], constraint) - } - } else { - conv.Unexpected(fmt.Sprintf("Got empty col or constraint")) - } + for rows.Next() { + if err := isi.processRow(rows, tableExists, conv, &primaryKeys, &checkKeys, m); err != nil { continue - - } - switch constraint { - case "PRIMARY KEY": - primaryKeys = append(primaryKeys, col) - default: - m[col] = append(m[col], constraint) } } + return primaryKeys, checkKeys, m, nil } +// checkCheckConstraintsTableExists checks if the CHECK_CONSTRAINTS table exists. +func (isi InfoSchemaImpl) isCheckConstraintsTablePresent() (bool, error) { + var tableExistsCount int + checkQuery := `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';` + err := isi.Db.QueryRow(checkQuery).Scan(&tableExistsCount) + if err != nil { + return false, err + } + return tableExistsCount > 0, nil +} + +// getQuery returns the appropriate SQL query based on the existence of CHECK_CONSTRAINTS. +func (isi InfoSchemaImpl) getQuery(tableExists bool) string { + if tableExists { + return `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE, COALESCE(c.CHECK_CLAUSE, '') AS CHECK_CLAUSE + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t + LEFT JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k + ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME + AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA + AND t.TABLE_NAME = k.TABLE_NAME + LEFT JOIN INFORMATION_SCHEMA.CHECK_CONSTRAINTS AS c + ON t.CONSTRAINT_NAME = c.CONSTRAINT_NAME + WHERE t.TABLE_SCHEMA = ? + AND t.TABLE_NAME = ? + ORDER BY k.ORDINAL_POSITION;` + } + return `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t + INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k + ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME + AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA + AND t.TABLE_NAME = k.TABLE_NAME + WHERE t.TABLE_SCHEMA = ? + AND t.TABLE_NAME = ? + ORDER BY k.ORDINAL_POSITION;` +} + +// processRow handles scanning and processing of a database row for GetConstraints. +func (isi InfoSchemaImpl) processRow( + rows *sql.Rows, tableExists bool, conv *internal.Conv, primaryKeys *[]string, + checkKeys *[]schema.CheckConstraint, m map[string][]string) error { + + var col, constraintType, checkClause string + err := rows.Scan(&col, &constraintType, &checkClause) + if err != nil { + conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) + return err + } + + if col == "" || constraintType == "" { + conv.Unexpected("Got empty column or constraint type") + return nil + } + + switch constraintType { + case "PRIMARY KEY": + *primaryKeys = append(*primaryKeys, col) + case "CHECK": + checkClause = strings.ReplaceAll(checkClause, "_utf8mb4\\", "") + checkClause = strings.ReplaceAll(checkClause, "\\", "") + constraintName := fmt.Sprintf("%s_check", col) + *checkKeys = append(*checkKeys, schema.CheckConstraint{Name: constraintName, Expr: checkClause, Id: internal.GenerateCheckConstrainstId()}) + default: + m[col] = append(m[col], constraintType) + } + return nil +} + // GetForeignKeys return list all the foreign keys constraints. // MySQL supports cross-database foreign key constraints. We ignore // them because the Spanner migration tool works database at a time (a specific run diff --git a/sources/mysql/infoschema_test.go b/sources/mysql/infoschema_test.go index ffc7e460b..379ae4067 100644 --- a/sources/mysql/infoschema_test.go +++ b/sources/mysql/infoschema_test.go @@ -531,6 +531,7 @@ func mkMockDB(t *testing.T, ms []mockSpec) *sql.DB { } return db } + func TestGetConstraints(t *testing.T) { case1 := []mockSpec{ @@ -608,7 +609,7 @@ func TestGetConstraints(t *testing.T) { t.Errorf("expected %v, got %v for primary keys", expectedPrimaryKeys, primaryKeys) } - expectedCheckKeys := []schema.CheckConstraints{ + expectedCheckKeys := []schema.CheckConstraint{ {Name: "chk_test", Expr: "amount > 0", Id: "ck1"}, } diff --git a/sources/oracle/infoschema.go b/sources/oracle/infoschema.go index f62bfb270..770ef93cc 100644 --- a/sources/oracle/infoschema.go +++ b/sources/oracle/infoschema.go @@ -247,7 +247,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := fmt.Sprintf(` SELECT k.column_name, @@ -260,7 +260,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem `, table.Schema, table.Name) rows, err := isi.Db.Query(q) if err != nil { - return nil, nil, err + return nil, nil, nil, err } defer rows.Close() var primaryKeys []string @@ -292,7 +292,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys return list all the foreign keys constraints. diff --git a/sources/postgres/infoschema.go b/sources/postgres/infoschema.go index ccf7b8dcf..e4d278b0b 100644 --- a/sources/postgres/infoschema.go +++ b/sources/postgres/infoschema.go @@ -337,7 +337,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k @@ -345,7 +345,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem WHERE k.TABLE_SCHEMA = $1 AND k.TABLE_NAME = $2 ORDER BY k.ordinal_position;` rows, err := isi.Db.Query(q, table.Schema, table.Name) if err != nil { - return nil, nil, err + return nil, nil, nil, err } defer rows.Close() var primaryKeys []string @@ -368,7 +368,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys returns a list of all the foreign key constraints. diff --git a/sources/spanner/infoschema.go b/sources/spanner/infoschema.go index ca08985bc..c1bf36b4d 100644 --- a/sources/spanner/infoschema.go +++ b/sources/spanner/infoschema.go @@ -190,7 +190,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := `SELECT k.column_name, t.constraint_type FROM information_schema.table_constraints AS t INNER JOIN information_schema.KEY_COLUMN_USAGE AS k @@ -221,11 +221,11 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem break } if err != nil { - return nil, nil, fmt.Errorf("couldn't get row while reading constraints: %w", err) + return nil, nil, nil, fmt.Errorf("couldn't get row while reading constraints: %w", err) } err = row.Columns(&col, &constraint) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if col == "" || constraint == "" { conv.Unexpected("Got empty col or constraint") @@ -238,7 +238,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys returns a list of all the foreign key constraints. diff --git a/sources/sqlserver/infoschema.go b/sources/sqlserver/infoschema.go index 50f46d157..b9c98c544 100644 --- a/sources/sqlserver/infoschema.go +++ b/sources/sqlserver/infoschema.go @@ -280,7 +280,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := ` SELECT k.COLUMN_NAME, @@ -292,7 +292,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem ` rows, err := isi.Db.Query(q, table.Schema, table.Name) if err != nil { - return nil, nil, err + return nil, nil, nil, err } defer rows.Close() var primaryKeys []string @@ -315,7 +315,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys returns a list of all the foreign key constraints. diff --git a/spanner/ddl/ast.go b/spanner/ddl/ast.go index 366222bc4..d55d35595 100644 --- a/spanner/ddl/ast.go +++ b/spanner/ddl/ast.go @@ -324,17 +324,17 @@ func (k Foreignkey) PrintForeignKey(c Config) string { // // create_table: CREATE TABLE table_name ([column_def, ...] ) primary_key [, cluster] type CreateTable struct { - Name string - ColIds []string // Provides names and order of columns - ShardIdColumn string - ColDefs map[string]ColumnDef // Provides definition of columns (a map for simpler/faster lookup during type processing) - PrimaryKeys []IndexKey - ForeignKeys []Foreignkey - Indexes []CreateIndex - ParentTable InterleavedParent //if not empty, this table will be interleaved - CheckConstraint []CheckConstraint - Comment string - Id string + Name string + ColIds []string // Provides names and order of columns + ShardIdColumn string + ColDefs map[string]ColumnDef // Provides definition of columns (a map for simpler/faster lookup during type processing) + PrimaryKeys []IndexKey + ForeignKeys []Foreignkey + Indexes []CreateIndex + ParentTable InterleavedParent //if not empty, this table will be interleaved + CheckConstraints []CheckConstraint + Comment string + Id string } // PrintCreateTable unparses a CREATE TABLE statement. @@ -389,19 +389,19 @@ func (ct CreateTable) PrintCreateTable(spSchema Schema, config Config) string { } var checkString string - if len(ct.CheckConstraint) != 0 { - checkString = PrintCheckConstraintTable(ct.CheckConstraint) + if len(ct.CheckConstraints) > 0 { + checkString = PrintCheckConstraintTable(ct.CheckConstraints) } else { checkString = "" } if len(keys) == 0 { - return fmt.Sprintf("%sCREATE TABLE %s (\n%s %s) %s", tableComment, config.quote(ct.Name), cols, checkString, interleave) + return fmt.Sprintf("%sCREATE TABLE %s (\n%s%s) %s", tableComment, config.quote(ct.Name), cols, checkString, interleave) } if config.SpDialect == constants.DIALECT_POSTGRESQL { return fmt.Sprintf("%sCREATE TABLE %s (\n%s\tPRIMARY KEY (%s)\n)%s", tableComment, config.quote(ct.Name), cols, strings.Join(keys, ", "), interleave) } - return fmt.Sprintf("%sCREATE TABLE %s (\n%s %s) PRIMARY KEY (%s)%s", tableComment, config.quote(ct.Name), cols, checkString, strings.Join(keys, ", "), interleave) + return fmt.Sprintf("%sCREATE TABLE %s (\n%s%s) PRIMARY KEY (%s)%s", tableComment, config.quote(ct.Name), cols, checkString, strings.Join(keys, ", "), interleave) } // CreateIndex encodes the following DDL definition: @@ -508,21 +508,21 @@ func (k Foreignkey) PrintForeignKeyAlterTable(spannerSchema Schema, c Config, ta return s } -// PrintCheckConstraintTable unparses the check constraints using CHECK CONSTRAINTS. +// PrintCheckConstraintTable formats the check constraints in SQL syntax. func PrintCheckConstraintTable(cks []CheckConstraint) string { + var builder strings.Builder - var s string - s = "" - for index, col := range cks { - if index == len(cks)-1 { - s = s + fmt.Sprintf("\tCONSTRAINT %s CHECK %s\n", col.Name, col.Expr) - } else { - s = s + fmt.Sprintf("\tCONSTRAINT %s CHECK %s,\n", col.Name, col.Expr) - } + for _, col := range cks { + builder.WriteString(fmt.Sprintf("\tCONSTRAINT %s CHECK %s,\n", col.Name, col.Expr)) + } + if builder.Len() > 0 { + // Trim the trailing comma and newline + result := builder.String() + return result[:len(result)-2] + "\n" } - return s + return "" } // Schema stores a map of table names and Tables. diff --git a/spanner/ddl/ast_test.go b/spanner/ddl/ast_test.go index e424354b4..7b8435249 100644 --- a/spanner/ddl/ast_test.go +++ b/spanner/ddl/ast_test.go @@ -143,7 +143,7 @@ func TestPrintCreateTable(t *testing.T) { }, PrimaryKeys: []IndexKey{{ColId: "col1", Desc: true}}, ForeignKeys: nil, - CheckConstraint: []CheckConstraint{ + CheckConstraints: []CheckConstraint{ {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, }, @@ -160,13 +160,13 @@ func TestPrintCreateTable(t *testing.T) { "col4": {Name: "col4", T: Type{Name: Int64}, NotNull: true}, "col5": {Name: "col5", T: Type{Name: String, Len: MaxLength}, NotNull: false}, }, - PrimaryKeys: []IndexKey{{ColId: "col4", Desc: true}}, - ForeignKeys: nil, - Indexes: nil, - CheckConstraint: nil, - ParentTable: InterleavedParent{Id: "t1", OnDelete: constants.FK_CASCADE}, - Comment: "", - Id: "t2", + PrimaryKeys: []IndexKey{{ColId: "col4", Desc: true}}, + ForeignKeys: nil, + Indexes: nil, + CheckConstraints: nil, + ParentTable: InterleavedParent{Id: "t1", OnDelete: constants.FK_CASCADE}, + Comment: "", + Id: "t2", }, "t3": CreateTable{ Name: "table3", @@ -175,13 +175,13 @@ func TestPrintCreateTable(t *testing.T) { ColDefs: map[string]ColumnDef{ "col6": {Name: "col6", T: Type{Name: Int64}, NotNull: true}, }, - PrimaryKeys: []IndexKey{{ColId: "col6", Desc: true}}, - ForeignKeys: nil, - Indexes: nil, - CheckConstraint: nil, - ParentTable: InterleavedParent{Id: "t1", OnDelete: ""}, - Comment: "", - Id: "t3", + PrimaryKeys: []IndexKey{{ColId: "col6", Desc: true}}, + ForeignKeys: nil, + Indexes: nil, + CheckConstraints: nil, + ParentTable: InterleavedParent{Id: "t1", OnDelete: ""}, + Comment: "", + Id: "t3", }, "t4": CreateTable{ Name: "table1", @@ -194,7 +194,7 @@ func TestPrintCreateTable(t *testing.T) { }, PrimaryKeys: nil, ForeignKeys: nil, - CheckConstraint: []CheckConstraint{ + CheckConstraints: []CheckConstraint{ {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, }, @@ -217,8 +217,8 @@ func TestPrintCreateTable(t *testing.T) { "CREATE TABLE table1 (\n" + " col1 INT64 NOT NULL ,\n" + " col2 STRING(MAX),\n" + - " col3 BYTES(42),\n " + - "CONSTRAINT check_1 CHECK (age > 18),\nCONSTRAINT check_2 CHECK (age < 99)\n" + + " col3 BYTES(42),\n" + + "\tCONSTRAINT check_1 CHECK (age > 18),\n\tCONSTRAINT check_2 CHECK (age < 99)\n" + ") PRIMARY KEY (col1 DESC)", }, { @@ -228,8 +228,8 @@ func TestPrintCreateTable(t *testing.T) { "CREATE TABLE `table1` (\n" + " `col1` INT64 NOT NULL ,\n" + " `col2` STRING(MAX),\n" + - " `col3` BYTES(42),\n " + - "CONSTRAINT check_1 CHECK (age > 18),\nCONSTRAINT check_2 CHECK (age < 99)\n" + + " `col3` BYTES(42),\n" + + "\tCONSTRAINT check_1 CHECK (age > 18),\n\tCONSTRAINT check_2 CHECK (age < 99)\n" + ") PRIMARY KEY (`col1` DESC)", }, { @@ -258,8 +258,8 @@ func TestPrintCreateTable(t *testing.T) { "CREATE TABLE table1 (\n" + " col1 INT64 NOT NULL ,\n" + " col2 STRING(MAX),\n" + - " col3 BYTES(42),\n " + - "CONSTRAINT check_1 CHECK (age > 18),\nCONSTRAINT check_2 CHECK (age < 99)\n" + + " col3 BYTES(42),\n" + + "\tCONSTRAINT check_1 CHECK (age > 18),\n\tCONSTRAINT check_2 CHECK (age < 99)\n" + ") ", }, } @@ -989,3 +989,39 @@ func TestGetSortedTableIdsBySpName(t *testing.T) { }) } } + +func TestPrintCheckConstraintTable(t *testing.T) { + tests := []struct { + description string + cks []CheckConstraint + expected string + }{ + { + description: "Empty constraints list", + cks: []CheckConstraint{}, + expected: "", + }, + { + description: "Single constraint", + cks: []CheckConstraint{ + {Name: "ck1", Expr: "(id > 0)"}, + }, + expected: "\tCONSTRAINT ck1 CHECK (id > 0)\n", + }, + { + description: "Multiple constraints", + cks: []CheckConstraint{ + {Name: "ck1", Expr: "(id > 0)"}, + {Name: "ck2", Expr: "(name IS NOT NULL)"}, + }, + expected: "\tCONSTRAINT ck1 CHECK (id > 0),\n\tCONSTRAINT ck2 CHECK (name IS NOT NULL)\n", + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + actual := PrintCheckConstraintTable(tc.cks) + assert.Equal(t, tc.expected, actual) + }) + } +} diff --git a/webv2/api/schema.go b/webv2/api/schema.go index 08ceca5ec..b01d5eba9 100644 --- a/webv2/api/schema.go +++ b/webv2/api/schema.go @@ -510,11 +510,8 @@ func UpdateCheckConstraint(w http.ResponseWriter, r *http.Request) { } sp := sessionState.Conv.SpSchema[tableId] - - sp.CheckConstraint = newCKs - + sp.CheckConstraints = newCKs sessionState.Conv.SpSchema[tableId] = sp - session.UpdateSessionFile() convm := session.ConvWithMetadata{ @@ -523,7 +520,6 @@ func UpdateCheckConstraint(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(convm) - } func doesNameExist(spcks []ddl.CheckConstraint, targetName string) bool { @@ -547,44 +543,30 @@ func ValidateCheckConstraint(w http.ResponseWriter, r *http.Request) { sp := sessionState.Conv.SpSchema srcschema := sessionState.Conv.SrcSchema - flag := true - - schemaissue := []internal.SchemaIssue{} + var schemaIssue []internal.SchemaIssue for _, src := range srcschema { - for _, col := range sp[src.Id].ColDefs { - - if len(sp[src.Id].CheckConstraint) != 0 { + if len(sp[src.Id].CheckConstraints) > 0 { spType := col.T.Name srcType := srcschema[src.Id].ColDefs[col.Id].Type - actualType := mysqlDefaultTypeMap[srcType.Name] - if actualType.Name != spType { - columnName := sp[src.Id].ColDefs[col.Id].Name - spcks := sp[src.Id].CheckConstraint + spcks := sp[src.Id].CheckConstraints if doesNameExist(spcks, columnName) { flag = false - - schemaissue = sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] - - if !utilities.IsSchemaIssuePresent(schemaissue, internal.TypeMismatch) { - schemaissue = append(schemaissue, internal.TypeMismatch) + schemaIssue = sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] + if !utilities.IsSchemaIssuePresent(schemaIssue, internal.TypeMismatch) { + schemaIssue = append(schemaIssue, internal.TypeMismatch) } - - sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] = schemaissue - + sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] = schemaIssue break - } } } - } - } w.WriteHeader(http.StatusOK) diff --git a/webv2/api/schema_test.go b/webv2/api/schema_test.go index c1ca74042..8623be071 100644 --- a/webv2/api/schema_test.go +++ b/webv2/api/schema_test.go @@ -2549,12 +2549,12 @@ func TestUpdateCheckConstraint(t *testing.T) { tableID := "table1" - expectedCheckConstraint := []ddl.Checkconstraint{ + expectedCheckConstraint := []ddl.CheckConstraint{ {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, } - checkConstraints := []schema.CheckConstraints{ + checkConstraints := []schema.CheckConstraint{ {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, } @@ -2579,7 +2579,7 @@ func TestUpdateCheckConstraint(t *testing.T) { updatedSp := sessionState.Conv.SpSchema[tableID] - assert.Equal(t, expectedCheckConstraint, updatedSp.CheckConstraint) + assert.Equal(t, expectedCheckConstraint, updatedSp.CheckConstraints) } func TestUpdateCheckConstraint_ParseError(t *testing.T) { @@ -2602,6 +2602,8 @@ func TestUpdateCheckConstraint_ParseError(t *testing.T) { assert.Contains(t, rr.Body.String(), expectedErrorMessage) } +type errReader struct{} + func (errReader) Read(p []byte) (n int, err error) { return 0, fmt.Errorf("simulated read error") } @@ -2694,7 +2696,7 @@ func buildConvMySQL_NoTypeMatch(conv *internal.Conv) { Name: "table1", Id: "t1", ColIds: []string{"c1", "c2", "c3"}, - CheckConstraints: []schema.CheckConstraints{ + CheckConstraints: []schema.CheckConstraint{ { Id: "ck1", Name: "check_1", @@ -2718,7 +2720,7 @@ func buildConvMySQL_NoTypeMatch(conv *internal.Conv) { Name: "table1", Id: "t1", ColIds: []string{"c1", "c2", "c3"}, - CheckConstraint: []ddl.Checkconstraint{ + CheckConstraints: []ddl.CheckConstraint{ { Id: "ck1", Name: "check_1", @@ -2757,7 +2759,7 @@ func buildConvMySQL_TypeMatch(conv *internal.Conv) { Name: "table1", Id: "t1", ColIds: []string{"c1", "c2", "c3"}, - CheckConstraints: []schema.CheckConstraints{ + CheckConstraints: []schema.CheckConstraint{ { Id: "ck1", Name: "check_1", @@ -2781,7 +2783,7 @@ func buildConvMySQL_TypeMatch(conv *internal.Conv) { Name: "table1", Id: "t1", ColIds: []string{"c1", "c2", "c3"}, - CheckConstraint: []ddl.Checkconstraint{ + CheckConstraints: []ddl.CheckConstraint{ { Id: "ck1", Name: "check_1", diff --git a/webv2/table/review_table_schema.go b/webv2/table/review_table_schema.go index 5a5d7c31b..5e9635e31 100644 --- a/webv2/table/review_table_schema.go +++ b/webv2/table/review_table_schema.go @@ -110,10 +110,10 @@ func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { } oldName := conv.SrcSchema[tableId].ColDefs[colId].Name - for i := range conv.SpSchema[tableId].CheckConstraint { - originalString := conv.SpSchema[tableId].CheckConstraint[i].Expr + for i := range conv.SpSchema[tableId].CheckConstraints { + originalString := conv.SpSchema[tableId].CheckConstraints[i].Expr updatedValue := strings.ReplaceAll(originalString, oldName, v.Rename) - conv.SpSchema[tableId].CheckConstraint[i].Expr = updatedValue + conv.SpSchema[tableId].CheckConstraints[i].Expr = updatedValue } interleaveTableSchema = reviewRenameColumn(v.Rename, tableId, colId, conv, interleaveTableSchema) diff --git a/webv2/table/update_table_schema.go b/webv2/table/update_table_schema.go index 76eaeca3a..27d05ad47 100644 --- a/webv2/table/update_table_schema.go +++ b/webv2/table/update_table_schema.go @@ -100,10 +100,10 @@ func UpdateTableSchema(w http.ResponseWriter, r *http.Request) { oldName := conv.SrcSchema[tableId].ColDefs[colId].Name - for i := range conv.SpSchema[tableId].CheckConstraint { - originalString := conv.SpSchema[tableId].CheckConstraint[i].Expr + for i := range conv.SpSchema[tableId].CheckConstraints { + originalString := conv.SpSchema[tableId].CheckConstraints[i].Expr updatedValue := strings.ReplaceAll(originalString, oldName, v.Rename) - conv.SpSchema[tableId].CheckConstraint[i].Expr = updatedValue + conv.SpSchema[tableId].CheckConstraints[i].Expr = updatedValue } renameColumn(v.Rename, tableId, colId, conv)