Skip to content

Commit

Permalink
feat: APIs for Backend Changes for Default Values (GoogleCloudPlatfor…
Browse files Browse the repository at this point in the history
…m#965)

* backend apis

* linting

* comment changes

* comment changes
  • Loading branch information
asthamohta authored and taherkl committed Dec 20, 2024
1 parent 10391eb commit 971a2de
Show file tree
Hide file tree
Showing 18 changed files with 784 additions and 26 deletions.
1 change: 1 addition & 0 deletions common/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ const (
// VerifyExpresions API
CHECK_EXPRESSION = "CHECK"
DEFAUT_EXPRESSION = "DEFAULT"
TEMP_DB = "smt-staging-db"

// Regex for matching database collation
DB_COLLATION_REGEX = `(_[a-zA-Z0-9]+\\|\\)`
Expand Down
6 changes: 6 additions & 0 deletions conversion/conversion_from_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ type DataFromSourceImpl struct{}
func (sads *SchemaFromSourceImpl) schemaFromDatabase(migrationProjectId string, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) {
conv := internal.MakeConv()
conv.SpDialect = targetProfile.Conn.Sp.Dialect
conv.SpProjectId = targetProfile.Conn.Sp.Project
conv.SpInstanceId = targetProfile.Conn.Sp.Instance
conv.Source = sourceProfile.Driver
//handle fetching schema differently for sharded migrations, we only connect to the primary shard to
//fetch the schema. We reuse the SourceProfileConnection object for this purpose.
var infoSchema common.InfoSchema
Expand Down Expand Up @@ -159,6 +162,9 @@ func (sads *DataFromSourceImpl) dataFromCSV(ctx context.Context, sourceProfile p
return nil, fmt.Errorf("dbName is mandatory in target-profile for csv source")
}
conv.SpDialect = targetProfile.Conn.Sp.Dialect
conv.SpProjectId = targetProfile.Conn.Sp.Project
conv.SpInstanceId = targetProfile.Conn.Sp.Instance
conv.Source = sourceProfile.Driver
dialect, err := targetProfile.FetchTargetDialect(ctx)
if err != nil {
return nil, fmt.Errorf("could not fetch dialect: %v", err)
Expand Down
108 changes: 105 additions & 3 deletions expressions_api/expression_verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"sync"

spannerclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/client"
spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants"
"github.com/GoogleCloudPlatform/spanner-migration-tool/common/task"
Expand All @@ -18,22 +19,50 @@ const THREAD_POOL = 500
type ExpressionVerificationAccessor interface {
//Batch API which parallelizes expression verification calls
VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput
RefreshSpannerClient(ctx context.Context, project string, instance string) error
}

type ExpressionVerificationAccessorImpl struct {
SpannerAccessor *spanneraccessor.SpannerAccessorImpl
}

func NewExpressionVerificationAccessorImpl(ctx context.Context, project string, instance string) (*ExpressionVerificationAccessorImpl, error) {
spannerAccessor, err := spanneraccessor.NewSpannerAccessorClientImplWithSpannerClient(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, "smt-staging-db"))
if err != nil {
return nil, err
var spannerAccessor *spanneraccessor.SpannerAccessorImpl
var err error
if project != "" && instance != "" {
spannerAccessor, err = spanneraccessor.NewSpannerAccessorClientImplWithSpannerClient(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, constants.TEMP_DB))
if err != nil {
return nil, err
}
} else {
spannerAccessor, err = spanneraccessor.NewSpannerAccessorClientImpl(ctx)
if err != nil {
return nil, err
}
}
return &ExpressionVerificationAccessorImpl{
SpannerAccessor: spannerAccessor,
}, nil
}

// APIs to verify and process Spanner DLL features such as Default Values, Check Constraints
type DDLVerifier interface {
VerifySpannerDDL(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error)
GetSourceExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail
GetSpannerExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail
RefreshSpannerClient(ctx context.Context, project string, instance string) error
}
type DDLVerifierImpl struct {
Expressions ExpressionVerificationAccessor
}

func NewDDLVerifierImpl(ctx context.Context, project string, instance string) (*DDLVerifierImpl, error) {
expVerifier, err := NewExpressionVerificationAccessorImpl(ctx, project, instance)
return &DDLVerifierImpl{
Expressions: expVerifier,
}, err
}

func (ev *ExpressionVerificationAccessorImpl) VerifyExpressions(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput {
err := ev.validateRequest(verifyExpressionsInput)
if err != nil {
Expand Down Expand Up @@ -79,6 +108,15 @@ func (ev *ExpressionVerificationAccessorImpl) VerifyExpressions(ctx context.Cont
return verifyExpressionsOutput
}

func (ev *ExpressionVerificationAccessorImpl) RefreshSpannerClient(ctx context.Context, project string, instance string) error {
spannerClient, err := spannerclient.NewSpannerClientImpl(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, constants.TEMP_DB))
if err != nil {
return err
}
ev.SpannerAccessor.SpannerClient = spannerClient
return nil
}

func (ev *ExpressionVerificationAccessorImpl) verifyExpressionInternal(expressionDetail internal.ExpressionDetail, mutex *sync.Mutex) task.TaskResult[internal.ExpressionVerificationOutput] {
var sqlStatement string
switch expressionDetail.Type {
Expand Down Expand Up @@ -129,3 +167,67 @@ func (ev *ExpressionVerificationAccessorImpl) removeExpressions(inputConv *inter
}
return convCopy, nil
}

func (ddlv *DDLVerifierImpl) VerifySpannerDDL(conv *internal.Conv, expressionDetails []internal.ExpressionDetail) (internal.VerifyExpressionsOutput, error) {
ctx := context.Background()
verifyExpressionsInput := internal.VerifyExpressionsInput{
Conv: conv,
Source: conv.Source,
ExpressionDetailList: expressionDetails,
}
verificationResults := ddlv.Expressions.VerifyExpressions(ctx, verifyExpressionsInput)

return verificationResults, verificationResults.Err
}

func (ddlv *DDLVerifierImpl) GetSourceExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail {
expressionDetails := []internal.ExpressionDetail{}
// Collect default values for verification
for _, tableId := range tableIds {
srcTable := conv.SrcSchema[tableId]
for _, srcColId := range srcTable.ColIds {
srcCol := srcTable.ColDefs[srcColId]
if srcCol.DefaultValue.IsPresent {
defaultValueExp := internal.ExpressionDetail{
ReferenceElement: internal.ReferenceElement{
Name: conv.SpSchema[tableId].ColDefs[srcColId].T.Name,
},
ExpressionId: srcCol.DefaultValue.Value.ExpressionId,
Expression: srcCol.DefaultValue.Value.Statement,
Type: "DEFAULT",
Metadata: map[string]string{"TableId": tableId, "ColId": srcColId},
}
expressionDetails = append(expressionDetails, defaultValueExp)
}
}
}
return expressionDetails
}

func (ddlv *DDLVerifierImpl) GetSpannerExpressionDetails(conv *internal.Conv, tableIds []string) []internal.ExpressionDetail {
expressionDetails := []internal.ExpressionDetail{}
// Collect default values for verification
for _, tableId := range tableIds {
spTable := conv.SpSchema[tableId]
for _, spColId := range spTable.ColIds {
spCol := spTable.ColDefs[spColId]
if spCol.DefaultValue.IsPresent {
defaultValueExp := internal.ExpressionDetail{
ReferenceElement: internal.ReferenceElement{
Name: conv.SpSchema[tableId].ColDefs[spColId].T.Name,
},
ExpressionId: spCol.DefaultValue.Value.ExpressionId,
Expression: spCol.DefaultValue.Value.Statement,
Type: "DEFAULT",
Metadata: map[string]string{"TableId": tableId, "ColId": spColId},
}
expressionDetails = append(expressionDetails, defaultValueExp)
}
}
}
return expressionDetails
}

func (ddlv *DDLVerifierImpl) RefreshSpannerClient(ctx context.Context, project string, instance string) error {
return ddlv.Expressions.RefreshSpannerClient(ctx, project, instance)
}
187 changes: 185 additions & 2 deletions expressions_api/expression_verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"github.com/GoogleCloudPlatform/spanner-migration-tool/expressions_api"
"github.com/GoogleCloudPlatform/spanner-migration-tool/internal"
"github.com/GoogleCloudPlatform/spanner-migration-tool/logger"
"github.com/GoogleCloudPlatform/spanner-migration-tool/schema"
"github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl"
"github.com/googleapis/gax-go/v2"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
Expand All @@ -32,8 +34,8 @@ func TestVerifyExpressions(t *testing.T) {
conv := internal.MakeConv()
ReadSessionFile(conv, "../../test_data/session_expression_verify.json")
input := internal.VerifyExpressionsInput{
Conv: conv,
Source: "mysql",
Conv: conv,
Source: "mysql",
ExpressionDetailList: []internal.ExpressionDetail{
{
Expression: "id > 10",
Expand Down Expand Up @@ -297,3 +299,184 @@ func ReadSessionFile(conv *internal.Conv, sessionJSON string) error {
}
return nil
}

func TestVerifySpannerDDL(t *testing.T) {
conv := *internal.MakeConv()
testCases := []struct {
name string
conv internal.Conv
expressionDetails []internal.ExpressionDetail
verifyExpressionMock expressions_api.MockExpressionVerificationAccessor
errorExpected bool
}{
{
name: "no error flow",
conv: conv,
expressionDetails: []internal.ExpressionDetail{},
verifyExpressionMock: expressions_api.MockExpressionVerificationAccessor{
VerifyExpressionsMock: func(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput {
return internal.VerifyExpressionsOutput{
ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{},
Err: nil,
}
},
},
errorExpected: false,
},
{
name: "error flow",
conv: conv,
expressionDetails: []internal.ExpressionDetail{},
verifyExpressionMock: expressions_api.MockExpressionVerificationAccessor{
VerifyExpressionsMock: func(ctx context.Context, verifyExpressionsInput internal.VerifyExpressionsInput) internal.VerifyExpressionsOutput {
return internal.VerifyExpressionsOutput{
ExpressionVerificationOutputList: []internal.ExpressionVerificationOutput{},
Err: fmt.Errorf("error"),
}
},
},
errorExpected: true,
},
}

for _, tc := range testCases {
ddlV := expressions_api.DDLVerifierImpl{
Expressions: &tc.verifyExpressionMock,
}
_, err := ddlV.VerifySpannerDDL(&tc.conv, tc.expressionDetails)
assert.Equal(t, tc.errorExpected, err != nil)
}
}

func TestGetSourceExpressionDetails(t *testing.T) {
conv := internal.MakeConv()
conv.SrcSchema = map[string]schema.Table{
"table1": {
ColIds: []string{"col1", "col2"},
ColDefs: map[string]schema.Column{
"col1": {
DefaultValue: ddl.DefaultValue{
IsPresent: true,
Value: ddl.Expression{
ExpressionId: "expr1",
Statement: "SELECT 1",
},
},
},
"col2": {
DefaultValue: ddl.DefaultValue{},
},
},
},
}
conv.SpSchema = ddl.Schema{
"table1": {
ColDefs: map[string]ddl.ColumnDef{
"col1": {
T: ddl.Type{
Name: "INT64",
},
},
},
},
}

testCases := []struct {
name string
conv *internal.Conv
tableIds []string
expectedDetails []internal.ExpressionDetail
}{
{
name: "single table with default value",
conv: conv,
tableIds: []string{"table1"},
expectedDetails: []internal.ExpressionDetail{
{
ReferenceElement: internal.ReferenceElement{
Name: "INT64",
},
ExpressionId: "expr1",
Expression: "SELECT 1",
Type: "DEFAULT",
Metadata: map[string]string{"TableId": "table1", "ColId": "col1"},
},
},
},
{
name: "no tables",
conv: conv,
tableIds: []string{},
expectedDetails: []internal.ExpressionDetail{},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ddlv := &expressions_api.DDLVerifierImpl{}
actualDetails := ddlv.GetSourceExpressionDetails(tc.conv, tc.tableIds)
assert.Equal(t, tc.expectedDetails, actualDetails)
})
}
}

func TestGetSpannerExpressionDetails(t *testing.T) {
conv := internal.MakeConv()
conv.SpSchema = ddl.Schema{
"table1": {
ColIds: []string{"col1", "col2"},
ColDefs: map[string]ddl.ColumnDef{
"col1": {
DefaultValue: ddl.DefaultValue{
IsPresent: true,
Value: ddl.Expression{
ExpressionId: "expr1",
Statement: "SELECT 1",
},
},
},
"col2": {
DefaultValue: ddl.DefaultValue{},
},
},
},
}

testCases := []struct {
name string
conv *internal.Conv
tableIds []string
expectedDetails []internal.ExpressionDetail
}{
{
name: "single table with default value",
conv: conv,
tableIds: []string{"table1"},
expectedDetails: []internal.ExpressionDetail{
{
ReferenceElement: internal.ReferenceElement{
Name: conv.SpSchema["table1"].ColDefs["col1"].T.Name,
},
ExpressionId: "expr1",
Expression: "SELECT 1",
Type: "DEFAULT",
Metadata: map[string]string{"TableId": "table1", "ColId": "col1"},
},
},
},
{
name: "no tables",
conv: conv,
tableIds: []string{},
expectedDetails: []internal.ExpressionDetail{},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ddlv := &expressions_api.DDLVerifierImpl{}
actualDetails := ddlv.GetSpannerExpressionDetails(tc.conv, tc.tableIds)
assert.Equal(t, tc.expectedDetails, actualDetails)
})
}
}
Loading

0 comments on commit 971a2de

Please sign in to comment.