Skip to content

Commit

Permalink
Support check constraint backend (GoogleCloudPlatform#962)
Browse files Browse the repository at this point in the history
* Check constraint backend (#9)

Backend Support for Check Constraint

* update api

* fix PR comment

* remove api call to while validating constraints

* Fixed db collation regex to remove collation name from the results

* renamed function name to formatCheckConstraints and added check if constraint name is empty

* fixed PR comments

* added test case for the empty check constraint name

* fix: added regular exprression to match the exact column

* fix: added regular expression to replace table name

* Added test case for the column rename for check constraint

* 1. Refactored GetConstraint function
2. Fixed inforschema unit tests

* added comment at handling case for check constraints

* reverted white spaces

* reverted white spaces

* nit: doesCheckConstraintNameExist

* added comments for doesCheckConstraintNameExist

* PR and UT fixes

* fix UT

* UT fix

* Removed isCheckConstraintsTablePresent function

* moved regex globally

* Fix UT

* fixed UT

* fixed handling of the constraints

* removed unused function

* added unit tests for incompatable name

* Combined unit tests

* added test case for the renaming column having substring of other column

* added the query changes which return distinct value

---------

Co-authored-by: taherkl <taher.lakdawala@ollion.com>
Co-authored-by: Akash Thawait <aakash@ollion.com>
Co-authored-by: Vivek Yadav <vivek.yadav@ollion.com>
  • Loading branch information
4 people committed Dec 23, 2024
1 parent 34480de commit 84b09c2
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 106 deletions.
1 change: 0 additions & 1 deletion internal/mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ func ToSpannerCheckConstraintName(conv *Conv, srcCheckConstraintName string) str
return getSpannerValidName(conv, srcCheckConstraintName)
}


// conv.UsedNames tracks Spanner names that have been used for table names, foreign key constraints
// and indexes. We use this to ensure we generate unique names when
// we map from source dbs to Spanner since Spanner requires all these names to be
Expand Down
8 changes: 4 additions & 4 deletions sources/common/toddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ func (ss *SchemaToSpannerImpl) SchemaToSpannerDDLHelper(conv *internal.Conv, tod
Comment: comment,
Id: srcTable.Id,
}

return nil
}

Expand Down Expand Up @@ -360,9 +359,10 @@ func cvtCheckConstraint(conv *internal.Conv, srcKeys []schema.CheckConstraint) [

for _, cc := range srcKeys {
spcc = append(spcc, ddl.CheckConstraint{
Id: cc.Id,
Name: internal.ToSpannerCheckConstraintName(conv, cc.Name),
Expr: cc.Expr,
Id: cc.Id,
Name: internal.ToSpannerCheckConstraintName(conv, cc.Name),
Expr: cc.Expr,
ExprId: cc.ExprId,
})
}
return spcc
Expand Down
40 changes: 28 additions & 12 deletions sources/common/toddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,14 +438,22 @@ func Test_cvtCheckContraint(t *testing.T) {
conv := internal.MakeConv()
srcSchema := []schema.CheckConstraint{
{
Id: "cc1",
Name: "check_1",
Expr: "age > 0",
Id: "cc1",
Name: "check_1",
Expr: "age > 0",
ExprId: "expr1",
},
{
Id: "cc2",
Name: "check_2",
Expr: "age < 99",
Id: "cc2",
Name: "check_2",
Expr: "age < 99",
ExprId: "expr2",
},
{
Id: "cc3",
Name: "@invalid_name", // incompatabile name
Expr: "age != 0",
ExprId: "expr3",
},
{
Id: "cc3",
Expand All @@ -455,14 +463,22 @@ func Test_cvtCheckContraint(t *testing.T) {
}
spSchema := []ddl.CheckConstraint{
{
Id: "cc1",
Name: "check_1",
Expr: "age > 0",
Id: "cc1",
Name: "check_1",
Expr: "age > 0",
ExprId: "expr1",
},
{
Id: "cc2",
Name: "check_2",
Expr: "age < 99",
ExprId: "expr2",
},
{
Id: "cc2",
Name: "check_2",
Expr: "age < 99",
Id: "cc3",
Name: "Ainvalid_name",
Expr: "age != 0",
ExprId: "expr3",
},
{
Id: "cc3",
Expand Down
2 changes: 1 addition & 1 deletion sources/mysql/infoschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ func (isi InfoSchemaImpl) processRow(
// Case added to handle check constraints
case "CHECK":
checkClause = collationRegex.ReplaceAllString(checkClause, "")
*checkKeys = append(*checkKeys, schema.CheckConstraint{Name: constraintName, Expr: checkClause, Id: internal.GenerateCheckConstrainstId()})
*checkKeys = append(*checkKeys, schema.CheckConstraint{Name: constraintName, Expr: checkClause, ExprId: internal.GenerateCheckConstrainstExprId(), Id: internal.GenerateCheckConstrainstId()})
default:
m[col] = append(m[col], constraintType)
}
Expand Down
2 changes: 1 addition & 1 deletion sources/mysql/infoschema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ func TestProcessData_MultiCol(t *testing.T) {
}
internal.AssertSpSchema(conv, t, expectedSchema, stripSchemaComments(conv.SpSchema))
columnLevelIssues := map[string][]internal.SchemaIssue{
"c5": {
"c49": {
2,
},
}
Expand Down
34 changes: 0 additions & 34 deletions webv2/api/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,40 +531,6 @@ func UpdateCheckConstraint(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(convm)
}

// UpdateCheckConstraint processes the request to update spanner table check constraints, ensuring session and schema validity, and responds with the updated conversion metadata.
func UpdateCheckConstraint(w http.ResponseWriter, r *http.Request) {
tableId := r.FormValue("table")
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError)
}
sessionState := session.GetSessionState()
if sessionState.Conv == nil || sessionState.Driver == "" {
http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound)
return
}
sessionState.Conv.ConvLock.Lock()
defer sessionState.Conv.ConvLock.Unlock()

newCKs := []ddl.CheckConstraint{}
if err = json.Unmarshal(reqBody, &newCKs); err != nil {
http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest)
return
}

sp := sessionState.Conv.SpSchema[tableId]
sp.CheckConstraints = newCKs
sessionState.Conv.SpSchema[tableId] = sp
session.UpdateSessionFile()

convm := session.ConvWithMetadata{
SessionMetadata: sessionState.SessionMetadata,
Conv: *sessionState.Conv,
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(convm)
}

// findColId based on constraint condition it will return colId.
func findColId(colDefs map[string]ddl.ColumnDef, condition string) string {
for _, colDef := range colDefs {
Expand Down
106 changes: 53 additions & 53 deletions webv2/api/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2550,83 +2550,83 @@ func TestGetAutoGenMapMySQL(t *testing.T) {
}

func TestUpdateCheckConstraint(t *testing.T) {
t.Run("ValidCheckConstraints", func(t *testing.T) {
sessionState := session.GetSessionState()
sessionState.Driver = constants.MYSQL
sessionState.Conv = internal.MakeConv()
t.Run("ValidCheckConstraints", func(t *testing.T) {
sessionState := session.GetSessionState()
sessionState.Driver = constants.MYSQL
sessionState.Conv = internal.MakeConv()

tableID := "table1"
tableID := "table1"

expectedCheckConstraint := []ddl.CheckConstraint{
{Id: "cc1", Name: "check_1", Expr: "(age > 18)"},
{Id: "cc2", Name: "check_2", Expr: "(age < 99)"},
}
expectedCheckConstraint := []ddl.CheckConstraint{
{Id: "cc1", Name: "check_1", Expr: "(age > 18)"},
{Id: "cc2", Name: "check_2", Expr: "(age < 99)"},
}

checkConstraints := []schema.CheckConstraint{
{Id: "cc1", Name: "check_1", Expr: "(age > 18)"},
{Id: "cc2", Name: "check_2", Expr: "(age < 99)"},
}
checkConstraints := []schema.CheckConstraint{
{Id: "cc1", Name: "check_1", Expr: "(age > 18)"},
{Id: "cc2", Name: "check_2", Expr: "(age < 99)"},
}

body, err := json.Marshal(checkConstraints)
assert.NoError(t, err)
body, err := json.Marshal(checkConstraints)
assert.NoError(t, err)

req, err := http.NewRequest("POST", "update/cc", bytes.NewBuffer(body))
assert.NoError(t, err)
req, err := http.NewRequest("POST", "update/cc", bytes.NewBuffer(body))
assert.NoError(t, err)

q := req.URL.Query()
q.Add("table", tableID)
req.URL.RawQuery = q.Encode()
q := req.URL.Query()
q.Add("table", tableID)
req.URL.RawQuery = q.Encode()

rr := httptest.NewRecorder()
handler := http.HandlerFunc(api.UpdateCheckConstraint)
handler.ServeHTTP(rr, req)
rr := httptest.NewRecorder()
handler := http.HandlerFunc(api.UpdateCheckConstraint)
handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, http.StatusOK, rr.Code)

updatedSp := sessionState.Conv.SpSchema[tableID]
assert.Equal(t, expectedCheckConstraint, updatedSp.CheckConstraints)
})
updatedSp := sessionState.Conv.SpSchema[tableID]
assert.Equal(t, expectedCheckConstraint, updatedSp.CheckConstraints)
})

t.Run("ParseError", func(t *testing.T) {
sessionState := session.GetSessionState()
sessionState.Driver = constants.MYSQL
sessionState.Conv = internal.MakeConv()
t.Run("ParseError", func(t *testing.T) {
sessionState := session.GetSessionState()
sessionState.Driver = constants.MYSQL
sessionState.Conv = internal.MakeConv()

invalidJSON := "invalid json body"
invalidJSON := "invalid json body"

rr := httptest.NewRecorder()
req, err := http.NewRequest("POST", "update/cc", io.NopCloser(strings.NewReader(invalidJSON)))
assert.NoError(t, err)
rr := httptest.NewRecorder()
req, err := http.NewRequest("POST", "update/cc", io.NopCloser(strings.NewReader(invalidJSON)))
assert.NoError(t, err)

handler := http.HandlerFunc(api.UpdateCheckConstraint)
handler.ServeHTTP(rr, req)
handler := http.HandlerFunc(api.UpdateCheckConstraint)
handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusBadRequest, rr.Code)
assert.Equal(t, http.StatusBadRequest, rr.Code)

expectedErrorMessage := "Request Body parse error"
assert.Contains(t, rr.Body.String(), expectedErrorMessage)
})
expectedErrorMessage := "Request Body parse error"
assert.Contains(t, rr.Body.String(), expectedErrorMessage)
})

t.Run("ImproperSession", func(t *testing.T) {
sessionState := session.GetSessionState()
sessionState.Conv = nil // Simulate no conversion
t.Run("ImproperSession", func(t *testing.T) {
sessionState := session.GetSessionState()
sessionState.Conv = nil // Simulate no conversion

rr := httptest.NewRecorder()
req, err := http.NewRequest("POST", "update/cc", io.NopCloser(errReader{}))
assert.NoError(t, err)
rr := httptest.NewRecorder()
req, err := http.NewRequest("POST", "update/cc", io.NopCloser(errReader{}))
assert.NoError(t, err)

handler := http.HandlerFunc(api.UpdateCheckConstraint)
handler.ServeHTTP(rr, req)
handler := http.HandlerFunc(api.UpdateCheckConstraint)
handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusInternalServerError, rr.Code)
assert.Contains(t, rr.Body.String(), "Schema is not converted or Driver is not configured properly")
})
assert.Equal(t, http.StatusInternalServerError, rr.Code)
assert.Contains(t, rr.Body.String(), "Schema is not converted or Driver is not configured properly")
})
}

type errReader struct{}

func (errReader) Read(p []byte) (n int, err error) {
return 0, fmt.Errorf("simulated read error")
return 0, fmt.Errorf("simulated read error")
}

func TestVerifyCheckConstraintExpressions(t *testing.T) {
Expand Down

0 comments on commit 84b09c2

Please sign in to comment.