From 971a2deec3fe7652eddbad081c6aed2119c9ee6f Mon Sep 17 00:00:00 2001 From: Astha Mohta <35952883+asthamohta@users.noreply.github.com> Date: Fri, 20 Dec 2024 11:30:54 +0530 Subject: [PATCH] feat: APIs for Backend Changes for Default Values (#965) * backend apis * linting * comment changes * comment changes --- common/constants/constants.go | 1 + conversion/conversion_from_source.go | 6 + expressions_api/expression_verify.go | 108 ++++++++++++- expressions_api/expression_verify_test.go | 187 +++++++++++++++++++++- expressions_api/mocks.go | 69 ++++++++ internal/convert.go | 22 +-- internal/helpers.go | 3 + internal/reports/report_helpers.go | 7 + schema/schema.go | 13 +- sources/common/infoschema.go | 12 ++ sources/common/toddl.go | 29 ++++ sources/common/toddl_test.go | 109 +++++++++++++ sources/common/utils.go | 15 ++ sources/common/utils_test.go | 30 ++++ spanner/ddl/ast.go | 52 +++++- spanner/ddl/ast_test.go | 90 +++++++++++ webv2/api/schema.go | 20 +++ webv2/table/utilities.go | 37 +++++ 18 files changed, 784 insertions(+), 26 deletions(-) create mode 100644 expressions_api/mocks.go diff --git a/common/constants/constants.go b/common/constants/constants.go index d9d273e0a..64a8c19b0 100644 --- a/common/constants/constants.go +++ b/common/constants/constants.go @@ -125,6 +125,7 @@ const ( // VerifyExpresions API CHECK_EXPRESSION = "CHECK" DEFAUT_EXPRESSION = "DEFAULT" + TEMP_DB = "smt-staging-db" // Regex for matching database collation DB_COLLATION_REGEX = `(_[a-zA-Z0-9]+\\|\\)` diff --git a/conversion/conversion_from_source.go b/conversion/conversion_from_source.go index 0290af537..1eec13068 100644 --- a/conversion/conversion_from_source.go +++ b/conversion/conversion_from_source.go @@ -54,6 +54,9 @@ type DataFromSourceImpl struct{} func (sads *SchemaFromSourceImpl) schemaFromDatabase(migrationProjectId string, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) { conv := internal.MakeConv() conv.SpDialect = targetProfile.Conn.Sp.Dialect + conv.SpProjectId = targetProfile.Conn.Sp.Project + conv.SpInstanceId = targetProfile.Conn.Sp.Instance + conv.Source = sourceProfile.Driver //handle fetching schema differently for sharded migrations, we only connect to the primary shard to //fetch the schema. We reuse the SourceProfileConnection object for this purpose. var infoSchema common.InfoSchema @@ -159,6 +162,9 @@ func (sads *DataFromSourceImpl) dataFromCSV(ctx context.Context, sourceProfile p return nil, fmt.Errorf("dbName is mandatory in target-profile for csv source") } conv.SpDialect = targetProfile.Conn.Sp.Dialect + conv.SpProjectId = targetProfile.Conn.Sp.Project + conv.SpInstanceId = targetProfile.Conn.Sp.Instance + conv.Source = sourceProfile.Driver dialect, err := targetProfile.FetchTargetDialect(ctx) if err != nil { return nil, fmt.Errorf("could not fetch dialect: %v", err) diff --git a/expressions_api/expression_verify.go b/expressions_api/expression_verify.go index 9caebce83..7e1f8be24 100644 --- a/expressions_api/expression_verify.go +++ b/expressions_api/expression_verify.go @@ -6,6 +6,7 @@ import ( "fmt" "sync" + spannerclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/client" spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/task" @@ -18,6 +19,7 @@ const THREAD_POOL = 500 type ExpressionVerificationAccessor interface { //Batch API which parallelizes expression verification calls VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput + RefreshSpannerClient(ctx context.Context, project string, instance string) error } type ExpressionVerificationAccessorImpl struct { @@ -25,15 +27,42 @@ type ExpressionVerificationAccessorImpl struct { } func NewExpressionVerificationAccessorImpl(ctx context.Context, project string, instance string) (*ExpressionVerificationAccessorImpl, error) { - spannerAccessor, err := spanneraccessor.NewSpannerAccessorClientImplWithSpannerClient(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, "smt-staging-db")) - if err != nil { - return nil, err + var spannerAccessor *spanneraccessor.SpannerAccessorImpl + var err error + if project != "" && instance != "" { + spannerAccessor, err = spanneraccessor.NewSpannerAccessorClientImplWithSpannerClient(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, constants.TEMP_DB)) + if err != nil { + return nil, err + } + } else { + spannerAccessor, err = spanneraccessor.NewSpannerAccessorClientImpl(ctx) + if err != nil { + return nil, err + } } return &ExpressionVerificationAccessorImpl{ SpannerAccessor: spannerAccessor, }, nil } +// APIs to verify and process Spanner DLL features such as Default Values, Check Constraints +type DDLVerifier interface { + VerifySpannerDDL(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error) + GetSourceExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail + GetSpannerExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail + RefreshSpannerClient(ctx context.Context, project string, instance string) error +} +type DDLVerifierImpl struct { + Expressions ExpressionVerificationAccessor +} + +func NewDDLVerifierImpl(ctx context.Context, project string, instance string) (*DDLVerifierImpl, error) { + expVerifier, err := NewExpressionVerificationAccessorImpl(ctx, project, instance) + return &DDLVerifierImpl{ + Expressions: expVerifier, + }, err +} + func (ev *ExpressionVerificationAccessorImpl) VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput { err := ev.validateRequest(verifyExpressionsInput) if err != nil { @@ -79,6 +108,15 @@ func (ev *ExpressionVerificationAccessorImpl) VerifyExpressions(ctx context.Cont return verifyExpressionsOutput } +func (ev *ExpressionVerificationAccessorImpl) RefreshSpannerClient(ctx context.Context, project string, instance string) error { + spannerClient, err := spannerclient.NewSpannerClientImpl(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, constants.TEMP_DB)) + if err != nil { + return err + } + ev.SpannerAccessor.SpannerClient = spannerClient + return nil +} + func (ev *ExpressionVerificationAccessorImpl) verifyExpressionInternal(expressionDetail internal.ExpressionDetail, mutex *sync.Mutex) task.TaskResult[internal.ExpressionVerificationOutput] { var sqlStatement string switch expressionDetail.Type { @@ -129,3 +167,67 @@ func (ev *ExpressionVerificationAccessorImpl) removeExpressions(inputConv *inter } return convCopy, nil } + +func (ddlv *DDLVerifierImpl) VerifySpannerDDL(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error) { + ctx := context.Background() + verifyExpressionsInput := internal.VerifyExpressionsInput{ + Conv: conv, + Source: conv.Source, + ExpressionDetailList: expressionDetails, + } + verificationResults := ddlv.Expressions.VerifyExpressions(ctx, verifyExpressionsInput) + + return verificationResults, verificationResults.Err +} + +func (ddlv *DDLVerifierImpl) GetSourceExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail { + expressionDetails := []internal.ExpressionDetail{} + // Collect default values for verification + for _, tableId := range tableIds { + srcTable := conv.SrcSchema[tableId] + for _, srcColId := range srcTable.ColIds { + srcCol := srcTable.ColDefs[srcColId] + if srcCol.DefaultValue.IsPresent { + defaultValueExp := internal.ExpressionDetail{ + ReferenceElement: internal.ReferenceElement{ + Name: conv.SpSchema[tableId].ColDefs[srcColId].T.Name, + }, + ExpressionId: srcCol.DefaultValue.Value.ExpressionId, + Expression: srcCol.DefaultValue.Value.Statement, + Type: "DEFAULT", + Metadata: map[string]string{"TableId": tableId, "ColId": srcColId}, + } + expressionDetails = append(expressionDetails, defaultValueExp) + } + } + } + return expressionDetails +} + +func (ddlv *DDLVerifierImpl) GetSpannerExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail { + expressionDetails := []internal.ExpressionDetail{} + // Collect default values for verification + for _, tableId := range tableIds { + spTable := conv.SpSchema[tableId] + for _, spColId := range spTable.ColIds { + spCol := spTable.ColDefs[spColId] + if spCol.DefaultValue.IsPresent { + defaultValueExp := internal.ExpressionDetail{ + ReferenceElement: internal.ReferenceElement{ + Name: conv.SpSchema[tableId].ColDefs[spColId].T.Name, + }, + ExpressionId: spCol.DefaultValue.Value.ExpressionId, + Expression: spCol.DefaultValue.Value.Statement, + Type: "DEFAULT", + Metadata: map[string]string{"TableId": tableId, "ColId": spColId}, + } + expressionDetails = append(expressionDetails, defaultValueExp) + } + } + } + return expressionDetails +} + +func (ddlv *DDLVerifierImpl) RefreshSpannerClient(ctx context.Context, project string, instance string) error { + return ddlv.Expressions.RefreshSpannerClient(ctx, project, instance) +} diff --git a/expressions_api/expression_verify_test.go b/expressions_api/expression_verify_test.go index 8dadb01f8..11f3facc2 100644 --- a/expressions_api/expression_verify_test.go +++ b/expressions_api/expression_verify_test.go @@ -17,6 +17,8 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/googleapis/gax-go/v2" "github.com/stretchr/testify/assert" "go.uber.org/zap" @@ -32,8 +34,8 @@ func TestVerifyExpressions(t *testing.T) { conv := internal.MakeConv() ReadSessionFile(conv, "../../test_data/session_expression_verify.json") input := internal.VerifyExpressionsInput{ - Conv: conv, - Source: "mysql", + Conv: conv, + Source: "mysql", ExpressionDetailList: []internal.ExpressionDetail{ { Expression: "id > 10", @@ -297,3 +299,184 @@ func ReadSessionFile(conv *internal.Conv, sessionJSON string) error { } return nil } + +func TestVerifySpannerDDL(t *testing.T) { + conv := *internal.MakeConv() + testCases := []struct { + name string + conv internal.Conv + expressionDetails []internal.ExpressionDetail + verifyExpressionMock expressions_api.MockExpressionVerificationAccessor + errorExpected bool + }{ + { + name: "no error flow", + conv: conv, + expressionDetails: []internal.ExpressionDetail{}, + verifyExpressionMock: expressions_api.MockExpressionVerificationAccessor{ + VerifyExpressionsMock: func(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput { + return internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{}, + Err: nil, + } + }, + }, + errorExpected: false, + }, + { + name: "error flow", + conv: conv, + expressionDetails: []internal.ExpressionDetail{}, + verifyExpressionMock: expressions_api.MockExpressionVerificationAccessor{ + VerifyExpressionsMock: func(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput { + return internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{}, + Err: fmt.Errorf("error"), + } + }, + }, + errorExpected: true, + }, + } + + for _, tc := range testCases { + ddlV := expressions_api.DDLVerifierImpl{ + Expressions: &tc.verifyExpressionMock, + } + _, err := ddlV.VerifySpannerDDL(&tc.conv, tc.expressionDetails) + assert.Equal(t, tc.errorExpected, err != nil) + } +} + +func TestGetSourceExpressionDetails(t *testing.T) { + conv := internal.MakeConv() + conv.SrcSchema = map[string]schema.Table{ + "table1": { + ColIds: []string{"col1", "col2"}, + ColDefs: map[string]schema.Column{ + "col1": { + DefaultValue: ddl.DefaultValue{ + IsPresent: true, + Value: ddl.Expression{ + ExpressionId: "expr1", + Statement: "SELECT 1", + }, + }, + }, + "col2": { + DefaultValue: ddl.DefaultValue{}, + }, + }, + }, + } + conv.SpSchema = ddl.Schema{ + "table1": { + ColDefs: map[string]ddl.ColumnDef{ + "col1": { + T: ddl.Type{ + Name: "INT64", + }, + }, + }, + }, + } + + testCases := []struct { + name string + conv *internal.Conv + tableIds []string + expectedDetails []internal.ExpressionDetail + }{ + { + name: "single table with default value", + conv: conv, + tableIds: []string{"table1"}, + expectedDetails: []internal.ExpressionDetail{ + { + ReferenceElement: internal.ReferenceElement{ + Name: "INT64", + }, + ExpressionId: "expr1", + Expression: "SELECT 1", + Type: "DEFAULT", + Metadata: map[string]string{"TableId": "table1", "ColId": "col1"}, + }, + }, + }, + { + name: "no tables", + conv: conv, + tableIds: []string{}, + expectedDetails: []internal.ExpressionDetail{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ddlv := &expressions_api.DDLVerifierImpl{} + actualDetails := ddlv.GetSourceExpressionDetails(tc.conv, tc.tableIds) + assert.Equal(t, tc.expectedDetails, actualDetails) + }) + } +} + +func TestGetSpannerExpressionDetails(t *testing.T) { + conv := internal.MakeConv() + conv.SpSchema = ddl.Schema{ + "table1": { + ColIds: []string{"col1", "col2"}, + ColDefs: map[string]ddl.ColumnDef{ + "col1": { + DefaultValue: ddl.DefaultValue{ + IsPresent: true, + Value: ddl.Expression{ + ExpressionId: "expr1", + Statement: "SELECT 1", + }, + }, + }, + "col2": { + DefaultValue: ddl.DefaultValue{}, + }, + }, + }, + } + + testCases := []struct { + name string + conv *internal.Conv + tableIds []string + expectedDetails []internal.ExpressionDetail + }{ + { + name: "single table with default value", + conv: conv, + tableIds: []string{"table1"}, + expectedDetails: []internal.ExpressionDetail{ + { + ReferenceElement: internal.ReferenceElement{ + Name: conv.SpSchema["table1"].ColDefs["col1"].T.Name, + }, + ExpressionId: "expr1", + Expression: "SELECT 1", + Type: "DEFAULT", + Metadata: map[string]string{"TableId": "table1", "ColId": "col1"}, + }, + }, + }, + { + name: "no tables", + conv: conv, + tableIds: []string{}, + expectedDetails: []internal.ExpressionDetail{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ddlv := &expressions_api.DDLVerifierImpl{} + actualDetails := ddlv.GetSpannerExpressionDetails(tc.conv, tc.tableIds) + assert.Equal(t, tc.expectedDetails, actualDetails) + }) + } +} diff --git a/expressions_api/mocks.go b/expressions_api/mocks.go new file mode 100644 index 000000000..b56e87060 --- /dev/null +++ b/expressions_api/mocks.go @@ -0,0 +1,69 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expressions_api + +import ( + "context" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" +) + +type MockExpressionVerificationAccessor struct { + VerifyExpressionsMock func(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput + RefreshSpannerClientMock func(ctx context.Context, project string, instance string) error +} + +func (mev *MockExpressionVerificationAccessor) VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput { + return mev.VerifyExpressionsMock(ctx, verifyExpressionsInput) +} + +func (mev *MockExpressionVerificationAccessor) RefreshSpannerClient(ctx context.Context, project string, instance string) error { + return mev.RefreshSpannerClientMock(ctx, project, instance) +} + +type MockDDLVerifier struct { + VerifySpannerDDLMock func(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error) + GetSpannerExpressionDetailsMock func(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail + GetSourceExpressionDetailsMock func(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail + RefreshSpannerClientMock func(ctx context.Context, project string, instance string) error +} + +func (m *MockDDLVerifier) VerifySpannerDDL(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error) { + if m.VerifySpannerDDLMock != nil { + return m.VerifySpannerDDLMock(conv, expressionDetails) + } + return internal.VerifyExpressionsOutput{}, nil +} + +func (m *MockDDLVerifier) GetSpannerExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail { + if m.GetSpannerExpressionDetailsMock != nil { + return m.GetSpannerExpressionDetailsMock(conv, tableIds) + } + return []internal.ExpressionDetail{} +} + +func (m *MockDDLVerifier) GetSourceExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail { + if m.GetSourceExpressionDetailsMock != nil { + return m.GetSourceExpressionDetailsMock(conv, tableIds) + } + return []internal.ExpressionDetail{} +} + +func (m *MockDDLVerifier) RefreshSpannerClient(ctx context.Context, project string, instance string) error { + if m.RefreshSpannerClientMock != nil { + return m.RefreshSpannerClientMock(ctx, project, instance) + } + return nil +} diff --git a/internal/convert.go b/internal/convert.go index 41ad82aff..d2e892131 100644 --- a/internal/convert.go +++ b/internal/convert.go @@ -54,6 +54,9 @@ type Conv struct { UI bool // Flag if UI interface was used for migration. ToDo: Remove flag after resource generation is introduced to UI SpSequences map[string]ddl.Sequence // Maps Spanner Sequences to Sequence Schema SrcSequences map[string]ddl.Sequence // Maps source-DB Sequences to Sequence schema information + SpProjectId string // Spanner Project Id + SpInstanceId string // Spanner Instance Id + Source string // Source Database type being migrated } type TableIssues struct { @@ -129,6 +132,7 @@ const ( ForeignKeyActionNotSupported NumericPKNotSupported TypeMismatch + DefaultValueError ) const ( @@ -292,17 +296,17 @@ type TableDetails struct { } type VerifyExpressionsInput struct { - Conv *Conv - Source string + Conv *Conv + Source string ExpressionDetailList []ExpressionDetail } type ExpressionDetail struct { ReferenceElement ReferenceElement - ExpressionId string - Expression string - Type string - Metadata map[string]string + ExpressionId string + Expression string + Type string + Metadata map[string]string } type ReferenceElement struct { @@ -311,13 +315,13 @@ type ReferenceElement struct { type ExpressionVerificationOutput struct { ExpressionDetail ExpressionDetail - Result bool - Err error + Result bool + Err error } type VerifyExpressionsOutput struct { ExpressionVerificationOutputList []ExpressionVerificationOutput - Err error + Err error } // MakeConv returns a default-configured Conv. diff --git a/internal/helpers.go b/internal/helpers.go index 6fd3e3b6d..f587d12cf 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -76,6 +76,9 @@ func GenerateRuleId() string { func GenerateSequenceId() string { return GenerateId("s") } +func GenerateExpressionId() string { + return GenerateId("i") +} func GetSrcColNameIdMap(srcs schema.Table) map[string]string { if len(srcs.ColNameIdMap) > 0 { diff --git a/internal/reports/report_helpers.go b/internal/reports/report_helpers.go index 34cb58511..fa17f2762 100644 --- a/internal/reports/report_helpers.go +++ b/internal/reports/report_helpers.go @@ -410,6 +410,12 @@ func buildTableReportBody(conv *internal.Conv, tableId string, issues map[string } l = append(l, toAppend) + case internal.DefaultValueError: + toAppend := Issue{ + Category: IssueDB[i].Category, + Description: fmt.Sprintf("%s for table '%s' column '%s'", IssueDB[i].Brief, conv.SpSchema[tableId].Name, spColName), + } + l = append(l, toAppend) default: toAppend := Issue{ Category: IssueDB[i].Category, @@ -569,6 +575,7 @@ var IssueDB = map[internal.SchemaIssue]struct { internal.ForeignKeyOnUpdate: {Brief: "Spanner supports only ON UPDATE NO ACTION", Severity: warning, Category: "FOREIGN_KEY_ACTIONS"}, internal.ForeignKeyActionNotSupported: {Brief: "Spanner supports foreign key action migration only for MySQL and PostgreSQL", Severity: warning, Category: "FOREIGN_KEY_ACTIONS"}, internal.NumericPKNotSupported: {Brief: "Spanner PostgreSQL does not support numeric primary keys / unique indices", Severity: warning, Category: "NUMERIC_PK_NOT_SUPPORTED"}, + internal.DefaultValueError: {Brief: "Some columns have default value expressions not supported by Spanner. Please fix them to continue migration.", Severity: Errors, batch: true, Category: "INCOMPATIBLE_DEFAULT_VALUE_CONSTRAINTS"}, } type Severity int diff --git a/schema/schema.go b/schema/schema.go index 4b18ca777..eab021cf0 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -50,12 +50,13 @@ type Table struct { // Column represents a database column. // TODO: add support for foreign keys. type Column struct { - Name string - Type Type - NotNull bool - Ignored Ignored - Id string - AutoGen ddl.AutoGenCol + Name string + Type Type + NotNull bool + Ignored Ignored + Id string + AutoGen ddl.AutoGenCol + DefaultValue ddl.DefaultValue } // ForeignKey represents a foreign key. diff --git a/sources/common/infoschema.go b/sources/common/infoschema.go index 1d64770dd..0876731ae 100644 --- a/sources/common/infoschema.go +++ b/sources/common/infoschema.go @@ -17,6 +17,7 @@ package common import ( "context" "fmt" + "strings" "sync" sp "cloud.google.com/go/spanner" @@ -254,3 +255,14 @@ func (is *InfoSchemaImpl) GetIncludedSrcTablesFromConv(conv *internal.Conv) (sch } return schemaToTablesMap, nil } + +// SanitizeDefaultValue removes extra characters added to Default Value in information schema in MySQL. +func SanitizeDefaultValue(defaultValue string, ty string, generated bool) string { + defaultValue = strings.ReplaceAll(defaultValue, "_utf8mb4", "") + defaultValue = strings.ReplaceAll(defaultValue, "\\\\", "\\") + defaultValue = strings.ReplaceAll(defaultValue, "\\'", "'") + if !generated && (ty == "char" || ty == "varchar" || ty == "text" || ty == "STRING") && !strings.HasPrefix(defaultValue, "'") && !strings.HasSuffix(defaultValue, "'") { + defaultValue = "'" + defaultValue + "'" + } + return defaultValue +} diff --git a/sources/common/toddl.go b/sources/common/toddl.go index 8bd3291a5..460523575 100644 --- a/sources/common/toddl.go +++ b/sources/common/toddl.go @@ -390,3 +390,32 @@ func CvtIndexHelper(conv *internal.Conv, tableId string, srcIndex schema.Index, } return spIndex } + +// Applies all valid expressions which can be migrated to spanner conv object +func spannerSchemaApplyExpressions(conv *internal.Conv, expressions internal.VerifyExpressionsOutput) { + for _, expression := range expressions.ExpressionVerificationOutputList { + switch expression.ExpressionDetail.Type { + case "DEFAULT": + { + tableId := expression.ExpressionDetail.Metadata["TableId"] + columnId := expression.ExpressionDetail.Metadata["ColId"] + + if expression.Result { + col := conv.SpSchema[tableId].ColDefs[columnId] + col.DefaultValue = ddl.DefaultValue{ + IsPresent: true, + Value: ddl.Expression{ + ExpressionId: expression.ExpressionDetail.ExpressionId, + Statement: expression.ExpressionDetail.Expression, + }, + } + conv.SpSchema[tableId].ColDefs[columnId] = col + } else { + colIssues := conv.SchemaIssues[tableId].ColumnLevelIssues[columnId] + colIssues = append(colIssues, internal.DefaultValue) + conv.SchemaIssues[tableId].ColumnLevelIssues[columnId] = colIssues + } + } + } + } +} diff --git a/sources/common/toddl_test.go b/sources/common/toddl_test.go index fd12bce4b..6a98b8ca9 100644 --- a/sources/common/toddl_test.go +++ b/sources/common/toddl_test.go @@ -469,3 +469,112 @@ func Test_cvtCheckContraint(t *testing.T) { result := cvtCheckConstraint(conv, srcSchema) assert.Equal(t, spSchema, result) } + +func TestSpannerSchemaApplyExpressions(t *testing.T) { + makeConv := func() *internal.Conv { + conv := internal.MakeConv() + conv.SchemaIssues = make(map[string]internal.TableIssues) + conv.SchemaIssues["table1"] = internal.TableIssues{ + ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + } + conv.SpSchema = ddl.Schema{ + "table1": { + ColDefs: map[string]ddl.ColumnDef{ + "col1": {}, + }, + }, + } + return conv + } + + makeResultConv := func(SpSchema ddl.Schema, SchemaIssues map[string]internal.TableIssues) *internal.Conv { + conv := internal.MakeConv() + conv.SpSchema = SpSchema + conv.SchemaIssues = SchemaIssues + return conv + } + + testCases := []struct { + name string + conv *internal.Conv + expressions internal.VerifyExpressionsOutput + expectedConv *internal.Conv + }{ + { + name: "successful default value application", + conv: makeConv(), + expressions: internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + { + Result: true, + ExpressionDetail: internal.ExpressionDetail{ + Type: "DEFAULT", + ExpressionId: "expr1", + Expression: "SELECT 1", + Metadata: map[string]string{"TableId": "table1", "ColId": "col1"}, + }, + }, + }, + }, + expectedConv: makeResultConv( + ddl.Schema{ + "table1": { + ColDefs: map[string]ddl.ColumnDef{ + "col1": { + DefaultValue: ddl.DefaultValue{ + IsPresent: true, + Value: ddl.Expression{ + ExpressionId: "expr1", + Statement: "SELECT 1", + }, + }, + }, + }, + }, + }, map[string]internal.TableIssues{ + "table1": { + ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + }, + }), + }, + { + name: "failed default value application", + conv: makeConv(), + expressions: internal.VerifyExpressionsOutput{ + ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{ + { + Result: false, + ExpressionDetail: internal.ExpressionDetail{ + Type: "DEFAULT", + ExpressionId: "expr1", + Expression: "SELECT 1", + Metadata: map[string]string{"TableId": "table1", "ColId": "col1"}, + }, + }, + }, + }, + expectedConv: makeResultConv( + ddl.Schema{ + "table1": { + ColDefs: map[string]ddl.ColumnDef{ + "col1": {}, + }, + }, + }, + map[string]internal.TableIssues{ + "table1": { + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "col1": {internal.DefaultValue}, + }, + }, + }), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + spannerSchemaApplyExpressions(tc.conv, tc.expressions) + assert.Equal(t, tc.expectedConv, tc.conv) + }) + } +} diff --git a/sources/common/utils.go b/sources/common/utils.go index c36c5b533..4581bd287 100644 --- a/sources/common/utils.go +++ b/sources/common/utils.go @@ -72,6 +72,21 @@ func GetSortedTableIdsBySrcName(srcSchema map[string]schema.Table) []string { return sortedTableIds } +func GetSortedTableIdsBySpName(spSchema ddl.Schema) []string { + tableNameIdMap := map[string]string{} + tableNames := []string{} + sortedTableIds := []string{} + for id, spTable := range spSchema { + tableNames = append(tableNames, spTable.Name) + tableNameIdMap[spTable.Name] = id + } + sort.Strings(tableNames) + for _, name := range tableNames { + sortedTableIds = append(sortedTableIds, tableNameIdMap[name]) + } + return sortedTableIds +} + func (uo *UtilsOrderImpl) initPrimaryKeyOrder(conv *internal.Conv) { for k, table := range conv.SrcSchema { for i := range table.PrimaryKeys { diff --git a/sources/common/utils_test.go b/sources/common/utils_test.go index 68fe6ad1c..b593c7c7a 100644 --- a/sources/common/utils_test.go +++ b/sources/common/utils_test.go @@ -270,3 +270,33 @@ func TestPrepareValues(t *testing.T) { assert.Equal(t, tc.expectedValues, res) } } + +func TestGetSortedTableIdsBySpName(t *testing.T) { + testCases := []struct { + name string + spSchema ddl.Schema + expectedIds []string + }{ + { + name: "multiple tables", + spSchema: ddl.Schema{ + "table2": {Name: "TableB"}, + "table1": {Name: "TableA"}, + "table3": {Name: "TableC"}, + }, + expectedIds: []string{"table1", "table2", "table3"}, + }, + { + name: "no tables", + spSchema: ddl.Schema{}, + expectedIds: []string{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sortedIds := GetSortedTableIdsBySpName(tc.spSchema) + assert.Equal(t, tc.expectedIds, sortedIds) + }) + } +} diff --git a/spanner/ddl/ast.go b/spanner/ddl/ast.go index b8f6db28f..f6534b262 100644 --- a/spanner/ddl/ast.go +++ b/spanner/ddl/ast.go @@ -180,12 +180,13 @@ func (ty Type) PGPrintColumnDefType() string { // column_def: // column_name type [NOT NULL] [options_def] type ColumnDef struct { - Name string - T Type - NotNull bool - Comment string - Id string - AutoGen AutoGenCol + Name string + T Type + NotNull bool + Comment string + Id string + AutoGen AutoGenCol + DefaultValue DefaultValue } // Config controls how AST nodes are printed (aka unparsed). @@ -424,6 +425,45 @@ type AutoGenCol struct { GenerationType string } +// DefaultValue represents a Default value. +type DefaultValue struct { + IsPresent bool + Value Expression +} + +type Expression struct { + ExpressionId string + Statement string +} + +func (dv DefaultValue) PrintDefaultValue(ty Type) string { + if !dv.IsPresent { + return "" + } + var value string + switch ty.Name { + case "FLOAT32", "NUMERIC", "BOOL": + value = fmt.Sprintf(" DEFAULT (CAST(%s AS %s))", dv.Value.Statement, ty.Name) + default: + value = " DEFAULT (" + dv.Value.Statement + ")" + } + return value +} + +func (dv DefaultValue) PGPrintDefaultValue(ty Type) string { + if !dv.IsPresent { + return "" + } + var value string + switch ty.Name { + case "FLOAT8", "FLOAT4", "REAL", "NUMERIC", "DECIMAL", "BOOL": + value = fmt.Sprintf(" DEFAULT (CAST(%s AS %s))", dv.Value.Statement, ty.Name) + default: + value = " DEFAULT (" + dv.Value.Statement + ")" + } + return value +} + func (agc AutoGenCol) PrintAutoGenCol() string { if agc.Name == constants.UUID && agc.GenerationType == "Pre-defined" { return " DEFAULT (GENERATE_UUID())" diff --git a/spanner/ddl/ast_test.go b/spanner/ddl/ast_test.go index af4a404bf..c84dfa800 100644 --- a/spanner/ddl/ast_test.go +++ b/spanner/ddl/ast_test.go @@ -537,6 +537,96 @@ func TestPrintForeignKeyAlterTable(t *testing.T) { } } +func TestPrintDefaultValue(t *testing.T) { + tests := []struct { + name string + dv DefaultValue + ty Type + expected string + }{ + { + name: "default value present", + dv: DefaultValue{ + IsPresent: true, + Value: Expression{Statement: "(`col1` + 1)"}, + }, + ty: Type{ + Name: "INT64", + }, + expected: " DEFAULT ((`col1` + 1))", + }, + { + name: "default value present", + dv: DefaultValue{ + IsPresent: true, + Value: Expression{Statement: "(`col1` + 1)"}, + }, + ty: Type{ + Name: "NUMERIC", + }, + expected: " DEFAULT (CAST((`col1` + 1) AS NUMERIC))", + }, + { + name: "empty default value", + dv: DefaultValue{}, + ty: Type{ + Name: "INT64", + }, + expected: "", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.dv.PrintDefaultValue(tc.ty)) + }) + } +} + +func TestPGPrintDefaultValue(t *testing.T) { + tests := []struct { + name string + dv DefaultValue + ty Type + expected string + }{ + { + name: "default value present", + dv: DefaultValue{ + IsPresent: true, + Value: Expression{Statement: "(`col1` + 1)"}, + }, + ty: Type{ + Name: "INT64", + }, + expected: " DEFAULT ((`col1` + 1))", + }, + { + name: "default value present", + dv: DefaultValue{ + IsPresent: true, + Value: Expression{Statement: "(`col1` + 1)"}, + }, + ty: Type{ + Name: "NUMERIC", + }, + expected: " DEFAULT (CAST((`col1` + 1) AS NUMERIC))", + }, + { + name: "empty default value", + dv: DefaultValue{}, + ty: Type{ + Name: "INT64", + }, + expected: "", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, tc.dv.PGPrintDefaultValue(tc.ty)) + }) + } +} + func TestPrintAutoGenCol(t *testing.T) { tests := []struct { agc AutoGenCol diff --git a/webv2/api/schema.go b/webv2/api/schema.go index 1906e7894..b9f5ff876 100644 --- a/webv2/api/schema.go +++ b/webv2/api/schema.go @@ -1627,3 +1627,23 @@ func makeGoogleSqlDialectAutoGenMap(sequences map[string]ddl.Sequence) { } } } + +func uniqueAndSortTableIdName(tableIdName []types.TableIdAndName) []types.TableIdAndName { + uniqueMap := make(map[string]types.TableIdAndName) + for _, item := range tableIdName { + uniqueMap[item.Name] = item // Use Name as the unique key + } + + // Convert the map back to a slice + uniqueSlice := make([]types.TableIdAndName, 0, len(uniqueMap)) + for _, value := range uniqueMap { + uniqueSlice = append(uniqueSlice, value) + } + + // Sort the slice by Name + sort.Slice(uniqueSlice, func(i, j int) bool { + return uniqueSlice[i].Name < uniqueSlice[j].Name + }) + + return uniqueSlice +} diff --git a/webv2/table/utilities.go b/webv2/table/utilities.go index 76702b2ad..7494f8e53 100644 --- a/webv2/table/utilities.go +++ b/webv2/table/utilities.go @@ -16,9 +16,11 @@ package table import ( "fmt" + "regexp" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/session" ) @@ -335,3 +337,38 @@ func getSequenceId(sequenceName string, spSeq map[string]ddl.Sequence) string { } return "" } + +// Add, deletes and updates default value associated with a column during edit column functionality +func UpdateDefaultValue(dv ddl.DefaultValue, tableId, colId string, conv *internal.Conv) { + col := conv.SpSchema[tableId].ColDefs[colId] + if !dv.IsPresent { + col.DefaultValue = ddl.DefaultValue{} + conv.SpSchema[tableId].ColDefs[colId] = col + return + } + + var expressionId string + if dv.Value.ExpressionId == "" { + if _, exists := conv.SrcSchema[tableId]; exists { + if column, exists := conv.SrcSchema[tableId].ColDefs[colId]; exists { + if column.DefaultValue.Value.ExpressionId != "" { + expressionId = column.DefaultValue.Value.ExpressionId + } + } + } + if expressionId != "" { + expressionId = internal.GenerateExpressionId() + } + } else { + expressionId = dv.Value.ExpressionId + } + re := regexp.MustCompile(`\([^)]*\)`) + col.DefaultValue = ddl.DefaultValue{ + Value: ddl.Expression{ + ExpressionId: expressionId, + Statement: common.SanitizeDefaultValue(dv.Value.Statement, col.T.Name, re.MatchString(dv.Value.Statement)), + }, + IsPresent: true, + } + conv.SpSchema[tableId].ColDefs[colId] = col +}