Skip to content

Commit

Permalink
Huge refactor, hiding non public properties and functions
Browse files Browse the repository at this point in the history
  • Loading branch information
olbrichattila committed May 29, 2024
1 parent 6be20eb commit 473ec69
Show file tree
Hide file tree
Showing 13 changed files with 101 additions and 85 deletions.
14 changes: 10 additions & 4 deletions cmd/cmd-run.go.bak
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// main package is for testing only locally, rename it to use
package main

// rename cmd.go to bak, then rename this to .go and test your changes
Expand All @@ -14,9 +15,13 @@ import (

func main() {
migrationFilePath := "./migrations"
dbType, provider, function, count := params()
dbType, provider, function, count, add := params()

fmt.Printf("Running with %s, %s, %s, %d\n", dbType, provider, function, count)
fmt.Printf("Running with %s, %s, %s, %d %s\n", dbType, provider, function, count, add)

if add != "" {
migrator.AddNewMigrationFiles(add, "")
}

db, err := getConnection(dbType)
if err != nil {
Expand Down Expand Up @@ -55,14 +60,15 @@ func main() {
}
}

func params() (string, string, string, int) {
func params() (string, string, string, int, string) {
db := flag.String("db", "sqlite", "Database driver name")
function := flag.String("function", "migrate", "function=migrate/rollback/report")
count := flag.Int("count", 0, "count=1")
provider := flag.String("provider", "db", "provider=db/json")
add := flag.String("add", "", "--add=filename")
flag.Parse()

return *db, *provider, *function, *count
return *db, *provider, *function, *count, *add
}

func connectToFirebaseDatabase() (*sql.DB, error) {
Expand Down
1 change: 1 addition & 0 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Main package is to display basic package info
package main

import "fmt"
Expand Down
12 changes: 6 additions & 6 deletions migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

type migration struct {
db *sql.DB
migrationProvider MigrationProvider
migrationProvider migrationProvider
migrationFilePath string
}

Expand Down Expand Up @@ -41,7 +41,7 @@ func (m *migration) orderedMigrationFiles() ([]string, error) {
return fileNames, nil
}

func (m *migration) executeSqlFile(fileName string) (bool, error) {
func (m *migration) executeSQLFile(fileName string) (bool, error) {
exists, err := m.migrationProvider.migrationExistsForFile(fileName)
if err != nil {
return false, err
Expand All @@ -57,7 +57,7 @@ func (m *migration) executeSqlFile(fileName string) (bool, error) {
return false, err
}

err = m.executeSql(string(content))
err = m.executeSQL(string(content))
if err == nil {
err = m.migrationProvider.addToMigration(fileName)
if err != nil {
Expand All @@ -70,7 +70,7 @@ func (m *migration) executeSqlFile(fileName string) (bool, error) {
return true, err
}

func (m *migration) executeRollbackSqlFile(fileName string) error {
func (m *migration) executeRollbackSQLFile(fileName string) error {
rollbackFileName, err := m.resolveRollbackFile(fileName)
if err != nil {
fmt.Printf("Skip rollback for %s as rollback file does not exists\n", fileName)
Expand All @@ -88,7 +88,7 @@ func (m *migration) executeRollbackSqlFile(fileName string) error {
return err
}

err = m.executeSql(string(content))
err = m.executeSQL(string(content))
if err == nil {
err = m.migrationProvider.removeFromMigration(fileName)
if err != nil {
Expand All @@ -101,7 +101,7 @@ func (m *migration) executeRollbackSqlFile(fileName string) error {
return err
}

func (m *migration) executeSql(sql string) error {
func (m *migration) executeSQL(sql string) error {
tx, err := m.db.Begin()
if err != nil {
return err
Expand Down
52 changes: 26 additions & 26 deletions migration_db_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (
const (
dbTypeSqlite = "sqlite"
dbTypePostgres = "pg"
dbTypeMySql = "mysql"
dbTypeMySQL = "mysql"
dbTypeFirebird = "firebird"
)

type DbMigration struct {
type dbMigration struct {
db *sql.DB
timeString string
sqlBindingParameter string
Expand All @@ -29,31 +29,31 @@ type reportRow struct {
Message string
}

func newDbMigration(db *sql.DB) (*DbMigration, error) {
dbMigration := &DbMigration{db: db}
func newDbMigration(db *sql.DB) (*dbMigration, error) {
dbMigration := &dbMigration{db: db}
dbMigration.resetDate()
driverType, err := dbMigration.diverType()
if err != nil {
return nil, err
}
dbMigration.setSqlBindingParameter(driverType)
createSqlProvider, err := MigrationTableProviderByDriverName(driverType)
dbMigration.setSQLBindingParameter(driverType)
createSQLProvider, err := migrationTableProviderByDriverName(driverType)
if err != nil {
return nil, err
}

err = dbMigration.Init(createSqlProvider)
err = dbMigration.init(createSQLProvider)
if err != nil {
return nil, err
}
return dbMigration, nil
}

func (m *DbMigration) resetDate() {
func (m *dbMigration) resetDate() {
m.timeString = time.Now().Format("2006-01-02 15:04:05")
}

func (m *DbMigration) migrations(isLatest bool) ([]string, error) {
func (m *dbMigration) migrations(isLatest bool) ([]string, error) {
var migrationList []string
var rows *sql.Rows
var err error
Expand Down Expand Up @@ -94,7 +94,7 @@ func (m *DbMigration) migrations(isLatest bool) ([]string, error) {
return migrationList, nil
}

func (m *DbMigration) latestMigrations(lastMigrationDate string) (*sql.Rows, error) {
func (m *dbMigration) latestMigrations(lastMigrationDate string) (*sql.Rows, error) {
return m.db.Query(fmt.Sprintf(
`SELECT file_name
FROM migrations
Expand All @@ -105,7 +105,7 @@ func (m *DbMigration) latestMigrations(lastMigrationDate string) (*sql.Rows, err
), lastMigrationDate)
}

func (m *DbMigration) allMigrations() (*sql.Rows, error) {
func (m *dbMigration) allMigrations() (*sql.Rows, error) {
return m.db.Query(
`SELECT file_name
FROM migrations
Expand All @@ -114,7 +114,7 @@ func (m *DbMigration) allMigrations() (*sql.Rows, error) {
)
}

func (m *DbMigration) addToMigration(fileName string) error {
func (m *dbMigration) addToMigration(fileName string) error {
sql := fmt.Sprintf(`INSERT INTO migrations
(file_name, created_at)
VALUES (%s, %s)`,
Expand All @@ -128,7 +128,7 @@ func (m *DbMigration) addToMigration(fileName string) error {

}

func (m *DbMigration) removeFromMigration(fileName string) error {
func (m *dbMigration) removeFromMigration(fileName string) error {
sql := fmt.Sprintf(`UPDATE migrations
SET deleted_at = %s
WHERE file_name = %s
Expand All @@ -142,7 +142,7 @@ func (m *DbMigration) removeFromMigration(fileName string) error {
return err
}

func (m *DbMigration) migrationExistsForFile(fileName string) (bool, error) {
func (m *dbMigration) migrationExistsForFile(fileName string) (bool, error) {
sql := fmt.Sprintf(`SELECT count(*) as cnt
FROM migrations
WHERE file_name = %s
Expand All @@ -167,21 +167,21 @@ func (m *DbMigration) migrationExistsForFile(fileName string) (bool, error) {
return cnt > 0, nil
}

func (m *DbMigration) Init(createSqlProvider MigrationTableSqlProvider) error {
sql := createSqlProvider.CreateMigrationSql()
func (m *dbMigration) init(createSQLProvider migrationTableSQLProvider) error {
sql := createSQLProvider.createMigrationSQL()

_, err := m.db.Exec(sql)
if err != nil {
return err
}

sql = createSqlProvider.CreateReportSql()
sql = createSQLProvider.createReportSQL()
_, err = m.db.Exec(sql)

return err
}

func (m *DbMigration) lastMigrationDate() (string, error) {
func (m *dbMigration) lastMigrationDate() (string, error) {
sql := `SELECT max(created_at) as latest_migration
FROM migrations
WHERE deleted_at IS NULL`
Expand All @@ -193,7 +193,7 @@ func (m *DbMigration) lastMigrationDate() (string, error) {
return maxdate, err
}

func (m *DbMigration) setSqlBindingParameter(driverType string) {
func (m *dbMigration) setSQLBindingParameter(driverType string) {
if driverType == dbTypePostgres {
m.sqlBindingParameter = "$"

Expand All @@ -203,19 +203,19 @@ func (m *DbMigration) setSqlBindingParameter(driverType string) {
m.sqlBindingParameter = "?"
}

func (m *DbMigration) getBindingParameter(index int) string {
func (m *dbMigration) getBindingParameter(index int) string {
if m.sqlBindingParameter == "?" {
return "?"
}

return fmt.Sprintf("$%d", index)
}

func (m *DbMigration) diverType() (string, error) {
func (m *dbMigration) diverType() (string, error) {
driverType := reflect.TypeOf(m.db.Driver()).String()

if strings.Contains(driverType, "mysql") {
return dbTypeMySql, nil
return dbTypeMySQL, nil
}

if strings.Contains(driverType, "pq") || strings.Contains(driverType, "postgres") {
Expand All @@ -233,16 +233,16 @@ func (m *DbMigration) diverType() (string, error) {
return "", fmt.Errorf("the driver used %s does not match any known dirver by the application", driverType)
}

func (m *DbMigration) getJsonFileName() string {
func (m *dbMigration) getJSONFileName() string {
// dummy, not used in db version, need due to interface
return ""
}

func (m *DbMigration) SetJsonFilePath(filePath string) {
func (m *dbMigration) SetJSONFilePath(_ string) {
// dummy, not used in db version, need due to interface
}

func (m *DbMigration) AddToMigrationReport(fileName string, errorToLog error) error {
func (m *dbMigration) AddToMigrationReport(fileName string, errorToLog error) error {
sql := fmt.Sprintf(`INSERT INTO migration_reports
(file_name, created_at, result_status, message)
VALUES (%s, %s, %s, %s)`,
Expand All @@ -266,7 +266,7 @@ func (m *DbMigration) AddToMigrationReport(fileName string, errorToLog error) er
return err
}

func (m *DbMigration) Report() (string, error) {
func (m *dbMigration) Report() (string, error) {
rows, err := m.db.Query(`SELECT file_name, created_at, result_status, message FROM migration_reports`)
if err != nil {
return "", err
Expand Down
24 changes: 12 additions & 12 deletions migration_json_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
"time"
)

const migrationJsonFileName = "./migrations/migrations.json"
const migrationJsonReportFileName = "./migrations/migration_report.json"
const migrationJSONFileName = "./migrations/migrations.json"
const migrationJSONReportFileName = "./migrations/migration_report.json"

type jsonMigration struct {
data map[string]string
Expand All @@ -27,7 +27,7 @@ type jsonMigrationReport struct {
Message string `json:"message"`
}

func newJsonMigration() (*jsonMigration, error) {
func newJSONMigration() (*jsonMigration, error) {
jsonMigration := &jsonMigration{}
jsonMigration.resetDate()
err := jsonMigration.loadMigrationFile()
Expand Down Expand Up @@ -62,7 +62,7 @@ func (m *jsonMigration) migrations(isLatest bool) ([]string, error) {

func (m *jsonMigration) loadMigrationFile() error {
m.data = make(map[string]string)
jsonFileName := m.getJsonFileName()
jsonFileName := m.getJSONFileName()
if !fileExists(jsonFileName) {
return nil
}
Expand All @@ -86,7 +86,7 @@ func (m *jsonMigration) saveMigrationFile() error {
return err
}

jsonFileName := m.getJsonFileName()
jsonFileName := m.getJSONFileName()
return ioutil.WriteFile(jsonFileName, jsonData, 0644)
}

Expand All @@ -106,29 +106,29 @@ func (m *jsonMigration) migrationExistsForFile(fileName string) (bool, error) {
return m.data[fileName] != "", nil
}

func (m *jsonMigration) getJsonFileName() string {
func (m *jsonMigration) getJSONFileName() string {
if m.jsonFileName == "" {
return migrationJsonFileName
return migrationJSONFileName
}

return m.jsonFileName
}

func (m *jsonMigration) getJsonReportFileName() string {
func (m *jsonMigration) getJSONReportFileName() string {
if m.jsonFileName == "" {
return migrationJsonReportFileName
return migrationJSONReportFileName
}

return m.jsonReporFileName
}

func (m *jsonMigration) SetJsonFilePath(filePath string) {
func (m *jsonMigration) SetJSONFilePath(filePath string) {
m.jsonFileName = filePath + "/migrations.json"
m.jsonReporFileName = filePath + "/migration_reports.json"
}

func (m *jsonMigration) AddToMigrationReport(fileName string, errorToLog error) error {
storeFileName := m.getJsonReportFileName()
storeFileName := m.getJSONReportFileName()
message := "ok"
status := "success"
if errorToLog != nil {
Expand Down Expand Up @@ -171,7 +171,7 @@ func (m *jsonMigration) AddToMigrationReport(fileName string, errorToLog error)
}

func (m *jsonMigration) Report() (string, error) {
storeFileName := m.getJsonReportFileName()
storeFileName := m.getJSONReportFileName()

_, err := os.Stat(storeFileName)
if os.IsNotExist(err) {
Expand Down
13 changes: 8 additions & 5 deletions migration_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@ import (
"fmt"
)

type MigrationProvider interface {
type migrationProvider interface {
migrations(bool) ([]string, error)
addToMigration(string) error
removeFromMigration(string) error
migrationExistsForFile(string) (bool, error)
resetDate()
getJsonFileName() string
SetJsonFilePath(string)
getJSONFileName() string
SetJSONFilePath(string)
AddToMigrationReport(string, error) error
Report() (string, error)
}

func NewMigrationProvider(providerType string, db *sql.DB) (MigrationProvider, error) {
// NewMigrationProvider returns a migration provider, which follows the provider type
// The provider type can be json or db, error returned if the type incorrectly provided
// db should be your database *sql.DB, which can be MySQL, Postgres, Sqlite or Firebird
func NewMigrationProvider(providerType string, db *sql.DB) (migrationProvider, error) {
switch providerType {
case "json":
return newJsonMigration()
return newJSONMigration()
case "db":
return newDbMigration(db)
default:
Expand Down
Loading

0 comments on commit 473ec69

Please sign in to comment.