From 8d8eaa24fa637a9b387080a80bd242095b046bf4 Mon Sep 17 00:00:00 2001 From: Janez Troha Date: Tue, 12 Sep 2023 20:24:44 +0200 Subject: [PATCH] Add support for nested tables ref: https://github.com/teamniteo/operations/issues/2103 --- .editorconfig | 14 ++ .gitignore | 5 +- cli/main.go | 12 +- devenv.nix | 46 +++++++ flake.lock | 232 +++++++++++++++++++++++++++++-- flake.nix | 52 +++---- go.mod | 1 + go.sum | 2 + shell.nix | 1 + subsetter/graph.go | 20 +++ subsetter/graph_test.go | 50 +++++++ subsetter/query.go | 50 ++++++- subsetter/query_test.go | 2 +- subsetter/relations.go | 28 ++-- subsetter/relations_test.go | 14 +- subsetter/sync.go | 263 +++++++++++++++++++++--------------- subsetter/sync_test.go | 26 ---- 17 files changed, 610 insertions(+), 208 deletions(-) create mode 100644 .editorconfig create mode 100644 devenv.nix create mode 100644 shell.nix create mode 100644 subsetter/graph.go create mode 100644 subsetter/graph_test.go diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..9edcaf5 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,14 @@ +# EditorConfig is awesome: https://EditorConfig.org + +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true +charset = utf-8 + +# Tab indentation (no size specified) +[Makefile] +indent_style = tab diff --git a/.gitignore b/.gitignore index 01b41f1..a760101 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,7 @@ bin dist/ *.sql -.nix-profile* \ No newline at end of file +.nix-profile* +*.sh +.devenv +.pre-commit-config.yaml diff --git a/cli/main.go b/cli/main.go index 7adebd8..c8d1cea 100644 --- a/cli/main.go +++ b/cli/main.go @@ -19,7 +19,7 @@ var ( var src = flag.String("src", "", "Source database DSN") var dst = flag.String("dst", "", "Destination database DSN") var fraction = flag.Float64("f", 0.05, "Fraction of rows to copy") -var verbose = flag.Bool("verbose", true, "Show more information during sync") +var verbose = flag.Bool("verbose", false, "Show more information during sync") var ver = flag.Bool("v", false, "Release information") var extraInclude arrayExtra var extraExclude arrayExtra @@ -45,11 +45,17 @@ func main() { log.Fatal().Msg("Fraction must be between 0 and 1") } + if *verbose { + zerolog.SetGlobalLevel(zerolog.DebugLevel) + } else { + zerolog.SetGlobalLevel(zerolog.InfoLevel) + } + if len(extraInclude) > 0 { - log.Info().Str("include", extraInclude.String()).Msg("Forcibly including") + log.Info().Str("include", extraInclude.String()).Msg("Forcibly") } if len(extraExclude) > 0 { - log.Info().Str("exclude", extraExclude.String()).Msg("Forcibly ignoring") + log.Info().Str("exclude", extraExclude.String()).Msg("Forcibly") } s, err := subsetter.NewSync(*src, *dst, *fraction, extraInclude, extraExclude, *verbose) diff --git a/devenv.nix b/devenv.nix new file mode 100644 index 0000000..1feb94a --- /dev/null +++ b/devenv.nix @@ -0,0 +1,46 @@ +{ pkgs, lib, rootDir, ... }: + +{ + # See https://devenv.sh/getting-started/ for more information + + packages = with pkgs; + [ + entr # Run arbitrary commands when files change + gitAndTools.gh # GitHub CLI + heroku # Heroku CLI + process-compose # Run multiple processes in a single terminal + golangci-lint # Linter for Go + postgresql_15 # PostgreSQL database + eclint # EditorConfig linter and fixer + gnumake # GNU Make + goreleaser # Go binary release tool + ]; + + languages.javascript.enable = true; + languages.go.enable = true; + languages.go.package = pkgs.go_1_21; + + + pre-commit.hooks = { + shellcheck.enable = true; + nixpkgs-fmt.enable = true; + gofmt.enable = true; + shfmt.enable = true; + golangci-lint = { + enable = false; + pass_filenames = false; + name = "golangci-lint"; + files = ".*"; + entry = "bash -c 'cd $(${rootDir})/backend; ${pkgs.golangci-lint}/bin/golangci-lint run --fix'"; + }; + eclint = { + enable = true; + pass_filenames = false; + name = "eclint"; + files = ".*"; + entry = "${pkgs.eclint}/bin/eclint --fix"; + }; + }; + + +} diff --git a/flake.lock b/flake.lock index 1231073..65c3c87 100644 --- a/flake.lock +++ b/flake.lock @@ -1,35 +1,168 @@ { "nodes": { + "devenv": { + "inputs": { + "flake-compat": "flake-compat", + "nix": "nix", + "nixpkgs": [ + "nixpkgs" + ], + "pre-commit-hooks": "pre-commit-hooks" + }, + "locked": { + "lastModified": 1694422554, + "narHash": "sha256-s5NTPzT66yIMmau+ZGP7q9z4NjgceDETL4xZ6HJ/TBg=", + "owner": "cachix", + "repo": "devenv", + "rev": "63d20fe09aa09060ea9ec9bb6d582c025402ba15", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "devenv", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1673956053, + "narHash": "sha256-4gtG9iQuiKITOjNQQeQIpoIB6b16fm+504Ch3sNKLd8=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "35bb57c0c8d8b62bbfd284272c928ceb64ddbde9", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, "flake-parts": { "inputs": { "nixpkgs-lib": "nixpkgs-lib" }, "locked": { - "lastModified": 1690933134, - "narHash": "sha256-ab989mN63fQZBFrkk4Q8bYxQCktuHmBIBqUG1jl6/FQ=", + "lastModified": 1693611461, + "narHash": "sha256-aPODl8vAgGQ0ZYFIRisxYG5MOGSkIczvu2Cd8Gb9+1Y=", "owner": "hercules-ci", "repo": "flake-parts", - "rev": "59cf3f1447cfc75087e7273b04b31e689a8599fb", + "rev": "7f53fdb7bdc5bb237da7fefef12d099e4fd611ca", + "type": "github" + }, + "original": { + "id": "flake-parts", + "type": "indirect" + } + }, + "flake-root": { + "locked": { + "lastModified": 1692742795, + "narHash": "sha256-f+Y0YhVCIJ06LemO+3Xx00lIcqQxSKJHXT/yk1RTKxw=", + "owner": "srid", + "repo": "flake-root", + "rev": "d9a70d9c7a5fd7f3258ccf48da9335e9b47c3937", "type": "github" }, "original": { + "owner": "srid", + "repo": "flake-root", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1685518550, + "narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "devenv", + "pre-commit-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1660459072, + "narHash": "sha256-8DFJjXG8zqoONA1vXtgeKXy68KdJL5UaXR8NtVMUbx8=", "owner": "hercules-ci", - "repo": "flake-parts", + "repo": "gitignore.nix", + "rev": "a20de23b925fd8264fd7fad6454652e142fd7f73", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "lowdown-src": { + "flake": false, + "locked": { + "lastModified": 1633514407, + "narHash": "sha256-Dw32tiMjdK9t3ETl5fzGrutQTzh2rufgZV4A/BbxuD4=", + "owner": "kristapsdz", + "repo": "lowdown", + "rev": "d2c2b44ff6c27b936ec27358a2653caaef8f73b8", + "type": "github" + }, + "original": { + "owner": "kristapsdz", + "repo": "lowdown", + "type": "github" + } + }, + "nix": { + "inputs": { + "lowdown-src": "lowdown-src", + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "nixpkgs-regression": "nixpkgs-regression" + }, + "locked": { + "lastModified": 1676545802, + "narHash": "sha256-EK4rZ+Hd5hsvXnzSzk2ikhStJnD63odF7SzsQ8CuSPU=", + "owner": "domenkozar", + "repo": "nix", + "rev": "7c91803598ffbcfe4a55c44ac6d49b2cf07a527f", + "type": "github" + }, + "original": { + "owner": "domenkozar", + "ref": "relaxed-flakes", + "repo": "nix", "type": "github" } }, "nixpkgs": { "locked": { - "lastModified": 1691709280, - "narHash": "sha256-zmfH2OlZEXwv572d0g8f6M5Ac6RiO8TxymOpY3uuqrM=", + "lastModified": 1694422566, + "narHash": "sha256-lHJ+A9esOz9vln/3CJG23FV6Wd2OoOFbDeEs4cMGMqc=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "cf73a86c35a84de0e2f3ba494327cf6fb51c0dfd", + "rev": "3a2786eea085f040a66ecde1bc3ddc7099f6dbeb", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixpkgs-unstable", + "ref": "nixos-unstable", "repo": "nixpkgs", "type": "github" } @@ -37,11 +170,11 @@ "nixpkgs-lib": { "locked": { "dir": "lib", - "lastModified": 1690881714, - "narHash": "sha256-h/nXluEqdiQHs1oSgkOOWF+j8gcJMWhwnZ9PFabN6q0=", + "lastModified": 1693471703, + "narHash": "sha256-0l03ZBL8P1P6z8MaSDS/MvuU8E75rVxe5eE1N6gxeTo=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "9e1960bc196baf6881340d53dccb203a951745a2", + "rev": "3e52e76b70d5508f3cec70b882a29199f4d1ee85", "type": "github" }, "original": { @@ -52,11 +185,88 @@ "type": "github" } }, + "nixpkgs-regression": { + "locked": { + "lastModified": 1643052045, + "narHash": "sha256-uGJ0VXIhWKGXxkeNnq4TvV3CIOkUJ3PAoLZ3HMzNVMw=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + } + }, + "nixpkgs-stable": { + "locked": { + "lastModified": 1685801374, + "narHash": "sha256-otaSUoFEMM+LjBI1XL/xGB5ao6IwnZOXc47qhIgJe8U=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c37ca420157f4abc31e26f436c1145f8951ff373", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-23.05", + "repo": "nixpkgs", + "type": "github" + } + }, + "pre-commit-hooks": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "flake-utils": "flake-utils", + "gitignore": "gitignore", + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "nixpkgs-stable": "nixpkgs-stable" + }, + "locked": { + "lastModified": 1688056373, + "narHash": "sha256-2+SDlNRTKsgo3LBRiMUcoEUb6sDViRNQhzJquZ4koOI=", + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "rev": "5843cf069272d92b60c3ed9e55b7a8989c01d4c7", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "type": "github" + } + }, "root": { "inputs": { + "devenv": "devenv", "flake-parts": "flake-parts", + "flake-root": "flake-root", "nixpkgs": "nixpkgs" } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 76b7e50..3c7237b 100644 --- a/flake.nix +++ b/flake.nix @@ -1,46 +1,30 @@ +# Minimal flake layer to support nix-shell and devenv { - nixConfig = { - allowed-users = [ "@wheel" "@staff" ]; # allow compiling on every device/machine - }; inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; - flake-parts.url = "github:hercules-ci/flake-parts"; + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-root.url = "github:srid/flake-root"; + devenv = { + url = "github:cachix/devenv"; + inputs.nixpkgs.follows = "nixpkgs"; + }; }; - outputs = inputs@{ self, nixpkgs, flake-parts, ... }: - flake-parts.lib.mkFlake { inherit inputs; } { - systems = nixpkgs.lib.systems.flakeExposed; + outputs = inputs@{ flake-parts, nixpkgs, ... }: + flake-parts.lib.mkFlake { inherit inputs; } { imports = [ - inputs.flake-parts.flakeModules.easyOverlay + inputs.devenv.flakeModule + inputs.flake-root.flakeModule ]; - - perSystem = { config, self', inputs', pkgs, system, ... }: + systems = nixpkgs.lib.systems.flakeExposed; + perSystem = { config, self', inputs', pkgs, system, lib, ... }: let - - # dev env without compile tools - stdenvMinimal = pkgs.stdenvNoCC.override { - cc = null; - preHook = ""; - allowedRequisites = null; - initialPath = pkgs.lib.filter - (a: pkgs.lib.hasPrefix "coreutils" a.name) - pkgs.stdenvNoCC.initialPath; - extraNativeBuildInputs = [ ]; - }; + rootDir = lib.getExe config.flake-root.package; in { - devShells.default = pkgs.mkShell { - stdenv = stdenvMinimal; - packages = with pkgs; [ - go - goreleaser - golangci-lint - postgresql_15 - process-compose - nixpkgs-fmt - pgweb - ]; - }; + devenv.shells.default = + (import ./devenv.nix { + inherit inputs pkgs lib rootDir; + }); }; }; } diff --git a/go.mod b/go.mod index e75a2ee..2e0783b 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.20 require ( github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.30.0 + github.com/stevenle/topsort v0.2.0 ) require ( diff --git a/go.sum b/go.sum index 131fcde..afe55b4 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= +github.com/stevenle/topsort v0.2.0 h1:LLWgtp34HPX6/RBDRS0kElVxGOTzGBLI1lSAa5Lb46k= +github.com/stevenle/topsort v0.2.0/go.mod h1:ck2WG2/ZrOr6dLApQ/5Xrqy5wv3T0qhKYWE7r9tkibc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/shell.nix b/shell.nix new file mode 100644 index 0000000..e093f6d --- /dev/null +++ b/shell.nix @@ -0,0 +1 @@ +(builtins.getFlake ("git+file://" + toString ./. + "?shallow=1")).devShells.${builtins.currentSystem}.default diff --git a/subsetter/graph.go b/subsetter/graph.go new file mode 100644 index 0000000..f13a515 --- /dev/null +++ b/subsetter/graph.go @@ -0,0 +1,20 @@ +package subsetter + +import ( + "slices" + + "github.com/stevenle/topsort" +) + +func TableGraph(primary string, relations []Relation) (l []string, e error) { + graph := topsort.NewGraph() // Create a new graph + + for _, r := range relations { + if !r.IsSelfRelated() { + graph.AddEdge(r.PrimaryTable, r.ForeignTable) + } + } + l, e = graph.TopSort(primary) + slices.Reverse(l) + return +} diff --git a/subsetter/graph_test.go b/subsetter/graph_test.go new file mode 100644 index 0000000..7b414ff --- /dev/null +++ b/subsetter/graph_test.go @@ -0,0 +1,50 @@ +package subsetter + +import ( + "testing" + + "github.com/samber/lo" +) + +func TestTableGraph(t *testing.T) { + + relations := []Relation{ + {"blog_networks", "id", "blogs", "blog_id"}, + {"users", "id", "blog_networks", "user_id"}, + {"users", "id", "blogs", "user_id"}, + {"users", "id", "users", "owner_id"}, // self reference + {"users", "id", "collaborator_api_keys", "user_id"}, + {"blogs", "id", "backups", "blog_id"}, + {"blogs", "id", "blog_imports", "blog_id"}, + {"blogs", "id", "blog_imports", "blog_id"}, + {"blogs", "id", "cleanup_notification", "blog_id"}, + } + + got, _ := TableGraph("users", relations) + + if want, _ := lo.Last(got); want != "users" { + t.Fatalf("TableGraph() = %v, want %v", got, "users") + } + +} +func TestTableGraphNnoRelation(t *testing.T) { + + relations := []Relation{ + {"blog_networks", "id", "blogs", "blog_id"}, + {"users", "id", "blog_networks", "user_id"}, + {"users", "id", "blogs", "user_id"}, + {"users", "id", "users", "owner_id"}, // self reference + {"users", "id", "collaborator_api_keys", "user_id"}, + {"blogs", "id", "backups", "blog_id"}, + {"blogs", "id", "blog_imports", "blog_id"}, + {"blogs", "id", "blog_imports", "blog_id"}, + {"blogs", "id", "cleanup_notification", "blog_id"}, + } + + got, _ := TableGraph("simple", relations) + + if want, _ := lo.Last(got); want != "simple" { + t.Fatalf("TableGraph() = %v, want %v", got, "simple") + } + +} diff --git a/subsetter/query.go b/subsetter/query.go index 1061443..24709fb 100644 --- a/subsetter/query.go +++ b/subsetter/query.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/jackc/pgx/v5/pgxpool" + "github.com/rs/zerolog/log" "github.com/samber/lo" ) @@ -16,16 +17,27 @@ type Table struct { Relations []Relation } -func (t *Table) RelationNames() (names string) { - rel := lo.Map(t.Relations, func(r Relation, _ int) string { +// RelationNames returns a list of relation names in human readable format. +func (t *Table) RelationNames() (names []string) { + names = lo.Map(t.Relations, func(r Relation, _ int) string { return r.PrimaryTable + ">" + r.PrimaryColumn }) - if len(rel) > 0 { - return strings.Join(rel, ", ") + + return +} + +// IsSelfRelated returns true if a table is self related. +func (t *Table) IsSelfRelated() bool { + for _, r := range t.Relations { + if r.IsSelfRelated() { + return true + } } - return "none" + return false } +// GetTablesWithRows returns a list of tables with the number of rows in each table. +// Warning reltuples used to dermine size is an estimate of the number of rows in the table and can be zero for small tables. func GetTablesWithRows(conn *pgxpool.Pool) (tables []Table, err error) { q := `SELECT relname, @@ -41,14 +53,30 @@ func GetTablesWithRows(conn *pgxpool.Pool) (tables []Table, err error) { var table Table if err := rows.Scan(&table.Name, &table.Rows); err == nil { + // skip system tables that are marked public + if strings.HasPrefix(table.Name, "pg_") { + continue + } + // fix for tables with no rows if table.Rows == -1 { table.Rows = 0 } + + // Do a precise count for small tables + if table.Rows == 0 { + table.Rows, err = CountRows(table.Name, conn) + if err != nil { + return nil, err + } + } + + // Get relations table.Relations, err = GetRelations(table.Name, conn) if err != nil { return nil, err } + tables = append(tables, table) } @@ -58,6 +86,7 @@ func GetTablesWithRows(conn *pgxpool.Pool) (tables []Table, err error) { return } +// GetKeys returns a list of keys from a query. func GetKeys(q string, conn *pgxpool.Pool) (ids []string, err error) { rows, err := conn.Query(context.Background(), q) for rows.Next() { @@ -73,6 +102,7 @@ func GetKeys(q string, conn *pgxpool.Pool) (ids []string, err error) { return } +// GetPrimaryKeyName returns the name of the primary key for a table. func GetPrimaryKeyName(table string, conn *pgxpool.Pool) (name string, err error) { q := fmt.Sprintf(`SELECT a.attname FROM pg_index i @@ -90,12 +120,14 @@ func GetPrimaryKeyName(table string, conn *pgxpool.Pool) (name string, err error return } +// DeleteRows deletes rows from a table. func DeleteRows(table string, where string, conn *pgxpool.Pool) (err error) { q := fmt.Sprintf(`DELETE FROM %s WHERE %s`, table, where) _, err = conn.Exec(context.Background(), q) return } +// CopyQueryToString copies a query to a string. func CopyQueryToString(query string, conn *pgxpool.Pool) (result string, err error) { q := fmt.Sprintf(`copy (%s) to stdout`, query) var buff bytes.Buffer @@ -112,11 +144,14 @@ func CopyQueryToString(query string, conn *pgxpool.Pool) (result string, err err return } -func CopyTableToString(table string, limit int, where string, conn *pgxpool.Pool) (result string, err error) { - q := fmt.Sprintf(`SELECT * FROM %s %s order by random() limit %d`, table, where, limit) +// CopyTableToString copies a table to a string. +func CopyTableToString(table string, limit string, where string, conn *pgxpool.Pool) (result string, err error) { + q := fmt.Sprintf(`SELECT * FROM %s %s order by random() %s`, table, where, limit) + log.Debug().Msgf("Query: %s", q) return CopyQueryToString(q, conn) } +// CopyStringToTable copies a string to a table. func CopyStringToTable(table string, data string, conn *pgxpool.Pool) (err error) { q := fmt.Sprintf(`copy %s from stdin`, table) var buff bytes.Buffer @@ -134,6 +169,7 @@ func CopyStringToTable(table string, data string, conn *pgxpool.Pool) (err error return } +// CountRows returns the number of rows in a table. func CountRows(s string, conn *pgxpool.Pool) (count int, err error) { q := "SELECT count(*) FROM " + s err = conn.QueryRow(context.Background(), q).Scan(&count) diff --git a/subsetter/query_test.go b/subsetter/query_test.go index 6bad69d..aff3aea 100644 --- a/subsetter/query_test.go +++ b/subsetter/query_test.go @@ -53,7 +53,7 @@ func TestCopyTableToString(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotResult, err := CopyTableToString(tt.table, 10, "", tt.conn) + gotResult, err := CopyTableToString(tt.table, "", "", tt.conn) if (err != nil) != tt.wantErr { t.Errorf("CopyTableToString() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/subsetter/relations.go b/subsetter/relations.go index 71dc9ff..5f69019 100644 --- a/subsetter/relations.go +++ b/subsetter/relations.go @@ -17,6 +17,10 @@ type Relation struct { ForeignColumn string } +func (r *Relation) IsSelfRelated() bool { + return r.PrimaryTable == r.ForeignTable +} + func (r *Relation) Query(subset []string) string { subset = lo.Map(subset, func(s string, _ int) string { @@ -30,13 +34,15 @@ func (r *Relation) PrimaryQuery() string { return fmt.Sprintf(`SELECT %s FROM %s`, r.ForeignColumn, r.ForeignTable) } -type RelationInfo struct { - TableName string +// RelationRaw is a raw representation of a relation in the database. +type RelationRaw struct { + PrimaryTable string ForeignTable string SQL string } -func (r *RelationInfo) toRelation() Relation { +// toRelation converts a RelationRaw to a Relation. +func (r *RelationRaw) toRelation() Relation { var rel Relation re := regexp.MustCompile(`FOREIGN KEY \((\w+)\) REFERENCES (\w+)\((\w+)\).*`) matches := re.FindStringSubmatch(r.SQL) @@ -45,7 +51,7 @@ func (r *RelationInfo) toRelation() Relation { rel.ForeignTable = matches[2] rel.ForeignColumn = matches[3] } - rel.PrimaryTable = r.TableName + rel.PrimaryTable = r.PrimaryTable return rel } @@ -53,7 +59,7 @@ func (r *RelationInfo) toRelation() Relation { func GetRelations(table string, conn *pgxpool.Pool) (relations []Relation, err error) { q := `SELECT - conrelid::regclass AS table_name, + conrelid::regclass AS primary_table, confrelid::regclass AS refrerenced_table, pg_get_constraintdef(c.oid, TRUE) AS sql FROM @@ -70,17 +76,17 @@ func GetRelations(table string, conn *pgxpool.Pool) (relations []Relation, err e defer rows.Close() for rows.Next() { - var rel RelationInfo + var rel RelationRaw - err = rows.Scan(&rel.TableName, &rel.ForeignTable, &rel.SQL) + err = rows.Scan(&rel.PrimaryTable, &rel.ForeignTable, &rel.SQL) if err != nil { return } - relations = append(relations, rel.toRelation()) + if table == rel.PrimaryTable { + relations = append(relations, rel.toRelation()) + } + } - relations = lo.Filter(relations, func(rel Relation, _ int) bool { - return rel.ForeignTable == table - }) return } diff --git a/subsetter/relations_test.go b/subsetter/relations_test.go index edb3534..4134edb 100644 --- a/subsetter/relations_test.go +++ b/subsetter/relations_test.go @@ -46,33 +46,33 @@ func TestRelation_Query(t *testing.T) { } } -func TestRelationInfo_toRelation(t *testing.T) { +func TestRelationRaw_toRelation(t *testing.T) { tests := []struct { name string - fields RelationInfo + fields RelationRaw want Relation }{ { "Simple", - RelationInfo{"relation", "simple", "FOREIGN KEY (simple_id) REFERENCES simple(id)"}, + RelationRaw{"relation", "simple", "FOREIGN KEY (simple_id) REFERENCES simple(id)"}, Relation{"relation", "simple_id", "simple", "id"}, }, { "Simple with cascade", - RelationInfo{"relation", "simple", "FOREIGN KEY (simple_id) REFERENCES simple(id) ON DELETE CASCADE"}, + RelationRaw{"relation", "simple", "FOREIGN KEY (simple_id) REFERENCES simple(id) ON DELETE CASCADE"}, Relation{"relation", "simple_id", "simple", "id"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := &RelationInfo{ - TableName: tt.fields.TableName, + r := &RelationRaw{ + PrimaryTable: tt.fields.PrimaryTable, ForeignTable: tt.fields.ForeignTable, SQL: tt.fields.SQL, } if got := r.toRelation(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("RelationInfo.toRelation() = %v, want %v", spew.Sdump(got), spew.Sdump(tt.want)) + t.Errorf("RelationRaw.toRelation() = %v, want %v", spew.Sdump(got), spew.Sdump(tt.want)) } }) } diff --git a/subsetter/sync.go b/subsetter/sync.go index a148d08..4f509fb 100644 --- a/subsetter/sync.go +++ b/subsetter/sync.go @@ -3,7 +3,6 @@ package subsetter import ( "context" "fmt" - "sort" "strings" "github.com/jackc/pgx/v5/pgconn" @@ -38,6 +37,18 @@ func (r *Rule) Query() string { return fmt.Sprintf("SELECT * FROM %s WHERE %s", r.Table, r.Where) } +func (r *Rule) Copy(s *Sync) (err error) { + log.Debug().Str("query", r.Where).Msgf("Copying forced rows for table %s", r.Table) + var data string + if data, err = CopyQueryToString(r.Query(), s.source); err != nil { + return errors.Wrapf(err, "Error copying forced rows for table %s", r.Table) + } + if err = CopyStringToTable(r.Table, data, s.destination); err != nil { + return errors.Wrapf(err, "Error inserting forced rows for table %s", r.Table) + } + return +} + func NewSync(source string, target string, fraction float64, include []Rule, exclude []Rule, verbose bool) (*Sync, error) { src, err := pgxpool.New(context.Background(), source) if err != nil { @@ -75,103 +86,167 @@ func (s *Sync) Close() { } // copyTableData copies the data from a table in the source database to the destination database -func copyTableData(table Table, source *pgxpool.Pool, destination *pgxpool.Pool) (err error) { +func copyTableData(table Table, relatedQueries []string, withLimit bool, source *pgxpool.Pool, destination *pgxpool.Pool) (err error) { // Backtrace the inserted ids from main table to related table - - // Get primary keys - primaryKeyName, err := GetPrimaryKeyName(table.Name, source) - if err != nil { - return errors.Wrapf(err, "Error getting primary key for %s", table.Name) + subselectQeury := "" + if len(relatedQueries) > 0 { + subselectQeury = "WHERE " + strings.Join(relatedQueries, " AND ") } - var ignoredPrimaryKeys []string - if ignoredPrimaryKeys, err = GetKeys(fmt.Sprintf("SELECT %s FROM %s", primaryKeyName, table.Name), destination); err != nil { - return errors.Wrapf(err, "Error getting primary keys for %s", table.Name) - } - ignoredPrimaryQuery := "" - if len(ignoredPrimaryKeys) > 0 { - keys := lo.Map(ignoredPrimaryKeys, func(key string, _ int) string { - return QuoteString(key) - }) - ignoredPrimaryQuery = fmt.Sprintf("WHERE %s NOT IN (%s)", primaryKeyName, strings.Join(keys, ",")) + limit := "" + if withLimit { + limit = fmt.Sprintf("LIMIT %d", table.Rows) } var data string - if data, err = CopyTableToString(table.Name, table.Rows, ignoredPrimaryQuery, source); err != nil { - log.Error().Err(err).Msgf("Error copying table %s", table.Name) + if data, err = CopyTableToString(table.Name, limit, subselectQeury, source); err != nil { + log.Error().Err(err).Msgf("Error getting table data for %s", table.Name) return } if err = CopyStringToTable(table.Name, data, destination); err != nil { - log.Error().Err(err).Msgf("Error pasting table %s", table.Name) + log.Error().Err(err).Msgf("Error pushing table data for %s", table.Name) return } return + } -// ViableSubset returns a subset of tables that can be copied to the destination database -func ViableSubset(tables []Table) (subset []Table) { +func relatedQueriesBuilder( + depth *int, + tables []Table, + relation Relation, + table Table, + source *pgxpool.Pool, + destination *pgxpool.Pool, + visitedTables *[]string, + relatedQueries *[]string, +) (err error) { + +retry: + q := fmt.Sprintf(`SELECT %s FROM %s`, relation.ForeignColumn, relation.ForeignTable) + log.Debug().Str("query", q).Msgf("Getting keys for %s from target", table.Name) + + if primaryKeys, err := GetKeys(q, destination); err != nil { + log.Error().Err(err).Msgf("Error getting keys for %s", table.Name) + return err + } else { + if len(primaryKeys) == 0 { + log.Warn().Int("depth", *depth).Msgf("No keys found for %s", relation.ForeignTable) + missingTable := lo.Filter(tables, func(table Table, _ int) bool { + return table.Name == relation.ForeignTable + })[0] + RelationalCopy(depth, tables, missingTable, visitedTables, source, destination) + *depth++ + log.Debug().Int("depth", *depth).Msgf("Retrying keys for %s", relation.ForeignTable) + if *depth < 1 { + goto retry + } else { + log.Warn().Int("depth", *depth).Msgf("Max depth reached for %s", relation.ForeignTable) + return errors.New("Max depth reached") + } - // Filter out tables with no rows - subset = lo.Filter(tables, func(table Table, _ int) bool { return table.Rows > 0 }) + } else { + *depth = 0 + keys := lo.Map(primaryKeys, func(key string, _ int) string { + return QuoteString(key) + }) + rq := fmt.Sprintf(`%s IN (%s)`, relation.PrimaryColumn, strings.Join(keys, ",")) + *relatedQueries = append(*relatedQueries, rq) + } + } + return nil +} - // Ignore tables with relations to tables - // they are populated by the primary table - tablesWithRelations := lo.Filter(tables, func(table Table, _ int) bool { - return len(table.Relations) > 0 - }) +func RelationalCopy( + depth *int, + tables []Table, + table Table, + visitedTables *[]string, + source *pgxpool.Pool, + destination *pgxpool.Pool, +) error { + log.Debug().Str("table", table.Name).Msg("Preparing") + + relatedTables, err := TableGraph(table.Name, table.Relations) + if err != nil { + return errors.Wrapf(err, "Error sorting tables from graph") + } + log.Debug().Strs("tables", relatedTables).Msgf("Order of copy") - var relatedTables []string - for _, table := range tablesWithRelations { - for _, relation := range table.Relations { - if table.Name != relation.PrimaryTable { - relatedTables = append(relatedTables, relation.PrimaryTable) + for _, tableName := range relatedTables { + + if lo.Contains(*visitedTables, tableName) { + continue + } + relatedTable := lo.Filter(tables, func(table Table, _ int) bool { + return table.Name == tableName + })[0] + *visitedTables = append(*visitedTables, relatedTable.Name) + // Use realized query to get priamry keys that are already in the destination for all related tables + + // Selection query for this table + relatedQueries := []string{} + + for _, relation := range relatedTable.Relations { + relatedQueriesBuilder(depth, tables, relation, relatedTable, source, destination, visitedTables, &relatedQueries) + } + + if len(relatedQueries) > 0 { + log.Debug().Str("table", relatedTable.Name).Strs("relatedQueries", relatedQueries).Msg("Copying with RelationalCopy") + } + + if err = copyTableData(relatedTable, relatedQueries, false, source, destination); err != nil { + if condition, ok := err.(*pgconn.PgError); ok && condition.Code == "23503" { // foreign key violation + RelationalCopy(depth, tables, relatedTable, visitedTables, source, destination) } } - } - subset = lo.Filter(subset, func(table Table, _ int) bool { - return !lo.Contains(relatedTables, table.Name) - }) + } - sort.Slice(subset, func(i, j int) bool { - return len(subset[i].Relations) < len(subset[j].Relations) - }) - return + return nil } // CopyTables copies the data from a list of tables in the source database to the destination database func (s *Sync) CopyTables(tables []Table) (err error) { - excludedTables := lo.Map(s.exclude, func(rule Rule, _ int) string { - return rule.Table - }) - - tables = lo.Filter(tables, func(table Table, _ int) bool { - return !lo.Contains(excludedTables, table.Name) - }) - - for _, table := range tables { + visitedTables := []string{} + // Copy tables without relations first + for _, table := range lo.Filter(tables, func(table Table, _ int) bool { + return len(table.Relations) == 0 + }) { + log.Info().Str("table", table.Name).Msg("Copying") + if err = copyTableData(table, []string{}, true, s.source, s.destination); err != nil { + return errors.Wrapf(err, "Error copying table %s", table.Name) + } for _, include := range s.include { if include.Table == table.Name { - log.Info().Str("query", include.Where).Msgf("Copying forced rows for table %s", table.Name) - var data string - if data, err = CopyQueryToString(include.Query(), s.source); err != nil { - return errors.Wrapf(err, "Error copying forced rows for table %s", table.Name) - } - if err = CopyStringToTable(table.Name, data, s.destination); err != nil { - return errors.Wrapf(err, "Error inserting forced rows for table %s", table.Name) - } + include.Copy(s) } } + + visitedTables = append(visitedTables, table.Name) } - for _, table := range tables { - log.Info().Msgf("Preparing %s", table.Name) - if err = copyTableData(table, s.source, s.destination); err != nil { - return errors.Wrapf(err, "Error copying table %s", table.Name) + // Prevent infinite loop, by setting max depth + depth := 0 + // Copy tables with relations + for _, complexTable := range lo.Filter(tables, func(table Table, _ int) bool { + return len(table.Relations) > 0 + }) { + log.Info().Str("table", complexTable.Name).Msg("Copying") + RelationalCopy(&depth, tables, complexTable, &visitedTables, s.source, s.destination) + + for _, include := range s.include { + if include.Table == complexTable.Name { + log.Warn().Str("table", complexTable.Name).Msgf("Copying forced rows for relational table is not supported.") + } } + } + // Remove excluded rows and print reports + for _, table := range tables { + // to ensure no data is in excluded tables for _, exclude := range s.exclude { if exclude.Table == table.Name { log.Info().Str("query", exclude.Where).Msgf("Deleting excluded rows for table %s", table.Name) @@ -183,67 +258,41 @@ func (s *Sync) CopyTables(tables []Table) (err error) { count, _ := CountRows(table.Name, s.destination) log.Info().Int("count", count).Msgf("Copied table %s", table.Name) - - for _, relation := range table.Relations { - if lo.Contains(excludedTables, relation.PrimaryTable) { - continue - } - - // Backtrace the inserted ids from main table to related table - log.Info().Msgf("Preparing %s for %s", relation.PrimaryTable, table.Name) - var pKeys []string - if pKeys, err = GetKeys(relation.PrimaryQuery(), s.destination); err != nil { - return errors.Wrapf(err, "Error getting primary keys for %s", relation.PrimaryTable) - } - var data string - if data, err = CopyQueryToString(relation.Query(pKeys), s.source); err != nil { - return errors.Wrapf(err, "Error copying related table %s", relation.PrimaryTable) - } - if err = CopyStringToTable(relation.PrimaryTable, data, s.destination); err != nil { - if condition, ok := err.(*pgconn.PgError); ok && condition.Code == "23503" { - log.Warn().AnErr("sql", err).Msgf("Skipping %s because of cyclic foreign key", relation.PrimaryTable) - err = nil - } else if condition, ok := err.(*pgconn.PgError); ok && condition.Code == "23505" { - log.Warn().AnErr("sql", err).Msgf("Skipping %s because of present foreign key", relation.PrimaryTable) - err = nil - } else { - return errors.Wrapf(err, "Error inserting related table %s", relation.PrimaryTable) - } - - } - count, _ := CountRows(relation.PrimaryTable, s.destination) - log.Info().Int("count", count).Msgf("Copied %s for %s", relation.PrimaryTable, table.Name) - } } + return } // Sync copies a subset of tables from source to destination func (s *Sync) Sync() (err error) { var tables []Table + + // Get all tables with rows if tables, err = GetTablesWithRows(s.source); err != nil { return } - var allTables []Table - if allTables = GetTargetSet(s.fraction, tables); err != nil { + // Filter out tables that are not in the include list + ruleExcludedTables := lo.Map(s.exclude, func(rule Rule, _ int) string { + return rule.Table + }) + tables = lo.Filter(tables, func(table Table, _ int) bool { + return !lo.Contains(ruleExcludedTables, table.Name) // excluded tables + }) + + // Calculate fraction to be coped over + if tables = GetTargetSet(s.fraction, tables); err != nil { return } - subset := ViableSubset(allTables) - if s.verbose { - for _, t := range subset { - log.Info(). - Str("table", t.Name). - Int("rows", t.Rows). - Str("related", t.RelationNames()). - Msg("Prepared for sync") - - } + log.Info().Strs("tables", lo.Map(tables, func(table Table, _ int) string { + return table.Name + })).Msg("Tables to be copied") } - if err = s.CopyTables(subset); err != nil { + // Copy tables + if err = s.CopyTables(tables); err != nil { return } diff --git a/subsetter/sync_test.go b/subsetter/sync_test.go index b3f9c8f..bffdb7b 100644 --- a/subsetter/sync_test.go +++ b/subsetter/sync_test.go @@ -1,35 +1,9 @@ package subsetter import ( - "reflect" "testing" ) -func TestViableSubset(t *testing.T) { - tests := []struct { - name string - tables []Table - wantSubset []Table - }{ - { - "Simple", - []Table{{"simple", 10, []Relation{}}}, - []Table{{"simple", 10, []Relation{}}}, - }, - { - "No rows", - []Table{{"simple", 0, []Relation{}}}, - []Table{}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if gotSubset := ViableSubset(tt.tables); !reflect.DeepEqual(gotSubset, tt.wantSubset) { - t.Errorf("ViableSubset() = %v, want %v", gotSubset, tt.wantSubset) - } - }) - } -} - func TestSync_CopyTables(t *testing.T) { src := getTestConnection() dst := getTestConnectionDst()