diff --git a/main.go b/main.go index a08f472..d559c5f 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/afrase/mysqldumpsplit/msds" + "time" ) type config struct { @@ -37,12 +38,10 @@ func parseFlags() *config { func main() { conf := parseFlags() - - readCh := make(chan string) - tableNameCh := make(chan string) - tableDataCh := make(chan string) - tableSchemeCh := make(chan string) - doneCh := make(chan bool) + if conf.InputFile == "" { + flag.PrintDefaults() + os.Exit(0) + } file, err := msds.OpenFile(conf.InputFile) if err != nil { @@ -50,23 +49,38 @@ func main() { os.Exit(1) } - info, _ := file.Stat() - fmt.Printf("Original file size %s\n", msds.StringifyFileSize(info.Size())) - fmt.Printf("Outputing all tables to %s\n", conf.OutputPath) + bus := msds.ChannelBus{ + Finished: make(chan bool), + Log: make(chan string), + TableData: make(chan string), + TableScheme: make(chan string), + TableName: make(chan string), + CurrentLine: make(chan string), + } + + go msds.Logger(bus) + + bus.Log <- fmt.Sprintf("outputing all tables to %s\n", conf.OutputPath) if len(conf.Skip) > 0 { - fmt.Printf("Skiping data from tables %s\n", strings.Join(conf.Skip, ", ")) + bus.Log <- fmt.Sprintf("skiping data from tables %s\n", strings.Join(conf.Skip, ", ")) } - fmt.Printf("Begin processing %s\n\n", conf.InputFile) + start := time.Now() + bus.Log <- fmt.Sprintf("begin processing %s\n", conf.InputFile) // create a pipeline of goroutines - go msds.Producer(file, readCh) - go msds.Consumer(readCh, tableNameCh, tableSchemeCh, tableDataCh) - go msds.Writer(conf.OutputPath, conf.Skip, tableNameCh, tableSchemeCh, tableDataCh, doneCh) + go msds.LineReader(file, bus) + go msds.LineParser(bus, conf.Combine) + go msds.Writer(conf.OutputPath, conf.Skip, bus) // wait for the writer to finish. - <-doneCh + <-bus.Finished if conf.Combine { - msds.CombineFiles(conf.CombineFilePath, conf.OutputPath) + msds.CombineFiles(conf.CombineFilePath, conf.OutputPath, bus) } + + bus.Log <- fmt.Sprintf("finished in %s", time.Now().Sub(start)) + bus.Log <- "" + close(bus.Log) + close(bus.Finished) } diff --git a/msds/mysqldumpsplit.go b/msds/mysqldumpsplit.go index ced5536..61a8b29 100644 --- a/msds/mysqldumpsplit.go +++ b/msds/mysqldumpsplit.go @@ -5,47 +5,35 @@ import ( "compress/gzip" "fmt" "io/ioutil" + "log" "os" "path" "path/filepath" "strings" + "time" ) -const ( - sentinelString = "****SENTINEL-STRING****" - headerData = `/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */; -/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */; -/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */; -/*!40101 SET NAMES utf8 */; -/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */; -/*!40103 SET TIME_ZONE='+00:00' */; -/*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */; -/*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */; -/*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */; -/*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */; -` -) +const sentinelString = "****SENTINEL-STRING****" -func checkBytes(b *bufio.Reader, buf []byte) bool { - m, err := b.Peek(len(buf)) - if err != nil { - return false - } - for i := range buf { - if m[i] != buf[i] { - return false - } - } - return true +// ChannelBus a struct to hold all channels used by the different go routines +type ChannelBus struct { + Finished chan bool + Log chan string + CurrentLine chan string + TableName chan string + TableScheme chan string + TableData chan string } func isGzip(b *bufio.Reader) bool { - return checkBytes(b, []byte{0x1f, 0x8b}) + if m, err := b.Peek(2); err != nil && m[0] != 0x1f && m[1] != 0x8b { + return false + } + return true } func openReader(f *os.File) *bufio.Reader { pageSize := os.Getpagesize() * 2 - fmt.Println(pageSize) buf := bufio.NewReaderSize(f, pageSize) if isGzip(buf) { gbuf, _ := gzip.NewReader(buf) @@ -54,73 +42,87 @@ func openReader(f *os.File) *bufio.Reader { return buf } -// Producer reads `file` line-by-line and adds it to the `readCh` channel. +// LineReader reads `file` line-by-line and adds it to the `bus.CurrentLine` channel. // Note: This function closes `file`. -func Producer(file *os.File, readCh chan string) { +func LineReader(file *os.File, bus ChannelBus) { r := openReader(file) for line, err := r.ReadString('\n'); err == nil; line, err = r.ReadString('\n') { - readCh <- line + bus.CurrentLine <- line } file.Close() - close(readCh) + close(bus.CurrentLine) } -// Consumer splits the file up and fills the different channels. -func Consumer(readCh, tableNameCh, tableSchemeCh, tableDataCh chan string) { - onTableScheme, onTableData := false, false - for line := range readCh { +// LineParser reads the CurrentLine and figures out which channel to put it in. +func LineParser(bus ChannelBus, combineFiles bool) { + onTableScheme, onTableData, pastHeader := false, false, false + headerMetaData := fmt.Sprintf("-- Generated with mysqldumpsplit on %s\n\n", time.Now()) + for line := range bus.CurrentLine { + // The beginning of a mysqldump has some flags at the top of the file. Capture them into a variable. + if !pastHeader && strings.Contains(line, "/*!40") { + headerMetaData += line + } + if strings.Contains(line, "Table structure for table") { onTableScheme, onTableData = true, false tableName := strings.Replace(line, "-- Table structure for table ", "", 1) - tableNameCh <- strings.TrimSpace(strings.Replace(tableName, "`", "", -1)) - tableSchemeCh <- "--\n" - tableSchemeCh <- line + bus.TableName <- strings.TrimSpace(strings.Replace(tableName, "`", "", -1)) + // add headers to each file unless we are combining all of them into 1 file. + if !combineFiles { + bus.TableScheme <- headerMetaData + } else if !pastHeader { + // add the meta data to only the first table. + bus.TableScheme <- headerMetaData + } + + pastHeader = true + bus.TableScheme <- "\n--\n" + line } else if strings.Contains(line, "LOCK TABLES `") { onTableData, onTableScheme = true, false - tableDataCh <- line + bus.TableData <- line } else { if onTableScheme { - tableSchemeCh <- line + bus.TableScheme <- line } if onTableData { - tableDataCh <- line + bus.TableData <- line } if strings.Contains(line, "-- Dumping data for table") { onTableScheme = false - tableSchemeCh <- "--\n" - tableSchemeCh <- sentinelString + bus.TableScheme <- "--\n" + bus.TableScheme <- sentinelString } else if strings.Contains(line, "UNLOCK TABLES;") { onTableData = false - tableDataCh <- sentinelString + bus.TableData <- sentinelString } } } - close(tableNameCh) - close(tableDataCh) - close(tableSchemeCh) + close(bus.TableName) + close(bus.TableData) + close(bus.TableScheme) } // Writer writes the data from the different channels to different files. -func Writer(outputDir string, skipTables []string, tableNameCh, tableSchemeCh, tableDataCh chan string, doneCh chan bool) { +func Writer(outputDir string, skipTables []string, bus ChannelBus) { os.Mkdir(outputDir, os.ModePerm) numTables := 0 - for tableName := range tableNameCh { - fmt.Printf("extracting table: %s\n", tableName) + for tableName := range bus.TableName { + bus.Log <- fmt.Sprintf("extracting table: %s\n", tableName) numTables++ tablePath := filepath.Join(outputDir, tableName+".sql") tableFile, _ := os.Create(tablePath) - for tableData := range tableSchemeCh { + for tableData := range bus.TableScheme { if tableData == sentinelString { break } tableFile.WriteString(tableData) } - for tableData := range tableDataCh { + for tableData := range bus.TableData { if tableData == sentinelString { break } @@ -131,18 +133,17 @@ func Writer(outputDir string, skipTables []string, tableNameCh, tableSchemeCh, t } tableFile.Close() } - fmt.Printf("\nExtracted %d tables\n", numTables) - doneCh <- true + bus.Log <- fmt.Sprintf("extracted %d tables\n", numTables) + bus.Finished <- true } // CombineFiles combines all files ina directory into a single file -func CombineFiles(filePath, outputDir string) { +func CombineFiles(filePath, outputDir string, bus ChannelBus) { combineFile, _ := os.Create(filePath) - combineFile.WriteString(headerData) cleanUpOutputDir := true files, _ := ioutil.ReadDir(outputDir) - fmt.Printf("Combining all %d files into %s\n", len(files), filePath) + bus.Log <- fmt.Sprintf("Combining all %d files into %s\n", len(files), filePath) for _, file := range files { fullPath := path.Join(outputDir, file.Name()) @@ -160,16 +161,26 @@ func CombineFiles(filePath, outputDir string) { // write a newline between each file combineFile.WriteString("\n") // close then delete the table file + fileName := sqlFile.Name() sqlFile.Close() - os.Remove(sqlFile.Name()) + os.Remove(fileName) } info, _ := combineFile.Stat() - fmt.Printf("New file size %s\n", StringifyFileSize(info.Size())) + bus.Log <- fmt.Sprintf("New file size %s\n", StringifyFileSize(info.Size())) combineFile.Close() if cleanUpOutputDir { - fmt.Println("Deleting output directory") + bus.Log <- fmt.Sprintf("Deleting output directory") os.RemoveAll(outputDir) } } + +// Logger reads messages from `bus.Log` and outputs them to the logger. +func Logger(bus ChannelBus) { + for msg := range bus.Log { + if msg != "" { + log.Output(3, msg) + } + } +} diff --git a/msds/mysqldumpsplit_test.go b/msds/mysqldumpsplit_test.go index b357117..c28f14d 100644 --- a/msds/mysqldumpsplit_test.go +++ b/msds/mysqldumpsplit_test.go @@ -1 +1,139 @@ package msds + +import ( + "bufio" + "os" + "reflect" + "testing" +) + +func Test_isGzip(t *testing.T) { + type args struct { + b *bufio.Reader + } + tests := []struct { + name string + args args + want bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isGzip(tt.args.b); got != tt.want { + t.Errorf("isGzip() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_openReader(t *testing.T) { + type args struct { + f *os.File + } + tests := []struct { + name string + args args + want *bufio.Reader + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := openReader(tt.args.f); !reflect.DeepEqual(got, tt.want) { + t.Errorf("openReader() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLineReader(t *testing.T) { + type args struct { + file *os.File + bus ChannelBus + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + LineReader(tt.args.file, tt.args.bus) + }) + } +} + +func TestLineParser(t *testing.T) { + type args struct { + bus ChannelBus + combineFiles bool + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + LineParser(tt.args.bus, tt.args.combineFiles) + }) + } +} + +func TestWriter(t *testing.T) { + type args struct { + outputDir string + skipTables []string + bus ChannelBus + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Writer(tt.args.outputDir, tt.args.skipTables, tt.args.bus) + }) + } +} + +func TestCombineFiles(t *testing.T) { + type args struct { + filePath string + outputDir string + bus ChannelBus + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + CombineFiles(tt.args.filePath, tt.args.outputDir, tt.args.bus) + }) + } +} + +func TestLogger(t *testing.T) { + type args struct { + bus ChannelBus + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Logger(tt.args.bus) + }) + } +}