From 473ec690564a951cebe27ff011d1550193849dd8 Mon Sep 17 00:00:00 2001 From: Attila Olbrich Date: Wed, 29 May 2024 19:41:32 +0100 Subject: [PATCH] Huge refactor, hiding non public properties and functions --- cmd/cmd-run.go.bak | 14 +++++--- cmd/cmd.go | 1 + migration.go | 12 +++---- migration_db_provider.go | 52 ++++++++++++++-------------- migration_json_provider.go | 24 ++++++------- migration_provider.go | 13 ++++--- migration_table_firebird_provider.go | 6 ++-- migration_table_mysql_provider.go | 6 ++-- migration_table_postgres_provider.go | 6 ++-- migration_table_sql_provider.go | 18 +++++----- migration_table_sqlite_provider.go | 6 ++-- migrator.go | 26 ++++++++------ test/migrator_report_json_test.go | 2 +- 13 files changed, 101 insertions(+), 85 deletions(-) diff --git a/cmd/cmd-run.go.bak b/cmd/cmd-run.go.bak index 6347a9e..9f012e6 100644 --- a/cmd/cmd-run.go.bak +++ b/cmd/cmd-run.go.bak @@ -1,3 +1,4 @@ +// main package is for testing only locally, rename it to use package main // rename cmd.go to bak, then rename this to .go and test your changes @@ -14,9 +15,13 @@ import ( func main() { migrationFilePath := "./migrations" - dbType, provider, function, count := params() + dbType, provider, function, count, add := params() - fmt.Printf("Running with %s, %s, %s, %d\n", dbType, provider, function, count) + fmt.Printf("Running with %s, %s, %s, %d %s\n", dbType, provider, function, count, add) + + if add != "" { + migrator.AddNewMigrationFiles(add, "") + } db, err := getConnection(dbType) if err != nil { @@ -55,14 +60,15 @@ func main() { } } -func params() (string, string, string, int) { +func params() (string, string, string, int, string) { db := flag.String("db", "sqlite", "Database driver name") function := flag.String("function", "migrate", "function=migrate/rollback/report") count := flag.Int("count", 0, "count=1") provider := flag.String("provider", "db", "provider=db/json") + add := flag.String("add", "", "--add=filename") flag.Parse() - return *db, *provider, *function, *count + return *db, *provider, *function, *count, *add } func connectToFirebaseDatabase() (*sql.DB, error) { diff --git a/cmd/cmd.go b/cmd/cmd.go index 7c846f1..595a7e5 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,3 +1,4 @@ +// Main package is to display basic package info package main import "fmt" diff --git a/migration.go b/migration.go index 1f8fce2..162bed6 100644 --- a/migration.go +++ b/migration.go @@ -10,7 +10,7 @@ import ( type migration struct { db *sql.DB - migrationProvider MigrationProvider + migrationProvider migrationProvider migrationFilePath string } @@ -41,7 +41,7 @@ func (m *migration) orderedMigrationFiles() ([]string, error) { return fileNames, nil } -func (m *migration) executeSqlFile(fileName string) (bool, error) { +func (m *migration) executeSQLFile(fileName string) (bool, error) { exists, err := m.migrationProvider.migrationExistsForFile(fileName) if err != nil { return false, err @@ -57,7 +57,7 @@ func (m *migration) executeSqlFile(fileName string) (bool, error) { return false, err } - err = m.executeSql(string(content)) + err = m.executeSQL(string(content)) if err == nil { err = m.migrationProvider.addToMigration(fileName) if err != nil { @@ -70,7 +70,7 @@ func (m *migration) executeSqlFile(fileName string) (bool, error) { return true, err } -func (m *migration) executeRollbackSqlFile(fileName string) error { +func (m *migration) executeRollbackSQLFile(fileName string) error { rollbackFileName, err := m.resolveRollbackFile(fileName) if err != nil { fmt.Printf("Skip rollback for %s as rollback file does not exists\n", fileName) @@ -88,7 +88,7 @@ func (m *migration) executeRollbackSqlFile(fileName string) error { return err } - err = m.executeSql(string(content)) + err = m.executeSQL(string(content)) if err == nil { err = m.migrationProvider.removeFromMigration(fileName) if err != nil { @@ -101,7 +101,7 @@ func (m *migration) executeRollbackSqlFile(fileName string) error { return err } -func (m *migration) executeSql(sql string) error { +func (m *migration) executeSQL(sql string) error { tx, err := m.db.Begin() if err != nil { return err diff --git a/migration_db_provider.go b/migration_db_provider.go index 2db9d84..f42caa7 100644 --- a/migration_db_provider.go +++ b/migration_db_provider.go @@ -12,11 +12,11 @@ import ( const ( dbTypeSqlite = "sqlite" dbTypePostgres = "pg" - dbTypeMySql = "mysql" + dbTypeMySQL = "mysql" dbTypeFirebird = "firebird" ) -type DbMigration struct { +type dbMigration struct { db *sql.DB timeString string sqlBindingParameter string @@ -29,31 +29,31 @@ type reportRow struct { Message string } -func newDbMigration(db *sql.DB) (*DbMigration, error) { - dbMigration := &DbMigration{db: db} +func newDbMigration(db *sql.DB) (*dbMigration, error) { + dbMigration := &dbMigration{db: db} dbMigration.resetDate() driverType, err := dbMigration.diverType() if err != nil { return nil, err } - dbMigration.setSqlBindingParameter(driverType) - createSqlProvider, err := MigrationTableProviderByDriverName(driverType) + dbMigration.setSQLBindingParameter(driverType) + createSQLProvider, err := migrationTableProviderByDriverName(driverType) if err != nil { return nil, err } - err = dbMigration.Init(createSqlProvider) + err = dbMigration.init(createSQLProvider) if err != nil { return nil, err } return dbMigration, nil } -func (m *DbMigration) resetDate() { +func (m *dbMigration) resetDate() { m.timeString = time.Now().Format("2006-01-02 15:04:05") } -func (m *DbMigration) migrations(isLatest bool) ([]string, error) { +func (m *dbMigration) migrations(isLatest bool) ([]string, error) { var migrationList []string var rows *sql.Rows var err error @@ -94,7 +94,7 @@ func (m *DbMigration) migrations(isLatest bool) ([]string, error) { return migrationList, nil } -func (m *DbMigration) latestMigrations(lastMigrationDate string) (*sql.Rows, error) { +func (m *dbMigration) latestMigrations(lastMigrationDate string) (*sql.Rows, error) { return m.db.Query(fmt.Sprintf( `SELECT file_name FROM migrations @@ -105,7 +105,7 @@ func (m *DbMigration) latestMigrations(lastMigrationDate string) (*sql.Rows, err ), lastMigrationDate) } -func (m *DbMigration) allMigrations() (*sql.Rows, error) { +func (m *dbMigration) allMigrations() (*sql.Rows, error) { return m.db.Query( `SELECT file_name FROM migrations @@ -114,7 +114,7 @@ func (m *DbMigration) allMigrations() (*sql.Rows, error) { ) } -func (m *DbMigration) addToMigration(fileName string) error { +func (m *dbMigration) addToMigration(fileName string) error { sql := fmt.Sprintf(`INSERT INTO migrations (file_name, created_at) VALUES (%s, %s)`, @@ -128,7 +128,7 @@ func (m *DbMigration) addToMigration(fileName string) error { } -func (m *DbMigration) removeFromMigration(fileName string) error { +func (m *dbMigration) removeFromMigration(fileName string) error { sql := fmt.Sprintf(`UPDATE migrations SET deleted_at = %s WHERE file_name = %s @@ -142,7 +142,7 @@ func (m *DbMigration) removeFromMigration(fileName string) error { return err } -func (m *DbMigration) migrationExistsForFile(fileName string) (bool, error) { +func (m *dbMigration) migrationExistsForFile(fileName string) (bool, error) { sql := fmt.Sprintf(`SELECT count(*) as cnt FROM migrations WHERE file_name = %s @@ -167,21 +167,21 @@ func (m *DbMigration) migrationExistsForFile(fileName string) (bool, error) { return cnt > 0, nil } -func (m *DbMigration) Init(createSqlProvider MigrationTableSqlProvider) error { - sql := createSqlProvider.CreateMigrationSql() +func (m *dbMigration) init(createSQLProvider migrationTableSQLProvider) error { + sql := createSQLProvider.createMigrationSQL() _, err := m.db.Exec(sql) if err != nil { return err } - sql = createSqlProvider.CreateReportSql() + sql = createSQLProvider.createReportSQL() _, err = m.db.Exec(sql) return err } -func (m *DbMigration) lastMigrationDate() (string, error) { +func (m *dbMigration) lastMigrationDate() (string, error) { sql := `SELECT max(created_at) as latest_migration FROM migrations WHERE deleted_at IS NULL` @@ -193,7 +193,7 @@ func (m *DbMigration) lastMigrationDate() (string, error) { return maxdate, err } -func (m *DbMigration) setSqlBindingParameter(driverType string) { +func (m *dbMigration) setSQLBindingParameter(driverType string) { if driverType == dbTypePostgres { m.sqlBindingParameter = "$" @@ -203,7 +203,7 @@ func (m *DbMigration) setSqlBindingParameter(driverType string) { m.sqlBindingParameter = "?" } -func (m *DbMigration) getBindingParameter(index int) string { +func (m *dbMigration) getBindingParameter(index int) string { if m.sqlBindingParameter == "?" { return "?" } @@ -211,11 +211,11 @@ func (m *DbMigration) getBindingParameter(index int) string { return fmt.Sprintf("$%d", index) } -func (m *DbMigration) diverType() (string, error) { +func (m *dbMigration) diverType() (string, error) { driverType := reflect.TypeOf(m.db.Driver()).String() if strings.Contains(driverType, "mysql") { - return dbTypeMySql, nil + return dbTypeMySQL, nil } if strings.Contains(driverType, "pq") || strings.Contains(driverType, "postgres") { @@ -233,16 +233,16 @@ func (m *DbMigration) diverType() (string, error) { return "", fmt.Errorf("the driver used %s does not match any known dirver by the application", driverType) } -func (m *DbMigration) getJsonFileName() string { +func (m *dbMigration) getJSONFileName() string { // dummy, not used in db version, need due to interface return "" } -func (m *DbMigration) SetJsonFilePath(filePath string) { +func (m *dbMigration) SetJSONFilePath(_ string) { // dummy, not used in db version, need due to interface } -func (m *DbMigration) AddToMigrationReport(fileName string, errorToLog error) error { +func (m *dbMigration) AddToMigrationReport(fileName string, errorToLog error) error { sql := fmt.Sprintf(`INSERT INTO migration_reports (file_name, created_at, result_status, message) VALUES (%s, %s, %s, %s)`, @@ -266,7 +266,7 @@ func (m *DbMigration) AddToMigrationReport(fileName string, errorToLog error) er return err } -func (m *DbMigration) Report() (string, error) { +func (m *dbMigration) Report() (string, error) { rows, err := m.db.Query(`SELECT file_name, created_at, result_status, message FROM migration_reports`) if err != nil { return "", err diff --git a/migration_json_provider.go b/migration_json_provider.go index c802914..c031f38 100644 --- a/migration_json_provider.go +++ b/migration_json_provider.go @@ -10,8 +10,8 @@ import ( "time" ) -const migrationJsonFileName = "./migrations/migrations.json" -const migrationJsonReportFileName = "./migrations/migration_report.json" +const migrationJSONFileName = "./migrations/migrations.json" +const migrationJSONReportFileName = "./migrations/migration_report.json" type jsonMigration struct { data map[string]string @@ -27,7 +27,7 @@ type jsonMigrationReport struct { Message string `json:"message"` } -func newJsonMigration() (*jsonMigration, error) { +func newJSONMigration() (*jsonMigration, error) { jsonMigration := &jsonMigration{} jsonMigration.resetDate() err := jsonMigration.loadMigrationFile() @@ -62,7 +62,7 @@ func (m *jsonMigration) migrations(isLatest bool) ([]string, error) { func (m *jsonMigration) loadMigrationFile() error { m.data = make(map[string]string) - jsonFileName := m.getJsonFileName() + jsonFileName := m.getJSONFileName() if !fileExists(jsonFileName) { return nil } @@ -86,7 +86,7 @@ func (m *jsonMigration) saveMigrationFile() error { return err } - jsonFileName := m.getJsonFileName() + jsonFileName := m.getJSONFileName() return ioutil.WriteFile(jsonFileName, jsonData, 0644) } @@ -106,29 +106,29 @@ func (m *jsonMigration) migrationExistsForFile(fileName string) (bool, error) { return m.data[fileName] != "", nil } -func (m *jsonMigration) getJsonFileName() string { +func (m *jsonMigration) getJSONFileName() string { if m.jsonFileName == "" { - return migrationJsonFileName + return migrationJSONFileName } return m.jsonFileName } -func (m *jsonMigration) getJsonReportFileName() string { +func (m *jsonMigration) getJSONReportFileName() string { if m.jsonFileName == "" { - return migrationJsonReportFileName + return migrationJSONReportFileName } return m.jsonReporFileName } -func (m *jsonMigration) SetJsonFilePath(filePath string) { +func (m *jsonMigration) SetJSONFilePath(filePath string) { m.jsonFileName = filePath + "/migrations.json" m.jsonReporFileName = filePath + "/migration_reports.json" } func (m *jsonMigration) AddToMigrationReport(fileName string, errorToLog error) error { - storeFileName := m.getJsonReportFileName() + storeFileName := m.getJSONReportFileName() message := "ok" status := "success" if errorToLog != nil { @@ -171,7 +171,7 @@ func (m *jsonMigration) AddToMigrationReport(fileName string, errorToLog error) } func (m *jsonMigration) Report() (string, error) { - storeFileName := m.getJsonReportFileName() + storeFileName := m.getJSONReportFileName() _, err := os.Stat(storeFileName) if os.IsNotExist(err) { diff --git a/migration_provider.go b/migration_provider.go index fa0d7db..60552ca 100644 --- a/migration_provider.go +++ b/migration_provider.go @@ -5,22 +5,25 @@ import ( "fmt" ) -type MigrationProvider interface { +type migrationProvider interface { migrations(bool) ([]string, error) addToMigration(string) error removeFromMigration(string) error migrationExistsForFile(string) (bool, error) resetDate() - getJsonFileName() string - SetJsonFilePath(string) + getJSONFileName() string + SetJSONFilePath(string) AddToMigrationReport(string, error) error Report() (string, error) } -func NewMigrationProvider(providerType string, db *sql.DB) (MigrationProvider, error) { +// NewMigrationProvider returns a migration provider, which follows the provider type +// The provider type can be json or db, error returned if the type incorrectly provided +// db should be your database *sql.DB, which can be MySQL, Postgres, Sqlite or Firebird +func NewMigrationProvider(providerType string, db *sql.DB) (migrationProvider, error) { switch providerType { case "json": - return newJsonMigration() + return newJSONMigration() case "db": return newDbMigration(db) default: diff --git a/migration_table_firebird_provider.go b/migration_table_firebird_provider.go index 3b37ce6..031b1d5 100644 --- a/migration_table_firebird_provider.go +++ b/migration_table_firebird_provider.go @@ -1,9 +1,9 @@ package migrator -type FirebirdMigrationTableSqlProvider struct { +type firebirdMigrationTableSQLProvider struct { } -func (p *FirebirdMigrationTableSqlProvider) CreateMigrationSql() string { +func (p *firebirdMigrationTableSQLProvider) createMigrationSQL() string { return `EXECUTE BLOCK AS BEGIN if (not exists(select 1 from rdb$relations where rdb$relation_name = 'MIGRATIONS')) then execute statement 'CREATE TABLE MIGRATIONS ( @@ -13,7 +13,7 @@ func (p *FirebirdMigrationTableSqlProvider) CreateMigrationSql() string { END` } -func (p *FirebirdMigrationTableSqlProvider) CreateReportSql() string { +func (p *firebirdMigrationTableSQLProvider) createReportSQL() string { return `EXECUTE BLOCK AS BEGIN if (not exists(select 1 from rdb$relations where rdb$relation_name = 'MIGRATION_REPORTS')) then execute statement 'CREATE TABLE MIGRATION_REPORTS ( diff --git a/migration_table_mysql_provider.go b/migration_table_mysql_provider.go index b81b93c..d80e529 100644 --- a/migration_table_mysql_provider.go +++ b/migration_table_mysql_provider.go @@ -1,9 +1,9 @@ package migrator -type MySqlMigrationTableSqlProvider struct { +type mySQLMigrationTableSQLProvider struct { } -func (p *MySqlMigrationTableSqlProvider) CreateMigrationSql() string { +func (p *mySQLMigrationTableSQLProvider) createMigrationSQL() string { return `CREATE TABLE IF NOT EXISTS migrations ( file_name VARCHAR(255), created_at DATETIME, @@ -11,7 +11,7 @@ func (p *MySqlMigrationTableSqlProvider) CreateMigrationSql() string { )` } -func (p *MySqlMigrationTableSqlProvider) CreateReportSql() string { +func (p *mySQLMigrationTableSQLProvider) createReportSQL() string { return `CREATE TABLE IF NOT EXISTS migration_reports ( file_name VARCHAR(255), result_status VARCHAR(12), diff --git a/migration_table_postgres_provider.go b/migration_table_postgres_provider.go index 5695a62..0073995 100644 --- a/migration_table_postgres_provider.go +++ b/migration_table_postgres_provider.go @@ -1,9 +1,9 @@ package migrator -type PostgresMigrationTableSqlProvider struct { +type postgresMigrationTableSQLProvider struct { } -func (p *PostgresMigrationTableSqlProvider) CreateMigrationSql() string { +func (p *postgresMigrationTableSQLProvider) createMigrationSQL() string { return `CREATE TABLE IF NOT EXISTS migrations ( file_name VARCHAR(255), created_at TIMESTAMP, @@ -11,7 +11,7 @@ func (p *PostgresMigrationTableSqlProvider) CreateMigrationSql() string { )` } -func (p *PostgresMigrationTableSqlProvider) CreateReportSql() string { +func (p *postgresMigrationTableSQLProvider) createReportSQL() string { return `CREATE TABLE IF NOT EXISTS migration_reports ( file_name VARCHAR(255), result_status VARCHAR(12), diff --git a/migration_table_sql_provider.go b/migration_table_sql_provider.go index acae9eb..eec035d 100644 --- a/migration_table_sql_provider.go +++ b/migration_table_sql_provider.go @@ -2,21 +2,21 @@ package migrator import "fmt" -type MigrationTableSqlProvider interface { - CreateMigrationSql() string - CreateReportSql() string +type migrationTableSQLProvider interface { + createMigrationSQL() string + createReportSQL() string } -func MigrationTableProviderByDriverName(driverName string) (MigrationTableSqlProvider, error) { +func migrationTableProviderByDriverName(driverName string) (migrationTableSQLProvider, error) { switch driverName { case dbTypeSqlite: - return &SqliteMigrationTableSqlProvider{}, nil + return &sqliteMigrationTableSQLProvider{}, nil case dbTypePostgres: - return &PostgresMigrationTableSqlProvider{}, nil - case dbTypeMySql: - return &MySqlMigrationTableSqlProvider{}, nil + return &postgresMigrationTableSQLProvider{}, nil + case dbTypeMySQL: + return &mySQLMigrationTableSQLProvider{}, nil case dbTypeFirebird: - return &FirebirdMigrationTableSqlProvider{}, nil + return &firebirdMigrationTableSQLProvider{}, nil default: return nil, fmt.Errorf("provider %s does not exists", driverName) } diff --git a/migration_table_sqlite_provider.go b/migration_table_sqlite_provider.go index 479e40d..8ea0b3e 100644 --- a/migration_table_sqlite_provider.go +++ b/migration_table_sqlite_provider.go @@ -1,9 +1,9 @@ package migrator -type SqliteMigrationTableSqlProvider struct { +type sqliteMigrationTableSQLProvider struct { } -func (p *SqliteMigrationTableSqlProvider) CreateMigrationSql() string { +func (p *sqliteMigrationTableSQLProvider) createMigrationSQL() string { return `CREATE TABLE IF NOT EXISTS migrations ( file_name VARCHAR(255), created_at DATETIME, @@ -11,7 +11,7 @@ func (p *SqliteMigrationTableSqlProvider) CreateMigrationSql() string { )` } -func (p *SqliteMigrationTableSqlProvider) CreateReportSql() string { +func (p *sqliteMigrationTableSQLProvider) createReportSQL() string { return `CREATE TABLE IF NOT EXISTS migration_reports ( file_name VARCHAR(255), result_status VARCHAR(12), diff --git a/migrator.go b/migrator.go index 06567bb..c8580fb 100644 --- a/migrator.go +++ b/migrator.go @@ -1,3 +1,4 @@ +// Package migrator is a lightweight database migrator package, pass only the *sql.DB, migrate, rollback and migration reports package migrator import ( @@ -7,18 +8,20 @@ import ( "time" ) +// Rollback rolls back last migrated items or all if count is 0 func Rollback( db *sql.DB, - migrationProvider MigrationProvider, + migrationProvider migrationProvider, migrationFilePath string, count int, ) error { return rollback(db, migrationProvider, migrationFilePath, count, false) } +// Refresh runs a full rollback and migrate again func Refresh( db *sql.DB, - migrationProvider MigrationProvider, + migrationProvider migrationProvider, migrationFilePath string, ) error { err := rollback(db, migrationProvider, migrationFilePath, 0, true) @@ -29,16 +32,17 @@ func Refresh( return Migrate(db, migrationProvider, migrationFilePath, 0) } +// Migrate execute migrations func Migrate( db *sql.DB, - migrationProvider MigrationProvider, + migrationProvider migrationProvider, migrationFilePath string, count int, ) error { m := newMigrator(db) m.migrationFilePath = migrationFilePath m.migrationProvider = migrationProvider - m.migrationProvider.SetJsonFilePath(migrationFilePath) + m.migrationProvider.SetJSONFilePath(migrationFilePath) m.migrationProvider.resetDate() fileNames, err := m.orderedMigrationFiles() if err != nil { @@ -52,7 +56,7 @@ func Migrate( break } } - migrated, err := m.executeSqlFile(fileName) + migrated, err := m.executeSQLFile(fileName) if err != nil { return err } @@ -67,19 +71,21 @@ func Migrate( return nil } +// Report return a report of the alredy executed migrations func Report( db *sql.DB, - migrationProvider MigrationProvider, + migrationProvider migrationProvider, migrationFilePath string, ) (string, error) { m := newMigrator(db) m.migrationFilePath = migrationFilePath m.migrationProvider = migrationProvider - m.migrationProvider.SetJsonFilePath(migrationFilePath) + m.migrationProvider.SetJSONFilePath(migrationFilePath) return m.migrationProvider.Report() } +// AddNewMigrationFiles adds a new blank migration file and a rollback file func AddNewMigrationFiles(migrationFilePath, customText string) error { var err error err = createNewMigrationFiles(migrationFilePath, customText, false) @@ -96,7 +102,7 @@ func AddNewMigrationFiles(migrationFilePath, customText string) error { func rollback( db *sql.DB, - migrationProvider MigrationProvider, + migrationProvider migrationProvider, migrationFilePath string, count int, isCompleteRollback bool, @@ -105,7 +111,7 @@ func rollback( m := newMigrator(db) m.migrationFilePath = migrationFilePath m.migrationProvider = migrationProvider - m.migrationProvider.SetJsonFilePath(migrationFilePath) + m.migrationProvider.SetJSONFilePath(migrationFilePath) migrations, err := m.migrationProvider.migrations(!isCompleteRollback) if err != nil { return err @@ -123,7 +129,7 @@ func rollback( } } - err = m.executeRollbackSqlFile(fileName) + err = m.executeRollbackSQLFile(fileName) if err != nil { return err } diff --git a/test/migrator_report_json_test.go b/test/migrator_report_json_test.go index 4675925..480bb04 100644 --- a/test/migrator_report_json_test.go +++ b/test/migrator_report_json_test.go @@ -33,7 +33,7 @@ func (t *ReportTestSuite) TestDBMigratorMigrateAllTables() { t.Nil(err) migrationProvider, err := migrator.NewMigrationProvider("json", t.db) - migrationProvider.SetJsonFilePath(testFixtureFolder) + migrationProvider.SetJSONFilePath(testFixtureFolder) t.Nil(err) report, err := migrationProvider.Report()