From a947cc84e4005a2f3d63553a8d035ab2599456bf Mon Sep 17 00:00:00 2001 From: Nhat Date: Wed, 20 Nov 2024 15:29:21 +0700 Subject: [PATCH] Support revert old column attributes (#58) Support revert old column attributes --- README.md | 2 + README_zh.md | 10 +++-- avro/builder.go | 6 +-- element/column.go | 88 +++++++++++++++++++++++++++++----------- element/migration.go | 2 +- element/node.go | 2 + element/table.go | 19 +++++---- sql-parser/mysql.go | 42 +++++++++++-------- sql-parser/postgresql.go | 26 +++++++----- sql-parser/sqlite.go | 12 ++++-- 10 files changed, 138 insertions(+), 71 deletions(-) diff --git a/README.md b/README.md index eb3a6eb..959679b 100644 --- a/README.md +++ b/README.md @@ -237,6 +237,8 @@ CREATE TABLE user ( //CREATE UNIQUE INDEX `idx_name_age` ON `user`(`name`, `age`); println(sql1.StringDown()) + //ALTER TABLE `user` MODIFY COLUMN `id` int(11); + //ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime; //DROP INDEX `idx_name_age` ON `user`; } ``` \ No newline at end of file diff --git a/README_zh.md b/README_zh.md index babf143..0a7f790 100644 --- a/README_zh.md +++ b/README_zh.md @@ -232,11 +232,13 @@ CREATE TABLE user ( sql1.Diff(*sql2) println(sql1.StringUp()) - // ALTER TABLE `user` MODIFY COLUMN `id` int(11) AUTO_INCREMENT PRIMARY KEY; - // ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP(); - // CREATE UNIQUE INDEX `idx_name_age` ON `user`(`name`, `age`); + //ALTER TABLE `user` MODIFY COLUMN `id` int(11) AUTO_INCREMENT PRIMARY KEY; + //ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP(); + //CREATE UNIQUE INDEX `idx_name_age` ON `user`(`name`, `age`); println(sql1.StringDown()) - // DROP INDEX `idx_name_age` ON `user`; + //ALTER TABLE `user` MODIFY COLUMN `id` int(11); + //ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime; + //DROP INDEX `idx_name_age` ON `user`; } ``` \ No newline at end of file diff --git a/avro/builder.go b/avro/builder.go index 62946a8..a7b593d 100644 --- a/avro/builder.go +++ b/avro/builder.go @@ -59,19 +59,19 @@ func getAvroType(col element.Column) interface{} { "type": "string", "connect.version": 1, "connect.parameters": map[string]string{ - "allowed": strings.Join(col.MysqlType.Elems, ","), + "allowed": strings.Join(col.CurrentAttr.MysqlType.Elems, ","), }, "connect.default": "init", "connect.name": "io.debezium.data.Enum", } } - switch col.MysqlType.EvalType() { + switch col.CurrentAttr.MysqlType.EvalType() { case types.ETInt: return "int" case types.ETDecimal: - displayFlen, displayDecimal := col.MysqlType.Flen, col.MysqlType.Decimal + displayFlen, displayDecimal := col.CurrentAttr.MysqlType.Flen, col.CurrentAttr.MysqlType.Decimal return map[string]interface{}{ "type": "bytes", "scale": displayDecimal, diff --git a/element/column.go b/element/column.go index 6c126de..6609b15 100644 --- a/element/column.go +++ b/element/column.go @@ -22,9 +22,7 @@ const ( LowerRestoreFlag = format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameLowercase | format.RestoreNameBackQuotes ) -// Column ... -type Column struct { - Node +type SqlAttr struct { MysqlType *types.FieldType PgType *ptypes.T LiteType *sqlite.Type @@ -32,10 +30,18 @@ type Column struct { Comment string } +// Column ... +type Column struct { + Node + + CurrentAttr SqlAttr + PreviousAttr SqlAttr +} + // GetType ... func (c Column) GetType() byte { - if c.MysqlType != nil { - return c.MysqlType.Tp + if c.CurrentAttr.MysqlType != nil { + return c.CurrentAttr.MysqlType.Tp } return 0 @@ -43,7 +49,7 @@ func (c Column) GetType() byte { // HasDefaultValue ... func (c Column) HasDefaultValue() bool { - for _, opt := range c.Options { + for _, opt := range c.CurrentAttr.Options { if opt.Tp == ast.ColumnOptionDefaultValue { return true } @@ -54,7 +60,7 @@ func (c Column) HasDefaultValue() bool { func (c Column) hashValue() string { strHash := sql.EscapeSqlName(c.Name) - strHash += c.typeDefinition() + strHash += c.typeDefinition(false) hash := md5.Sum([]byte(strHash)) return hex.EncodeToString(hash[:]) } @@ -71,7 +77,7 @@ func (c Column) migrationUp(tbName, after string, ident int) []string { strSql += strings.Repeat(" ", ident-len(c.Name)) } - strSql += c.definition() + strSql += c.definition(false) if ident < 0 { if after != "" { @@ -90,10 +96,27 @@ func (c Column) migrationUp(tbName, after string, ident int) []string { return []string{fmt.Sprintf(sql.AlterTableDropColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.Name))} case MigrateModifyAction: - def := strings.Replace(c.definition(), sql.PrimaryOption(), "", 1) + def, isPk := c.pkDefinition(false) + if isPk { + if _, isPrevPk := c.pkDefinition(true); isPrevPk { + // avoid repeat define primary key + def = strings.Replace(def, " "+sql.PrimaryOption(), "", 1) + } + } return []string{fmt.Sprintf(sql.AlterTableModifyColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.Name)+def)} + case MigrateRevertAction: + prevDef, isPrevPk := c.pkDefinition(true) + if isPrevPk { + if _, isPk := c.pkDefinition(false); isPk { + // avoid repeat define primary key + prevDef = strings.Replace(prevDef, " "+sql.PrimaryOption(), "", 1) + } + } + + return []string{fmt.Sprintf(sql.AlterTableModifyColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.Name)+prevDef)} + case MigrateRenameAction: return []string{fmt.Sprintf(sql.AlterTableRenameColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.OldName), sql.EscapeSqlName(c.Name))} @@ -103,12 +126,12 @@ func (c Column) migrationUp(tbName, after string, ident int) []string { } func (c Column) migrationCommentUp(tbName string) []string { - if c.Comment == "" || sql.GetDialect() != sql_templates.PostgresDialect { + if c.CurrentAttr.Comment == "" || sql.GetDialect() != sql_templates.PostgresDialect { return nil } // apply for postgres only - return []string{fmt.Sprintf(sql.ColumnComment(), tbName, c.Name, c.Comment)} + return []string{fmt.Sprintf(sql.ColumnComment(), tbName, c.Name, c.CurrentAttr.Comment)} } func (c Column) migrationDown(tbName, after string) []string { @@ -123,7 +146,7 @@ func (c Column) migrationDown(tbName, after string) []string { c.Action = MigrateAddAction case MigrateModifyAction: - return nil + c.Action = MigrateRevertAction case MigrateRenameAction: c.Name, c.OldName = c.OldName, c.Name @@ -135,10 +158,19 @@ func (c Column) migrationDown(tbName, after string) []string { return c.migrationUp(tbName, after, -1) } -func (c Column) definition() string { - strSql := c.typeDefinition() +func (c Column) pkDefinition(isPrev bool) (string, bool) { + attr := c.CurrentAttr + if isPrev { + attr = c.PreviousAttr + } + strSql := c.typeDefinition(isPrev) + + isPrimaryKey := false + for _, opt := range attr.Options { + if opt.Tp == ast.ColumnOptionPrimaryKey { + isPrimaryKey = true + } - for _, opt := range c.Options { b := bytes.NewBufferString("") var ctx *format.RestoreCtx @@ -157,17 +189,27 @@ func (c Column) definition() string { strSql += " " + b.String() } - return strSql + return strSql, isPrimaryKey +} + +func (c Column) definition(isPrev bool) string { + def, _ := c.pkDefinition(isPrev) + return def } -func (c Column) typeDefinition() string { +func (c Column) typeDefinition(isPrev bool) string { + attr := c.CurrentAttr + if isPrev { + attr = c.PreviousAttr + } + switch { - case sql.IsPostgres() && c.PgType != nil: - return " " + c.PgType.SQLString() - case sql.IsSqlite() && c.LiteType != nil: - return " " + c.LiteType.Name.Name - case c.MysqlType != nil: - return " " + c.MysqlType.String() + case sql.IsPostgres() && attr.PgType != nil: + return " " + attr.PgType.SQLString() + case sql.IsSqlite() && attr.LiteType != nil: + return " " + attr.LiteType.Name.Name + case attr.MysqlType != nil: + return " " + attr.MysqlType.String() } return "" // column type is empty diff --git a/element/migration.go b/element/migration.go index 3e330e5..6828aac 100644 --- a/element/migration.go +++ b/element/migration.go @@ -150,7 +150,7 @@ func (m *Migration) AddComment(tbName, colName, comment string) { return } - m.Tables[id].Columns[colIdx].Comment = comment + m.Tables[id].Columns[colIdx].CurrentAttr.Comment = comment } // AddIndex ... diff --git a/element/node.go b/element/node.go index 83ea0c0..5253188 100644 --- a/element/node.go +++ b/element/node.go @@ -12,6 +12,8 @@ const ( MigrateRemoveAction // MigrateModifyAction ... MigrateModifyAction + // MigrateRevertAction ... + MigrateRevertAction // MigrateRenameAction ... MigrateRenameAction ) diff --git a/element/table.go b/element/table.go index 417d27c..314db19 100644 --- a/element/table.go +++ b/element/table.go @@ -62,18 +62,18 @@ func (t *Table) AddColumn(col Column) { t.Columns[id] = col default: - t.Columns[id].Options = append(t.Columns[id].Options, col.Options...) + t.Columns[id].CurrentAttr.Options = append(t.Columns[id].CurrentAttr.Options, col.CurrentAttr.Options...) - if size := len(t.Columns[id].Options); size > 0 { - for i := range t.Columns[id].Options[:size-1] { - if t.Columns[id].Options[i].Tp == ast.ColumnOptionPrimaryKey { - t.Columns[id].Options[i], t.Columns[id].Options[size-1] = t.Columns[id].Options[size-1], t.Columns[id].Options[i] + if size := len(t.Columns[id].CurrentAttr.Options); size > 0 { + for i := range t.Columns[id].CurrentAttr.Options[:size-1] { + if t.Columns[id].CurrentAttr.Options[i].Tp == ast.ColumnOptionPrimaryKey { + t.Columns[id].CurrentAttr.Options[i], t.Columns[id].CurrentAttr.Options[size-1] = t.Columns[id].CurrentAttr.Options[size-1], t.Columns[id].CurrentAttr.Options[i] break } } } - t.Columns[id].MysqlType = col.MysqlType + t.Columns[id].CurrentAttr.MysqlType = col.CurrentAttr.MysqlType return } @@ -291,10 +291,11 @@ func (t *Table) Diff(old Table) { for i := range t.Columns { if j := old.getIndexColumn(t.Columns[i].Name); t.Columns[i].Action == MigrateAddAction && j >= 0 && old.Columns[j].Action != MigrateNoAction { - if hasChangedMysqlOptions(t.Columns[i].Options, old.Columns[j].Options) || - hasChangedMysqlType(t.Columns[i].MysqlType, old.Columns[j].MysqlType) || - hasChangePostgresType(t.Columns[i].PgType, old.Columns[j].PgType) { + if hasChangedMysqlOptions(t.Columns[i].CurrentAttr.Options, old.Columns[j].CurrentAttr.Options) || + hasChangedMysqlType(t.Columns[i].CurrentAttr.MysqlType, old.Columns[j].CurrentAttr.MysqlType) || + hasChangePostgresType(t.Columns[i].CurrentAttr.PgType, old.Columns[j].CurrentAttr.PgType) { t.Columns[i].Action = MigrateModifyAction + t.Columns[i].PreviousAttr = old.Columns[j].CurrentAttr } else { t.Columns[i].Action = MigrateNoAction } diff --git a/sql-parser/mysql.go b/sql-parser/mysql.go index cf440e1..999a0da 100644 --- a/sql-parser/mysql.go +++ b/sql-parser/mysql.go @@ -73,11 +73,13 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) { }) } else { p.Migration.AddColumn(alter.Table.Text(), element.Column{ - Node: element.Node{Name: cols[0], Action: element.MigrateAddAction}, - MysqlType: nil, - Options: []*ast.ColumnOption{ - { - Tp: ast.ColumnOptionPrimaryKey, + Node: element.Node{Name: cols[0], Action: element.MigrateAddAction}, + CurrentAttr: element.SqlAttr{ + MysqlType: nil, + Options: []*ast.ColumnOption{ + { + Tp: ast.ColumnOptionPrimaryKey, + }, }, }, }) @@ -113,9 +115,11 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) { if len(alter.Specs[i].NewColumns) > 0 { for j := range alter.Specs[i].NewColumns { col := element.Column{ - Node: element.Node{Name: alter.Specs[i].NewColumns[j].Name.Name.O, Action: element.MigrateModifyAction}, - MysqlType: alter.Specs[i].NewColumns[j].Tp, - Comment: alter.Specs[i].Comment, + Node: element.Node{Name: alter.Specs[i].NewColumns[j].Name.Name.O, Action: element.MigrateModifyAction}, + CurrentAttr: element.SqlAttr{ + MysqlType: alter.Specs[i].NewColumns[j].Tp, + Comment: alter.Specs[i].Comment, + }, } p.Migration.AddColumn(alter.Table.Name.O, col) } @@ -161,11 +165,13 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) { }) } else { tb.AddColumn(element.Column{ - Node: element.Node{Name: cols[0], Action: element.MigrateAddAction}, - MysqlType: nil, - Options: []*ast.ColumnOption{ - { - Tp: ast.ColumnOptionPrimaryKey, + Node: element.Node{Name: cols[0], Action: element.MigrateAddAction}, + CurrentAttr: element.SqlAttr{ + MysqlType: nil, + Options: []*ast.ColumnOption{ + { + Tp: ast.ColumnOptionPrimaryKey, + }, }, }, }) @@ -218,10 +224,12 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) { } column := element.Column{ - Node: element.Node{Name: def.Name.Name.O, Action: element.MigrateAddAction}, - MysqlType: def.Tp, - Options: def.Options, - Comment: comment, + Node: element.Node{Name: def.Name.Name.O, Action: element.MigrateAddAction}, + CurrentAttr: element.SqlAttr{ + MysqlType: def.Tp, + Options: def.Options, + Comment: comment, + }, } p.Migration.AddColumn("", column) } diff --git a/sql-parser/postgresql.go b/sql-parser/postgresql.go index 66b93f9..a9803ac 100644 --- a/sql-parser/postgresql.go +++ b/sql-parser/postgresql.go @@ -92,8 +92,10 @@ func (p *Parser) walker(ctx interface{}, node interface{}) (stop bool) { case *tree.AlterTableAlterColumnType: col := element.Column{ - Node: element.Node{Name: nc.Column.String(), Action: element.MigrateModifyAction}, - PgType: nc.ToType, + Node: element.Node{Name: nc.Column.String(), Action: element.MigrateModifyAction}, + CurrentAttr: element.SqlAttr{ + PgType: nc.ToType, + }, } p.Migration.AddColumn(n.Table.String(), col) @@ -101,11 +103,13 @@ func (p *Parser) walker(ctx interface{}, node interface{}) (stop bool) { if nc.Default != nil { col := element.Column{ Node: element.Node{Name: nc.Column.String(), Action: element.MigrateModifyAction}, - Options: []*ast.ColumnOption{{ - Expr: nil, - Tp: ast.ColumnOptionDefaultValue, - StrValue: nc.Default.String(), - }}, + CurrentAttr: element.SqlAttr{ + Options: []*ast.ColumnOption{{ + Expr: nil, + Tp: ast.ColumnOptionDefaultValue, + StrValue: nc.Default.String(), + }}, + }, } p.Migration.AddColumn(n.Table.String(), col) } @@ -166,9 +170,11 @@ func postgresColumn(n *tree.ColumnTableDef) (element.Column, []element.Index) { } return element.Column{ - Node: element.Node{Name: n.Name.String(), Action: element.MigrateAddAction}, - PgType: n.Type, - Options: opts, + Node: element.Node{Name: n.Name.String(), Action: element.MigrateAddAction}, + CurrentAttr: element.SqlAttr{ + PgType: n.Type, + Options: opts, + }, }, indexes } diff --git a/sql-parser/sqlite.go b/sql-parser/sqlite.go index e05a8df..c07ebff 100644 --- a/sql-parser/sqlite.go +++ b/sql-parser/sqlite.go @@ -46,8 +46,10 @@ func (p *Parser) Visit(node sqlite.Node) (w sqlite.Visitor, err error) { Name: n.Columns[i].Name.String(), Action: element.MigrateAddAction, }, - LiteType: n.Columns[i].Type, - Options: p.parseSqliteConstrains(tbName, n.Columns[i].Constraints), + CurrentAttr: element.SqlAttr{ + LiteType: n.Columns[i].Type, + Options: p.parseSqliteConstrains(tbName, n.Columns[i].Constraints), + }, } p.Migration.AddColumn(tbName, col) @@ -118,8 +120,10 @@ func (p *Parser) Visit(node sqlite.Node) (w sqlite.Visitor, err error) { Name: n.ColumnDef.Name.String(), Action: element.MigrateAddAction, }, - LiteType: n.ColumnDef.Type, - Options: p.parseSqliteConstrains(tbName, n.ColumnDef.Constraints), + CurrentAttr: element.SqlAttr{ + LiteType: n.ColumnDef.Type, + Options: p.parseSqliteConstrains(tbName, n.ColumnDef.Constraints), + }, }) } }