-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathloader.go
153 lines (132 loc) · 4.27 KB
/
loader.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package gonymizer
import (
"database/sql"
"encoding/csv"
"errors"
"fmt"
"io"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
// LoadFile will load an SQL file into the specified PGConfig.
func LoadFile(conf PGConfig, filePath string) (err error) {
var (
dbExists bool
mainConn *sql.DB
tempDbConf PGConfig
psqlDbConf PGConfig
)
// Build the temporary databadse config. This is where we will load the new data to minimize
// downtime during the reload.
tempDbConf = conf
tempDbConf.DefaultDBName = conf.DefaultDBName + "_gonymizer_loading"
mainConn, err = OpenDB(conf)
if err != nil {
return err
}
defer mainConn.Close()
// It is always good to check to see if a previous version of the gonymizer table still exists
log.Infof("Checking to see if database '%s' exists", tempDbConf.DefaultDBName)
dbExists, err = CheckIfDbExists(mainConn, tempDbConf.DefaultDBName)
if err != nil {
return err
} else if dbExists {
return fmt.Errorf("Found a previous version of the %s database. Is there another copy "+
"of Gonymizer running?", tempDbConf.DefaultDBName)
}
// Create temp database
log.Info("Creating database: ", tempDbConf.DefaultDBName)
if err = CreateDatabase(tempDbConf); err != nil {
log.Error("Unable to create database: ", tempDbConf.DefaultDBName)
return err
}
log.Infof("Reloading database file '%s' -> '%s' ", filePath, tempDbConf.DefaultDBName)
if err = SQLCommandFile(tempDbConf, filePath, true); err != nil {
log.Fatalf("There was an error importing '%s' to: %s", filePath, tempDbConf.DefaultDBName)
return err
}
// Kill all database connections so we can swap the databases
// Reload the database into the new temp db
psqlDbConf = conf
psqlDbConf.DefaultDBName = "postgres"
psqlConn, err := OpenDB(psqlDbConf)
if err != nil {
return err
}
// Kill db connections so we can rename the database
log.Info("Killing all connections on database: ", conf.DefaultDBName)
if err = KillDatabaseConnections(psqlConn, conf.DefaultDBName); err != nil {
log.Error("Unable to kill connections on database: ", psqlDbConf.DefaultDBName)
return err
}
// Rename main database -> old database
oldDbName := conf.DefaultDBName + "_old_" + strconv.FormatInt(time.Now().Unix(), 10)
log.Infof("Renaming database '%s' -> '%s'", conf.DefaultDBName, oldDbName)
if err = RenameDatabase(psqlConn, conf.DefaultDBName, oldDbName); err != nil {
return err
}
// Rename temp database -> main database
log.Infof("Renaming database '%s' -> '%s'", tempDbConf.DefaultDBName, conf.DefaultDBName)
return RenameDatabase(psqlConn, tempDbConf.DefaultDBName, conf.DefaultDBName)
}
// VerifyRowCount will verify that the rowcounts in the PGConfig matches the supplied CSV file (see command/dump)
func VerifyRowCount(conf PGConfig, filePath string) (err error) {
// Load local row counts into a map of maps so we can quickly look up values
dbRowCount := make(map[string]map[string]int)
rowObjs, err := GetTableRowCountsInDB(conf, "", []string{})
if err != nil {
return err
}
for _, row := range *rowObjs {
if len(dbRowCount[*row.SchemaName]) < 1 {
dbRowCount[*row.SchemaName] = make(map[string]int)
dbRowCount[*row.SchemaName][*row.TableName] = *row.Count
} else {
dbRowCount[*row.SchemaName][*row.TableName] = *row.Count
}
}
// No read in CSV file and compare to our DB counts
reader, err := os.OpenFile(filePath, os.O_RDONLY, 0644)
if err != nil {
return err
}
csvReader := csv.NewReader(reader)
lineNum := 1
// Now loop through CSV and verify our count matches the CSV
for {
csvRow, err := csvReader.Read()
if err == io.EOF {
break
} else if err != nil {
return err
}
// Verify we have schema, table, count
if len(csvRow) != 3 {
e := fmt.Sprint("CSV should contain exactly 3 columns, but has ", len(csvRow))
return errors.New(e)
}
// Carve out data from CSV
schema := csvRow[0]
table := csvRow[1]
count, err := strconv.Atoi(csvRow[2])
if err != nil {
return err
}
if len(csvRow) != 3 {
return errors.New("CSV file had the wrong number of columns")
}
// Now check to see if they match
if dbRowCount[schema][table] != count {
log.Warnf("Production row counts do not match: (prod) %s.%s = %d / %d",
schema,
table,
count,
dbRowCount[schema][table],
)
}
lineNum++
}
return err
}