From 1a4e61bddfdab608f84439163f8154c0adff91f1 Mon Sep 17 00:00:00 2001 From: Taher Lakdawala <78196491+taherkl@users.noreply.github.com> Date: Thu, 26 Dec 2024 12:57:20 +0530 Subject: [PATCH] Add Support for Check Constraint - Backend code changes (#945) * Check constraint backend (#9) Backend Support for Check Constraint * update api * fix PR comment * remove api call to while validating constraints * Fixed db collation regex to remove collation name from the results * renamed function name to formatCheckConstraints and added check if constraint name is empty * fixed PR comments * added test case for the empty check constraint name * fix: added regular exprression to match the exact column * fix: added regular expression to replace table name * Added test case for the column rename for check constraint * 1. Refactored GetConstraint function 2. Fixed inforschema unit tests * added comment at handling case for check constraints * reverted white spaces * reverted white spaces * nit: doesCheckConstraintNameExist * added comments for doesCheckConstraintNameExist * PR and UT fixes * fix UT * UT fix * Removed isCheckConstraintsTablePresent function * moved regex globally * Fix UT * fixed UT * fixed handling of the constraints * removed unused function * added unit tests for incompatable name * Combined unit tests * added test case for the renaming column having substring of other column * added the query changes which return distinct value * Updating version of msprod (#969) * fix(deps): update module golang.org/x/net to v0.33.0 [security] (#967) Co-authored-by: Vardhan Vinay Thigle <39047439+VardhanThigle@users.noreply.github.com> * feat: APIs for Backend Changes for Default Values (#965) * backend apis * linting * comment changes * comment changes * feat: default value for mysql source (#963) * source dv * test fix * change * comment changes * test fix * change * fix github v * change * Check constraint backend (#9) Backend Support for Check Constraint * 1. Refactored GetConstraint function 2. Fixed inforschema unit tests * removed duplicate function * Fixed UT --------- Co-authored-by: taherkl Co-authored-by: Akash Thawait Co-authored-by: Vivek Yadav Co-authored-by: Vardhan Vinay Thigle <39047439+VardhanThigle@users.noreply.github.com> Co-authored-by: Mend Renovate Co-authored-by: Astha Mohta <35952883+asthamohta@users.noreply.github.com> --- common/constants/constants.go | 21 +- internal/convert.go | 1 + internal/helpers.go | 5 + internal/mapping.go | 8 + internal/reports/report_helpers.go | 7 + schema/schema.go | 26 ++- sources/common/infoschema.go | 23 +- sources/common/toddl.go | 38 +++- sources/common/toddl_test.go | 41 ++++ sources/dynamodb/schema.go | 6 +- sources/dynamodb/schema_test.go | 4 +- sources/mysql/infoschema.go | 136 ++++++++---- sources/mysql/infoschema_test.go | 274 ++++++++++++++++++------ sources/mysql/mysqldump.go | 1 + sources/oracle/infoschema.go | 6 +- sources/postgres/infoschema.go | 6 +- sources/spanner/infoschema.go | 8 +- sources/sqlserver/infoschema.go | 6 +- spanner/ddl/ast.go | 62 ++++-- spanner/ddl/ast_test.go | 126 +++++++++-- webv2/api/schema.go | 137 +++++++++--- webv2/api/schema_test.go | 81 +++++++ webv2/routes.go | 2 +- webv2/table/review_table_schema.go | 18 +- webv2/table/review_table_schema_test.go | 211 ++++++++++++++---- webv2/table/update_table_schema.go | 21 +- 26 files changed, 993 insertions(+), 282 deletions(-) diff --git a/common/constants/constants.go b/common/constants/constants.go index b5cdb0246..fb150c216 100644 --- a/common/constants/constants.go +++ b/common/constants/constants.go @@ -74,22 +74,22 @@ const ( AddIndex = "add_index" EditColumnMaxLength = "edit_column_max_length" AddShardIdPrimaryKey = "add_shard_id_primary_key" - //bulk migration type + // bulk migration type BULK_MIGRATION = "bulk" - //dataflow migration type + // dataflow migration type DATAFLOW_MIGRATION = "dataflow" - //DMS migration type + // DMS migration type DMS_MIGRATION = "dms" SESSION_FILE = "sessionFile" - //Default shardId + // Default shardId DEFAULT_SHARD_ID string = "smt-default" - //Metadata database name + // Metadata database name METADATA_DB string = "spannermigrationtool_metadata" - //Migration types + // Migration types MINIMAL_DOWNTIME_MIGRATION = "minimal_downtime" - //Job Resource Types + // Job Resource Types DATAFLOW_RESOURCE string = "dataflow" PUBSUB_RESOURCE string = "pubsub" DLQ_PUBSUB_RESOURCE string = "dlq_pubsub" @@ -111,7 +111,7 @@ const ( // Default gcs path of the Dataflow template. DEFAULT_TEMPLATE_PATH string = "gs://dataflow-templates/latest/flex/Cloud_Datastream_to_Spanner" - //FK Actions + // FK Actions FK_NO_ACTION string = "NO ACTION" FK_CASCADE string = "CASCADE" FK_SET_DEFAULT string = "SET DEFAULT" @@ -122,9 +122,12 @@ const ( REGULAR_GCS string = "data" DLQ_GCS string = "dlq" - //VerifyExpresions API + // VerifyExpresions API CHECK_EXPRESSION = "CHECK" DEFAUT_EXPRESSION = "DEFAULT" DEFAULT_GENERATED = "DEFAULT_GENERATED" TEMP_DB = "smt-staging-db" + + // Regex for matching database collation + DB_COLLATION_REGEX = `(_[a-zA-Z0-9]+\\|\\)` ) diff --git a/internal/convert.go b/internal/convert.go index 413691260..d2e892131 100644 --- a/internal/convert.go +++ b/internal/convert.go @@ -131,6 +131,7 @@ const ( SequenceCreated ForeignKeyActionNotSupported NumericPKNotSupported + TypeMismatch DefaultValueError ) diff --git a/internal/helpers.go b/internal/helpers.go index 621b36365..a452e98a1 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -65,6 +65,11 @@ func GenerateForeignkeyId() string { func GenerateIndexesId() string { return GenerateId("i") } + +func GenerateCheckConstrainstId() string { + return GenerateId("cc") +} + func GenerateRuleId() string { return GenerateId("r") } diff --git a/internal/mapping.go b/internal/mapping.go index 0eb9bd055..d98008058 100644 --- a/internal/mapping.go +++ b/internal/mapping.go @@ -243,6 +243,14 @@ func ToSpannerIndexName(conv *Conv, srcIndexName string) string { return getSpannerValidName(conv, srcIndexName) } +// Note that the check constraints names in spanner have to be globally unique +// (across the database). But in some source databases, such as MySQL, +// they only have to be unique for a table. Hence we must map each source +// constraint name to a unique spanner constraint name. +func ToSpannerCheckConstraintName(conv *Conv, srcCheckConstraintName string) string { + return getSpannerValidName(conv, srcCheckConstraintName) +} + // conv.UsedNames tracks Spanner names that have been used for table names, foreign key constraints // and indexes. We use this to ensure we generate unique names when // we map from source dbs to Spanner since Spanner requires all these names to be diff --git a/internal/reports/report_helpers.go b/internal/reports/report_helpers.go index 38df589af..c4f3084bd 100644 --- a/internal/reports/report_helpers.go +++ b/internal/reports/report_helpers.go @@ -409,6 +409,13 @@ func buildTableReportBody(conv *internal.Conv, tableId string, issues map[string Description: fmt.Sprintf("%s for table '%s' column '%s'", IssueDB[i].Brief, conv.SpSchema[tableId].Name, spColName), } l = append(l, toAppend) + case internal.TypeMismatch: + toAppend := Issue{ + Category: IssueDB[i].Category, + Description: fmt.Sprintf("Table '%s': Type mismatch in '%s'column affecting check constraints. Verify data type compatibility with constraint logic", conv.SpSchema[tableId].Name, conv.SpSchema[tableId].ColDefs[colId].Name), + } + l = append(l, toAppend) + default: toAppend := Issue{ Category: IssueDB[i].Category, diff --git a/schema/schema.go b/schema/schema.go index 7d73cd799..eab021cf0 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -35,15 +35,16 @@ import ( // Table represents a database table. type Table struct { - Name string - Schema string - ColIds []string // List of column Ids (for predictable iteration order e.g. printing). - ColDefs map[string]Column // Details of columns. - ColNameIdMap map[string]string `json:"-"` // Computed every time just after conv is generated or after any column renaming - PrimaryKeys []Key - ForeignKeys []ForeignKey - Indexes []Index - Id string + Name string + Schema string + ColIds []string // List of column Ids (for predictable iteration order e.g. printing). + ColDefs map[string]Column // Details of columns. + ColNameIdMap map[string]string `json:"-"` // Computed every time just after conv is generated or after any column renaming + PrimaryKeys []Key + ForeignKeys []ForeignKey + CheckConstraints []CheckConstraint + Indexes []Index + Id string } // Column represents a database column. @@ -77,6 +78,13 @@ type ForeignKey struct { Id string } +// CheckConstraints represents a check constraint defined in the schema. +type CheckConstraint struct { + Name string + Expr string + Id string +} + // Key respresents a primary key or index key. type Key struct { ColId string diff --git a/sources/common/infoschema.go b/sources/common/infoschema.go index b4f9c9e7c..ae43ea4a0 100644 --- a/sources/common/infoschema.go +++ b/sources/common/infoschema.go @@ -38,7 +38,7 @@ type InfoSchema interface { GetColumns(conv *internal.Conv, table SchemaAndName, constraints map[string][]string, primaryKeys []string) (map[string]schema.Column, []string, error) GetRowsFromTable(conv *internal.Conv, srcTable string) (interface{}, error) GetRowCount(table SchemaAndName) (int64, error) - GetConstraints(conv *internal.Conv, table SchemaAndName) ([]string, map[string][]string, error) + GetConstraints(conv *internal.Conv, table SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) GetForeignKeys(conv *internal.Conv, table SchemaAndName) (foreignKeys []schema.ForeignKey, err error) GetIndexes(conv *internal.Conv, table SchemaAndName, colNameIdMp map[string]string) ([]schema.Index, error) ProcessData(conv *internal.Conv, tableId string, srcSchema schema.Table, spCols []string, spSchema ddl.CreateTable, additionalAttributes internal.AdditionalDataAttributes) error @@ -187,7 +187,7 @@ func (is *InfoSchemaImpl) processTable(conv *internal.Conv, table SchemaAndName, var t schema.Table fmt.Println("processing schema for table", table) tblId := internal.GenerateTableId() - primaryKeys, constraints, err := infoSchema.GetConstraints(conv, table) + primaryKeys, checkConstraints, constraints, err := infoSchema.GetConstraints(conv, table) if err != nil { return t, fmt.Errorf("couldn't get constraints for table %s.%s: %s", table.Schema, table.Name, err) } @@ -217,15 +217,16 @@ func (is *InfoSchemaImpl) processTable(conv *internal.Conv, table SchemaAndName, schemaPKeys = append(schemaPKeys, schema.Key{ColId: colNameIdMap[k]}) } t = schema.Table{ - Id: tblId, - Name: name, - Schema: table.Schema, - ColIds: colIds, - ColNameIdMap: colNameIdMap, - ColDefs: colDefs, - PrimaryKeys: schemaPKeys, - Indexes: indexes, - ForeignKeys: foreignKeys} + Id: tblId, + Name: name, + Schema: table.Schema, + ColIds: colIds, + ColNameIdMap: colNameIdMap, + ColDefs: colDefs, + PrimaryKeys: schemaPKeys, + CheckConstraints: checkConstraints, + Indexes: indexes, + ForeignKeys: foreignKeys} return t, nil } diff --git a/sources/common/toddl.go b/sources/common/toddl.go index 1b706274e..460523575 100644 --- a/sources/common/toddl.go +++ b/sources/common/toddl.go @@ -114,7 +114,7 @@ func (ss *SchemaToSpannerImpl) SchemaToSpannerDDLHelper(conv *internal.Conv, tod if srcCol.Ignored.Default { issues = append(issues, internal.DefaultValue) } - if srcCol.Ignored.AutoIncrement { //TODO(adibh) - check why this is not there in postgres + if srcCol.Ignored.AutoIncrement { // TODO(adibh) - check why this is not there in postgres issues = append(issues, internal.AutoIncrement) } // Set the not null constraint to false for unsupported source datatypes @@ -167,14 +167,16 @@ func (ss *SchemaToSpannerImpl) SchemaToSpannerDDLHelper(conv *internal.Conv, tod } comment := "Spanner schema for source table " + quoteIfNeeded(srcTable.Name) conv.SpSchema[srcTable.Id] = ddl.CreateTable{ - Name: spTableName, - ColIds: spColIds, - ColDefs: spColDef, - PrimaryKeys: cvtPrimaryKeys(srcTable.PrimaryKeys), - ForeignKeys: cvtForeignKeys(conv, spTableName, srcTable.Id, srcTable.ForeignKeys, isRestore), - Indexes: cvtIndexes(conv, srcTable.Id, srcTable.Indexes, spColIds, spColDef), - Comment: comment, - Id: srcTable.Id} + Name: spTableName, + ColIds: spColIds, + ColDefs: spColDef, + PrimaryKeys: cvtPrimaryKeys(srcTable.PrimaryKeys), + ForeignKeys: cvtForeignKeys(conv, spTableName, srcTable.Id, srcTable.ForeignKeys, isRestore), + CheckConstraints: cvtCheckConstraint(conv, srcTable.CheckConstraints), + Indexes: cvtIndexes(conv, srcTable.Id, srcTable.Indexes, spColIds, spColDef), + Comment: comment, + Id: srcTable.Id, + } return nil } @@ -234,6 +236,20 @@ func cvtForeignKeys(conv *internal.Conv, spTableName string, srcTableId string, return spKeys } +// cvtCheckConstraint converts check constraints from source to Spanner. +func cvtCheckConstraint(conv *internal.Conv, srcKeys []schema.CheckConstraint) []ddl.CheckConstraint { + var spcc []ddl.CheckConstraint + + for _, cc := range srcKeys { + spcc = append(spcc, ddl.CheckConstraint{ + Id: cc.Id, + Name: internal.ToSpannerCheckConstraintName(conv, cc.Name), + Expr: cc.Expr, + }) + } + return spcc +} + func CvtForeignKeysHelper(conv *internal.Conv, spTableName string, srcTableId string, srcKey schema.ForeignKey, isRestore bool) (ddl.Foreignkey, error) { if len(srcKey.ColIds) != len(srcKey.ReferColumnIds) { conv.Unexpected(fmt.Sprintf("ConvertForeignKeys: ColIds and referColumns don't have the same lengths: len(columns)=%d, len(referColumns)=%d for source tableId: %s, referenced table: %s", len(srcKey.ColIds), len(srcKey.ReferColumnIds), srcTableId, srcKey.ReferTableId)) @@ -330,8 +346,8 @@ func CvtIndexHelper(conv *internal.Conv, tableId string, srcIndex schema.Index, isPresent = true if conv.SpDialect == constants.DIALECT_POSTGRESQL { if spColDef[v].T.Name == ddl.Numeric { - //index on NUMERIC is not supported in PGSQL Dialect currently. - //Indexes which contains a NUMERIC column in it will need to be skipped. + // index on NUMERIC is not supported in PGSQL Dialect currently. + // Indexes which contains a NUMERIC column in it will need to be skipped. return ddl.CreateIndex{} } } diff --git a/sources/common/toddl_test.go b/sources/common/toddl_test.go index dcf5b3651..6a98b8ca9 100644 --- a/sources/common/toddl_test.go +++ b/sources/common/toddl_test.go @@ -429,6 +429,47 @@ func Test_SchemaToSpannerSequenceHelper(t *testing.T) { } } +func Test_cvtCheckContraint(t *testing.T) { + + conv := internal.MakeConv() + srcSchema := []schema.CheckConstraint{ + { + Id: "cc1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "cc2", + Name: "check_2", + Expr: "age < 99", + }, + { + Id: "cc3", + Name: "@invalid_name", // incompatabile name + Expr: "age != 0", + }, + } + spSchema := []ddl.CheckConstraint{ + { + Id: "cc1", + Name: "check_1", + Expr: "age > 0", + }, + { + Id: "cc2", + Name: "check_2", + Expr: "age < 99", + }, + { + Id: "cc3", + Name: "Ainvalid_name", + Expr: "age != 0", + }, + } + result := cvtCheckConstraint(conv, srcSchema) + assert.Equal(t, spSchema, result) +} + func TestSpannerSchemaApplyExpressions(t *testing.T) { makeConv := func() *internal.Conv { conv := internal.MakeConv() diff --git a/sources/dynamodb/schema.go b/sources/dynamodb/schema.go index 4bc38f0ea..4af603988 100644 --- a/sources/dynamodb/schema.go +++ b/sources/dynamodb/schema.go @@ -129,20 +129,20 @@ func (isi InfoSchemaImpl) GetRowCount(table common.SchemaAndName) (int64, error) return *result.Table.ItemCount, err } -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) (primaryKeys []string, constraints map[string][]string, err error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) (primaryKeys []string, checkConstraints []schema.CheckConstraint, constraints map[string][]string, err error) { input := &dynamodb.DescribeTableInput{ TableName: aws.String(table.Name), } result, err := isi.DynamoClient.DescribeTable(input) if err != nil { - return primaryKeys, constraints, fmt.Errorf("failed to make a DescribeTable API call for table %v: %v", table.Name, err) + return primaryKeys, checkConstraints, constraints, fmt.Errorf("failed to make a DescribeTable API call for table %v: %v", table.Name, err) } // Primary keys. for _, i := range result.Table.KeySchema { primaryKeys = append(primaryKeys, *i.AttributeName) } - return primaryKeys, constraints, nil + return primaryKeys, checkConstraints, constraints, nil } func (isi InfoSchemaImpl) GetForeignKeys(conv *internal.Conv, table common.SchemaAndName) (foreignKeys []schema.ForeignKey, err error) { diff --git a/sources/dynamodb/schema_test.go b/sources/dynamodb/schema_test.go index b9d10ab3a..9919d3d5a 100644 --- a/sources/dynamodb/schema_test.go +++ b/sources/dynamodb/schema_test.go @@ -633,7 +633,7 @@ func TestInfoSchemaImpl_GetConstraints(t *testing.T) { dySchema := common.SchemaAndName{Name: "test"} conv := internal.MakeConv() isi := InfoSchemaImpl{client, nil, 10} - primaryKeys, constraints, err := isi.GetConstraints(conv, dySchema) + primaryKeys, _, constraints, err := isi.GetConstraints(conv, dySchema) assert.Nil(t, err) pKeys := []string{"a", "b"} @@ -705,7 +705,7 @@ func TestInfoSchemaImpl_GetColumns(t *testing.T) { client := &mockDynamoClient{ scanOutputs: scanOutputs, } - dySchema := common.SchemaAndName{Name: "test", Id: "t1"} + dySchema := common.SchemaAndName{Name: "test", Id: "t1"} isi := InfoSchemaImpl{client, nil, 10} diff --git a/sources/mysql/infoschema.go b/sources/mysql/infoschema.go index 0cb4ad7c8..0769b34b7 100644 --- a/sources/mysql/infoschema.go +++ b/sources/mysql/infoschema.go @@ -147,7 +147,7 @@ func (isi InfoSchemaImpl) GetRowCount(table common.SchemaAndName) (int64, error) err := rows.Scan(&count) return count, err } - return 0, nil //Check if 0 is ok to return + return 0, nil // Check if 0 is ok to return } // GetTables return list of tables in the selected database. @@ -194,16 +194,6 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd continue } ignored := schema.Ignored{} - for _, c := range constraints[colName] { - // c can be UNIQUE, PRIMARY KEY, FOREIGN KEY or CHECK - // We've already filtered out PRIMARY KEY. - switch c { - case "CHECK": - ignored.Check = true - case "FOREIGN KEY", "PRIMARY KEY", "UNIQUE": - // Nothing to do here -- these are all handled elsewhere. - } - } ignored.Default = colDefault.Valid colId := internal.GenerateColumnId() if colExtra.String == "auto_increment" { @@ -227,7 +217,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd if colDefault.Valid { defaultVal.Value = ddl.Expression{ ExpressionId: internal.GenerateExpressionId(), - Statement: common.SanitizeDefaultValue(colDefault.String, dataType, colExtra.String == constants.DEFAULT_GENERATED), + Statement: common.SanitizeDefaultValue(colDefault.String, dataType, colExtra.String == constants.DEFAULT_GENERATED), } } @@ -250,38 +240,108 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { - q := `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE - FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t - INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k - ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA AND t.TABLE_NAME=k.TABLE_NAME - WHERE k.TABLE_SCHEMA = ? AND k.TABLE_NAME = ? ORDER BY k.ordinal_position;` - rows, err := isi.Db.Query(q, table.Schema, table.Name) +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { + finalQuery, err := isi.getConstraintsDQL() + if err != nil { + return nil, nil, nil, err + } + rows, err := isi.Db.Query(finalQuery, table.Schema, table.Name) if err != nil { - return nil, nil, err + return nil, nil, nil, err } defer rows.Close() + var primaryKeys []string - var col, constraint string + var checkKeys []schema.CheckConstraint m := make(map[string][]string) + for rows.Next() { - err := rows.Scan(&col, &constraint) - if err != nil { - conv.Unexpected(fmt.Sprintf("Can't scan: %v", err)) + if err := isi.processRow(rows, conv, &primaryKeys, &checkKeys, m); err != nil { + conv.Unexpected(fmt.Sprintf("Can't scan constrants. error: %v", err)) continue } - if col == "" || constraint == "" { - conv.Unexpected(fmt.Sprintf("Got empty col or constraint")) - continue - } - switch constraint { - case "PRIMARY KEY": - primaryKeys = append(primaryKeys, col) - default: - m[col] = append(m[col], constraint) - } } - return primaryKeys, m, nil + + return primaryKeys, checkKeys, m, nil +} + +// getConstraintsDQL returns the appropriate SQL query based on the existence of CHECK_CONSTRAINTS. +func (isi InfoSchemaImpl) getConstraintsDQL() (string, error) { + var tableExistsCount int + // check if CHECK_CONSTRAINTS table exists. + checkQuery := `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';` + err := isi.Db.QueryRow(checkQuery).Scan(&tableExistsCount) + if err != nil { + return "", err + } + + // mysql version 8.0.16 and above has CHECK_CONSTRAINTS table. + if tableExistsCount > 0 { + return `SELECT DISTINCT COALESCE(k.COLUMN_NAME,'') AS COLUMN_NAME,t.CONSTRAINT_NAME, t.CONSTRAINT_TYPE, COALESCE(c.CHECK_CLAUSE, '') AS CHECK_CLAUSE + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t + LEFT JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k + ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME + AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA + AND t.TABLE_NAME = k.TABLE_NAME + LEFT JOIN INFORMATION_SCHEMA.CHECK_CONSTRAINTS AS c + ON t.CONSTRAINT_NAME = c.CONSTRAINT_NAME + WHERE t.TABLE_SCHEMA = ? + AND t.TABLE_NAME = ?;`, nil + } + return `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t + INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k + ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME + AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA + AND t.TABLE_NAME = k.TABLE_NAME + WHERE t.TABLE_SCHEMA = ? + AND t.TABLE_NAME = ? + ORDER BY k.ORDINAL_POSITION;`, nil +} + +// processRow handles scanning and processing of a database row for GetConstraints. +func (isi InfoSchemaImpl) processRow( + rows *sql.Rows, conv *internal.Conv, primaryKeys *[]string, + checkKeys *[]schema.CheckConstraint, m map[string][]string, +) error { + var col, constraintType, checkClause, constraintName string + var err error + cols, err := rows.Columns() + if err != nil { + conv.Unexpected(fmt.Sprintf("Failed to get columns: %v", err)) + return err + } + + switch len(cols) { + case 2: + err = rows.Scan(&col, &constraintType) + case 4: + err = rows.Scan(&col, &constraintName, &constraintType, &checkClause) + default: + conv.Unexpected(fmt.Sprintf("unexpected number of columns: %d", len(cols))) + return fmt.Errorf("unexpected number of columns: %d", len(cols)) + } + if err != nil { + return err + } + + if col == "" && constraintType == "" { + conv.Unexpected("Got empty column or constraint type") + return nil + } + + switch constraintType { + case "PRIMARY KEY": + *primaryKeys = append(*primaryKeys, col) + + // Case added to handle check constraints + case "CHECK": + checkClause = collationRegex.ReplaceAllString(checkClause, "") + *checkKeys = append(*checkKeys, schema.CheckConstraint{Name: constraintName, Expr: checkClause, Id: internal.GenerateCheckConstrainstId()}) + default: + m[col] = append(m[col], constraintType) + } + return nil } // GetForeignKeys return list all the foreign keys constraints. @@ -380,12 +440,14 @@ func (isi InfoSchemaImpl) GetIndexes(conv *internal.Conv, table common.SchemaAnd indexMap[name] = schema.Index{ Id: internal.GenerateIndexesId(), Name: name, - Unique: (nonUnique == "0")} + Unique: (nonUnique == "0"), + } } index := indexMap[name] index.Keys = append(index.Keys, schema.Key{ ColId: colNameIdMap[column], - Desc: (collation.Valid && collation.String == "D")}) + Desc: (collation.Valid && collation.String == "D"), + }) indexMap[name] = index } for _, k := range indexNames { diff --git a/sources/mysql/infoschema_test.go b/sources/mysql/infoschema_test.go index 37836f382..987a9b4a3 100644 --- a/sources/mysql/infoschema_test.go +++ b/sources/mysql/infoschema_test.go @@ -17,6 +17,7 @@ package mysql import ( "database/sql" "database/sql/driver" + "regexp" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -39,7 +40,6 @@ type mockSpec struct { func TestProcessSchemaMYSQL(t *testing.T) { ms := []mockSpec{ - { query: "SELECT (.+) FROM information_schema.tables where table_type = 'BASE TABLE' and (.+)", args: []driver.Value{"test"}, @@ -49,22 +49,35 @@ func TestProcessSchemaMYSQL(t *testing.T) { {"cart"}, {"product"}, {"test"}, - {"test_ref"}}, - }, { + {"test_ref"}, + }, + }, + { + query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + args: nil, + cols: []string{"count"}, + rows: [][]driver.Value{ + {int64(0)}, + }, + }, + { query: "SELECT (.+) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS (.+)", args: []driver.Value{"test", "user"}, cols: []string{"column_name", "constraint_type"}, rows: [][]driver.Value{ {"user_id", "PRIMARY KEY"}, - {"ref", "FOREIGN KEY"}}, - }, { + {"ref", "FOREIGN KEY"}, + }, + }, + { query: "SELECT (.+) FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS (.+)", args: []driver.Value{"test", "user"}, cols: []string{"REFERENCED_TABLE_NAME", "COLUMN_NAME", "REFERENCED_COLUMN_NAME", "CONSTRAINT_NAME", "DELETE_RULE", "UPDATE_RULE"}, rows: [][]driver.Value{ {"test", "ref", "id", "fk_test", constants.FK_SET_NULL, constants.FK_CASCADE}, }, - }, { + }, + { query: "SELECT (.+) FROM information_schema.COLUMNS (.+)", args: []driver.Value{"test", "user"}, cols: []string{"column_name", "data_type", "column_type", "is_nullable", "column_default", "character_maximum_length", "numeric_precision", "numeric_scale", "extra"}, @@ -79,14 +92,24 @@ func TestProcessSchemaMYSQL(t *testing.T) { args: []driver.Value{"test", "user"}, cols: []string{"INDEX_NAME", "COLUMN_NAME", "SEQ_IN_INDEX", "COLLATION", "NON_UNIQUE"}, }, + { + query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + args: nil, + cols: []string{"count"}, + rows: [][]driver.Value{ + {int64(0)}, + }, + }, { query: "SELECT (.+) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS (.+)", args: []driver.Value{"test", "cart"}, cols: []string{"column_name", "constraint_type"}, rows: [][]driver.Value{ {"productid", "PRIMARY KEY"}, - {"userid", "PRIMARY KEY"}}, - }, { + {"userid", "PRIMARY KEY"}, + }, + }, + { query: "SELECT (.+) FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS (.+)", args: []driver.Value{"test", "cart"}, cols: []string{"REFERENCED_TABLE_NAME", "COLUMN_NAME", "REFERENCED_COLUMN_NAME", "CONSTRAINT_NAME", "DELETE_RULE", "UPDATE_RULE"}, @@ -94,14 +117,16 @@ func TestProcessSchemaMYSQL(t *testing.T) { {"product", "productid", "product_id", "fk_test2", constants.FK_NO_ACTION, constants.FK_NO_ACTION}, {"user", "userid", "user_id", "fk_test3", constants.FK_RESTRICT, constants.FK_SET_NULL}, }, - }, { + }, + { query: "SELECT (.+) FROM information_schema.COLUMNS (.+)", args: []driver.Value{"test", "cart"}, cols: []string{"column_name", "data_type", "column_type", "is_nullable", "column_default", "character_maximum_length", "numeric_precision", "numeric_scale", "extra"}, rows: [][]driver.Value{ {"productid", "text", "text", "NO", nil, nil, nil, nil, nil}, {"userid", "text", "text", "NO", nil, nil, nil, nil, nil}, - {"quantity", "bigint", "bigint", "YES", nil, nil, 64, 0, nil}}, + {"quantity", "bigint", "bigint", "YES", nil, nil, 64, 0, nil}, + }, }, // db call to fetch index happens after fetching of column { @@ -113,25 +138,38 @@ func TestProcessSchemaMYSQL(t *testing.T) { {"index2", "userid", 1, "A", "1"}, {"index2", "productid", 2, "D", "1"}, {"index3", "productid", 1, "A", "0"}, - {"index3", "userid", 2, "D", "0"}}, + {"index3", "userid", 2, "D", "0"}, + }, + }, + { + query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + args: nil, + cols: []string{"count"}, + rows: [][]driver.Value{ + {int64(0)}, + }, }, { query: "SELECT (.+) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS (.+)", args: []driver.Value{"test", "product"}, cols: []string{"column_name", "constraint_type"}, rows: [][]driver.Value{ - {"product_id", "PRIMARY KEY"}}, - }, { + {"product_id", "PRIMARY KEY"}, + }, + }, + { query: "SELECT (.+) FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS (.+)", args: []driver.Value{"test", "product"}, cols: []string{"REFERENCED_TABLE_NAME", "COLUMN_NAME", "REFERENCED_COLUMN_NAME", "CONSTRAINT_NAME", "DELETE_RULE", "UPDATE_RULE"}, - }, { + }, + { query: "SELECT (.+) FROM information_schema.COLUMNS (.+)", args: []driver.Value{"test", "product"}, cols: []string{"column_name", "data_type", "column_type", "is_nullable", "column_default", "character_maximum_length", "numeric_precision", "numeric_scale", "extra"}, rows: [][]driver.Value{ {"product_id", "text", "text", "NO", nil, nil, nil, nil, nil}, - {"product_name", "text", "text", "NO", nil, nil, nil, nil, nil}}, + {"product_name", "text", "text", "NO", nil, nil, nil, nil, nil}, + }, }, // db call to fetch index happens after fetching of column { @@ -139,18 +177,30 @@ func TestProcessSchemaMYSQL(t *testing.T) { args: []driver.Value{"test", "product"}, cols: []string{"INDEX_NAME", "COLUMN_NAME", "SEQ_IN_INDEX", "COLLATION", "NON_UNIQUE"}, }, + { + query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + args: nil, + cols: []string{"count"}, + rows: [][]driver.Value{ + {int64(0)}, + }, + }, { query: "SELECT (.+) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS (.+)", args: []driver.Value{"test", "test"}, cols: []string{"column_name", "constraint_type"}, rows: [][]driver.Value{{"id", "PRIMARY KEY"}, {"id", "FOREIGN KEY"}}, - }, { + }, + { query: "SELECT (.+) FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS (.+)", args: []driver.Value{"test", "test"}, cols: []string{"REFERENCED_TABLE_NAME", "COLUMN_NAME", "REFERENCED_COLUMN_NAME", "CONSTRAINT_NAME", "DELETE_RULE", "UPDATE_RULE"}, - rows: [][]driver.Value{{"test_ref", "id", "ref_id", "fk_test4", constants.FK_CASCADE, constants.FK_RESTRICT}, - {"test_ref", "txt", "ref_txt", "fk_test4", constants.FK_CASCADE, constants.FK_RESTRICT}}, - }, { + rows: [][]driver.Value{ + {"test_ref", "id", "ref_id", "fk_test4", constants.FK_CASCADE, constants.FK_RESTRICT}, + {"test_ref", "txt", "ref_txt", "fk_test4", constants.FK_CASCADE, constants.FK_RESTRICT}, + }, + }, + { query: "SELECT (.+) FROM information_schema.COLUMNS (.+)", args: []driver.Value{"test", "test"}, cols: []string{"column_name", "data_type", "column_type", "is_nullable", "column_default", "character_maximum_length", "numeric_precision", "numeric_scale", "extra"}, @@ -174,7 +224,8 @@ func TestProcessSchemaMYSQL(t *testing.T) { {"ts", "datetime", "datetime", "YES", nil, nil, nil, nil, nil}, {"tz", "timestamp", "timestamp", "YES", nil, nil, nil, nil, nil}, {"vc", "varchar", "varchar", "YES", nil, nil, nil, nil, nil}, - {"vc6", "varchar", "varchar(6)", "YES", nil, 6, nil, nil, nil}}, + {"vc6", "varchar", "varchar(6)", "YES", nil, 6, nil, nil, nil}, + }, }, // db call to fetch index happens after fetching of column { @@ -182,25 +233,37 @@ func TestProcessSchemaMYSQL(t *testing.T) { args: []driver.Value{"test", "test"}, cols: []string{"INDEX_NAME", "COLUMN_NAME", "SEQ_IN_INDEX", "COLLATION", "NON_UNIQUE"}, }, + { + query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + args: nil, + cols: []string{"count"}, + rows: [][]driver.Value{ + {int64(0)}, + }, + }, { query: "SELECT (.+) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS (.+)", args: []driver.Value{"test", "test_ref"}, cols: []string{"column_name", "constraint_type"}, rows: [][]driver.Value{ {"ref_id", "PRIMARY KEY"}, - {"ref_txt", "PRIMARY KEY"}}, - }, { + {"ref_txt", "PRIMARY KEY"}, + }, + }, + { query: "SELECT (.+) FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS (.+)", args: []driver.Value{"test", "test_ref"}, cols: []string{"REFERENCED_TABLE_NAME", "COLUMN_NAME", "REFERENCED_COLUMN_NAME", "CONSTRAINT_NAME", "DELETE_RULE", "UPDATE_RULE"}, - }, { + }, + { query: "SELECT (.+) FROM information_schema.COLUMNS (.+)", args: []driver.Value{"test", "test_ref"}, cols: []string{"column_name", "data_type", "column_type", "is_nullable", "column_default", "character_maximum_length", "numeric_precision", "numeric_scale", "extra"}, rows: [][]driver.Value{ {"ref_id", "bigint", "bigint", "NO", nil, nil, 64, 0, nil}, {"ref_txt", "text", "text", "NO", nil, nil, nil, nil, nil}, - {"abc", "text", "text", "NO", nil, nil, nil, nil, nil}}, + {"abc", "text", "text", "NO", nil, nil, nil, nil, nil}, + }, }, // db call to fetch index happens after fetching of column { @@ -216,18 +279,23 @@ func TestProcessSchemaMYSQL(t *testing.T) { _, err := commonInfoSchema.GenerateSrcSchema(conv, isi, 1) assert.Nil(t, err) expectedSchema := map[string]schema.Table{ - "cart": schema.Table{Name: "cart", Schema: "test", ColIds: []string{"productid", "userid", "quantity"}, ColDefs: map[string]schema.Column{ - "productid": schema.Column{Name: "productid", Type: schema.Type{Name: "text", Mods: []int64(nil), ArrayBounds: []int64(nil)}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: ""}, - "quantity": schema.Column{Name: "quantity", Type: schema.Type{Name: "bigint", Mods: []int64{64}, ArrayBounds: []int64(nil)}, NotNull: false, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: ""}, - "userid": schema.Column{Name: "userid", Type: schema.Type{Name: "text", Mods: []int64(nil), ArrayBounds: []int64(nil)}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: ""}}, - PrimaryKeys: []schema.Key{schema.Key{ColId: "productid", Desc: false, Order: 0}, schema.Key{ColId: "userid", Desc: false, Order: 0}}, - ForeignKeys: []schema.ForeignKey{schema.ForeignKey{Name: "fk_test2", ColIds: []string{"productid"}, ReferTableId: "product", ReferColumnIds: []string{"product_id"}, OnDelete: constants.FK_NO_ACTION, OnUpdate: constants.FK_NO_ACTION, Id: ""}, schema.ForeignKey{Name: "fk_test3", ColIds: []string{"userid"}, ReferTableId: "user", ReferColumnIds: []string{"user_id"}, OnUpdate: constants.FK_SET_NULL, OnDelete: constants.FK_RESTRICT, Id: ""}}, - Indexes: []schema.Index{schema.Index{Name: "index1", Unique: true, Keys: []schema.Key{schema.Key{ColId: "userid", Desc: false, Order: 0}}, Id: "", StoredColumnIds: []string(nil)}, schema.Index{Name: "index2", Unique: false, Keys: []schema.Key{schema.Key{ColId: "userid", Desc: false, Order: 0}, schema.Key{ColId: "productid", Desc: true, Order: 0}}, Id: "", StoredColumnIds: []string(nil)}, schema.Index{Name: "index3", Unique: true, Keys: []schema.Key{schema.Key{ColId: "productid", Desc: false, Order: 0}, schema.Key{ColId: "userid", Desc: true, Order: 0}}, Id: "", StoredColumnIds: []string(nil)}}, Id: ""}, + "cart": { + Name: "cart", Schema: "test", ColIds: []string{"productid", "userid", "quantity"}, ColDefs: map[string]schema.Column{ + "productid": {Name: "productid", Type: schema.Type{Name: "text", Mods: []int64(nil), ArrayBounds: []int64(nil)}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: ""}, + "quantity": {Name: "quantity", Type: schema.Type{Name: "bigint", Mods: []int64{64}, ArrayBounds: []int64(nil)}, NotNull: false, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: ""}, + "userid": {Name: "userid", Type: schema.Type{Name: "text", Mods: []int64(nil), ArrayBounds: []int64(nil)}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: ""}, + }, + PrimaryKeys: []schema.Key{{ColId: "productid", Desc: false, Order: 0}, {ColId: "userid", Desc: false, Order: 0}}, + ForeignKeys: []schema.ForeignKey{{Name: "fk_test2", ColIds: []string{"productid"}, ReferTableId: "product", ReferColumnIds: []string{"product_id"}, OnDelete: constants.FK_NO_ACTION, OnUpdate: constants.FK_NO_ACTION, Id: ""}, {Name: "fk_test3", ColIds: []string{"userid"}, ReferTableId: "user", ReferColumnIds: []string{"user_id"}, OnUpdate: constants.FK_SET_NULL, OnDelete: constants.FK_RESTRICT, Id: ""}}, + Indexes: []schema.Index{{Name: "index1", Unique: true, Keys: []schema.Key{{ColId: "userid", Desc: false, Order: 0}}, Id: "", StoredColumnIds: []string(nil)}, {Name: "index2", Unique: false, Keys: []schema.Key{{ColId: "userid", Desc: false, Order: 0}, {ColId: "productid", Desc: true, Order: 0}}, Id: "", StoredColumnIds: []string(nil)}, {Name: "index3", Unique: true, Keys: []schema.Key{{ColId: "productid", Desc: false, Order: 0}, {ColId: "userid", Desc: true, Order: 0}}, Id: "", StoredColumnIds: []string(nil)}}, Id: "", + }, - "product": schema.Table{Name: "product", Schema: "test", ColIds: []string{"product_id", "product_name"}, ColDefs: map[string]schema.Column{ - "product_id": schema.Column{Name: "product_id", Type: schema.Type{Name: "text", Mods: []int64(nil), ArrayBounds: []int64(nil)}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: ""}, - "product_name": schema.Column{Name: "product_name", Type: schema.Type{Name: "text", Mods: []int64(nil), ArrayBounds: []int64(nil)}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: ""}}, - PrimaryKeys: []schema.Key{schema.Key{ColId: "product_id", Desc: false, Order: 0}}, + "product": { + Name: "product", Schema: "test", ColIds: []string{"product_id", "product_name"}, ColDefs: map[string]schema.Column{ + "product_id": {Name: "product_id", Type: schema.Type{Name: "text", Mods: []int64(nil), ArrayBounds: []int64(nil)}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: ""}, + "product_name": {Name: "product_name", Type: schema.Type{Name: "text", Mods: []int64(nil), ArrayBounds: []int64(nil)}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: ""}, + }, + PrimaryKeys: []schema.Key{{ColId: "product_id", Desc: false, Order: 0}}, ForeignKeys: []schema.ForeignKey(nil), Indexes: []schema.Index(nil), Id: ""}, "test": schema.Table{Name: "test", Schema: "test", ColIds: []string{"id", "s", "txt", "b", "bs", "bl", "c", "c8", "d", "dec", "f8", "f4", "i8", "i4", "i2", "si", "ts", "tz", "vc", "vc6"}, ColDefs: map[string]schema.Column{ @@ -280,7 +348,8 @@ func TestProcessData(t *testing.T) { rows: [][]driver.Value{ {42.3, 3, "cat"}, {6.6, 22, "dog"}, - {6.6, "2006-01-02", "dog"}}, // Test bad row logic. + {6.6, "2006-01-02", "dog"}, + }, // Test bad row logic. }, } db := mkMockDB(t, ms) @@ -290,9 +359,9 @@ func TestProcessData(t *testing.T) { Id: "t1", ColIds: []string{"c1", "c2", "c3"}, ColDefs: map[string]ddl.ColumnDef{ - "c1": ddl.ColumnDef{Name: "a_a", Id: "c1", T: ddl.Type{Name: ddl.Float64}}, - "c2": ddl.ColumnDef{Name: "Ab", Id: "c2", T: ddl.Type{Name: ddl.Int64}}, - "c3": ddl.ColumnDef{Name: "Ac_", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c1": {Name: "a_a", Id: "c1", T: ddl.Type{Name: ddl.Float64}}, + "c2": {Name: "Ab", Id: "c2", T: ddl.Type{Name: ddl.Int64}}, + "c3": {Name: "Ac_", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, }, }, schema.Table{ @@ -301,9 +370,9 @@ func TestProcessData(t *testing.T) { Schema: "test", ColIds: []string{"c1", "c2", "c3"}, ColDefs: map[string]schema.Column{ - "c1": schema.Column{Name: "a a", Id: "c1", Type: schema.Type{Name: "float"}}, - "c2": schema.Column{Name: " b", Id: "c2", Type: schema.Type{Name: "int"}}, - "c3": schema.Column{Name: " c ", Id: "c3", Type: schema.Type{Name: "text"}}, + "c1": {Name: "a a", Id: "c1", Type: schema.Type{Name: "float"}}, + "c2": {Name: " b", Id: "c2", Type: schema.Type{Name: "int"}}, + "c3": {Name: " c ", Id: "c3", Type: schema.Type{Name: "text"}}, }, ColNameIdMap: map[string]string{ "a a": "c1", @@ -323,8 +392,8 @@ func TestProcessData(t *testing.T) { commonInfoSchema.ProcessData(conv, isi, internal.AdditionalDataAttributes{}) assert.Equal(t, []spannerData{ - spannerData{table: "te_st", cols: []string{"a_a", "Ab", "Ac_"}, vals: []interface{}{float64(42.3), int64(3), "cat"}}, - spannerData{table: "te_st", cols: []string{"a_a", "Ab", "Ac_"}, vals: []interface{}{float64(6.6), int64(22), "dog"}}, + {table: "te_st", cols: []string{"a_a", "Ab", "Ac_"}, vals: []interface{}{float64(42.3), int64(3), "cat"}}, + {table: "te_st", cols: []string{"a_a", "Ab", "Ac_"}, vals: []interface{}{float64(6.6), int64(22), "dog"}}, }, rows) assert.Equal(t, conv.BadRows(), int64(1)) @@ -344,23 +413,35 @@ func TestProcessData_MultiCol(t *testing.T) { args: []driver.Value{"test"}, cols: []string{"table_name"}, rows: [][]driver.Value{{"test"}}, - }, { + }, + { + query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + args: nil, + cols: []string{"count"}, + rows: [][]driver.Value{ + {int64(0)}, + }, + }, + { query: "SELECT (.+) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS (.+)", args: []driver.Value{"test", "test"}, cols: []string{"column_name", "constraint_type"}, rows: [][]driver.Value{}, // No primary key --> force generation of synthetic key. - }, { + }, + { query: "SELECT (.+) FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS (.+)", args: []driver.Value{"test", "test"}, cols: []string{"REFERENCED_TABLE_NAME", "COLUMN_NAME", "REFERENCED_COLUMN_NAME", "CONSTRAINT_NAME", "DELETE_RULE", "UPDATE_RULE"}, - }, { + }, + { query: "SELECT (.+) FROM information_schema.COLUMNS (.+)", args: []driver.Value{"test", "test"}, cols: []string{"column_name", "data_type", "column_type", "is_nullable", "column_default", "character_maximum_length", "numeric_precision", "numeric_scale", "extra"}, rows: [][]driver.Value{ {"a", "text", "text", "NO", nil, nil, nil, nil, nil}, {"b", "double", "double", "YES", nil, nil, 53, nil, nil}, - {"c", "bigint", "bigint", "YES", nil, nil, 64, 0, nil}}, + {"c", "bigint", "bigint", "YES", nil, nil, 64, 0, nil}, + }, }, { query: "SELECT (.+) FROM INFORMATION_SCHEMA.STATISTICS (.+)", @@ -372,7 +453,8 @@ func TestProcessData_MultiCol(t *testing.T) { cols: []string{"a", "b", "c"}, rows: [][]driver.Value{ {"cat", 42.3, nil}, - {"dog", nil, 22}}, + {"dog", nil, 22}, + }, }, } db := mkMockDB(t, ms) @@ -382,16 +464,17 @@ func TestProcessData_MultiCol(t *testing.T) { err := processSchema.ProcessSchema(conv, isi, 1, internal.AdditionalSchemaAttributes{}, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) expectedSchema := map[string]ddl.CreateTable{ - "test": ddl.CreateTable{ + "test": { Name: "test", ColIds: []string{"a", "b", "c", "synth_id"}, ColDefs: map[string]ddl.ColumnDef{ - "a": ddl.ColumnDef{Name: "a", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - "b": ddl.ColumnDef{Name: "b", T: ddl.Type{Name: ddl.Float64}}, - "c": ddl.ColumnDef{Name: "c", T: ddl.Type{Name: ddl.Int64}}, - "synth_id": ddl.ColumnDef{Name: "synth_id", T: ddl.Type{Name: ddl.String, Len: 50}}, + "a": {Name: "a", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + "b": {Name: "b", T: ddl.Type{Name: ddl.Float64}}, + "c": {Name: "c", T: ddl.Type{Name: ddl.Int64}}, + "synth_id": {Name: "synth_id", T: ddl.Type{Name: ddl.String, Len: 50}}, }, - PrimaryKeys: []ddl.IndexKey{ddl.IndexKey{ColId: "synth_id", Order: 1}}}, + PrimaryKeys: []ddl.IndexKey{{ColId: "synth_id", Order: 1}}, + }, } internal.AssertSpSchema(conv, t, expectedSchema, stripSchemaComments(conv.SpSchema)) columnLevelIssues := map[string][]internal.SchemaIssue{ @@ -416,7 +499,8 @@ func TestProcessData_MultiCol(t *testing.T) { commonInfoSchema.ProcessData(conv, isi, internal.AdditionalDataAttributes{}) assert.Equal(t, []spannerData{ {table: "test", cols: []string{"a", "b", "synth_id"}, vals: []interface{}{"cat", float64(42.3), "0"}}, - {table: "test", cols: []string{"a", "c", "synth_id"}, vals: []interface{}{"dog", int64(22), "-9223372036854775808"}}}, + {table: "test", cols: []string{"a", "c", "synth_id"}, vals: []interface{}{"dog", int64(22), "-9223372036854775808"}}, + }, rows) assert.Equal(t, int64(0), conv.Unexpecteds()) } @@ -433,23 +517,35 @@ func TestProcessSchema_Sharded(t *testing.T) { args: []driver.Value{"test"}, cols: []string{"table_name"}, rows: [][]driver.Value{{"test"}}, - }, { + }, + { + query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + args: nil, + cols: []string{"count"}, + rows: [][]driver.Value{ + {int64(0)}, + }, + }, + { query: "SELECT (.+) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS (.+)", args: []driver.Value{"test", "test"}, cols: []string{"column_name", "constraint_type"}, rows: [][]driver.Value{}, // No primary key --> force generation of synthetic key. - }, { + }, + { query: "SELECT (.+) FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS (.+)", args: []driver.Value{"test", "test"}, cols: []string{"REFERENCED_TABLE_NAME", "COLUMN_NAME", "REFERENCED_COLUMN_NAME", "CONSTRAINT_NAME", "DELETE_RULE", "UPDATE_RULE"}, - }, { + }, + { query: "SELECT (.+) FROM information_schema.COLUMNS (.+)", args: []driver.Value{"test", "test"}, cols: []string{"column_name", "data_type", "column_type", "is_nullable", "column_default", "character_maximum_length", "numeric_precision", "numeric_scale", "extra"}, rows: [][]driver.Value{ {"a", "text", "text", "NO", nil, nil, nil, nil, nil}, {"b", "double", "double", "YES", nil, nil, 53, nil, nil}, - {"c", "bigint", "bigint", "YES", nil, nil, 64, 0, nil}}, + {"c", "bigint", "bigint", "YES", nil, nil, 64, 0, nil}, + }, }, { query: "SELECT (.+) FROM INFORMATION_SCHEMA.STATISTICS (.+)", @@ -461,7 +557,8 @@ func TestProcessSchema_Sharded(t *testing.T) { cols: []string{"a", "b", "c"}, rows: [][]driver.Value{ {"cat", 42.3, nil}, - {"dog", nil, 22}}, + {"dog", nil, 22}, + }, }, } db := mkMockDB(t, ms) @@ -481,7 +578,8 @@ func TestProcessSchema_Sharded(t *testing.T) { "synth_id": {Name: "synth_id", T: ddl.Type{Name: ddl.String, Len: 50}}, "migration_shard_id": {Name: "migration_shard_id", T: ddl.Type{Name: ddl.String, Len: 50}}, }, - PrimaryKeys: []ddl.IndexKey{{ColId: "synth_id", Order: 1}}}, + PrimaryKeys: []ddl.IndexKey{{ColId: "synth_id", Order: 1}}, + }, } internal.AssertSpSchema(conv, t, expectedSchema, stripSchemaComments(conv.SpSchema)) } @@ -530,3 +628,55 @@ func mkMockDB(t *testing.T, ms []mockSpec) *sql.DB { } return db } + +func TestGetConstraints_CheckConstraintsTableExists(t *testing.T) { + ms := []mockSpec{ + { + query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + cols: []string{"COUNT(*)"}, + rows: [][]driver.Value{{1}}, + }, + { + query: regexp.QuoteMeta(`SELECT DISTINCT COALESCE(k.COLUMN_NAME,'') AS COLUMN_NAME,t.CONSTRAINT_NAME, t.CONSTRAINT_TYPE, COALESCE(c.CHECK_CLAUSE, '') AS CHECK_CLAUSE + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t + LEFT JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k + ON t.CONSTRAINT_NAME = k.CONSTRAINT_NAME + AND t.CONSTRAINT_SCHEMA = k.CONSTRAINT_SCHEMA + AND t.TABLE_NAME = k.TABLE_NAME + LEFT JOIN INFORMATION_SCHEMA.CHECK_CONSTRAINTS AS c + ON t.CONSTRAINT_NAME = c.CONSTRAINT_NAME + WHERE t.TABLE_SCHEMA = ? + AND t.TABLE_NAME = ?;`), + args: []driver.Value{"test_schema", "test_table"}, + cols: []string{"COLUMN_NAME", "CONSTRAINT_NAME", "CONSTRAINT_TYPE", "CHECK_CLAUSE"}, + rows: [][]driver.Value{{"column1", "PRIMARY", "PRIMARY KEY", ""}, {"column2", "check_name", "CHECK", "(column2 > 0)"}}, + }, + } + db := mkMockDB(t, ms) + isi := InfoSchemaImpl{Db: db} + conv := &internal.Conv{} + + primaryKeys, checkKeys, m, err := isi.GetConstraints(conv, common.SchemaAndName{Schema: "test_schema", Name: "test_table"}) + assert.NoError(t, err) + assert.Equal(t, []string{"column1"}, primaryKeys) + assert.Equal(t, len(checkKeys), 1) + assert.Equal(t, checkKeys[0].Name, "check_name") + assert.Equal(t, checkKeys[0].Expr, "(column2 > 0)") + assert.NotNil(t, m) +} + +func TestGetConstraints_CheckConstraintsTableAbsent(t *testing.T) { + ms := []mockSpec{ + { + query: `SELECT COUNT\(\*\) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'INFORMATION_SCHEMA' AND TABLE_NAME = 'CHECK_CONSTRAINTS';`, + cols: []string{"COUNT(*)"}, + rows: [][]driver.Value{{0}}, + }, + } + db := mkMockDB(t, ms) + isi := InfoSchemaImpl{Db: db} + conv := &internal.Conv{} + + _, _, _, err := isi.GetConstraints(conv, common.SchemaAndName{Schema: "your_schema", Name: "your_table"}) + assert.Error(t, err) +} diff --git a/sources/mysql/mysqldump.go b/sources/mysql/mysqldump.go index 6145aee27..05e473f1d 100644 --- a/sources/mysql/mysqldump.go +++ b/sources/mysql/mysqldump.go @@ -36,6 +36,7 @@ import ( var valuesRegexp = regexp.MustCompile("\\((.*?)\\)") var insertRegexp = regexp.MustCompile("INSERT\\sINTO\\s(.*?)\\sVALUES\\s") var unsupportedRegexp = regexp.MustCompile("function|procedure|trigger") +var collationRegex = regexp.MustCompile(constants.DB_COLLATION_REGEX) // MysqlSpatialDataTypes is an array of all MySQL spatial data types. var MysqlSpatialDataTypes = []string{"geometrycollection", "multipoint", "multilinestring", "multipolygon", "point", "linestring", "polygon", "geometry"} diff --git a/sources/oracle/infoschema.go b/sources/oracle/infoschema.go index f62bfb270..770ef93cc 100644 --- a/sources/oracle/infoschema.go +++ b/sources/oracle/infoschema.go @@ -247,7 +247,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := fmt.Sprintf(` SELECT k.column_name, @@ -260,7 +260,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem `, table.Schema, table.Name) rows, err := isi.Db.Query(q) if err != nil { - return nil, nil, err + return nil, nil, nil, err } defer rows.Close() var primaryKeys []string @@ -292,7 +292,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys return list all the foreign keys constraints. diff --git a/sources/postgres/infoschema.go b/sources/postgres/infoschema.go index ccf7b8dcf..e4d278b0b 100644 --- a/sources/postgres/infoschema.go +++ b/sources/postgres/infoschema.go @@ -337,7 +337,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := `SELECT k.COLUMN_NAME, t.CONSTRAINT_TYPE FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS k @@ -345,7 +345,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem WHERE k.TABLE_SCHEMA = $1 AND k.TABLE_NAME = $2 ORDER BY k.ordinal_position;` rows, err := isi.Db.Query(q, table.Schema, table.Name) if err != nil { - return nil, nil, err + return nil, nil, nil, err } defer rows.Close() var primaryKeys []string @@ -368,7 +368,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys returns a list of all the foreign key constraints. diff --git a/sources/spanner/infoschema.go b/sources/spanner/infoschema.go index ca08985bc..c1bf36b4d 100644 --- a/sources/spanner/infoschema.go +++ b/sources/spanner/infoschema.go @@ -190,7 +190,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := `SELECT k.column_name, t.constraint_type FROM information_schema.table_constraints AS t INNER JOIN information_schema.KEY_COLUMN_USAGE AS k @@ -221,11 +221,11 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem break } if err != nil { - return nil, nil, fmt.Errorf("couldn't get row while reading constraints: %w", err) + return nil, nil, nil, fmt.Errorf("couldn't get row while reading constraints: %w", err) } err = row.Columns(&col, &constraint) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if col == "" || constraint == "" { conv.Unexpected("Got empty col or constraint") @@ -238,7 +238,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys returns a list of all the foreign key constraints. diff --git a/sources/sqlserver/infoschema.go b/sources/sqlserver/infoschema.go index 50f46d157..b9c98c544 100644 --- a/sources/sqlserver/infoschema.go +++ b/sources/sqlserver/infoschema.go @@ -280,7 +280,7 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd // other constraints. Note: we need to preserve ordinal order of // columns in primary key constraints. // Note that foreign key constraints are handled in getForeignKeys. -func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, []schema.CheckConstraint, map[string][]string, error) { q := ` SELECT k.COLUMN_NAME, @@ -292,7 +292,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem ` rows, err := isi.Db.Query(q, table.Schema, table.Name) if err != nil { - return nil, nil, err + return nil, nil, nil, err } defer rows.Close() var primaryKeys []string @@ -315,7 +315,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem m[col] = append(m[col], constraint) } } - return primaryKeys, m, nil + return primaryKeys, nil, m, nil } // GetForeignKeys returns a list of all the foreign key constraints. diff --git a/spanner/ddl/ast.go b/spanner/ddl/ast.go index 91646281a..f6534b262 100644 --- a/spanner/ddl/ast.go +++ b/spanner/ddl/ast.go @@ -196,7 +196,7 @@ type Config struct { Tables bool // If true, print tables ForeignKeys bool // If true, print foreign key constraints. SpDialect string - Source string //SourceDB information for determining case-sensitivity handling for PGSQL + Source string // SourceDB information for determining case-sensitivity handling for PGSQL } func isIdentifierReservedInPG(identifier string) bool { @@ -265,6 +265,12 @@ type IndexKey struct { Order int } +type CheckConstraint struct { + Id string + Name string + Expr string +} + // PrintPkOrIndexKey unparses the primary or index keys. func (idx IndexKey) PrintPkOrIndexKey(ct CreateTable, c Config) string { col := c.quote(ct.ColDefs[idx.ColId].Name) @@ -319,16 +325,17 @@ func (k Foreignkey) PrintForeignKey(c Config) string { // // create_table: CREATE TABLE table_name ([column_def, ...] ) primary_key [, cluster] type CreateTable struct { - Name string - ColIds []string // Provides names and order of columns - ShardIdColumn string - ColDefs map[string]ColumnDef // Provides definition of columns (a map for simpler/faster lookup during type processing) - PrimaryKeys []IndexKey - ForeignKeys []Foreignkey - Indexes []CreateIndex - ParentTable InterleavedParent //if not empty, this table will be interleaved - Comment string - Id string + Name string + ColIds []string // Provides names and order of columns + ShardIdColumn string + ColDefs map[string]ColumnDef // Provides definition of columns (a map for simpler/faster lookup during type processing) + PrimaryKeys []IndexKey + ForeignKeys []Foreignkey + Indexes []CreateIndex + ParentTable InterleavedParent // if not empty, this table will be interleaved + CheckConstraints []CheckConstraint + Comment string + Id string } // PrintCreateTable unparses a CREATE TABLE statement. @@ -382,13 +389,20 @@ func (ct CreateTable) PrintCreateTable(spSchema Schema, config Config) string { } } + var checkString string + if len(ct.CheckConstraints) > 0 { + checkString = FormatCheckConstraints(ct.CheckConstraints) + } else { + checkString = "" + } + if len(keys) == 0 { - return fmt.Sprintf("%sCREATE TABLE %s (\n%s) %s", tableComment, config.quote(ct.Name), cols, interleave) + return fmt.Sprintf("%sCREATE TABLE %s (\n%s%s) %s", tableComment, config.quote(ct.Name), cols, checkString, interleave) } if config.SpDialect == constants.DIALECT_POSTGRESQL { return fmt.Sprintf("%sCREATE TABLE %s (\n%s\tPRIMARY KEY (%s)\n)%s", tableComment, config.quote(ct.Name), cols, strings.Join(keys, ", "), interleave) } - return fmt.Sprintf("%sCREATE TABLE %s (\n%s) PRIMARY KEY (%s)%s", tableComment, config.quote(ct.Name), cols, strings.Join(keys, ", "), interleave) + return fmt.Sprintf("%sCREATE TABLE %s (\n%s%s) PRIMARY KEY (%s)%s", tableComment, config.quote(ct.Name), cols, checkString, strings.Join(keys, ", "), interleave) } // CreateIndex encodes the following DDL definition: @@ -534,6 +548,27 @@ func (k Foreignkey) PrintForeignKeyAlterTable(spannerSchema Schema, c Config, ta return s } +// FormatCheckConstraints formats the check constraints in SQL syntax. +func FormatCheckConstraints(cks []CheckConstraint) string { + var builder strings.Builder + + for _, col := range cks { + if col.Name != "" { + builder.WriteString(fmt.Sprintf("\tCONSTRAINT %s CHECK %s,\n", col.Name, col.Expr)) + } else { + builder.WriteString(fmt.Sprintf("\tCHECK %s,\n", col.Expr)) + } + } + + if builder.Len() > 0 { + // Trim the trailing comma and newline + result := builder.String() + return result[:len(result)-2] + "\n" + } + + return "" +} + // Schema stores a map of table names and Tables. type Schema map[string]CreateTable @@ -548,7 +583,6 @@ func NewSchema() Schema { // TODO: Move this method to mapping.go and preserve the table names in sorted // order in conv so that we don't need to order the table names multiple times. func GetSortedTableIdsBySpName(s Schema) []string { - var tableNames, sortedTableNames, sortedTableIds []string tableNameIdMap := map[string]string{} for _, t := range s { diff --git a/spanner/ddl/ast_test.go b/spanner/ddl/ast_test.go index b9e6a510e..c84dfa800 100644 --- a/spanner/ddl/ast_test.go +++ b/spanner/ddl/ast_test.go @@ -143,6 +143,10 @@ func TestPrintCreateTable(t *testing.T) { }, PrimaryKeys: []IndexKey{{ColId: "col1", Desc: true}}, ForeignKeys: nil, + CheckConstraints: []CheckConstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + }, Indexes: nil, ParentTable: InterleavedParent{}, Comment: "", @@ -156,12 +160,13 @@ func TestPrintCreateTable(t *testing.T) { "col4": {Name: "col4", T: Type{Name: Int64}, NotNull: true}, "col5": {Name: "col5", T: Type{Name: String, Len: MaxLength}, NotNull: false}, }, - PrimaryKeys: []IndexKey{{ColId: "col4", Desc: true}}, - ForeignKeys: nil, - Indexes: nil, - ParentTable: InterleavedParent{Id: "t1", OnDelete: constants.FK_CASCADE}, - Comment: "", - Id: "t2", + PrimaryKeys: []IndexKey{{ColId: "col4", Desc: true}}, + ForeignKeys: nil, + Indexes: nil, + CheckConstraints: nil, + ParentTable: InterleavedParent{Id: "t1", OnDelete: constants.FK_CASCADE}, + Comment: "", + Id: "t2", }, "t3": CreateTable{ Name: "table3", @@ -170,12 +175,33 @@ func TestPrintCreateTable(t *testing.T) { ColDefs: map[string]ColumnDef{ "col6": {Name: "col6", T: Type{Name: Int64}, NotNull: true}, }, - PrimaryKeys: []IndexKey{{ColId: "col6", Desc: true}}, + PrimaryKeys: []IndexKey{{ColId: "col6", Desc: true}}, + ForeignKeys: nil, + Indexes: nil, + CheckConstraints: nil, + ParentTable: InterleavedParent{Id: "t1", OnDelete: ""}, + Comment: "", + Id: "t3", + }, + "t4": CreateTable{ + Name: "table1", + ColIds: []string{"col1", "col2", "col3"}, + ShardIdColumn: "", + ColDefs: map[string]ColumnDef{ + "col1": {Name: "col1", T: Type{Name: Int64}, NotNull: true}, + "col2": {Name: "col2", T: Type{Name: String, Len: MaxLength}, NotNull: false}, + "col3": {Name: "col3", T: Type{Name: Bytes, Len: int64(42)}, NotNull: false}, + }, + PrimaryKeys: nil, ForeignKeys: nil, + CheckConstraints: []CheckConstraint{ + {Id: "ck1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "ck2", Name: "check_2", Expr: "(age < 99)"}, + }, Indexes: nil, - ParentTable: InterleavedParent{Id: "t1", OnDelete: ""}, + ParentTable: InterleavedParent{}, Comment: "", - Id: "t3", + Id: "t1", }, } tests := []struct { @@ -192,6 +218,7 @@ func TestPrintCreateTable(t *testing.T) { " col1 INT64 NOT NULL ,\n" + " col2 STRING(MAX),\n" + " col3 BYTES(42),\n" + + "\tCONSTRAINT check_1 CHECK (age > 18),\n\tCONSTRAINT check_2 CHECK (age < 99)\n" + ") PRIMARY KEY (col1 DESC)", }, { @@ -202,6 +229,7 @@ func TestPrintCreateTable(t *testing.T) { " `col1` INT64 NOT NULL ,\n" + " `col2` STRING(MAX),\n" + " `col3` BYTES(42),\n" + + "\tCONSTRAINT check_1 CHECK (age > 18),\n\tCONSTRAINT check_2 CHECK (age < 99)\n" + ") PRIMARY KEY (`col1` DESC)", }, { @@ -223,6 +251,17 @@ func TestPrintCreateTable(t *testing.T) { ") PRIMARY KEY (col6 DESC),\n" + "INTERLEAVE IN PARENT table1", }, + { + "no quote", + false, + s["t4"], + "CREATE TABLE table1 (\n" + + " col1 INT64 NOT NULL ,\n" + + " col2 STRING(MAX),\n" + + " col3 BYTES(42),\n" + + "\tCONSTRAINT check_1 CHECK (age > 18),\n\tCONSTRAINT check_2 CHECK (age < 99)\n" + + ") ", + }, } for _, tc := range tests { assert.Equal(t, tc.expected, tc.ct.PrintCreateTable(s, Config{ProtectIds: tc.protectIds})) @@ -356,7 +395,8 @@ func TestPrintCreateIndex(t *testing.T) { []IndexKey{{ColId: "c1", Desc: true}, {ColId: "c2"}}, "i2", nil, - }} + }, + } tests := []struct { name string protectIds bool @@ -427,13 +467,13 @@ func TestPrintForeignKey(t *testing.T) { func TestPrintForeignKeyAlterTable(t *testing.T) { spannerSchema := map[string]CreateTable{ - "t1": CreateTable{ + "t1": { Name: "table1", ColIds: []string{"c1", "c2", "c3"}, ColDefs: map[string]ColumnDef{ - "c1": ColumnDef{Name: "productid", T: Type{Name: String, Len: MaxLength}}, - "c2": ColumnDef{Name: "userid", T: Type{Name: String, Len: MaxLength}}, - "c3": ColumnDef{Name: "quantity", T: Type{Name: Int64}}, + "c1": {Name: "productid", T: Type{Name: String, Len: MaxLength}}, + "c2": {Name: "userid", T: Type{Name: String, Len: MaxLength}}, + "c3": {Name: "quantity", T: Type{Name: Int64}}, }, ForeignKeys: []Foreignkey{ { @@ -466,14 +506,15 @@ func TestPrintForeignKeyAlterTable(t *testing.T) { }, }, - "t2": CreateTable{ + "t2": { Name: "table2", ColIds: []string{"c4", "c5"}, ColDefs: map[string]ColumnDef{ - "c4": ColumnDef{Name: "productid", T: Type{Name: String, Len: MaxLength}}, - "c5": ColumnDef{Name: "userid", T: Type{Name: String, Len: MaxLength}}, + "c4": {Name: "productid", T: Type{Name: String, Len: MaxLength}}, + "c5": {Name: "userid", T: Type{Name: String, Len: MaxLength}}, }, - }} + }, + } tests := []struct { name string @@ -843,7 +884,8 @@ func TestGetDDL(t *testing.T) { StartWithCounter: "7", } e4 := []string{ - "CREATE SEQUENCE sequence1 OPTIONS (sequence_kind='bit_reversed_positive', skip_range_min = 0, skip_range_max = 5, start_with_counter = 7) "} + "CREATE SEQUENCE sequence1 OPTIONS (sequence_kind='bit_reversed_positive', skip_range_min = 0, skip_range_max = 5, start_with_counter = 7) ", + } sequencesOnly := GetDDL(Config{}, Schema{}, sequences) assert.ElementsMatch(t, e4, sequencesOnly) } @@ -955,7 +997,8 @@ func TestGetPGDDL(t *testing.T) { StartWithCounter: "7", } e4 := []string{ - "CREATE SEQUENCE sequence1 BIT_REVERSED_POSITIVE SKIP RANGE 0 5 START COUNTER WITH 7"} + "CREATE SEQUENCE sequence1 BIT_REVERSED_POSITIVE SKIP RANGE 0 5 START COUNTER WITH 7", + } sequencesOnly := GetDDL(Config{SpDialect: constants.DIALECT_POSTGRESQL}, Schema{}, sequences) assert.ElementsMatch(t, e4, sequencesOnly) } @@ -1040,3 +1083,46 @@ func TestGetSortedTableIdsBySpName(t *testing.T) { }) } } + +func TestFormatCheckConstraints(t *testing.T) { + tests := []struct { + description string + cks []CheckConstraint + expected string + }{ + { + description: "Empty constraints list", + cks: []CheckConstraint{}, + expected: "", + }, + { + description: "Single constraint", + cks: []CheckConstraint{ + {Name: "ck1", Expr: "(id > 0)"}, + }, + expected: "\tCONSTRAINT ck1 CHECK (id > 0)\n", + }, + { + description: "Constraint without name", + cks: []CheckConstraint{ + {Name: "", Expr: "(id > 0)"}, + }, + expected: "\tCHECK (id > 0)\n", + }, + { + description: "Multiple constraints", + cks: []CheckConstraint{ + {Name: "ck1", Expr: "(id > 0)"}, + {Name: "ck2", Expr: "(name IS NOT NULL)"}, + }, + expected: "\tCONSTRAINT ck1 CHECK (id > 0),\n\tCONSTRAINT ck2 CHECK (name IS NOT NULL)\n", + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + actual := FormatCheckConstraints(tc.cks) + assert.Equal(t, tc.expected, actual) + }) + } +} diff --git a/webv2/api/schema.go b/webv2/api/schema.go index 18cd3d4ff..c0f5be2ed 100644 --- a/webv2/api/schema.go +++ b/webv2/api/schema.go @@ -32,15 +32,19 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/utilities" ) -var mysqlDefaultTypeMap = make(map[string]ddl.Type) -var postgresDefaultTypeMap = make(map[string]ddl.Type) -var sqlserverDefaultTypeMap = make(map[string]ddl.Type) -var oracleDefaultTypeMap = make(map[string]ddl.Type) +var ( + mysqlDefaultTypeMap = make(map[string]ddl.Type) + postgresDefaultTypeMap = make(map[string]ddl.Type) + sqlserverDefaultTypeMap = make(map[string]ddl.Type) + oracleDefaultTypeMap = make(map[string]ddl.Type) +) -var mysqlTypeMap = make(map[string][]types.TypeIssue) -var postgresTypeMap = make(map[string][]types.TypeIssue) -var sqlserverTypeMap = make(map[string][]types.TypeIssue) -var oracleTypeMap = make(map[string][]types.TypeIssue) +var ( + mysqlTypeMap = make(map[string][]types.TypeIssue) + postgresTypeMap = make(map[string][]types.TypeIssue) + sqlserverTypeMap = make(map[string][]types.TypeIssue) + oracleTypeMap = make(map[string][]types.TypeIssue) +) var autoGenMap = make(map[string][]types.AutoGen) @@ -330,7 +334,6 @@ func GetTypeMap(w http.ResponseWriter, r *http.Request) { } else { filteredTypeMap[key][i].DisplayT = filteredTypeMap[key][i].T } - } } w.WriteHeader(http.StatusOK) @@ -338,7 +341,6 @@ func GetTypeMap(w http.ResponseWriter, r *http.Request) { } func GetAutoGenMap(w http.ResponseWriter, r *http.Request) { - sessionState := session.GetSessionState() if sessionState.Conv == nil || sessionState.Driver == "" { http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) @@ -485,7 +487,91 @@ func RestoreSecondaryIndex(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(convm) +} + +// UpdateCheckConstraint processes the request to update spanner table check constraints, ensuring session and schema validity, and responds with the updated conversion metadata. +func UpdateCheckConstraint(w http.ResponseWriter, r *http.Request) { + tableId := r.FormValue("table") + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + } + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + + newCc := []ddl.CheckConstraint{} + if err = json.Unmarshal(reqBody, &newCc); err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + + sp := sessionState.Conv.SpSchema[tableId] + sp.CheckConstraints = newCc + sessionState.Conv.SpSchema[tableId] = sp + session.UpdateSessionFile() + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} +func doesNameExist(spcks []ddl.CheckConstraint, targetName string) bool { + for _, spck := range spcks { + if strings.Contains(spck.Expr, targetName) { + return true + } + } + return false +} + +// ValidateCheckConstraint verifies if the type of a database column has been altered and add an error if a change is detected. +func ValidateCheckConstraint(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + + sp := sessionState.Conv.SpSchema + srcschema := sessionState.Conv.SrcSchema + flag := true + var schemaIssue []internal.SchemaIssue + + for _, src := range srcschema { + for _, col := range sp[src.Id].ColDefs { + if len(sp[src.Id].CheckConstraints) > 0 { + spType := col.T.Name + srcType := srcschema[src.Id].ColDefs[col.Id].Type + actualType := mysqlDefaultTypeMap[srcType.Name] + if actualType.Name != spType { + columnName := sp[src.Id].ColDefs[col.Id].Name + spcks := sp[src.Id].CheckConstraints + if doesNameExist(spcks, columnName) { + flag = false + schemaIssue = sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] + if !utilities.IsSchemaIssuePresent(schemaIssue, internal.TypeMismatch) { + schemaIssue = append(schemaIssue, internal.TypeMismatch) + } + sessionState.Conv.SchemaIssues[src.Id].ColumnLevelIssues[col.Id] = schemaIssue + break + } + } + } + } + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(flag) } // renameForeignKeys checks the new names for spanner name validity, ensures the new names are already not used by existing tables @@ -717,7 +803,8 @@ func SetParentTable(w http.ResponseWriter, r *http.Request) { } json.NewEncoder(w).Encode(map[string]interface{}{ "tableInterleaveStatus": tableInterleaveStatus, - "sessionState": convm}) + "sessionState": convm, + }) } else { json.NewEncoder(w).Encode(map[string]interface{}{ "tableInterleaveStatus": tableInterleaveStatus, @@ -801,7 +888,6 @@ func RemoveParentTable(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(convm) - } func UpdateIndexes(w http.ResponseWriter, r *http.Request) { @@ -837,7 +923,6 @@ func UpdateIndexes(w http.ResponseWriter, r *http.Request) { st := sessionState.Conv.SrcSchema[table] for i, ind := range sp.Indexes { - if ind.TableId == newIndexes[0].TableId && ind.Id == newIndexes[0].Id { index.RemoveIndexIssues(table, sp.Indexes[i]) @@ -1090,11 +1175,8 @@ func checkPrimaryKeyOrder(tableId string, refTableId string, fk ddl.Foreignkey) childTable := sessionState.Conv.SpSchema[tableId] parentTable := sessionState.Conv.SpSchema[refTableId] for i := 0; i < len(parentPks); i++ { - for j := 0; j < len(childPks); j++ { - for k := 0; k < len(fk.ReferColumnIds); k++ { - if childTable.ColDefs[fk.ColIds[k]].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name && parentTable.ColDefs[parentPks[i].ColId].Name == childTable.ColDefs[childPks[j].ColId].Name && parentTable.ColDefs[parentPks[i].ColId].T.Name == childTable.ColDefs[childPks[j].ColId].T.Name && @@ -1106,15 +1188,12 @@ func checkPrimaryKeyOrder(tableId string, refTableId string, fk ddl.Foreignkey) } } } - } - } return "" } func checkPrimaryKeyPrefix(tableId string, refTableId string, fk ddl.Foreignkey, tableInterleaveStatus *types.TableInterleaveStatus) bool { - sessionState := session.GetSessionState() childTable := sessionState.Conv.SpSchema[tableId] parentTable := sessionState.Conv.SpSchema[refTableId] @@ -1151,11 +1230,8 @@ func checkPrimaryKeyPrefix(tableId string, refTableId string, fk ddl.Foreignkey, interleaved := []ddl.IndexKey{} for i := 0; i < len(parentPks); i++ { - for j := 0; j < len(childPks); j++ { - for k := 0; k < len(fk.ReferColumnIds); k++ { - if childTable.ColDefs[fk.ColIds[k]].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name && parentTable.ColDefs[parentPks[i].ColId].Name == childTable.ColDefs[childPks[j].ColId].Name && parentTable.ColDefs[parentPks[i].ColId].T.Name == childTable.ColDefs[childPks[j].ColId].T.Name && @@ -1166,9 +1242,7 @@ func checkPrimaryKeyPrefix(tableId string, refTableId string, fk ddl.Foreignkey, interleaved = append(interleaved, parentPks[i]) } } - } - } if len(interleaved) == len(parentPks) { @@ -1178,18 +1252,13 @@ func checkPrimaryKeyPrefix(tableId string, refTableId string, fk ddl.Foreignkey, diff := []ddl.IndexKey{} if len(interleaved) == 0 { - for i := 0; i < len(parentPks); i++ { - for j := 0; j < len(childPks); j++ { - if parentTable.ColDefs[parentPks[i].ColId].Name != childTable.ColDefs[childPks[j].ColId].Name || parentTable.ColDefs[parentPks[i].ColId].T.Len != childTable.ColDefs[childPks[j].ColId].T.Len { diff = append(diff, parentPks[i]) } - } } - } canInterleavedOnAdd := []string{} @@ -1220,7 +1289,6 @@ func checkPrimaryKeyPrefix(tableId string, refTableId string, fk ddl.Foreignkey, } else { canInterleavedOnRename = append(canInterleavedOnRename, fk.ColIds[parentColIndex]) } - } } @@ -1326,7 +1394,7 @@ func dropTableHelper(w http.ResponseWriter, tableId string) session.ConvWithMeta issues := sessionState.Conv.SchemaIssues syntheticPkey := sessionState.Conv.SyntheticPKeys - //remove deleted name from usedName + // remove deleted name from usedName usedNames := sessionState.Conv.UsedNames delete(usedNames, strings.ToLower(sessionState.Conv.SpSchema[tableId].Name)) for _, index := range sessionState.Conv.SpSchema[tableId].Indexes { @@ -1343,7 +1411,7 @@ func dropTableHelper(w http.ResponseWriter, tableId string) session.ConvWithMeta } delete(syntheticPkey, tableId) - //drop reference foreign key + // drop reference foreign key for tableName, spTable := range spSchema { fks := []ddl.Foreignkey{} for _, fk := range spTable.ForeignKeys { @@ -1352,13 +1420,12 @@ func dropTableHelper(w http.ResponseWriter, tableId string) session.ConvWithMeta } else { delete(usedNames, fk.Name) } - } spTable.ForeignKeys = fks spSchema[tableName] = spTable } - //remove interleave that are interleaved on the drop table as parent + // remove interleave that are interleaved on the drop table as parent for id, spTable := range spSchema { if spTable.ParentTable.Id == tableId { spTable.ParentTable.Id = "" @@ -1367,7 +1434,7 @@ func dropTableHelper(w http.ResponseWriter, tableId string) session.ConvWithMeta } } - //remove interleavable suggestion on droping the parent table + // remove interleavable suggestion on droping the parent table for tableName, tableIssues := range issues { for colName, colIssues := range tableIssues.ColumnLevelIssues { updatedColIssues := []internal.SchemaIssue{} diff --git a/webv2/api/schema_test.go b/webv2/api/schema_test.go index 9e7894ea8..ef487d426 100644 --- a/webv2/api/schema_test.go +++ b/webv2/api/schema_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -2541,3 +2542,83 @@ func TestGetAutoGenMapMySQL(t *testing.T) { } } + +func TestUpdateCheckConstraint(t *testing.T) { + t.Run("ValidCheckConstraints", func(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + tableID := "table1" + + expectedCheckConstraint := []ddl.CheckConstraint{ + {Id: "cc1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "cc2", Name: "check_2", Expr: "(age < 99)"}, + } + + checkConstraints := []schema.CheckConstraint{ + {Id: "cc1", Name: "check_1", Expr: "(age > 18)"}, + {Id: "cc2", Name: "check_2", Expr: "(age < 99)"}, + } + + body, err := json.Marshal(checkConstraints) + assert.NoError(t, err) + + req, err := http.NewRequest("POST", "update/cc", bytes.NewBuffer(body)) + assert.NoError(t, err) + + q := req.URL.Query() + q.Add("table", tableID) + req.URL.RawQuery = q.Encode() + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(api.UpdateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + updatedSp := sessionState.Conv.SpSchema[tableID] + assert.Equal(t, expectedCheckConstraint, updatedSp.CheckConstraints) + }) + + t.Run("ParseError", func(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + + invalidJSON := "invalid json body" + + rr := httptest.NewRecorder() + req, err := http.NewRequest("POST", "update/cc", io.NopCloser(strings.NewReader(invalidJSON))) + assert.NoError(t, err) + + handler := http.HandlerFunc(api.UpdateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) + + expectedErrorMessage := "Request Body parse error" + assert.Contains(t, rr.Body.String(), expectedErrorMessage) + }) + + t.Run("ImproperSession", func(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Conv = nil // Simulate no conversion + + rr := httptest.NewRecorder() + req, err := http.NewRequest("POST", "update/cc", io.NopCloser(errReader{})) + assert.NoError(t, err) + + handler := http.HandlerFunc(api.UpdateCheckConstraint) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), "Schema is not converted or Driver is not configured properly") + }) +} + +type errReader struct{} + +func (errReader) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("simulated read error") +} diff --git a/webv2/routes.go b/webv2/routes.go index 191c1e1a5..fdb7913da 100644 --- a/webv2/routes.go +++ b/webv2/routes.go @@ -75,7 +75,6 @@ func getRoutes() *mux.Router { router.HandleFunc("/spannerDefaultTypeMap", api.SpannerDefaultTypeMap).Methods("GET") router.HandleFunc("/autoGenMap", api.GetAutoGenMap).Methods("GET") router.HandleFunc("/getSequenceKind", api.GetSequenceKind).Methods("GET") - router.HandleFunc("/setparent", api.SetParentTable).Methods("GET") router.HandleFunc("/removeParent", api.RemoveParentTable).Methods("POST") @@ -92,6 +91,7 @@ func getRoutes() *mux.Router { router.HandleFunc("/UpdateSequence", api.UpdateSequence).Methods("POST") router.HandleFunc("/update/fks", api.UpdateForeignKeys).Methods("POST") + router.HandleFunc("/update/cc", api.UpdateCheckConstraint).Methods("POST") router.HandleFunc("/update/indexes", api.UpdateIndexes).Methods("POST") // Session Management diff --git a/webv2/table/review_table_schema.go b/webv2/table/review_table_schema.go index 0ad6fb0ec..9ec551617 100644 --- a/webv2/table/review_table_schema.go +++ b/webv2/table/review_table_schema.go @@ -19,6 +19,7 @@ import ( "fmt" "io/ioutil" "net/http" + "regexp" "strconv" "strings" @@ -50,7 +51,6 @@ type InterleaveColumn struct { // ReviewTableSchema review Spanner Table Schema. func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { - reqBody, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) @@ -89,15 +89,11 @@ func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { for colId, v := range t.UpdateCols { if v.Add { - addColumn(tableId, colId, conv) - } if v.Removed { - RemoveColumn(tableId, colId, conv) - } if v.Rename != "" && v.Rename != conv.SpSchema[tableId].ColDefs[colId].Name { @@ -108,6 +104,15 @@ func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { return } } + oldName := conv.SpSchema[tableId].ColDefs[colId].Name + // Using a regular expression to match the exact column name + re := regexp.MustCompile(`\b` + regexp.QuoteMeta(oldName) + `\b`) + + for i := range conv.SpSchema[tableId].CheckConstraints { + originalString := conv.SpSchema[tableId].CheckConstraints[i].Expr + updatedValue := re.ReplaceAllString(originalString, v.Rename) + conv.SpSchema[tableId].CheckConstraints[i].Expr = updatedValue + } interleaveTableSchema = reviewRenameColumn(v.Rename, tableId, colId, conv, interleaveTableSchema) @@ -117,7 +122,6 @@ func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { if v.ToType != "" && found { typeChange, err := utilities.IsTypeChanged(v.ToType, tableId, colId, conv) - if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -148,7 +152,7 @@ func ReviewTableSchema(w http.ResponseWriter, r *http.Request) { } } - if !v.Removed && !v.Add && v.Rename== ""{ + if !v.Removed && !v.Add && v.Rename == "" { sequences := UpdateAutoGenCol(v.AutoGen, tableId, colId, conv) conv.SpSequences = sequences } diff --git a/webv2/table/review_table_schema_test.go b/webv2/table/review_table_schema_test.go index f448a2da5..34bfcf339 100644 --- a/webv2/table/review_table_schema_test.go +++ b/webv2/table/review_table_schema_test.go @@ -31,7 +31,6 @@ import ( ) func TestReviewTableSchema(t *testing.T) { - tc := []struct { name string tableId string @@ -61,7 +60,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: 6}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "t1", @@ -71,7 +71,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "varchar", Mods: []int64{6}}}, }, PrimaryKeys: []schema.Key{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: make(map[string][]internal.SchemaIssue), @@ -91,7 +92,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Bytes, Len: 6}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "t1", @@ -101,7 +103,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "varchar", Mods: []int64{6}}}, }, PrimaryKeys: []schema.Key{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -145,7 +148,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]ddl.ColumnDef{ "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, }, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "table1", @@ -188,7 +192,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Bytes, Len: 6}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "table1", @@ -198,7 +203,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "varchar", Mods: []int64{6}}}, }, PrimaryKeys: []schema.Key{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -241,7 +247,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]ddl.ColumnDef{ "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, }, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "table1", @@ -292,7 +299,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]ddl.ColumnDef{ "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, }, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "table1", @@ -310,7 +318,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]schema.Column{ "c3": {Name: "c", Id: "c3", Type: schema.Type{Name: "bigint", Mods: []int64{}}}, }, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -353,7 +362,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]ddl.ColumnDef{ "c3": {Name: "a", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, }, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "table1", @@ -405,7 +415,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]ddl.ColumnDef{ "c3": {Name: "a", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, }, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "table1", @@ -424,7 +435,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]schema.Column{ "c3": {Name: "c", Id: "c3", Type: schema.Type{Name: "bigint", Mods: []int64{}}}, }, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -467,7 +479,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]ddl.ColumnDef{ "c3": {Name: "a", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, }, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "table1", @@ -519,7 +532,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]ddl.ColumnDef{ "c3": {Name: "a", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, }, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "table1", @@ -538,7 +552,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]schema.Column{ "c3": {Name: "c", Id: "c3", Type: schema.Type{Name: "bigint", Mods: []int64{}}}, }, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -581,7 +596,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]ddl.ColumnDef{ "c3": {Name: "a", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, }, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "table1", @@ -633,7 +649,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]ddl.ColumnDef{ "c3": {Name: "a", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, }, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Name: "table1", @@ -652,7 +669,8 @@ func TestReviewTableSchema(t *testing.T) { ColDefs: map[string]schema.Column{ "c3": {Name: "c", Id: "c3", Type: schema.Type{Name: "bigint", Mods: []int64{}}}, }, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -686,7 +704,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Id: "c2", Name: "b", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Id: "t1", @@ -698,7 +717,8 @@ func TestReviewTableSchema(t *testing.T) { "c3": {Id: "c3", Name: "c", Type: schema.Type{Name: "varchar", Mods: []int64{}}}, }, PrimaryKeys: []schema.Key{{ColId: "c1"}}, - }}, + }, + }, Audit: internal.Audit{ MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), @@ -716,7 +736,8 @@ func TestReviewTableSchema(t *testing.T) { "c3": {Id: "c3", Name: "c", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, SrcSchema: map[string]schema.Table{ "t1": { Id: "t1", @@ -728,7 +749,8 @@ func TestReviewTableSchema(t *testing.T) { "c3": {Id: "c3", Name: "c", Type: schema.Type{Name: "varchar", Mods: []int64{}}}, }, PrimaryKeys: []schema.Key{{ColId: "c1"}}, - }}, + }, + }, Audit: internal.Audit{ MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), @@ -756,7 +778,8 @@ func TestReviewTableSchema(t *testing.T) { "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -778,7 +801,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": {}, }, @@ -1345,7 +1369,8 @@ func TestReviewTableSchema(t *testing.T) { "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}, AutoGen: ddl.AutoGenCol{Name: "seq", GenerationType: constants.SEQUENCE}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -1374,7 +1399,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": {}, }, @@ -1382,7 +1408,7 @@ func TestReviewTableSchema(t *testing.T) { MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), }, SpSequences: map[string]ddl.Sequence{ - "s1": ddl.Sequence{ + "s1": { Id: "s1", Name: "seq", ColumnsUsingSeq: map[string][]string{"t1": {}}, @@ -1411,7 +1437,8 @@ func TestReviewTableSchema(t *testing.T) { "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, Audit: internal.Audit{ MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), }, @@ -1427,7 +1454,8 @@ func TestReviewTableSchema(t *testing.T) { "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, Audit: internal.Audit{ MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), }, @@ -1454,7 +1482,8 @@ func TestReviewTableSchema(t *testing.T) { "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, Audit: internal.Audit{ MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), }, @@ -1470,7 +1499,8 @@ func TestReviewTableSchema(t *testing.T) { "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, }, PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }}, + }, + }, Audit: internal.Audit{ MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), }, @@ -1601,7 +1631,6 @@ func TestReviewTableSchema(t *testing.T) { }, }, { - name: "Test change type success for related foreign key columns", tableId: "t1", payload: ` @@ -1808,7 +1837,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "varchar", Mods: []int64{6}}}, }, PrimaryKeys: []schema.Key{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -1843,7 +1873,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "varchar", Mods: []int64{6}}}, }, PrimaryKeys: []schema.Key{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -1886,7 +1917,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "varchar", Mods: []int64{6}}}, }, PrimaryKeys: []schema.Key{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -1929,7 +1961,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "varchar", Mods: []int64{6}}}, }, PrimaryKeys: []schema.Key{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -1980,7 +2013,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "varchar", Mods: []int64{6}}}, }, PrimaryKeys: []schema.Key{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -2031,7 +2065,8 @@ func TestReviewTableSchema(t *testing.T) { "c2": {Name: "b", Id: "c2", Type: schema.Type{Name: "varchar", Mods: []int64{6}}}, }, PrimaryKeys: []schema.Key{{ColId: "c1"}}, - }}, + }, + }, SchemaIssues: map[string]internal.TableIssues{ "t1": { ColumnLevelIssues: map[string][]internal.SchemaIssue{ @@ -2060,6 +2095,102 @@ func TestReviewTableSchema(t *testing.T) { }, }, }, + { + name: "rename constraints column", + tableId: "t1", + payload: `{"UpdateCols":{"c1": { "Rename": "aa" }}}`, + statusCode: http.StatusOK, + conv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "t1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + CheckConstraints: []ddl.CheckConstraint{{ + Name: "check1", + Expr: "a > 0", + }}, + }, + }, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), + }, + }, + expectedConv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "t1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "aa", Id: "c1", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + CheckConstraints: []ddl.CheckConstraint{{ + Name: "check1", + Expr: "aa > 0", + }}, + }, + }, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), + }, + }, + }, + { + name: "exact match of column name", + tableId: "t1", + payload: `{"UpdateCols":{"c2": { "Rename": "c2" }}}`, + statusCode: http.StatusOK, + conv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "t1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "c1", Id: "c1", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c2": {Name: "c1_1", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c3": {Name: "c3", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + CheckConstraints: []ddl.CheckConstraint{{ + Name: "check1", + Expr: "c1_1 > 0", + }}, + }, + }, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), + }, + }, + expectedConv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "t1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "c1", Id: "c1", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c2": {Name: "c2", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c3": {Name: "c3", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + CheckConstraints: []ddl.CheckConstraint{{ + Name: "check1", + Expr: "c2 > 0", + }}, + }, + }, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), + }, + }, + }, } for _, tc := range tc { diff --git a/webv2/table/update_table_schema.go b/webv2/table/update_table_schema.go index 6e03c709c..b99bdf232 100644 --- a/webv2/table/update_table_schema.go +++ b/webv2/table/update_table_schema.go @@ -19,6 +19,7 @@ import ( "fmt" "io/ioutil" "net/http" + "regexp" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" @@ -55,8 +56,8 @@ type updateTable struct { // (3) Rename column. // (4) Add or Remove NotNull constraint. // (5) Update Spanner type. +// (6) Update Check constraints Name. func UpdateTableSchema(w http.ResponseWriter, r *http.Request) { - reqBody, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) @@ -83,19 +84,26 @@ func UpdateTableSchema(w http.ResponseWriter, r *http.Request) { for colId, v := range t.UpdateCols { if v.Add { - addColumn(tableId, colId, conv) - } if v.Removed { - RemoveColumn(tableId, colId, conv) - } if v.Rename != "" && v.Rename != conv.SpSchema[tableId].ColDefs[colId].Name { + oldName := conv.SrcSchema[tableId].ColDefs[colId].Name + + // Use a regular expression to match the exact column name + re := regexp.MustCompile(`\b` + regexp.QuoteMeta(oldName) + `\b`) + + for i := range conv.SpSchema[tableId].CheckConstraints { + originalString := conv.SpSchema[tableId].CheckConstraints[i].Expr + updatedValue := re.ReplaceAllString(originalString, v.Rename) + conv.SpSchema[tableId].CheckConstraints[i].Expr = updatedValue + } + renameColumn(v.Rename, tableId, colId, conv) } @@ -103,16 +111,13 @@ func UpdateTableSchema(w http.ResponseWriter, r *http.Request) { if v.ToType != "" && found { typeChange, err := utilities.IsTypeChanged(v.ToType, tableId, colId, conv) - if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } if typeChange { - UpdateColumnType(v.ToType, tableId, colId, conv, w) - } }