diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index de76f3840..f99d4a6a3 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -20,7 +20,7 @@ jobs: continue-on-error: true run: ./scripts/release-notes.sh ${{github.ref_name}} > ${{runner.temp}}/release_notes.txt - name: Run GoReleaser - uses: goreleaser/goreleaser-action@v4 + uses: goreleaser/goreleaser-action@v5 with: distribution: goreleaser version: latest diff --git a/CHANGELOG.md b/CHANGELOG.md index 76b340779..273df2f48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,18 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] - Added YDB support +## [v3.15.1] - 2023-10-10 + +- Fix regression that prevented registering Go migrations that didn't have the corresponding files + available in the filesystem. (#588) + - If Go migrations have been registered globally, but there are no .go files in the filesystem, + **always include** them. + - If Go migrations have been registered, and there are .go files in the filesystem, **only + include** those migrations. This was the original motivation behind #553. + - If there are .go files in the filesystem but not registered, **raise an error**. This is to + prevent accidentally adding valid looking Go migration files without explicitly registering + them. + ## [v3.15.0] - 2023-08-12 - Fix `sqlparser` to avoid skipping the last statement when it's not terminated with a semicolon @@ -50,7 +62,8 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - Add new `context.Context`-aware functions and methods, for both sql and go migrations. - Return error when no migration files found or dir is not a directory. -[Unreleased]: https://github.com/pressly/goose/compare/v3.15.0...HEAD +[Unreleased]: https://github.com/pressly/goose/compare/v3.15.1...HEAD +[v3.15.1]: https://github.com/pressly/goose/compare/v3.15.0...v3.15.1 [v3.15.0]: https://github.com/pressly/goose/compare/v3.14.0...v3.15.0 [v3.14.0]: https://github.com/pressly/goose/compare/v3.13.4...v3.14.0 [v3.13.4]: https://github.com/pressly/goose/compare/v3.13.1...v3.13.4 diff --git a/Makefile b/Makefile index 415117de4..08e6187e6 100644 --- a/Makefile +++ b/Makefile @@ -34,6 +34,9 @@ tools: test-packages: go test $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples) +test-packages-short: + go test -test.short $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples) + test-e2e: test-e2e-postgres test-e2e-mysql test-e2e-clickhouse test-e2e-vertica test-e2e-ydb test-e2e-postgres: diff --git a/README.md b/README.md index 9e9e0c4cb..3ab30907c 100644 --- a/README.md +++ b/README.md @@ -35,17 +35,23 @@ Goose supports [embedding SQL migrations](#embedded-sql-migrations), which means # Install - $ go install github.com/pressly/goose/v3/cmd/goose@latest +```shell +go install github.com/pressly/goose/v3/cmd/goose@latest +``` This will install the `goose` binary to your `$GOPATH/bin` directory. For a lite version of the binary without DB connection dependent commands, use the exclusive build tags: - $ go build -tags='no_postgres no_mysql no_sqlite3 no_ydb' -o goose ./cmd/goose +```shell +go build -tags='no_postgres no_mysql no_sqlite3 no_ydb' -o goose ./cmd/goose +``` For macOS users `goose` is available as a [Homebrew Formulae](https://formulae.brew.sh/formula/goose#default): - $ brew install goose +```shell +brew install goose +``` See the docs for more [installation instructions](https://pressly.github.io/goose/installation/). diff --git a/create_test.go b/create_test.go index fddf48d85..34791cc65 100644 --- a/create_test.go +++ b/create_test.go @@ -11,6 +11,9 @@ import ( func TestSequential(t *testing.T) { t.Parallel() + if testing.Short() { + t.Skip("skip long running test") + } dir := t.TempDir() defer os.Remove("./bin/create-goose") // clean up diff --git a/dialect.go b/dialect.go index 440d8e4e4..9cbeeb977 100644 --- a/dialect.go +++ b/dialect.go @@ -6,6 +6,20 @@ import ( "github.com/pressly/goose/v3/internal/dialect" ) +// Dialect is the type of database dialect. +type Dialect string + +const ( + DialectClickHouse Dialect = "clickhouse" + DialectMSSQL Dialect = "mssql" + DialectMySQL Dialect = "mysql" + DialectPostgres Dialect = "postgres" + DialectRedshift Dialect = "redshift" + DialectSQLite3 Dialect = "sqlite3" + DialectTiDB Dialect = "tidb" + DialectVertica Dialect = "vertica" +) + func init() { store, _ = dialect.NewStore(dialect.Postgres) } diff --git a/fix_test.go b/fix_test.go index 5c982dbe8..6a5e0842b 100644 --- a/fix_test.go +++ b/fix_test.go @@ -11,6 +11,9 @@ import ( func TestFix(t *testing.T) { t.Parallel() + if testing.Short() { + t.Skip("skip long running test") + } dir := t.TempDir() defer os.Remove("./bin/fix-goose") // clean up diff --git a/go.mod b/go.mod index 77ee0aa28..498e3eeda 100644 --- a/go.mod +++ b/go.mod @@ -3,31 +3,34 @@ module github.com/pressly/goose/v3 go 1.19 require ( - github.com/ClickHouse/clickhouse-go/v2 v2.13.0 + github.com/ClickHouse/clickhouse-go/v2 v2.14.2 github.com/go-sql-driver/mysql v1.7.1 github.com/jackc/pgx/v5 v5.4.3 github.com/microsoft/go-mssqldb v1.6.0 github.com/ory/dockertest/v3 v3.10.0 + github.com/sethvargo/go-retry v0.2.4 github.com/vertica/vertica-sql-go v1.3.3 github.com/ydb-platform/ydb-go-sdk/v3 v3.52.2 github.com/ziutek/mymysql v1.5.4 - modernc.org/sqlite v1.25.0 + go.uber.org/multierr v1.11.0 + golang.org/x/sync v0.4.0 + modernc.org/sqlite v1.26.0 ) require ( github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect - github.com/ClickHouse/ch-go v0.57.0 // indirect + github.com/ClickHouse/ch-go v0.58.2 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect - github.com/containerd/continuity v0.4.1 // indirect - github.com/docker/cli v24.0.2+incompatible // indirect - github.com/docker/docker v24.0.2+incompatible // indirect + github.com/containerd/continuity v0.4.2 // indirect + github.com/docker/cli v24.0.6+incompatible // indirect + github.com/docker/docker v24.0.6+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/elastic/go-sysinfo v1.11.0 // indirect + github.com/elastic/go-sysinfo v1.11.1 // indirect github.com/elastic/go-windows v1.0.1 // indirect github.com/go-faster/city v1.0.1 // indirect github.com/go-faster/errors v0.6.1 // indirect @@ -37,25 +40,25 @@ require ( github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/google/uuid v1.3.0 // indirect + github.com/google/uuid v1.3.1 // indirect github.com/imdario/mergo v0.3.16 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/joeshaw/multierror v0.0.0-20140124173710-69b34d4ec901 // indirect github.com/jonboulle/clockwork v0.3.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/klauspost/compress v1.16.7 // indirect + github.com/klauspost/compress v1.17.0 // indirect github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.0-rc4 // indirect - github.com/opencontainers/runc v1.1.7 // indirect + github.com/opencontainers/image-spec v1.1.0-rc5 // indirect + github.com/opencontainers/runc v1.1.9 // indirect github.com/paulmach/orb v0.10.0 // indirect github.com/pierrec/lz4/v4 v4.1.18 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/prometheus/procfs v0.11.0 // indirect + github.com/prometheus/procfs v0.12.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/shopspring/decimal v1.3.1 // indirect @@ -64,15 +67,15 @@ require ( github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect github.com/ydb-platform/ydb-go-genproto v0.0.0-20230801151335-81e01be38941 // indirect - go.opentelemetry.io/otel v1.16.0 // indirect - go.opentelemetry.io/otel/trace v1.16.0 // indirect - golang.org/x/crypto v0.12.0 // indirect + go.opentelemetry.io/otel v1.19.0 // indirect + go.opentelemetry.io/otel/trace v1.19.0 // indirect + golang.org/x/crypto v0.13.0 // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/net v0.14.0 // indirect golang.org/x/sync v0.3.0 // indirect - golang.org/x/sys v0.11.0 // indirect - golang.org/x/text v0.12.0 // indirect - golang.org/x/tools v0.10.0 // indirect + golang.org/x/sys v0.13.0 // indirect + golang.org/x/text v0.13.0 // indirect + golang.org/x/tools v0.13.0 // indirect google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect google.golang.org/grpc v1.53.0 // indirect google.golang.org/protobuf v1.28.1 // indirect @@ -81,12 +84,12 @@ require ( howett.net/plist v1.0.0 // indirect lukechampine.com/uint128 v1.3.0 // indirect modernc.org/cc/v3 v3.41.0 // indirect - modernc.org/ccgo/v3 v3.16.14 // indirect + modernc.org/ccgo/v3 v3.16.15 // indirect modernc.org/libc v1.24.1 // indirect modernc.org/mathutil v1.6.0 // indirect - modernc.org/memory v1.6.0 // indirect + modernc.org/memory v1.7.2 // indirect modernc.org/opt v0.1.3 // indirect - modernc.org/strutil v1.1.3 // indirect + modernc.org/strutil v1.2.0 // indirect modernc.org/token v1.1.0 // indirect ) diff --git a/go.sum b/go.sum index c96a15d70..09366a716 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,10 @@ github.com/ClickHouse/ch-go v0.57.0 h1:X/QmUmFhpUvLgPSQb7fWOSi1wvqGn6tJ7w2a59c4x github.com/ClickHouse/ch-go v0.57.0/go.mod h1:DR3iBn7OrrDj+KeUp1LbdxLEUDbW+5Qwdl/qkc+PQ+Y= github.com/ClickHouse/clickhouse-go/v2 v2.13.0 h1:oP1OlTQIbQKKLnqLzyDhiyNFvN3pbOtM+e/3qdexG9k= github.com/ClickHouse/clickhouse-go/v2 v2.13.0/go.mod h1:xyL0De2K54/n+HGsdtPuyYJq76wefafaHfGUXTDEq/0= +github.com/ClickHouse/ch-go v0.58.2 h1:jSm2szHbT9MCAB1rJ3WuCJqmGLi5UTjlNu+f530UTS0= +github.com/ClickHouse/ch-go v0.58.2/go.mod h1:Ap/0bEmiLa14gYjCiRkYGbXvbe8vwdrfTYWhsuQ99aw= +github.com/ClickHouse/clickhouse-go/v2 v2.14.2 h1:iYGP2bgPYJ33y6rCfZTQAPrnHt8wmsM3Ut4cLoYhWY0= +github.com/ClickHouse/clickhouse-go/v2 v2.14.2/go.mod h1:ZLn63wODwGxVdnGB0EIYmFL5tjtlLcLBuwQUH6B2sYk= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= @@ -33,15 +37,17 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/containerd/continuity v0.4.1 h1:wQnVrjIyQ8vhU2sgOiL5T07jo+ouqc2bnKsv5/EqGhU= github.com/containerd/continuity v0.4.1/go.mod h1:F6PTNCKepoxEaXLQp3wDAjygEnImnZ/7o4JzpodfroQ= +github.com/containerd/continuity v0.4.2 h1:v3y/4Yz5jwnvqPKJJ+7Wf93fyWoCB3F5EclWG023MDM= +github.com/containerd/continuity v0.4.2/go.mod h1:F6PTNCKepoxEaXLQp3wDAjygEnImnZ/7o4JzpodfroQ= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/docker/cli v24.0.2+incompatible h1:QdqR7znue1mtkXIJ+ruQMGQhpw2JzMJLRXp6zpzF6tM= -github.com/docker/cli v24.0.2+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/cli v24.0.6+incompatible h1:fF+XCQCgJjjQNIMjzaSmiKJSCcfcXb3TWTcc7GAneOY= +github.com/docker/cli v24.0.6+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= -github.com/docker/docker v24.0.2+incompatible h1:eATx+oLz9WdNVkQrr0qjQ8HvRJ4bOOxfzEo8R+dA3cg= -github.com/docker/docker v24.0.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v24.0.6+incompatible h1:hceabKCtUgDqPu+qm0NgsaXf28Ljf4/pWFL7xjWWDgE= +github.com/docker/docker v24.0.6+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= @@ -49,8 +55,8 @@ github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDD github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/elastic/go-sysinfo v1.8.1/go.mod h1:JfllUnzoQV/JRYymbH3dO1yggI3mV2oTKSXsDHM+uIM= -github.com/elastic/go-sysinfo v1.11.0 h1:QW+6BF1oxBoAprH3w2yephF7xLkrrSXj7gl2xC2BM4w= -github.com/elastic/go-sysinfo v1.11.0/go.mod h1:6KQb31j0QeWBDF88jIdWSxE8cwoOB9tO4Y4osN7Q70E= +github.com/elastic/go-sysinfo v1.11.1 h1:g9mwl05njS4r69TisC+vwHWTSKywZFYYUu3so3T/Lao= +github.com/elastic/go-sysinfo v1.11.1/go.mod h1:6KQb31j0QeWBDF88jIdWSxE8cwoOB9tO4Y4osN7Q70E= github.com/elastic/go-windows v1.0.0/go.mod h1:TsU0Nrp7/y3+VwE82FoZF8gC/XFg/Elz6CcloAxnPgU= github.com/elastic/go-windows v1.0.1 h1:AlYZOldA+UJ0/2nBuqWdo90GFCgG9xuyw9SYzGUtJm0= github.com/elastic/go-windows v1.0.1/go.mod h1:FoVvqWSun28vaDQPbj2Elfc0JahhPB7WQEGa3c814Ss= @@ -110,6 +116,8 @@ github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= +github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -128,8 +136,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:C github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= -github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= +github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -150,10 +158,10 @@ github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3 github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= -github.com/opencontainers/image-spec v1.1.0-rc4 h1:oOxKUJWnFC4YGHCCMNql1x4YaDfYBTS5Y4x/Cgeo1E0= -github.com/opencontainers/image-spec v1.1.0-rc4/go.mod h1:X4pATf0uXsnn3g5aiGIsVnJBR4mxhKzfwmvK/B2NTm8= -github.com/opencontainers/runc v1.1.7 h1:y2EZDS8sNng4Ksf0GUYNhKbTShZJPJg1FiXJNH/uoCk= -github.com/opencontainers/runc v1.1.7/go.mod h1:CbUumNnWCuTGFukNXahoo/RFBZvDAgRh/smNYNOhA50= +github.com/opencontainers/image-spec v1.1.0-rc5 h1:Ygwkfw9bpDvs+c9E34SdgGOj41dX/cbdlwvlWt0pnFI= +github.com/opencontainers/image-spec v1.1.0-rc5/go.mod h1:X4pATf0uXsnn3g5aiGIsVnJBR4mxhKzfwmvK/B2NTm8= +github.com/opencontainers/runc v1.1.9 h1:XR0VIHTGce5eWPkaPesqTBrhW2yAcaraWfsEalNwQLM= +github.com/opencontainers/runc v1.1.9/go.mod h1:CbUumNnWCuTGFukNXahoo/RFBZvDAgRh/smNYNOhA50= github.com/ory/dockertest/v3 v3.10.0 h1:4K3z2VMe8Woe++invjaTB7VRyQXQy5UY+loujO4aNE4= github.com/ory/dockertest/v3 v3.10.0/go.mod h1:nr57ZbRWMqfsdGdFNLHz5jjNdDb7VVFnzAeW1n5N1Lg= github.com/paulmach/orb v0.10.0 h1:guVYVqzxHE/CQ1KpfGO077TR0ATHSNjp4s6XGLn3W9s= @@ -169,14 +177,16 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/procfs v0.0.0-20190425082905-87a4384529e0/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.11.0 h1:5EAgkfkMl659uZPbe9AS2N68a7Cc1TJbPEuGzFuRbyk= -github.com/prometheus/procfs v0.11.0/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= +github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= +github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/sethvargo/go-retry v0.2.4 h1:T+jHEQy/zKJf5s95UkguisicE0zuF9y7+/vgz08Ocec= +github.com/sethvargo/go-retry v0.2.4/go.mod h1:1afjQuvh7s4gflMObvjLPaWgluLLyhA1wmVZ6KLpICw= github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -215,6 +225,12 @@ go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeH go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs= go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= +go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs= +go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY= +go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg= +go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -225,6 +241,8 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= @@ -246,6 +264,7 @@ golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAG golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -254,6 +273,8 @@ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= +golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -267,15 +288,15 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= -golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -284,8 +305,8 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= -golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM= +golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -344,22 +365,22 @@ lukechampine.com/uint128 v1.3.0 h1:cDdUVfRwDUDovz610ABgFD17nXD4/uDgVHl2sC3+sbo= lukechampine.com/uint128 v1.3.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= modernc.org/cc/v3 v3.41.0 h1:QoR1Sn3YWlmA1T4vLaKZfawdVtSiGx8H+cEojbC7v1Q= modernc.org/cc/v3 v3.41.0/go.mod h1:Ni4zjJYJ04CDOhG7dn640WGfwBzfE0ecX8TyMB0Fv0Y= -modernc.org/ccgo/v3 v3.16.14 h1:af6KNtFgsVmnDYrWk3PQCS9XT6BXe7o3ZFJKkIKvXNQ= -modernc.org/ccgo/v3 v3.16.14/go.mod h1:mPDSujUIaTNWQSG4eqKw+atqLOEbma6Ncsa94WbC9zo= +modernc.org/ccgo/v3 v3.16.15 h1:KbDR3ZAVU+wiLyMESPtbtE/Add4elztFyfsWoNTgxS0= +modernc.org/ccgo/v3 v3.16.15/go.mod h1:yT7B+/E2m43tmMOT51GMoM98/MtHIcQQSleGnddkUNI= modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk= modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= modernc.org/libc v1.24.1 h1:uvJSeCKL/AgzBo2yYIPPTy82v21KgGnizcGYfBHaNuM= modernc.org/libc v1.24.1/go.mod h1:FmfO1RLrU3MHJfyi9eYYmZBfi/R+tqZ6+hQ3yQQUkak= modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= -modernc.org/memory v1.6.0 h1:i6mzavxrE9a30whzMfwf7XWVODx2r5OYXvU46cirX7o= -modernc.org/memory v1.6.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= +modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E= +modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E= modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= -modernc.org/sqlite v1.25.0 h1:AFweiwPNd/b3BoKnBOfFm+Y260guGMF+0UFk0savqeA= -modernc.org/sqlite v1.25.0/go.mod h1:FL3pVXie73rg3Rii6V/u5BoHlSoyeZeIgKZEgHARyCU= -modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY= -modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw= +modernc.org/sqlite v1.26.0 h1:SocQdLRSYlA8W99V8YH0NES75thx19d9sB/aFc4R8Lw= +modernc.org/sqlite v1.26.0/go.mod h1:FL3pVXie73rg3Rii6V/u5BoHlSoyeZeIgKZEgHARyCU= +modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= +modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= modernc.org/tcl v1.15.2 h1:C4ybAYCGJw968e+Me18oW55kD/FexcHbqH2xak1ROSY= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/migrationstats/migrationstats_test.go b/internal/migrationstats/migrationstats_test.go index 67a65a3cf..26c49fd38 100644 --- a/internal/migrationstats/migrationstats_test.go +++ b/internal/migrationstats/migrationstats_test.go @@ -8,6 +8,7 @@ import ( ) func TestParsingGoMigrations(t *testing.T) { + t.Parallel() tests := []struct { name string input string @@ -38,6 +39,7 @@ func TestParsingGoMigrations(t *testing.T) { } func TestParsingGoMigrationsError(t *testing.T) { + t.Parallel() _, err := parseGoFile(strings.NewReader(emptyInit)) check.HasError(t, err) check.Contains(t, err.Error(), "no registered goose functions") diff --git a/internal/provider/collect.go b/internal/provider/collect.go new file mode 100644 index 000000000..fd7d63e75 --- /dev/null +++ b/internal/provider/collect.go @@ -0,0 +1,232 @@ +package provider + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strconv" + "strings" +) + +func NewSource(t MigrationType, fullpath string, version int64) Source { + return Source{ + Type: t, + Fullpath: fullpath, + Version: version, + } +} + +// fileSources represents a collection of migration files on the filesystem. +type fileSources struct { + sqlSources []Source + goSources []Source +} + +// TODO(mf): remove? +func (s *fileSources) lookup(t MigrationType, version int64) *Source { + switch t { + case TypeGo: + for _, source := range s.goSources { + if source.Version == version { + return &source + } + } + case TypeSQL: + for _, source := range s.sqlSources { + if source.Version == version { + return &source + } + } + } + return nil +} + +// collectFileSources scans the file system for migration files that have a numeric prefix (greater +// than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil, +// in which case an empty fileSources is returned. +// +// If strict is true, then any error parsing the numeric component of the filename will result in an +// error. The file is skipped otherwise. +// +// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects +// migration sources from the filesystem. +func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) { + if fsys == nil { + return new(fileSources), nil + } + sources := new(fileSources) + versionToBaseLookup := make(map[int64]string) // map[version]filepath.Base(fullpath) + for _, pattern := range []string{ + "*.sql", + "*.go", + } { + files, err := fs.Glob(fsys, pattern) + if err != nil { + return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err) + } + for _, fullpath := range files { + base := filepath.Base(fullpath) + // Skip explicit excludes or Go test files. + if excludes[base] || strings.HasSuffix(base, "_test.go") { + continue + } + // If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use + // that as the version. Otherwise, ignore it. This allows users to have arbitrary + // filenames, but still have versioned migrations within the same directory. For + // example, a user could have a helpers.go file which contains unexported helper + // functions for migrations. + version, err := NumericComponent(base) + if err != nil { + if strict { + return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err) + } + continue + } + // Ensure there are no duplicate versions. + if existing, ok := versionToBaseLookup[version]; ok { + return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", + version, + existing, + base, + ) + } + switch filepath.Ext(base) { + case ".sql": + sources.sqlSources = append(sources.sqlSources, NewSource(TypeSQL, fullpath, version)) + case ".go": + sources.goSources = append(sources.goSources, NewSource(TypeGo, fullpath, version)) + default: + // Should never happen since we already filtered out all other file types. + return nil, fmt.Errorf("unknown migration type: %s", base) + } + // Add the version to the lookup map. + versionToBaseLookup[version] = base + } + } + return sources, nil +} + +func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration, error) { + var migrations []*migration + migrationLookup := make(map[int64]*migration) + // Add all SQL migrations to the list of migrations. + for _, source := range sources.sqlSources { + m := &migration{ + Source: source, + SQL: nil, // SQL migrations are parsed lazily. + } + migrations = append(migrations, m) + migrationLookup[source.Version] = m + } + // If there are no Go files in the filesystem and no registered Go migrations, return early. + if len(sources.goSources) == 0 && len(registerd) == 0 { + return migrations, nil + } + // Return an error if the given sources contain a versioned Go migration that has not been + // registered. This is a sanity check to ensure users didn't accidentally create a valid looking + // Go migration file on disk and forget to register it. + // + // This is almost always a user error. + var unregistered []string + for _, s := range sources.goSources { + if _, ok := registerd[s.Version]; !ok { + unregistered = append(unregistered, s.Fullpath) + } + } + if len(unregistered) > 0 { + return nil, unregisteredError(unregistered) + } + // Add all registered Go migrations to the list of migrations, checking for duplicate versions. + // + // Important, users can register Go migrations manually via goose.Add_ functions. These + // migrations may not have a corresponding file on disk. Which is fine! We include them + // wholesale as part of migrations. This allows users to build a custom binary that only embeds + // the SQL migration files. + for version, r := range registerd { + fullpath := r.fullpath + if fullpath == "" { + if s := sources.lookup(TypeGo, version); s != nil { + fullpath = s.Fullpath + } + } + // Ensure there are no duplicate versions. + if existing, ok := migrationLookup[version]; ok { + fullpath := r.fullpath + if fullpath == "" { + fullpath = "manually registered (no source)" + } + return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", + version, + existing.Source.Fullpath, + fullpath, + ) + } + m := &migration{ + // Note, the fullpath may be empty if the migration was registered manually. + Source: NewSource(TypeGo, fullpath, version), + Go: r, + } + migrations = append(migrations, m) + migrationLookup[version] = m + } + // Sort migrations by version in ascending order. + sort.Slice(migrations, func(i, j int) bool { + return migrations[i].Source.Version < migrations[j].Source.Version + }) + return migrations, nil +} + +func unregisteredError(unregistered []string) error { + const ( + hintURL = "https://github.com/pressly/goose/tree/master/examples/go-migrations" + ) + f := "file" + if len(unregistered) > 1 { + f += "s" + } + var b strings.Builder + + b.WriteString(fmt.Sprintf("error: detected %d unregistered Go %s:\n", len(unregistered), f)) + for _, name := range unregistered { + b.WriteString("\t" + name + "\n") + } + hint := fmt.Sprintf("hint: go functions must be registered and built into a custom binary see:\n%s", hintURL) + b.WriteString(hint) + b.WriteString("\n") + + return errors.New(b.String()) +} + +type noopFS struct{} + +var _ fs.FS = noopFS{} + +func (f noopFS) Open(name string) (fs.File, error) { + return nil, os.ErrNotExist +} + +// NumericComponent parses the version from the migration file name. +// +// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of +// migration, either .sql or .go. +func NumericComponent(filename string) (int64, error) { + base := filepath.Base(filename) + if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { + return 0, errors.New("migration file does not have .sql or .go file extension") + } + idx := strings.Index(base, "_") + if idx < 0 { + return 0, errors.New("no filename separator '_' found") + } + n, err := strconv.ParseInt(base[:idx], 10, 64) + if err != nil { + return 0, err + } + if n < 1 { + return 0, errors.New("migration version must be greater than zero") + } + return n, nil +} diff --git a/internal/provider/collect_test.go b/internal/provider/collect_test.go new file mode 100644 index 000000000..73b2642c5 --- /dev/null +++ b/internal/provider/collect_test.go @@ -0,0 +1,313 @@ +package provider + +import ( + "io/fs" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" +) + +func TestCollectFileSources(t *testing.T) { + t.Parallel() + t.Run("nil_fsys", func(t *testing.T) { + sources, err := collectFileSources(nil, false, nil) + check.NoError(t, err) + check.Bool(t, sources != nil, true) + check.Number(t, len(sources.goSources), 0) + check.Number(t, len(sources.sqlSources), 0) + }) + t.Run("empty_fsys", func(t *testing.T) { + sources, err := collectFileSources(fstest.MapFS{}, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.goSources), 0) + check.Number(t, len(sources.sqlSources), 0) + check.Bool(t, sources != nil, true) + }) + t.Run("incorrect_fsys", func(t *testing.T) { + mapFS := fstest.MapFS{ + "00000_foo.sql": sqlMapFile, + } + // strict disable - should not error + sources, err := collectFileSources(mapFS, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.goSources), 0) + check.Number(t, len(sources.sqlSources), 0) + // strict enabled - should error + _, err = collectFileSources(mapFS, true, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "migration version must be greater than zero") + }) + t.Run("collect", func(t *testing.T) { + fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 4) + check.Number(t, len(sources.goSources), 0) + expected := fileSources{ + sqlSources: []Source{ + NewSource(TypeSQL, "00001_foo.sql", 1), + NewSource(TypeSQL, "00002_bar.sql", 2), + NewSource(TypeSQL, "00003_baz.sql", 3), + NewSource(TypeSQL, "00110_qux.sql", 110), + }, + } + for i := 0; i < len(sources.sqlSources); i++ { + check.Equal(t, sources.sqlSources[i], expected.sqlSources[i]) + } + }) + t.Run("excludes", func(t *testing.T) { + fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") + check.NoError(t, err) + sources, err := collectFileSources( + fsys, + false, + // exclude 2 files explicitly + map[string]bool{ + "00002_bar.sql": true, + "00110_qux.sql": true, + }, + ) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 2) + check.Number(t, len(sources.goSources), 0) + expected := fileSources{ + sqlSources: []Source{ + NewSource(TypeSQL, "00001_foo.sql", 1), + NewSource(TypeSQL, "00003_baz.sql", 3), + }, + } + for i := 0; i < len(sources.sqlSources); i++ { + check.Equal(t, sources.sqlSources[i], expected.sqlSources[i]) + } + }) + t.Run("strict", func(t *testing.T) { + mapFS := newSQLOnlyFS() + // Add a file with no version number + mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")} + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + _, err = collectFileSources(fsys, true, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`) + }) + t.Run("skip_go_test_files", func(t *testing.T) { + mapFS := fstest.MapFS{ + "1_foo.sql": sqlMapFile, + "2_bar.sql": sqlMapFile, + "3_baz.sql": sqlMapFile, + "4_qux.sql": sqlMapFile, + "5_foo_test.go": {Data: []byte(`package goose_test`)}, + } + sources, err := collectFileSources(mapFS, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 4) + check.Number(t, len(sources.goSources), 0) + }) + t.Run("skip_random_files", func(t *testing.T) { + mapFS := fstest.MapFS{ + "1_foo.sql": sqlMapFile, + "4_something.go": {Data: []byte(`package goose`)}, + "5_qux.sql": sqlMapFile, + "README.md": {Data: []byte(`# README`)}, + "LICENSE": {Data: []byte(`MIT`)}, + "no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)}, + "some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)}, + } + sources, err := collectFileSources(mapFS, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 2) + check.Number(t, len(sources.goSources), 1) + // 1 + check.Equal(t, sources.sqlSources[0].Fullpath, "1_foo.sql") + check.Equal(t, sources.sqlSources[0].Version, int64(1)) + // 2 + check.Equal(t, sources.sqlSources[1].Fullpath, "5_qux.sql") + check.Equal(t, sources.sqlSources[1].Version, int64(5)) + // 3 + check.Equal(t, sources.goSources[0].Fullpath, "4_something.go") + check.Equal(t, sources.goSources[0].Version, int64(4)) + }) + t.Run("duplicate_versions", func(t *testing.T) { + mapFS := fstest.MapFS{ + "001_foo.sql": sqlMapFile, + "01_bar.sql": sqlMapFile, + } + _, err := collectFileSources(mapFS, false, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "found duplicate migration version 1") + }) + t.Run("dirpath", func(t *testing.T) { + mapFS := fstest.MapFS{ + "dir1/101_a.sql": sqlMapFile, + "dir1/102_b.sql": sqlMapFile, + "dir1/103_c.sql": sqlMapFile, + "dir2/201_a.sql": sqlMapFile, + "876_a.sql": sqlMapFile, + } + assertDirpath := func(dirpath string, sqlSources []Source) { + t.Helper() + f, err := fs.Sub(mapFS, dirpath) + check.NoError(t, err) + got, err := collectFileSources(f, false, nil) + check.NoError(t, err) + check.Number(t, len(got.sqlSources), len(sqlSources)) + check.Number(t, len(got.goSources), 0) + for i := 0; i < len(got.sqlSources); i++ { + check.Equal(t, got.sqlSources[i], sqlSources[i]) + } + } + assertDirpath(".", []Source{ + NewSource(TypeSQL, "876_a.sql", 876), + }) + assertDirpath("dir1", []Source{ + NewSource(TypeSQL, "101_a.sql", 101), + NewSource(TypeSQL, "102_b.sql", 102), + NewSource(TypeSQL, "103_c.sql", 103), + }) + assertDirpath("dir2", []Source{ + NewSource(TypeSQL, "201_a.sql", 201), + }) + assertDirpath("dir3", nil) + }) +} + +func TestMerge(t *testing.T) { + t.Parallel() + + t.Run("with_go_files_on_disk", func(t *testing.T) { + mapFS := fstest.MapFS{ + // SQL + "migrations/00001_foo.sql": sqlMapFile, + // Go + "migrations/00002_bar.go": {Data: []byte(`package migrations`)}, + "migrations/00003_baz.go": {Data: []byte(`package migrations`)}, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + check.Equal(t, len(sources.sqlSources), 1) + check.Equal(t, len(sources.goSources), 2) + src1 := sources.lookup(TypeSQL, 1) + check.Bool(t, src1 != nil, true) + src2 := sources.lookup(TypeGo, 2) + check.Bool(t, src2 != nil, true) + src3 := sources.lookup(TypeGo, 3) + check.Bool(t, src3 != nil, true) + + t.Run("valid", func(t *testing.T) { + migrations, err := merge(sources, map[int64]*goMigration{ + 2: newGoMigration("", nil, nil), + 3: newGoMigration("", nil, nil), + }) + check.NoError(t, err) + check.Number(t, len(migrations), 3) + assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], NewSource(TypeGo, "00003_baz.go", 3)) + }) + t.Run("unregistered_all", func(t *testing.T) { + _, err := merge(sources, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "error: detected 2 unregistered Go files:") + check.Contains(t, err.Error(), "00002_bar.go") + check.Contains(t, err.Error(), "00003_baz.go") + }) + t.Run("unregistered_some", func(t *testing.T) { + _, err := merge(sources, map[int64]*goMigration{ + 2: newGoMigration("", nil, nil), + }) + check.HasError(t, err) + check.Contains(t, err.Error(), "error: detected 1 unregistered Go file") + check.Contains(t, err.Error(), "00003_baz.go") + }) + t.Run("duplicate_sql", func(t *testing.T) { + _, err := merge(sources, map[int64]*goMigration{ + 1: newGoMigration("", nil, nil), // duplicate. SQL already exists. + 2: newGoMigration("", nil, nil), + 3: newGoMigration("", nil, nil), + }) + check.HasError(t, err) + check.Contains(t, err.Error(), "found duplicate migration version 1") + }) + }) + t.Run("no_go_files_on_disk", func(t *testing.T) { + mapFS := fstest.MapFS{ + // SQL + "migrations/00001_foo.sql": sqlMapFile, + "migrations/00002_bar.sql": sqlMapFile, + "migrations/00005_baz.sql": sqlMapFile, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + t.Run("unregistered_all", func(t *testing.T) { + migrations, err := merge(sources, map[int64]*goMigration{ + 3: newGoMigration("", nil, nil), + // 4 is missing + 6: newGoMigration("", nil, nil), + }) + check.NoError(t, err) + check.Number(t, len(migrations), 5) + assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], NewSource(TypeSQL, "00002_bar.sql", 2)) + assertMigration(t, migrations[2], NewSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], NewSource(TypeSQL, "00005_baz.sql", 5)) + assertMigration(t, migrations[4], NewSource(TypeGo, "", 6)) + }) + }) + t.Run("partial_go_files_on_disk", func(t *testing.T) { + mapFS := fstest.MapFS{ + "migrations/00001_foo.sql": sqlMapFile, + "migrations/00002_bar.go": &fstest.MapFile{Data: []byte(`package migrations`)}, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + t.Run("unregistered_all", func(t *testing.T) { + migrations, err := merge(sources, map[int64]*goMigration{ + // This is the only Go file on disk. + 2: newGoMigration("", nil, nil), + // These are not on disk. Explicitly registered. + 3: newGoMigration("", nil, nil), + 6: newGoMigration("", nil, nil), + }) + check.NoError(t, err) + check.Number(t, len(migrations), 4) + assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], NewSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], NewSource(TypeGo, "", 6)) + }) + }) +} + +func assertMigration(t *testing.T, got *migration, want Source) { + t.Helper() + check.Equal(t, got.Source, want) + switch got.Source.Type { + case TypeGo: + check.Bool(t, got.Go != nil, true) + case TypeSQL: + check.Bool(t, got.SQL == nil, true) + default: + t.Fatalf("unknown migration type: %s", got.Source.Type) + } +} + +func newSQLOnlyFS() fstest.MapFS { + return fstest.MapFS{ + "migrations/00001_foo.sql": sqlMapFile, + "migrations/00002_bar.sql": sqlMapFile, + "migrations/00003_baz.sql": sqlMapFile, + "migrations/00110_qux.sql": sqlMapFile, + } +} + +var ( + sqlMapFile = &fstest.MapFile{Data: []byte(`-- +goose Up`)} +) diff --git a/internal/provider/errors.go b/internal/provider/errors.go new file mode 100644 index 000000000..e8ece3871 --- /dev/null +++ b/internal/provider/errors.go @@ -0,0 +1,39 @@ +package provider + +import ( + "errors" + "fmt" + "path/filepath" +) + +var ( + // ErrVersionNotFound when a migration version is not found. + ErrVersionNotFound = errors.New("version not found") + + // ErrAlreadyApplied when a migration has already been applied. + ErrAlreadyApplied = errors.New("already applied") + + // ErrNoMigrations is returned by [NewProvider] when no migrations are found. + ErrNoMigrations = errors.New("no migrations found") + + // ErrNoNextVersion when the next migration version is not found. + ErrNoNextVersion = errors.New("no next version found") +) + +// PartialError is returned when a migration fails, but some migrations already got applied. +type PartialError struct { + // Applied are migrations that were applied successfully before the error occurred. + Applied []*MigrationResult + // Failed contains the result of the migration that failed. + Failed *MigrationResult + // Err is the error that occurred while running the migration. + Err error +} + +func (e *PartialError) Error() string { + filename := "(file unknown)" + if e.Failed != nil && e.Failed.Source.Fullpath != "" { + filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Fullpath)) + } + return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err) +} diff --git a/internal/provider/migration.go b/internal/provider/migration.go new file mode 100644 index 000000000..87098cf22 --- /dev/null +++ b/internal/provider/migration.go @@ -0,0 +1,166 @@ +package provider + +import ( + "context" + "database/sql" + "fmt" + "path/filepath" + + "github.com/pressly/goose/v3/internal/sqlextended" +) + +type migration struct { + Source Source + // A migration is either a Go migration or a SQL migration, but never both. + // + // Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is + // an optimization to avoid parsing the SQL migration if it is never required. Also, the + // majority of the time migrations are incremental, so it is likely that the user will only want + // to run the last few migrations, and there is no need to parse ALL prior migrations. + // + // Exactly one of these fields will be set: + Go *goMigration + // -- OR -- + SQL *sqlMigration +} + +func (m *migration) useTx(direction bool) bool { + switch m.Source.Type { + case TypeSQL: + return m.SQL.UseTx + case TypeGo: + if m.Go == nil { + return false + } + if direction { + return m.Go.up.Run != nil + } + return m.Go.down.Run != nil + } + // This should never happen. + return false +} + +func (m *migration) filename() string { + return filepath.Base(m.Source.Fullpath) +} + +// run runs the migration inside of a transaction. +func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error { + switch m.Source.Type { + case TypeSQL: + if m.SQL == nil { + return fmt.Errorf("tx: sql migration has not been parsed") + } + return m.SQL.run(ctx, tx, direction) + case TypeGo: + return m.Go.run(ctx, tx, direction) + } + // This should never happen. + return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) +} + +// runNoTx runs the migration without a transaction. +func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { + switch m.Source.Type { + case TypeSQL: + if m.SQL == nil { + return fmt.Errorf("db: sql migration has not been parsed") + } + return m.SQL.run(ctx, db, direction) + case TypeGo: + return m.Go.runNoTx(ctx, db, direction) + } + // This should never happen. + return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) +} + +// runConn runs the migration without a transaction using the provided connection. +func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool) error { + switch m.Source.Type { + case TypeSQL: + if m.SQL == nil { + return fmt.Errorf("conn: sql migration has not been parsed") + } + return m.SQL.run(ctx, conn, direction) + case TypeGo: + return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") + } + // This should never happen. + return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) +} + +type goMigration struct { + fullpath string + up, down *GoMigration +} + +func newGoMigration(fullpath string, up, down *GoMigration) *goMigration { + return &goMigration{ + fullpath: fullpath, + up: up, + down: down, + } +} + +func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error { + if g == nil { + return nil + } + var fn func(context.Context, *sql.Tx) error + if direction && g.up != nil { + fn = g.up.Run + } + if !direction && g.down != nil { + fn = g.down.Run + } + if fn != nil { + return fn(ctx, tx) + } + return nil +} + +func (g *goMigration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { + if g == nil { + return nil + } + var fn func(context.Context, *sql.DB) error + if direction && g.up != nil { + fn = g.up.RunNoTx + } + if !direction && g.down != nil { + fn = g.down.RunNoTx + } + if fn != nil { + return fn(ctx, db) + } + return nil +} + +type sqlMigration struct { + UseTx bool + UpStatements []string + DownStatements []string +} + +func (s *sqlMigration) IsEmpty(direction bool) bool { + if direction { + return len(s.UpStatements) == 0 + } + return len(s.DownStatements) == 0 +} + +func (s *sqlMigration) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error { + var statements []string + if direction { + statements = s.UpStatements + } else { + statements = s.DownStatements + } + for _, stmt := range statements { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return err + } + } + return nil +} diff --git a/internal/provider/misc.go b/internal/provider/misc.go new file mode 100644 index 000000000..be84b4622 --- /dev/null +++ b/internal/provider/misc.go @@ -0,0 +1,39 @@ +package provider + +import ( + "context" + "database/sql" + "errors" + "fmt" +) + +type Migration struct { + Version int64 + Source string // path to .sql script or go file + Registered bool + UseTx bool + UpFnContext func(context.Context, *sql.Tx) error + DownFnContext func(context.Context, *sql.Tx) error + + UpFnNoTxContext func(context.Context, *sql.DB) error + DownFnNoTxContext func(context.Context, *sql.DB) error +} + +var registeredGoMigrations = make(map[int64]*Migration) + +func SetGlobalGoMigrations(migrations []*Migration) error { + for _, m := range migrations { + if m == nil { + return errors.New("cannot register nil go migration") + } + if _, ok := registeredGoMigrations[m.Version]; ok { + return fmt.Errorf("go migration with version %d already registered", m.Version) + } + registeredGoMigrations[m.Version] = m + } + return nil +} + +func ResetGlobalGoMigrations() { + registeredGoMigrations = make(map[int64]*Migration) +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go new file mode 100644 index 000000000..3982ac37b --- /dev/null +++ b/internal/provider/provider.go @@ -0,0 +1,238 @@ +package provider + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io/fs" + "math" + "sync" + + "github.com/pressly/goose/v3/internal/sqladapter" +) + +// NewProvider returns a new goose Provider. +// +// The caller is responsible for matching the database dialect with the database/sql driver. For +// example, if the database dialect is "postgres", the database/sql driver could be +// github.com/lib/pq or github.com/jackc/pgx. +// +// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to +// use os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is +// possible to use a different filesystem, such as embed.FS or filter out migrations using fs.Sub. +// +// See [ProviderOption] for more information on configuring the provider. +// +// Unless otherwise specified, all methods on Provider are safe for concurrent use. +// +// Experimental: This API is experimental and may change in the future. +func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { + if db == nil { + return nil, errors.New("db must not be nil") + } + if dialect == "" { + return nil, errors.New("dialect must not be empty") + } + if fsys == nil { + fsys = noopFS{} + } + cfg := config{ + registered: make(map[int64]*goMigration), + } + for _, opt := range opts { + if err := opt.apply(&cfg); err != nil { + return nil, err + } + } + // Set defaults after applying user-supplied options so option funcs can check for empty values. + if cfg.tableName == "" { + cfg.tableName = DefaultTablename + } + store, err := sqladapter.NewStore(string(dialect), cfg.tableName) + if err != nil { + return nil, err + } + // Collect migrations from the filesystem and merge with registered migrations. + // + // Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed + // lazily. + // + // TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to + // return an error if there are any SQL parsing errors. This adds a bit overhead to startup + // though, so we should make it optional. + sources, err := collectFileSources(fsys, false, cfg.excludes) + if err != nil { + return nil, err + } + // + // TODO(mf): move the merging of Go migrations into a separate function. + // + registered := make(map[int64]*goMigration) + // Add user-registered Go migrations. + for version, m := range cfg.registered { + registered[version] = newGoMigration("", m.up, m.down) + } + // Add init() functions. This is a bit ugly because we need to convert from the old Migration + // struct to the new goMigration struct. + for version, m := range registeredGoMigrations { + if _, ok := registered[version]; ok { + return nil, fmt.Errorf("go migration with version %d already registered", version) + } + if m == nil { + return nil, errors.New("registered migration with nil init function") + } + g := newGoMigration(m.Source, nil, nil) + if m.UpFnContext != nil && m.UpFnNoTxContext != nil { + return nil, errors.New("registered migration with both UpFnContext and UpFnNoTxContext") + } + if m.DownFnContext != nil && m.DownFnNoTxContext != nil { + return nil, errors.New("registered migration with both DownFnContext and DownFnNoTxContext") + } + // Up + if m.UpFnContext != nil { + g.up = &GoMigration{ + Run: m.UpFnContext, + } + } else if m.UpFnNoTxContext != nil { + g.up = &GoMigration{ + RunNoTx: m.UpFnNoTxContext, + } + } + // Down + if m.DownFnContext != nil { + g.down = &GoMigration{ + Run: m.DownFnContext, + } + } else if m.DownFnNoTxContext != nil { + g.down = &GoMigration{ + RunNoTx: m.DownFnNoTxContext, + } + } + registered[version] = g + } + migrations, err := merge(sources, registered) + if err != nil { + return nil, err + } + if len(migrations) == 0 { + return nil, ErrNoMigrations + } + return &Provider{ + db: db, + fsys: fsys, + cfg: cfg, + store: store, + migrations: migrations, + }, nil +} + +// Provider is a goose migration provider. +type Provider struct { + // mu protects all accesses to the provider and must be held when calling operations on the + // database. + mu sync.Mutex + + db *sql.DB + fsys fs.FS + cfg config + store sqladapter.Store + migrations []*migration +} + +// Status returns the status of all migrations, merging the list of migrations from the database and +// filesystem. The returned items are ordered by version, in ascending order. +func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { + return p.status(ctx) +} + +// GetDBVersion returns the max version from the database, regardless of the applied order. For +// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been +// applied, it returns 0. +// +// TODO(mf): this is not true? +func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { + return p.getDBVersion(ctx) +} + +// ListSources returns a list of all available migration sources the provider is aware of, sorted in +// ascending order by version. +func (p *Provider) ListSources() []Source { + sources := make([]Source, 0, len(p.migrations)) + for _, m := range p.migrations { + sources = append(sources, m.Source) + } + return sources +} + +// Ping attempts to ping the database to verify a connection is available. +func (p *Provider) Ping(ctx context.Context) error { + return p.db.PingContext(ctx) +} + +// Close closes the database connection. +func (p *Provider) Close() error { + return p.db.Close() +} + +// ApplyVersion applies exactly one migration at the specified version. If there is no source for +// the specified version, this method returns [ErrNoCurrentVersion]. If the migration has been +// applied already, this method returns [ErrAlreadyApplied]. +// +// When direction is true, the up migration is executed, and when direction is false, the down +// migration is executed. +func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) { + return p.apply(ctx, version, direction) +} + +// Up applies all pending migrations. If there are no new migrations to apply, this method returns +// empty list and nil error. +func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { + return p.up(ctx, false, math.MaxInt64) +} + +// UpByOne applies the next available migration. If there are no migrations to apply, this method +// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result. +func (p *Provider) UpByOne(ctx context.Context) ([]*MigrationResult, error) { + res, err := p.up(ctx, true, math.MaxInt64) + if err != nil { + return nil, err + } + if len(res) == 0 { + return nil, ErrNoNextVersion + } + return res, nil +} + +// UpTo applies all available migrations up to and including the specified version. If there are no +// migrations to apply, this method returns empty list and nil error. +// +// For instance, if there are three new migrations (9,10,11) and the current database version is 8 +// with a requested version of 10, only versions 9 and 10 will be applied. +func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { + return p.up(ctx, false, version) +} + +// Down rolls back the most recently applied migration. If there are no migrations to apply, this +// method returns [ErrNoNextVersion]. +func (p *Provider) Down(ctx context.Context) ([]*MigrationResult, error) { + res, err := p.down(ctx, true, 0) + if err != nil { + return nil, err + } + if len(res) == 0 { + return nil, ErrNoNextVersion + } + return res, nil +} + +// DownTo rolls back all migrations down to but not including the specified version. +// +// For instance, if the current database version is 11, and the requested version is 9, only +// migrations 11 and 10 will be rolled back. +func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { + if version < 0 { + return nil, fmt.Errorf("version must be a number greater than or equal zero: %d", version) + } + return p.down(ctx, false, version) +} diff --git a/internal/provider/provider_options.go b/internal/provider/provider_options.go new file mode 100644 index 000000000..0b7cd7ad6 --- /dev/null +++ b/internal/provider/provider_options.go @@ -0,0 +1,167 @@ +package provider + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/pressly/goose/v3/lock" +) + +const ( + DefaultTablename = "goose_db_version" +) + +// ProviderOption is a configuration option for a goose provider. +type ProviderOption interface { + apply(*config) error +} + +// WithTableName sets the name of the database table used to track history of applied migrations. +// +// If WithTableName is not called, the default value is "goose_db_version". +func WithTableName(name string) ProviderOption { + return configFunc(func(c *config) error { + if c.tableName != "" { + return fmt.Errorf("table already set to %q", c.tableName) + } + if name == "" { + return errors.New("table must not be empty") + } + c.tableName = name + return nil + }) +} + +// WithVerbose enables verbose logging. +func WithVerbose(b bool) ProviderOption { + return configFunc(func(c *config) error { + c.verbose = b + return nil + }) +} + +// WithSessionLocker enables locking using the provided SessionLocker. +// +// If WithSessionLocker is not called, locking is disabled. +func WithSessionLocker(locker lock.SessionLocker) ProviderOption { + return configFunc(func(c *config) error { + if c.lockEnabled { + return errors.New("lock already enabled") + } + if c.sessionLocker != nil { + return errors.New("session locker already set") + } + if locker == nil { + return errors.New("session locker must not be nil") + } + c.lockEnabled = true + c.sessionLocker = locker + return nil + }) +} + +// WithExcludes excludes the given file names from the list of migrations. +// +// If WithExcludes is called multiple times, the list of excludes is merged. +func WithExcludes(excludes []string) ProviderOption { + return configFunc(func(c *config) error { + for _, name := range excludes { + c.excludes[name] = true + } + return nil + }) +} + +// GoMigration is a user-defined Go migration, registered using the option [WithGoMigration]. +type GoMigration struct { + // One of the following must be set: + Run func(context.Context, *sql.Tx) error + // -- OR -- + RunNoTx func(context.Context, *sql.DB) error +} + +// WithGoMigration registers a Go migration with the given version. +// +// If WithGoMigration is called multiple times with the same version, an error is returned. Both up +// and down functions may be nil. But if set, exactly one of Run or RunNoTx functions must be set. +func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { + return configFunc(func(c *config) error { + if version < 1 { + return errors.New("version must be greater than zero") + } + if _, ok := c.registered[version]; ok { + return fmt.Errorf("go migration with version %d already registered", version) + } + // Allow nil up/down functions. This enables users to apply "no-op" migrations, while + // versioning them. + if up != nil { + if up.Run == nil && up.RunNoTx == nil { + return fmt.Errorf("go migration with version %d must have an up function", version) + } + if up.Run != nil && up.RunNoTx != nil { + return fmt.Errorf("go migration with version %d must not have both an up and upNoTx function", version) + } + } + if down != nil { + if down.Run == nil && down.RunNoTx == nil { + return fmt.Errorf("go migration with version %d must have a down function", version) + } + if down.Run != nil && down.RunNoTx != nil { + return fmt.Errorf("go migration with version %d must not have both a down and downNoTx function", version) + } + } + c.registered[version] = &goMigration{ + up: up, + down: down, + } + return nil + }) +} + +// WithAllowMissing allows the provider to apply missing (out-of-order) migrations. +// +// Example: migrations 1,6 are applied and then version 2,3,5 are introduced. If this option is +// true, then goose will apply 2,3,5 instead of raising an error. The final order of applied +// migrations will be: 1,6,2,3,5. +func WithAllowMissing(b bool) ProviderOption { + return configFunc(func(c *config) error { + c.allowMissing = b + return nil + }) +} + +// WithNoVersioning disables versioning. Disabling versioning allows the ability to apply migrations +// without tracking the versions in the database schema table. Useful for tests, seeding a database +// or running ad-hoc queries. +func WithNoVersioning(b bool) ProviderOption { + return configFunc(func(c *config) error { + c.noVersioning = b + return nil + }) +} + +type config struct { + tableName string + verbose bool + excludes map[string]bool + + // Go migrations registered by the user. These will be merged/resolved with migrations from the + // filesystem and init() functions. + registered map[int64]*goMigration + + // Locking options + lockEnabled bool + sessionLocker lock.SessionLocker + + // Feature + noVersioning bool + allowMissing bool +} + +type configFunc func(*config) error + +func (f configFunc) apply(cfg *config) error { + return f(cfg) +} diff --git a/internal/provider/provider_options_test.go b/internal/provider/provider_options_test.go new file mode 100644 index 000000000..2271111ba --- /dev/null +++ b/internal/provider/provider_options_test.go @@ -0,0 +1,66 @@ +package provider_test + +import ( + "database/sql" + "path/filepath" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/provider" + _ "modernc.org/sqlite" +) + +func TestNewProvider(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + fsys := fstest.MapFS{ + "1_foo.sql": {Data: []byte(migration1)}, + "2_bar.sql": {Data: []byte(migration2)}, + "3_baz.sql": {Data: []byte(migration3)}, + "4_qux.sql": {Data: []byte(migration4)}, + } + t.Run("invalid", func(t *testing.T) { + // Empty dialect not allowed + _, err = provider.NewProvider("", db, fsys) + check.HasError(t, err) + // Invalid dialect not allowed + _, err = provider.NewProvider("unknown-dialect", db, fsys) + check.HasError(t, err) + // Nil db not allowed + _, err = provider.NewProvider("sqlite3", nil, fsys) + check.HasError(t, err) + // Nil fsys not allowed + _, err = provider.NewProvider("sqlite3", db, nil) + check.HasError(t, err) + // Duplicate table name not allowed + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithTableName("foo"), + provider.WithTableName("bar"), + ) + check.HasError(t, err) + check.Equal(t, `table already set to "foo"`, err.Error()) + // Empty table name not allowed + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithTableName(""), + ) + check.HasError(t, err) + check.Equal(t, "table must not be empty", err.Error()) + }) + t.Run("valid", func(t *testing.T) { + // Valid dialect, db, and fsys allowed + _, err = provider.NewProvider("sqlite3", db, fsys) + check.NoError(t, err) + // Valid dialect, db, fsys, and table name allowed + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithTableName("foo"), + ) + check.NoError(t, err) + // Valid dialect, db, fsys, and verbose allowed + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithVerbose(testing.Verbose()), + ) + check.NoError(t, err) + }) +} diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go new file mode 100644 index 000000000..ac4ec7e0e --- /dev/null +++ b/internal/provider/provider_test.go @@ -0,0 +1,138 @@ +package provider_test + +import ( + "context" + "database/sql" + "errors" + "io/fs" + "path/filepath" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/provider" + _ "modernc.org/sqlite" +) + +func TestProvider(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + t.Run("empty", func(t *testing.T) { + _, err := provider.NewProvider("sqlite3", db, fstest.MapFS{}) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrNoMigrations), true) + }) + + mapFS := fstest.MapFS{ + "migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)}, + "migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)}, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + p, err := provider.NewProvider("sqlite3", db, fsys) + check.NoError(t, err) + sources := p.ListSources() + check.Equal(t, len(sources), 2) + check.Equal(t, sources[0], provider.NewSource(provider.TypeSQL, "001_foo.sql", 1)) + check.Equal(t, sources[1], provider.NewSource(provider.TypeSQL, "002_bar.sql", 2)) + + t.Run("duplicate_go", func(t *testing.T) { + // Not parallel because it modifies global state. + register := []*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnContext: nil, + DownFnContext: nil, + }, + } + err := provider.SetGlobalGoMigrations(register) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + + db := newDB(t) + _, err = provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration(1, nil, nil), + ) + check.HasError(t, err) + check.Equal(t, err.Error(), "go migration with version 1 already registered") + }) + t.Run("empty_go", func(t *testing.T) { + db := newDB(t) + // explicit + _, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration(1, &provider.GoMigration{Run: nil}, &provider.GoMigration{Run: nil}), + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "go migration with version 1 must have an up function") + }) + t.Run("duplicate_up", func(t *testing.T) { + err := provider.SetGlobalGoMigrations([]*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnContext: func(context.Context, *sql.Tx) error { return nil }, + UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, + }, + }) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + db := newDB(t) + _, err = provider.NewProvider(provider.DialectSQLite3, db, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "registered migration with both UpFnContext and UpFnNoTxContext") + }) + t.Run("duplicate_down", func(t *testing.T) { + err := provider.SetGlobalGoMigrations([]*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + DownFnContext: func(context.Context, *sql.Tx) error { return nil }, + DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, + }, + }) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + db := newDB(t) + _, err = provider.NewProvider(provider.DialectSQLite3, db, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "registered migration with both DownFnContext and DownFnNoTxContext") + }) +} + +var ( + migration1 = ` +-- +goose Up +CREATE TABLE foo (id INTEGER PRIMARY KEY); +-- +goose Down +DROP TABLE foo; +` + migration2 = ` +-- +goose Up +ALTER TABLE foo ADD COLUMN name TEXT; +-- +goose Down +ALTER TABLE foo DROP COLUMN name; +` + migration3 = ` +-- +goose Up +CREATE TABLE bar ( + id INTEGER PRIMARY KEY, + description TEXT +); +-- +goose Down +DROP TABLE bar; +` + migration4 = ` +-- +goose Up +-- Rename the 'foo' table to 'my_foo' +ALTER TABLE foo RENAME TO my_foo; + +-- Add a new column 'timestamp' to 'my_foo' +ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP; + +-- +goose Down +-- Remove the 'timestamp' column from 'my_foo' +ALTER TABLE my_foo DROP COLUMN timestamp; + +-- Rename the 'my_foo' table back to 'foo' +ALTER TABLE my_foo RENAME TO foo; +` +) diff --git a/internal/provider/run.go b/internal/provider/run.go new file mode 100644 index 000000000..55bef9f32 --- /dev/null +++ b/internal/provider/run.go @@ -0,0 +1,375 @@ +package provider + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io/fs" + "sort" + "strings" + "time" + + "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" +) + +// runMigrations runs migrations sequentially in the given direction. +// +// If the migrations slice is empty, this function returns nil with no error. +func (p *Provider) runMigrations( + ctx context.Context, + conn *sql.Conn, + migrations []*migration, + direction sqlparser.Direction, + byOne bool, +) ([]*MigrationResult, error) { + if len(migrations) == 0 { + return nil, nil + } + var apply []*migration + if byOne { + apply = []*migration{migrations[0]} + } else { + apply = migrations + } + // Lazily parse SQL migrations (if any) in both directions. We do this before running any + // migrations so that we can fail fast if there are any errors and avoid leaving the database in + // a partially migrated state. + + if err := parseSQL(p.fsys, false, apply); err != nil { + return nil, err + } + + // TODO(mf): If we decide to add support for advisory locks at the transaction level, this may + // be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe + // to run in a transaction. + + // + // + // + + // bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but + // are locking the database with *sql.Conn. If the caller sets max open connections to 1, then + // this will deadlock because the Go migration will try to acquire a connection from the pool, + // but the pool is locked. + // + // A potential solution is to expose a third Go register function *sql.Conn. Or continue to use + // *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of + // an edge case. if p.opt.LockMode != LockModeNone && p.db.Stats().MaxOpenConnections == 1 { + // for _, m := range apply { + // if m.IsGo() && !m.Go.UseTx { + // return nil, errors.New("potential deadlock detected: cannot run GoMigrationNoTx with max open connections set to 1") + // } + // } + // } + + // Run migrations individually, opening a new transaction for each migration if the migration is + // safe to run in a transaction. + + // Avoid allocating a slice because we may have a partial migration error. 1. Avoid giving the + // impression that N migrations were applied when in fact some were not 2. Avoid the caller + // having to check for nil results + var results []*MigrationResult + for _, m := range apply { + current := &MigrationResult{ + Source: m.Source, + Direction: strings.ToLower(direction.String()), + // TODO(mf): empty set here + } + + start := time.Now() + if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil { + // TODO(mf): we should also return the pending migrations here. + current.Error = err + current.Duration = time.Since(start) + return nil, &PartialError{ + Applied: results, + Failed: current, + Err: err, + } + } + + current.Duration = time.Since(start) + results = append(results, current) + } + return results, nil +} + +// runIndividually runs an individual migration, opening a new transaction if the migration is safe +// to run in a transaction. Otherwise, it runs the migration outside of a transaction with the +// supplied connection. +func (p *Provider) runIndividually( + ctx context.Context, + conn *sql.Conn, + direction bool, + m *migration, +) error { + if m.useTx(direction) { + // Run the migration in a transaction. + return p.beginTx(ctx, conn, func(tx *sql.Tx) error { + if err := m.run(ctx, tx, direction); err != nil { + return err + } + if p.cfg.noVersioning { + return nil + } + return p.store.InsertOrDelete(ctx, tx, direction, m.Source.Version) + }) + } + // Run the migration outside of a transaction. + switch m.Source.Type { + case TypeGo: + // Note, we're using *sql.DB instead of *sql.Conn because it's the contract of the + // GoMigrationNoTx function. This may be a deadlock scenario if the caller sets max open + // connections to 1. See the comment in runMigrations for more details. + if err := m.runNoTx(ctx, p.db, direction); err != nil { + return err + } + case TypeSQL: + if err := m.runConn(ctx, conn, direction); err != nil { + return err + } + } + if p.cfg.noVersioning { + return nil + } + return p.store.InsertOrDelete(ctx, conn, direction, m.Source.Version) +} + +// beginTx begins a transaction and runs the given function. If the function returns an error, the +// transaction is rolled back. Otherwise, the transaction is committed. +// +// If the provider is configured to use versioning, this function also inserts or deletes the +// migration version. +func (p *Provider) beginTx( + ctx context.Context, + conn *sql.Conn, + fn func(tx *sql.Tx) error, +) (retErr error) { + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, tx.Rollback()) + } + }() + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) { + p.mu.Lock() + conn, err := p.db.Conn(ctx) + if err != nil { + p.mu.Unlock() + return nil, nil, err + } + // cleanup is a function that cleans up the connection, and optionally, the session lock. + cleanup := func() error { + p.mu.Unlock() + return conn.Close() + } + if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled { + if err := l.SessionLock(ctx, conn); err != nil { + return nil, nil, multierr.Append(err, cleanup()) + } + cleanup = func() error { + p.mu.Unlock() + // Use a detached context to unlock the session. This is because the context passed to + // SessionLock may have been canceled, and we don't want to cancel the unlock. + // TODO(mf): use [context.WithoutCancel] added in go1.21 + detachedCtx := context.Background() + return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close()) + } + } + // If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't + // need the version table because there is no versioning. + if !p.cfg.noVersioning { + if err := p.ensureVersionTable(ctx, conn); err != nil { + return nil, nil, multierr.Append(err, cleanup()) + } + } + return conn, cleanup, nil +} + +// parseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it +// will not be parsed again. +// +// Important: This function will mutate SQL migrations and is not safe for concurrent use. +func parseSQL(fsys fs.FS, debug bool, migrations []*migration) error { + for _, m := range migrations { + // If the migration is a SQL migration, and it has not been parsed, parse it. + if m.Source.Type == TypeSQL && m.SQL == nil { + parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug) + if err != nil { + return err + } + m.SQL = &sqlMigration{ + UseTx: parsed.UseTx, + UpStatements: parsed.Up, + DownStatements: parsed.Down, + } + } + } + return nil +} + +func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) { + // feat(mf): this is where we can check if the version table exists instead of trying to fetch + // from a table that may not exist. https://github.com/pressly/goose/issues/461 + res, err := p.store.GetMigration(ctx, conn, 0) + if err == nil && res != nil { + return nil + } + return p.beginTx(ctx, conn, func(tx *sql.Tx) error { + if err := p.store.CreateVersionTable(ctx, tx); err != nil { + return err + } + if p.cfg.noVersioning { + return nil + } + return p.store.InsertOrDelete(ctx, tx, true, 0) + }) +} + +type missingMigration struct { + versionID int64 + filename string +} + +// findMissingMigrations returns a list of migrations that are missing from the database. A missing +// migration is one that has a version less than the max version in the database. +func findMissingMigrations( + dbMigrations []*sqladapter.ListMigrationsResult, + fsMigrations []*migration, + dbMaxVersion int64, +) []missingMigration { + existing := make(map[int64]bool) + for _, m := range dbMigrations { + existing[m.Version] = true + } + var missing []missingMigration + for _, m := range fsMigrations { + version := m.Source.Version + if !existing[version] && version < dbMaxVersion { + missing = append(missing, missingMigration{ + versionID: version, + filename: m.filename(), + }) + } + } + sort.Slice(missing, func(i, j int) bool { + return missing[i].versionID < missing[j].versionID + }) + return missing +} + +// getMigration returns the migration with the given version. If no migration is found, then +// ErrVersionNotFound is returned. +func (p *Provider) getMigration(version int64) (*migration, error) { + for _, m := range p.migrations { + if m.Source.Version == version { + return m, nil + } + } + return nil, ErrVersionNotFound +} + +func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) { + if version < 1 { + return nil, errors.New("version must be greater than zero") + } + + m, err := p.getMigration(version) + if err != nil { + return nil, err + } + + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + result, err := p.store.GetMigration(ctx, conn, version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + // If the migration has already been applied, return an error, unless the migration is being + // applied in the opposite direction. In that case, we allow the migration to be applied again. + if result != nil && direction { + return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) + } + + d := sqlparser.DirectionDown + if direction { + d = sqlparser.DirectionUp + } + results, err := p.runMigrations(ctx, conn, []*migration{m}, d, true) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) + } + return results[0], nil +} + +func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + // TODO(mf): add support for limit and order. Also would be nice to refactor the list query to + // support limiting the set. + + status := make([]*MigrationStatus, 0, len(p.migrations)) + for _, m := range p.migrations { + migrationStatus := &MigrationStatus{ + Source: m.Source, + State: StatePending, + } + dbResult, err := p.store.GetMigration(ctx, conn, m.Source.Version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + if dbResult != nil { + migrationStatus.State = StateApplied + migrationStatus.AppliedAt = dbResult.Timestamp + } + status = append(status, migrationStatus) + } + + return status, nil +} + +func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return 0, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + res, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return 0, err + } + if len(res) == 0 { + return 0, nil + } + return res[0].Version, nil +} diff --git a/internal/provider/run_down.go b/internal/provider/run_down.go new file mode 100644 index 000000000..011ba7990 --- /dev/null +++ b/internal/provider/run_down.go @@ -0,0 +1,53 @@ +package provider + +import ( + "context" + + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" +) + +func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + if len(p.migrations) == 0 { + return nil, nil + } + + if p.cfg.noVersioning { + var downMigrations []*migration + if downByOne { + downMigrations = append(downMigrations, p.migrations[len(p.migrations)-1]) + } else { + downMigrations = p.migrations + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) + } + + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + if dbMigrations[0].Version == 0 { + return nil, nil + } + + var downMigrations []*migration + for _, dbMigration := range dbMigrations { + if dbMigration.Version <= version { + break + } + m, err := p.getMigration(dbMigration.Version) + if err != nil { + return nil, err + } + downMigrations = append(downMigrations, m) + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) +} diff --git a/internal/provider/run_test.go b/internal/provider/run_test.go new file mode 100644 index 000000000..97e86ed21 --- /dev/null +++ b/internal/provider/run_test.go @@ -0,0 +1,1282 @@ +package provider_test + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io/fs" + "math" + "math/rand" + "os" + "path/filepath" + "reflect" + "sort" + "sync" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/provider" + "github.com/pressly/goose/v3/internal/testdb" + "github.com/pressly/goose/v3/lock" + "golang.org/x/sync/errgroup" +) + +func TestProviderRun(t *testing.T) { + t.Parallel() + + t.Run("closed_db", func(t *testing.T) { + p, db := newProviderWithDB(t) + check.NoError(t, db.Close()) + _, err := p.Up(context.Background()) + check.HasError(t, err) + check.Equal(t, err.Error(), "sql: database is closed") + }) + t.Run("ping_and_close", func(t *testing.T) { + p, _ := newProviderWithDB(t) + t.Cleanup(func() { + check.NoError(t, p.Close()) + }) + check.NoError(t, p.Ping(context.Background())) + }) + t.Run("apply_unknown_version", func(t *testing.T) { + p, _ := newProviderWithDB(t) + _, err := p.ApplyVersion(context.Background(), 999, true) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true) + _, err = p.ApplyVersion(context.Background(), 999, false) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true) + }) + t.Run("run_zero", func(t *testing.T) { + p, _ := newProviderWithDB(t) + _, err := p.UpTo(context.Background(), 0) + check.HasError(t, err) + check.Equal(t, err.Error(), "version must be greater than zero") + _, err = p.DownTo(context.Background(), -1) + check.HasError(t, err) + check.Equal(t, err.Error(), "version must be a number greater than or equal zero: -1") + _, err = p.ApplyVersion(context.Background(), 0, true) + check.HasError(t, err) + check.Equal(t, err.Error(), "version must be greater than zero") + }) + t.Run("up_and_down_all", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + const ( + numCount = 7 + ) + sources := p.ListSources() + check.Number(t, len(sources), numCount) + // Ensure only SQL migrations are returned + for _, s := range sources { + check.Equal(t, s.Type, provider.TypeSQL) + } + // Test Up + res, err := p.Up(ctx) + check.NoError(t, err) + check.Number(t, len(res), numCount) + assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") + assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up") + assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up") + assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up") + assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up") + assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up") + assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up") + // Test Down + res, err = p.DownTo(ctx, 0) + check.NoError(t, err) + check.Number(t, len(res), numCount) + assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down") + assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down") + assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down") + assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down") + assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down") + assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down") + assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down") + }) + t.Run("up_and_down_by_one", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + maxVersion := len(p.ListSources()) + // Apply all migrations one-by-one. + var counter int + for { + res, err := p.UpByOne(ctx) + counter++ + if counter > maxVersion { + if !errors.Is(err, provider.ErrNoNextVersion) { + t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion) + } + break + } + check.NoError(t, err) + check.Number(t, len(res), 1) + check.Number(t, res[0].Source.Version, int64(counter)) + } + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, int64(maxVersion)) + // Reset counter + counter = 0 + // Rollback all migrations one-by-one. + for { + res, err := p.Down(ctx) + counter++ + if counter > maxVersion { + if !errors.Is(err, provider.ErrNoNextVersion) { + t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion) + } + break + } + check.NoError(t, err) + check.Number(t, len(res), 1) + check.Number(t, res[0].Source.Version, int64(maxVersion-counter+1)) + } + // Once everything is tested the version should match the highest testdata version + currentVersion, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + }) + t.Run("up_to", func(t *testing.T) { + ctx := context.Background() + p, db := newProviderWithDB(t) + const ( + upToVersion int64 = 2 + ) + results, err := p.UpTo(ctx, upToVersion) + check.NoError(t, err) + check.Number(t, len(results), upToVersion) + assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") + assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up") + // Fetch the goose version from DB + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, upToVersion) + // Validate the version actually matches what goose claims it is + gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, gotVersion, upToVersion) + }) + t.Run("sql_connections", func(t *testing.T) { + tt := []struct { + name string + maxOpenConns int + maxIdleConns int + useDefaults bool + }{ + // Single connection ensures goose is able to function correctly when multiple + // connections are not available. + {name: "single_conn", maxOpenConns: 1, maxIdleConns: 1}, + {name: "defaults", useDefaults: true}, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + // Start a new database for each test case. + p, db := newProviderWithDB(t) + if !tc.useDefaults { + db.SetMaxOpenConns(tc.maxOpenConns) + db.SetMaxIdleConns(tc.maxIdleConns) + } + sources := p.ListSources() + check.NumberNotZero(t, len(sources)) + + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + + { + // Apply all up migrations + upResult, err := p.Up(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), len(sources)) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, p.ListSources()[len(sources)-1].Version) + // Validate the db migration version actually matches what goose claims it is + gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, gotVersion, currentVersion) + tables, err := getTableNames(db) + check.NoError(t, err) + if !reflect.DeepEqual(tables, knownTables) { + t.Logf("got tables: %v", tables) + t.Logf("known tables: %v", knownTables) + t.Fatal("failed to match tables") + } + } + { + // Apply all down migrations + downResult, err := p.DownTo(ctx, 0) + check.NoError(t, err) + check.Number(t, len(downResult), len(sources)) + gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, gotVersion, 0) + // Should only be left with a single table, the default goose table + tables, err := getTableNames(db) + check.NoError(t, err) + knownTables := []string{provider.DefaultTablename, "sqlite_sequence"} + if !reflect.DeepEqual(tables, knownTables) { + t.Logf("got tables: %v", tables) + t.Logf("known tables: %v", knownTables) + t.Fatal("failed to match tables") + } + } + }) + } + }) + t.Run("apply", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + sources := p.ListSources() + // Apply all migrations in the up direction. + for _, s := range sources { + res, err := p.ApplyVersion(ctx, s.Version, true) + check.NoError(t, err) + // Round-trip the migration result through the database to ensure it's valid. + assertResult(t, res, s, "up") + } + // Apply all migrations in the down direction. + for i := len(sources) - 1; i >= 0; i-- { + s := sources[i] + res, err := p.ApplyVersion(ctx, s.Version, false) + check.NoError(t, err) + // Round-trip the migration result through the database to ensure it's valid. + assertResult(t, res, s, "down") + } + // Try apply version 1 multiple times + _, err := p.ApplyVersion(ctx, 1, true) + check.NoError(t, err) + _, err = p.ApplyVersion(ctx, 1, true) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrAlreadyApplied), true) + check.Contains(t, err.Error(), "version 1: already applied") + }) + t.Run("status", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + numCount := len(p.ListSources()) + // Before any migrations are applied, the status should be empty. + status, err := p.Status(ctx) + check.NoError(t, err) + check.Number(t, len(status), numCount) + assertStatus(t, status[0], provider.StatePending, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), true) + assertStatus(t, status[1], provider.StatePending, provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), true) + assertStatus(t, status[2], provider.StatePending, provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), true) + assertStatus(t, status[3], provider.StatePending, provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), true) + assertStatus(t, status[4], provider.StatePending, provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), true) + assertStatus(t, status[5], provider.StatePending, provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), true) + assertStatus(t, status[6], provider.StatePending, provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), true) + // Apply all migrations + _, err = p.Up(ctx) + check.NoError(t, err) + status, err = p.Status(ctx) + check.NoError(t, err) + check.Number(t, len(status), numCount) + assertStatus(t, status[0], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), false) + assertStatus(t, status[1], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), false) + assertStatus(t, status[2], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), false) + assertStatus(t, status[3], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), false) + assertStatus(t, status[4], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), false) + assertStatus(t, status[5], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), false) + assertStatus(t, status[6], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), false) + }) + t.Run("tx_partial_errors", func(t *testing.T) { + countOwners := func(db *sql.DB) (int, error) { + q := `SELECT count(*)FROM owners` + var count int + if err := db.QueryRow(q).Scan(&count); err != nil { + return 0, err + } + return count, nil + } + + ctx := context.Background() + db := newDB(t) + mapFS := fstest.MapFS{ + "00001_users_table.sql": newMapFile(` +-- +goose Up +CREATE TABLE owners ( owner_name TEXT NOT NULL ); +`), + "00002_partial_error.sql": newMapFile(` +-- +goose Up +INSERT INTO invalid_table (invalid_table) VALUES ('invalid_value'); +`), + "00003_insert_data.sql": newMapFile(` +-- +goose Up +INSERT INTO owners (owner_name) VALUES ('seed-user-1'); +INSERT INTO owners (owner_name) VALUES ('seed-user-2'); +INSERT INTO owners (owner_name) VALUES ('seed-user-3'); +`), + } + p, err := provider.NewProvider(provider.DialectSQLite3, db, mapFS) + check.NoError(t, err) + _, err = p.Up(ctx) + check.HasError(t, err) + check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)") + var expected *provider.PartialError + check.Bool(t, errors.As(err, &expected), true) + // Check Err field + check.Bool(t, expected.Err != nil, true) + check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)") + // Check Results field + check.Number(t, len(expected.Applied), 1) + assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") + // Check Failed field + check.Bool(t, expected.Failed != nil, true) + assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2) + check.Bool(t, expected.Failed.Empty, false) + check.Bool(t, expected.Failed.Error != nil, true) + check.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)") + check.Equal(t, expected.Failed.Direction, "up") + check.Bool(t, expected.Failed.Duration > 0, true) + + // Ensure the partial error did not affect the database. + count, err := countOwners(db) + check.NoError(t, err) + check.Number(t, count, 0) + + status, err := p.Status(ctx) + check.NoError(t, err) + check.Number(t, len(status), 3) + assertStatus(t, status[0], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), false) + assertStatus(t, status[1], provider.StatePending, provider.NewSource(provider.TypeSQL, "00002_partial_error.sql", 2), true) + assertStatus(t, status[2], provider.StatePending, provider.NewSource(provider.TypeSQL, "00003_insert_data.sql", 3), true) + }) +} + +func TestConcurrentProvider(t *testing.T) { + t.Parallel() + + t.Run("up", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + maxVersion := len(p.ListSources()) + + ch := make(chan int64) + var wg sync.WaitGroup + for i := 0; i < maxVersion; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + res, err := p.UpByOne(ctx) + if err != nil { + t.Error(err) + return + } + if len(res) != 1 { + t.Errorf("expected 1 result, got %d", len(res)) + return + } + ch <- res[0].Source.Version + }() + } + go func() { + wg.Wait() + close(ch) + }() + var versions []int64 + for version := range ch { + versions = append(versions, version) + } + // Fail early if any of the goroutines failed. + if t.Failed() { + return + } + check.Number(t, len(versions), maxVersion) + for i := 0; i < maxVersion; i++ { + check.Number(t, versions[i], int64(i+1)) + } + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + }) + t.Run("down", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + maxVersion := len(p.ListSources()) + // Apply all migrations + _, err := p.Up(ctx) + check.NoError(t, err) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + + ch := make(chan []*provider.MigrationResult) + var wg sync.WaitGroup + for i := 0; i < maxVersion; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + res, err := p.DownTo(ctx, 0) + if err != nil { + t.Error(err) + return + } + ch <- res + }() + } + go func() { + wg.Wait() + close(ch) + }() + var ( + valid [][]*provider.MigrationResult + empty [][]*provider.MigrationResult + ) + for results := range ch { + if len(results) == 0 { + empty = append(empty, results) + continue + } + valid = append(valid, results) + } + // Fail early if any of the goroutines failed. + if t.Failed() { + return + } + check.Equal(t, len(valid), 1) + check.Equal(t, len(empty), maxVersion-1) + // Ensure the valid result is correct. + check.Number(t, len(valid[0]), maxVersion) + }) +} + +func TestNoVersioning(t *testing.T) { + t.Parallel() + + countSeedOwners := func(db *sql.DB) (int, error) { + q := `SELECT count(*)FROM owners WHERE owner_name LIKE'seed-user-%'` + var count int + if err := db.QueryRow(q).Scan(&count); err != nil { + return 0, err + } + return count, nil + } + countOwners := func(db *sql.DB) (int, error) { + q := `SELECT count(*)FROM owners` + var count int + if err := db.QueryRow(q).Scan(&count); err != nil { + return 0, err + } + return count, nil + } + ctx := context.Background() + dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8)) + db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName)) + check.NoError(t, err) + fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "migrations")) + const ( + // Total owners created by the seed files. + wantSeedOwnerCount = 250 + // These are owners created by migration files. + wantOwnerCount = 4 + ) + p, err := provider.NewProvider(provider.DialectSQLite3, db, fsys, + provider.WithVerbose(testing.Verbose()), + provider.WithNoVersioning(false), // This is the default. + ) + check.Number(t, len(p.ListSources()), 3) + check.NoError(t, err) + _, err = p.Up(ctx) + check.NoError(t, err) + baseVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, baseVersion, 3) + t.Run("seed-up-down-to-zero", func(t *testing.T) { + fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed")) + p, err := provider.NewProvider(provider.DialectSQLite3, db, fsys, + provider.WithVerbose(testing.Verbose()), + provider.WithNoVersioning(true), // Provider with no versioning. + ) + check.NoError(t, err) + check.Number(t, len(p.ListSources()), 2) + + // Run (all) up migrations from the seed dir + { + upResult, err := p.Up(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), 2) + // Confirm no changes to the versioned schema in the DB + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, baseVersion, currentVersion) + seedOwnerCount, err := countSeedOwners(db) + check.NoError(t, err) + check.Number(t, seedOwnerCount, wantSeedOwnerCount) + } + // Run (all) down migrations from the seed dir + { + downResult, err := p.DownTo(ctx, 0) + check.NoError(t, err) + check.Number(t, len(downResult), 2) + // Confirm no changes to the versioned schema in the DB + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, baseVersion, currentVersion) + seedOwnerCount, err := countSeedOwners(db) + check.NoError(t, err) + check.Number(t, seedOwnerCount, 0) + } + // The migrations added 4 non-seed owners, they must remain in the database afterwards + ownerCount, err := countOwners(db) + check.NoError(t, err) + check.Number(t, ownerCount, wantOwnerCount) + }) +} + +func TestAllowMissing(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Developer A and B check out the "main" branch which is currently on version 3. Developer A + // mistakenly creates migration 5 and commits. Developer B did not pull the latest changes and + // commits migration 4. Oops -- now the migrations are out of order. + // + // When goose is set to allow missing migrations, then 5 is applied after 4 with no error. + // Otherwise it's expected to be an error. + + t.Run("missing_now_allowed", func(t *testing.T) { + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), + provider.WithAllowMissing(false), + ) + check.NoError(t, err) + + // Create and apply first 3 migrations. + _, err = p.UpTo(ctx, 3) + check.NoError(t, err) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 3) + + // Developer A - migration 5 (mistakenly applied) + result, err := p.ApplyVersion(ctx, 5, true) + check.NoError(t, err) + check.Number(t, result.Source.Version, 5) + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + + // The database has migrations 1,2,3,5 applied. + + // Developer B is on version 3 (e.g., never pulled the latest changes). Adds migration 4. By + // default goose does not allow missing (out-of-order) migrations, which means halt if a + // missing migration is detected. + _, err = p.Up(ctx) + check.HasError(t, err) + // found 1 missing (out-of-order) migration: [00004_insert_data.sql] + check.Contains(t, err.Error(), "missing (out-of-order) migration") + // Confirm db version is unchanged. + current, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + + _, err = p.UpByOne(ctx) + check.HasError(t, err) + // found 1 missing (out-of-order) migration: [00004_insert_data.sql] + check.Contains(t, err.Error(), "missing (out-of-order) migration") + // Confirm db version is unchanged. + current, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + + _, err = p.UpTo(ctx, math.MaxInt64) + check.HasError(t, err) + // found 1 missing (out-of-order) migration: [00004_insert_data.sql] + check.Contains(t, err.Error(), "missing (out-of-order) migration") + // Confirm db version is unchanged. + current, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + }) + + t.Run("missing_allowed", func(t *testing.T) { + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), + provider.WithAllowMissing(true), + ) + check.NoError(t, err) + + // Create and apply first 3 migrations. + _, err = p.UpTo(ctx, 3) + check.NoError(t, err) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 3) + + // Developer A - migration 5 (mistakenly applied) + { + _, err = p.ApplyVersion(ctx, 5, true) + check.NoError(t, err) + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + } + // Developer B - migration 4 (missing) and 6 (new) + { + // 4 + upResult, err := p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), 1) + check.Number(t, upResult[0].Source.Version, 4) + // 6 + upResult, err = p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), 1) + check.Number(t, upResult[0].Source.Version, 6) + + count, err := getGooseVersionCount(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, count, 6) + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + // Expecting max(version_id) to be 8 + check.Number(t, current, 6) + } + + // The applied order in the database is expected to be: + // 1,2,3,5,4,6 + // So migrating down should be the reverse of the applied order: + // 6,4,5,3,2,1 + + expected := []int64{6, 4, 5, 3, 2, 1} + for i, v := range expected { + // TODO(mf): this is returning it by the order it was applied. + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, v) + downResult, err := p.Down(ctx) + if i == len(expected)-1 { + check.HasError(t, provider.ErrVersionNotFound) + } else { + check.NoError(t, err) + check.Number(t, len(downResult), 1) + check.Number(t, downResult[0].Source.Version, v) + } + } + }) +} + +func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) { + var gotVersion int64 + if err := db.QueryRow( + fmt.Sprintf("SELECT count(*) FROM %s WHERE version_id > 0", gooseTable), + ).Scan(&gotVersion); err != nil { + return 0, err + } + return gotVersion, nil +} + +func TestGoOnly(t *testing.T) { + // Not parallel because it modifies global state. + + countUser := func(db *sql.DB) int { + q := `SELECT count(*)FROM users` + var count int + err := db.QueryRow(q).Scan(&count) + check.NoError(t, err) + return count + } + + t.Run("with_tx", func(t *testing.T) { + ctx := context.Background() + register := []*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), + DownFnContext: newTxFn("DROP TABLE users"), + }, + } + err := provider.SetGlobalGoMigrations(register) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration( + 2, + &provider.GoMigration{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, + &provider.GoMigration{Run: newTxFn("DELETE FROM users")}, + ), + ) + check.NoError(t, err) + sources := p.ListSources() + check.Number(t, len(p.ListSources()), 2) + assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1) + assertSource(t, sources[1], provider.TypeGo, "", 2) + // Apply migration 1 + res, err := p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up") + check.Number(t, countUser(db), 0) + check.Bool(t, tableExists(t, db, "users"), true) + // Apply migration 2 + res, err = p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up") + check.Number(t, countUser(db), 3) + // Rollback migration 2 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down") + check.Number(t, countUser(db), 0) + // Rollback migration 1 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down") + // Check table does not exist + check.Bool(t, tableExists(t, db, "users"), false) + }) + t.Run("with_db", func(t *testing.T) { + ctx := context.Background() + register := []*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), + DownFnNoTxContext: newDBFn("DROP TABLE users"), + }, + } + err := provider.SetGlobalGoMigrations(register) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration( + 2, + &provider.GoMigration{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, + &provider.GoMigration{RunNoTx: newDBFn("DELETE FROM users")}, + ), + ) + check.NoError(t, err) + sources := p.ListSources() + check.Number(t, len(p.ListSources()), 2) + assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1) + assertSource(t, sources[1], provider.TypeGo, "", 2) + // Apply migration 1 + res, err := p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up") + check.Number(t, countUser(db), 0) + check.Bool(t, tableExists(t, db, "users"), true) + // Apply migration 2 + res, err = p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up") + check.Number(t, countUser(db), 3) + // Rollback migration 2 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down") + check.Number(t, countUser(db), 0) + // Rollback migration 1 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down") + // Check table does not exist + check.Bool(t, tableExists(t, db, "users"), false) + }) +} + +func TestLockModeAdvisorySession(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("skip long running test") + } + + // The migrations are written in such a way that they cannot be applied concurrently, they will + // fail 99.9999% of the time. This test ensures that the advisory session lock mode works as + // expected. + + // TODO(mf): small improvement here is to use the SAME postgres instance but different databases + // created from a template. This will speed up the test. + + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + + newProvider := func() *provider.Provider { + sessionLocker, err := lock.NewPostgresSessionLocker() + check.NoError(t, err) + p, err := provider.NewProvider(provider.DialectPostgres, db, os.DirFS("../../testdata/migrations"), + provider.WithSessionLocker(sessionLocker), // Use advisory session lock mode. + provider.WithVerbose(testing.Verbose()), + ) + check.NoError(t, err) + return p + } + provider1 := newProvider() + provider2 := newProvider() + + sources := provider1.ListSources() + maxVersion := sources[len(sources)-1].Version + + // Since the lock mode is advisory session, only one of these providers is expected to apply ALL + // the migrations. The other provider should apply NO migrations. The test MUST fail if both + // providers apply migrations. + + t.Run("up", func(t *testing.T) { + var g errgroup.Group + var res1, res2 int + g.Go(func() error { + ctx := context.Background() + results, err := provider1.Up(ctx) + check.NoError(t, err) + res1 = len(results) + currentVersion, err := provider1.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + return nil + }) + g.Go(func() error { + ctx := context.Background() + results, err := provider2.Up(ctx) + check.NoError(t, err) + res2 = len(results) + currentVersion, err := provider2.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + return nil + }) + check.NoError(t, g.Wait()) + // One of the providers should have applied all migrations and the other should have applied + // no migrations, but with no error. + if res1 == 0 && res2 == 0 { + t.Fatal("both providers applied no migrations") + } + if res1 > 0 && res2 > 0 { + t.Fatal("both providers applied migrations") + } + }) + + // Reset the database and run the same test with the advisory lock mode, but apply migrations + // one-by-one. + { + _, err := provider1.DownTo(context.Background(), 0) + check.NoError(t, err) + currentVersion, err := provider1.GetDBVersion(context.Background()) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + } + t.Run("up_by_one", func(t *testing.T) { + var g errgroup.Group + var ( + mu sync.Mutex + applied []int64 + ) + g.Go(func() error { + for { + results, err := provider1.UpByOne(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + check.NoError(t, err) + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + g.Go(func() error { + for { + results, err := provider2.UpByOne(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + check.NoError(t, err) + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + check.NoError(t, g.Wait()) + check.Number(t, len(applied), len(sources)) + sort.Slice(applied, func(i, j int) bool { + return applied[i] < applied[j] + }) + // Each migration should have been applied up exactly once. + for i := 0; i < len(sources); i++ { + check.Number(t, applied[i], sources[i].Version) + } + }) + + // Restore the database state by applying all migrations and run the same test with the advisory + // lock mode, but apply down migrations in parallel. + { + _, err := provider1.Up(context.Background()) + check.NoError(t, err) + currentVersion, err := provider1.GetDBVersion(context.Background()) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + } + + t.Run("down_to", func(t *testing.T) { + var g errgroup.Group + var res1, res2 int + g.Go(func() error { + ctx := context.Background() + results, err := provider1.DownTo(ctx, 0) + check.NoError(t, err) + res1 = len(results) + currentVersion, err := provider1.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + return nil + }) + g.Go(func() error { + ctx := context.Background() + results, err := provider2.DownTo(ctx, 0) + check.NoError(t, err) + res2 = len(results) + currentVersion, err := provider2.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + return nil + }) + check.NoError(t, g.Wait()) + + if res1 == 0 && res2 == 0 { + t.Fatal("both providers applied no migrations") + } + if res1 > 0 && res2 > 0 { + t.Fatal("both providers applied migrations") + } + }) + + // Restore the database state by applying all migrations and run the same test with the advisory + // lock mode, but apply down migrations one-by-one. + { + _, err := provider1.Up(context.Background()) + check.NoError(t, err) + currentVersion, err := provider1.GetDBVersion(context.Background()) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + } + + t.Run("down_by_one", func(t *testing.T) { + var g errgroup.Group + var ( + mu sync.Mutex + applied []int64 + ) + g.Go(func() error { + for { + results, err := provider1.Down(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + check.NoError(t, err) + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + g.Go(func() error { + for { + results, err := provider2.Down(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + check.NoError(t, err) + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + check.NoError(t, g.Wait()) + check.Number(t, len(applied), len(sources)) + sort.Slice(applied, func(i, j int) bool { + return applied[i] < applied[j] + }) + // Each migration should have been applied down exactly once. Since this is sequential the + // applied down migrations should be in reverse order. + for i := len(sources) - 1; i >= 0; i-- { + check.Number(t, applied[i], sources[i].Version) + } + }) +} + +func newDBFn(query string) func(context.Context, *sql.DB) error { + return func(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, query) + return err + } +} + +func newTxFn(query string) func(context.Context, *sql.Tx) error { + return func(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, query) + return err + } +} + +func tableExists(t *testing.T, db *sql.DB, table string) bool { + q := fmt.Sprintf(`SELECT CASE WHEN COUNT(*) > 0 THEN 1 ELSE 0 END AS table_exists FROM sqlite_master WHERE type = 'table' AND name = '%s'`, table) + var b string + err := db.QueryRow(q).Scan(&b) + check.NoError(t, err) + return b == "1" +} + +const ( + charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +) + +func randomAlphaNumeric(length int) string { + b := make([]byte, length) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] + } + return string(b) +} + +func newProviderWithDB(t *testing.T, opts ...provider.ProviderOption) (*provider.Provider, *sql.DB) { + t.Helper() + db := newDB(t) + opts = append( + opts, + provider.WithVerbose(testing.Verbose()), + ) + p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), opts...) + check.NoError(t, err) + return p, db +} + +func newDB(t *testing.T) *sql.DB { + t.Helper() + dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8)) + db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName)) + check.NoError(t, err) + return db +} + +func getMaxVersionID(db *sql.DB, gooseTable string) (int64, error) { + var gotVersion int64 + if err := db.QueryRow( + fmt.Sprintf("select max(version_id) from %s", gooseTable), + ).Scan(&gotVersion); err != nil { + return 0, err + } + return gotVersion, nil +} + +func getTableNames(db *sql.DB) ([]string, error) { + rows, err := db.Query(`SELECT name FROM sqlite_master WHERE type='table' ORDER BY name`) + if err != nil { + return nil, err + } + defer rows.Close() + var tables []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + tables = append(tables, name) + } + if err := rows.Err(); err != nil { + return nil, err + } + return tables, nil +} + +func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.State, source provider.Source, appliedIsZero bool) { + t.Helper() + check.Equal(t, got.State, state) + check.Equal(t, got.Source, source) + check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero) +} + +func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string) { + t.Helper() + check.Equal(t, got.Source, source) + check.Equal(t, got.Direction, direction) + check.Equal(t, got.Empty, false) + check.Bool(t, got.Error == nil, true) + check.Bool(t, got.Duration > 0, true) +} + +func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) { + t.Helper() + check.Equal(t, got.Type, typ) + check.Equal(t, got.Fullpath, name) + check.Equal(t, got.Version, version) + switch got.Type { + case provider.TypeGo: + check.Equal(t, got.Type.String(), "go") + case provider.TypeSQL: + check.Equal(t, got.Type.String(), "sql") + } +} + +func newMapFile(data string) *fstest.MapFile { + return &fstest.MapFile{ + Data: []byte(data), + } +} + +func newFsys() fs.FS { + return fstest.MapFS{ + "00001_users_table.sql": newMapFile(runMigration1), + "00002_posts_table.sql": newMapFile(runMigration2), + "00003_comments_table.sql": newMapFile(runMigration3), + "00004_insert_data.sql": newMapFile(runMigration4), + "00005_posts_view.sql": newMapFile(runMigration5), + "00006_empty_up.sql": newMapFile(runMigration6), + "00007_empty_up_down.sql": newMapFile(runMigration7), + } +} + +var ( + + // known tables are the tables (including goose table) created by running all migration files. + // If you add a table, make sure to add to this list and keep it in order. + knownTables = []string{ + "comments", + "goose_db_version", + "posts", + "sqlite_sequence", + "users", + } + + runMigration1 = ` +-- +goose Up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL, + email TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- +goose Down +DROP TABLE users; +` + + runMigration2 = ` +-- +goose Up +-- +goose StatementBegin +CREATE TABLE posts ( + id INTEGER PRIMARY KEY, + title TEXT NOT NULL, + content TEXT NOT NULL, + author_id INTEGER NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (author_id) REFERENCES users(id) +); +-- +goose StatementEnd +SELECT 1; +SELECT 2; + +-- +goose Down +DROP TABLE posts; +` + + runMigration3 = ` +-- +goose Up +CREATE TABLE comments ( + id INTEGER PRIMARY KEY, + post_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (post_id) REFERENCES posts(id), + FOREIGN KEY (user_id) REFERENCES users(id) +); + +-- +goose Down +DROP TABLE comments; +SELECT 1; +SELECT 2; +SELECT 3; +` + + runMigration4 = ` +-- +goose Up +INSERT INTO users (id, username, email) +VALUES + (1, 'john_doe', 'john@example.com'), + (2, 'jane_smith', 'jane@example.com'), + (3, 'alice_wonderland', 'alice@example.com'); + +INSERT INTO posts (id, title, content, author_id) +VALUES + (1, 'Introduction to SQL', 'SQL is a powerful language for managing databases...', 1), + (2, 'Data Modeling Techniques', 'Choosing the right data model is crucial...', 2), + (3, 'Advanced Query Optimization', 'Optimizing queries can greatly improve...', 1); + +INSERT INTO comments (id, post_id, user_id, content) +VALUES + (1, 1, 3, 'Great introduction! Looking forward to more.'), + (2, 1, 2, 'SQL can be a bit tricky at first, but practice helps.'), + (3, 2, 1, 'You covered normalization really well in this post.'); + +-- +goose Down +DELETE FROM comments; +DELETE FROM posts; +DELETE FROM users; +` + + runMigration5 = ` +-- +goose NO TRANSACTION + +-- +goose Up +CREATE VIEW posts_view AS + SELECT + p.id, + p.title, + p.content, + p.created_at, + u.username AS author + FROM posts p + JOIN users u ON p.author_id = u.id; + +-- +goose Down +DROP VIEW posts_view; +` + + runMigration6 = ` +-- +goose Up +` + + runMigration7 = ` +-- +goose Up +-- +goose Down +` +) diff --git a/internal/provider/run_up.go b/internal/provider/run_up.go new file mode 100644 index 000000000..7ee9c6c4f --- /dev/null +++ b/internal/provider/run_up.go @@ -0,0 +1,96 @@ +package provider + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" +) + +func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) { + if version < 1 { + return nil, errors.New("version must be greater than zero") + } + + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + if len(p.migrations) == 0 { + return nil, nil + } + if p.cfg.noVersioning { + // Short circuit if versioning is disabled and apply all migrations. + return p.runMigrations(ctx, conn, p.migrations, sqlparser.DirectionUp, upByOne) + } + + // optimize(mf): Listing all migrations from the database isn't great. This is only required to + // support the out-of-order (allow missing) feature. For users who don't use this feature, we + // could just query the database for the current version and then apply migrations that are + // greater than that version. + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + dbMaxVersion := dbMigrations[0].Version + // lookupAppliedInDB is a map of all applied migrations in the database. + lookupAppliedInDB := make(map[int64]bool) + for _, m := range dbMigrations { + lookupAppliedInDB[m.Version] = true + } + + missingMigrations := findMissingMigrations(dbMigrations, p.migrations, dbMaxVersion) + + // feature(mf): It is very possible someone may want to apply ONLY new migrations and skip + // missing migrations entirely. At the moment this is not supported, but leaving this comment + // because that's where that logic will be handled. + if len(missingMigrations) > 0 && !p.cfg.allowMissing { + var collected []string + for _, v := range missingMigrations { + collected = append(collected, v.filename) + } + msg := "migration" + if len(collected) > 1 { + msg += "s" + } + return nil, fmt.Errorf("found %d missing (out-of-order) %s: [%s]", + len(missingMigrations), msg, strings.Join(collected, ",")) + } + + var migrationsToApply []*migration + if p.cfg.allowMissing { + for _, v := range missingMigrations { + m, err := p.getMigration(v.versionID) + if err != nil { + return nil, err + } + migrationsToApply = append(migrationsToApply, m) + } + } + // filter all migrations with a version greater than the supplied version (min) and less than or + // equal to the requested version (max). + for _, m := range p.migrations { + if lookupAppliedInDB[m.Source.Version] { + continue + } + if m.Source.Version > dbMaxVersion && m.Source.Version <= version { + migrationsToApply = append(migrationsToApply, m) + } + } + + // feat(mf): this is where can (optionally) group multiple migrations to be run in a single + // transaction. The default is to apply each migration sequentially on its own. + // https://github.com/pressly/goose/issues/222 + // + // Note, we can't use a single transaction for all migrations because some may have to be run in + // their own transaction. + + return p.runMigrations(ctx, conn, migrationsToApply, sqlparser.DirectionUp, upByOne) +} diff --git a/internal/provider/testdata/no-versioning/migrations/00001_a.sql b/internal/provider/testdata/no-versioning/migrations/00001_a.sql new file mode 100644 index 000000000..839cb7a7b --- /dev/null +++ b/internal/provider/testdata/no-versioning/migrations/00001_a.sql @@ -0,0 +1,8 @@ +-- +goose Up +CREATE TABLE owners ( + owner_id INTEGER PRIMARY KEY AUTOINCREMENT, + owner_name TEXT NOT NULL +); + +-- +goose Down +DROP TABLE IF EXISTS owners; diff --git a/internal/provider/testdata/no-versioning/migrations/00002_b.sql b/internal/provider/testdata/no-versioning/migrations/00002_b.sql new file mode 100644 index 000000000..bd15ef51c --- /dev/null +++ b/internal/provider/testdata/no-versioning/migrations/00002_b.sql @@ -0,0 +1,9 @@ +-- +goose Up +-- +goose StatementBegin +INSERT INTO owners(owner_name) VALUES ('lucas'), ('ocean'); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DELETE FROM owners; +-- +goose StatementEnd diff --git a/internal/provider/testdata/no-versioning/migrations/00003_c.sql b/internal/provider/testdata/no-versioning/migrations/00003_c.sql new file mode 100644 index 000000000..422fb3068 --- /dev/null +++ b/internal/provider/testdata/no-versioning/migrations/00003_c.sql @@ -0,0 +1,9 @@ +-- +goose Up +-- +goose StatementBegin +INSERT INTO owners(owner_name) VALUES ('james'), ('space'); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DELETE FROM owners WHERE owner_name IN ('james', 'space'); +-- +goose StatementEnd diff --git a/internal/provider/testdata/no-versioning/seed/00001_a.sql b/internal/provider/testdata/no-versioning/seed/00001_a.sql new file mode 100644 index 000000000..64f9ff03c --- /dev/null +++ b/internal/provider/testdata/no-versioning/seed/00001_a.sql @@ -0,0 +1,17 @@ +-- +goose Up +-- +goose StatementBegin +-- Insert 100 owners. +INSERT INTO owners (owner_name) +WITH numbers AS ( + SELECT 1 AS n + UNION ALL + SELECT n + 1 FROM numbers WHERE n < 100 +) +SELECT 'seed-user-' || n FROM numbers; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +-- Delete the previously inserted data. +DELETE FROM owners WHERE owner_name LIKE 'seed-user-%'; +-- +goose StatementEnd diff --git a/internal/provider/testdata/no-versioning/seed/00002_b.sql b/internal/provider/testdata/no-versioning/seed/00002_b.sql new file mode 100644 index 000000000..aafe82752 --- /dev/null +++ b/internal/provider/testdata/no-versioning/seed/00002_b.sql @@ -0,0 +1,15 @@ +-- +goose Up + +-- Insert 150 more owners. +INSERT INTO owners (owner_name) +WITH numbers AS ( + SELECT 101 AS n + UNION ALL + SELECT n + 1 FROM numbers WHERE n < 250 +) +SELECT 'seed-user-' || n FROM numbers; + +-- +goose Down + +-- NOTE: there are 4 migration owners and 100 seed owners, that's why owner_id starts at 105 +DELETE FROM owners WHERE owner_name LIKE 'seed-user-%' AND owner_id BETWEEN 105 AND 254; diff --git a/internal/provider/types.go b/internal/provider/types.go new file mode 100644 index 000000000..21bb18beb --- /dev/null +++ b/internal/provider/types.go @@ -0,0 +1,99 @@ +package provider + +import ( + "fmt" + "time" +) + +// Dialect is the type of database dialect. +type Dialect string + +const ( + DialectClickHouse Dialect = "clickhouse" + DialectMSSQL Dialect = "mssql" + DialectMySQL Dialect = "mysql" + DialectPostgres Dialect = "postgres" + DialectRedshift Dialect = "redshift" + DialectSQLite3 Dialect = "sqlite3" + DialectTiDB Dialect = "tidb" + DialectVertica Dialect = "vertica" +) + +// MigrationType is the type of migration. +type MigrationType int + +const ( + TypeGo MigrationType = iota + 1 + TypeSQL +) + +func (t MigrationType) String() string { + switch t { + case TypeGo: + return "go" + case TypeSQL: + return "sql" + default: + // This should never happen. + return fmt.Sprintf("unknown (%d)", t) + } +} + +// Source represents a single migration source. +// +// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if +// the migration has a corresponding file on disk. It will be empty if the migration was registered +// manually. +type Source struct { + // Type is the type of migration. + Type MigrationType + // Full path to the migration file. + // + // Example: /path/to/migrations/001_create_users_table.sql + Fullpath string + // Version is the version of the migration. + Version int64 +} + +// MigrationResult is the result of a single migration operation. +// +// Note, the caller is responsible for checking the Error field for any errors that occurred while +// running the migration. If the Error field is not nil, the migration failed. +type MigrationResult struct { + Source Source + Duration time.Duration + Direction string + // Empty is true if the file was valid, but no statements to apply. These are still versioned + // migrations, but typically have no effect on the database. + // + // For SQL migrations, this means there was a valid .sql file but contained no statements. For + // Go migrations, this means the function was nil. + Empty bool + + // Error is any error that occurred while running the migration. + Error error +} + +// State represents the state of a migration. +type State string + +const ( + // StatePending represents a migration that is on the filesystem, but not in the database. + StatePending State = "pending" + // StateApplied represents a migration that is in BOTH the database and on the filesystem. + StateApplied State = "applied" + + // StateUntracked represents a migration that is in the database, but not on the filesystem. + // StateUntracked State = "untracked" +) + +// MigrationStatus represents the status of a single migration. +type MigrationStatus struct { + // State is the state of the migration. + State State + // AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or + // [StateUntracked]. + AppliedAt time.Time + // Source is the migration source. Only set if the state is [StatePending] or [StateApplied]. + Source Source +} diff --git a/internal/sqladapter/sqladapter.go b/internal/sqladapter/sqladapter.go new file mode 100644 index 000000000..f6c975dc4 --- /dev/null +++ b/internal/sqladapter/sqladapter.go @@ -0,0 +1,49 @@ +// Package sqladapter provides an interface for interacting with a SQL database. +// +// All supported database dialects must implement the Store interface. +package sqladapter + +import ( + "context" + "time" + + "github.com/pressly/goose/v3/internal/sqlextended" +) + +// Store is the interface that wraps the basic methods for a database dialect. +// +// A dialect is a set of SQL statements that are specific to a database. +// +// By defining a store interface, we can support multiple databases with a single codebase. +// +// The underlying implementation does not modify the error. It is the callers responsibility to +// assert for the correct error, such as [sql.ErrNoRows]. +type Store interface { + // CreateVersionTable creates the version table within a transaction. This table is used to + // record applied migrations. + CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error + + // InsertOrDelete inserts or deletes a version id from the version table. + InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error + + // GetMigration retrieves a single migration by version id. + // + // Returns the raw sql error if the query fails. It is the callers responsibility to assert for + // the correct error, such as [sql.ErrNoRows]. + GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) + + // ListMigrations retrieves all migrations sorted in descending order by id. + // + // If there are no migrations, an empty slice is returned with no error. + ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) +} + +type GetMigrationResult struct { + IsApplied bool + Timestamp time.Time +} + +type ListMigrationsResult struct { + Version int64 + IsApplied bool +} diff --git a/internal/sqladapter/store.go b/internal/sqladapter/store.go new file mode 100644 index 000000000..0ee90ca49 --- /dev/null +++ b/internal/sqladapter/store.go @@ -0,0 +1,111 @@ +package sqladapter + +import ( + "context" + "errors" + "fmt" + + "github.com/pressly/goose/v3/internal/dialect/dialectquery" + "github.com/pressly/goose/v3/internal/sqlextended" +) + +var _ Store = (*store)(nil) + +type store struct { + tablename string + querier dialectquery.Querier +} + +// NewStore returns a new [Store] backed by the given dialect. +// +// The dialect must match one of the supported dialects defined in dialect.go. +func NewStore(dialect string, table string) (Store, error) { + if table == "" { + return nil, errors.New("table must not be empty") + } + if dialect == "" { + return nil, errors.New("dialect must not be empty") + } + var querier dialectquery.Querier + switch dialect { + case "clickhouse": + querier = &dialectquery.Clickhouse{} + case "mssql": + querier = &dialectquery.Sqlserver{} + case "mysql": + querier = &dialectquery.Mysql{} + case "postgres": + querier = &dialectquery.Postgres{} + case "redshift": + querier = &dialectquery.Redshift{} + case "sqlite3": + querier = &dialectquery.Sqlite3{} + case "tidb": + querier = &dialectquery.Tidb{} + case "vertica": + querier = &dialectquery.Vertica{} + default: + return nil, fmt.Errorf("unknown dialect: %q", dialect) + } + return &store{ + tablename: table, + querier: querier, + }, nil +} + +func (s *store) CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error { + q := s.querier.CreateTable(s.tablename) + if _, err := db.ExecContext(ctx, q); err != nil { + return fmt.Errorf("failed to create version table %q: %w", s.tablename, err) + } + return nil +} + +func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error { + if direction { + q := s.querier.InsertVersion(s.tablename) + if _, err := db.ExecContext(ctx, q, version, true); err != nil { + return fmt.Errorf("failed to insert version %d: %w", version, err) + } + return nil + } + q := s.querier.DeleteVersion(s.tablename) + if _, err := db.ExecContext(ctx, q, version); err != nil { + return fmt.Errorf("failed to delete version %d: %w", version, err) + } + return nil +} + +func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) { + q := s.querier.GetMigrationByVersion(s.tablename) + var result GetMigrationResult + if err := db.QueryRowContext(ctx, q, version).Scan( + &result.Timestamp, + &result.IsApplied, + ); err != nil { + return nil, fmt.Errorf("failed to get migration %d: %w", version, err) + } + return &result, nil +} + +func (s *store) ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) { + q := s.querier.ListMigrations(s.tablename) + rows, err := db.QueryContext(ctx, q) + if err != nil { + return nil, fmt.Errorf("failed to list migrations: %w", err) + } + defer rows.Close() + + var migrations []*ListMigrationsResult + for rows.Next() { + var result ListMigrationsResult + if err := rows.Scan(&result.Version, &result.IsApplied); err != nil { + return nil, fmt.Errorf("failed to scan list migrations result: %w", err) + } + migrations = append(migrations, &result) + } + if err := rows.Err(); err != nil { + return nil, err + } + return migrations, nil +} diff --git a/internal/sqladapter/store_test.go b/internal/sqladapter/store_test.go new file mode 100644 index 000000000..1d0189598 --- /dev/null +++ b/internal/sqladapter/store_test.go @@ -0,0 +1,218 @@ +package sqladapter_test + +import ( + "context" + "database/sql" + "errors" + "path/filepath" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/internal/testdb" + "go.uber.org/multierr" + "modernc.org/sqlite" +) + +// The goal of this test is to verify the sqladapter package works as expected. This test is not +// meant to be exhaustive or test every possible database dialect. It is meant to verify the Store +// interface works against a real database. + +func TestStore(t *testing.T) { + t.Parallel() + t.Run("invalid", func(t *testing.T) { + // Test empty table name. + _, err := sqladapter.NewStore("sqlite3", "") + check.HasError(t, err) + // Test unknown dialect. + _, err = sqladapter.NewStore("unknown-dialect", "foo") + check.HasError(t, err) + // Test empty dialect. + _, err = sqladapter.NewStore("", "foo") + check.HasError(t, err) + }) + t.Run("postgres", func(t *testing.T) { + if testing.Short() { + t.Skip("skip long-running test") + } + // Test postgres specific behavior. + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + testStore(context.Background(), t, goose.DialectPostgres, db, func(t *testing.T, err error) { + var pgErr *pgconn.PgError + ok := errors.As(err, &pgErr) + check.Bool(t, ok, true) + check.Equal(t, pgErr.Code, "42P07") // duplicate_table + }) + }) + // Test generic behavior. + t.Run("sqlite3", func(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + testStore(context.Background(), t, goose.DialectSQLite3, db, func(t *testing.T, err error) { + var sqliteErr *sqlite.Error + ok := errors.As(err, &sqliteErr) + check.Bool(t, ok, true) + check.Equal(t, sqliteErr.Code(), 1) // Generic error (SQLITE_ERROR) + check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists") + }) + }) +} + +// testStore tests various store operations. +// +// If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable +// when the version table already exists. +func testStore(ctx context.Context, t *testing.T, dialect goose.Dialect, db *sql.DB, alreadyExists func(t *testing.T, err error)) { + const ( + tablename = "test_goose_db_version" + ) + store, err := sqladapter.NewStore(string(dialect), tablename) + check.NoError(t, err) + // Create the version table. + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.CreateVersionTable(ctx, tx) + }) + check.NoError(t, err) + // Create the version table again. This should fail. + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.CreateVersionTable(ctx, tx) + }) + check.HasError(t, err) + if alreadyExists != nil { + alreadyExists(t, err) + } + + // List migrations. There should be none. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 0) + return nil + }) + check.NoError(t, err) + + // Insert 5 migrations in addition to the zero migration. + for i := 0; i < 6; i++ { + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, true, int64(i)) + }) + check.NoError(t, err) + } + + // List migrations. There should be 6. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 6) + // Check versions are in descending order. + for i := 0; i < 6; i++ { + check.Number(t, res[i].Version, 5-i) + } + return nil + }) + check.NoError(t, err) + + // Delete 3 migrations backwards + for i := 5; i >= 3; i-- { + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, false, int64(i)) + }) + check.NoError(t, err) + } + + // List migrations. There should be 3. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 3) + // Check that the remaining versions are in descending order. + for i := 0; i < 3; i++ { + check.Number(t, res[i].Version, 2-i) + } + return nil + }) + check.NoError(t, err) + + // Get remaining migrations one by one. + for i := 0; i < 3; i++ { + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.GetMigration(ctx, conn, int64(i)) + check.NoError(t, err) + check.Equal(t, res.IsApplied, true) + check.Equal(t, res.Timestamp.IsZero(), false) + return nil + }) + check.NoError(t, err) + } + + // Delete remaining migrations one by one and use all 3 connection types: + + // 1. *sql.Tx + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.InsertOrDelete(ctx, tx, false, 2) + }) + check.NoError(t, err) + // 2. *sql.Conn + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, false, 1) + }) + check.NoError(t, err) + // 3. *sql.DB + err = store.InsertOrDelete(ctx, db, false, 0) + check.NoError(t, err) + + // List migrations. There should be none. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 0) + return nil + }) + check.NoError(t, err) + + // Try to get a migration that does not exist. + err = runConn(ctx, db, func(conn *sql.Conn) error { + _, err := store.GetMigration(ctx, conn, 0) + check.HasError(t, err) + check.Bool(t, errors.Is(err, sql.ErrNoRows), true) + return nil + }) + check.NoError(t, err) +} + +func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, tx.Rollback()) + } + }() + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func runConn(ctx context.Context, db *sql.DB, fn func(*sql.Conn) error) (retErr error) { + conn, err := db.Conn(ctx) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, conn.Close()) + } + }() + if err := fn(conn); err != nil { + return err + } + return conn.Close() +} diff --git a/internal/sqlextended/sqlextended.go b/internal/sqlextended/sqlextended.go new file mode 100644 index 000000000..83ca7ae8b --- /dev/null +++ b/internal/sqlextended/sqlextended.go @@ -0,0 +1,23 @@ +package sqlextended + +import ( + "context" + "database/sql" +) + +// DBTxConn is a thin interface for common method that is satisfied by *sql.DB, *sql.Tx and +// *sql.Conn. +// +// There is a long outstanding issue to formalize a std lib interface, but alas... See: +// https://github.com/golang/go/issues/14468 +type DBTxConn interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +var ( + _ DBTxConn = (*sql.DB)(nil) + _ DBTxConn = (*sql.Tx)(nil) + _ DBTxConn = (*sql.Conn)(nil) +) diff --git a/internal/sqlparser/parse.go b/internal/sqlparser/parse.go new file mode 100644 index 000000000..b42fdde14 --- /dev/null +++ b/internal/sqlparser/parse.go @@ -0,0 +1,59 @@ +package sqlparser + +import ( + "fmt" + "io/fs" + + "go.uber.org/multierr" + "golang.org/x/sync/errgroup" +) + +type ParsedSQL struct { + UseTx bool + Up, Down []string +} + +func ParseAllFromFS(fsys fs.FS, filename string, debug bool) (*ParsedSQL, error) { + parsedSQL := new(ParsedSQL) + // TODO(mf): parse is called twice, once for up and once for down. This is inefficient. It + // should be possible to parse both directions in one pass. Also, UseTx is set once (but + // returned twice), which is unnecessary and potentially error-prone if the two calls to + // parseSQL disagree based on direction. + var g errgroup.Group + g.Go(func() error { + up, useTx, err := parse(fsys, filename, DirectionUp, debug) + if err != nil { + return err + } + parsedSQL.Up = up + parsedSQL.UseTx = useTx + return nil + }) + g.Go(func() error { + down, _, err := parse(fsys, filename, DirectionDown, debug) + if err != nil { + return err + } + parsedSQL.Down = down + return nil + }) + if err := g.Wait(); err != nil { + return nil, err + } + return parsedSQL, nil +} + +func parse(fsys fs.FS, filename string, direction Direction, debug bool) (_ []string, _ bool, retErr error) { + r, err := fsys.Open(filename) + if err != nil { + return nil, false, err + } + defer func() { + retErr = multierr.Append(retErr, r.Close()) + }() + stmts, useTx, err := ParseSQLMigration(r, direction, debug) + if err != nil { + return nil, false, fmt.Errorf("failed to parse %s: %w", filename, err) + } + return stmts, useTx, nil +} diff --git a/internal/sqlparser/parse_test.go b/internal/sqlparser/parse_test.go new file mode 100644 index 000000000..632bbe13b --- /dev/null +++ b/internal/sqlparser/parse_test.go @@ -0,0 +1,82 @@ +package sqlparser_test + +import ( + "errors" + "os" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/sqlparser" +) + +func TestParseAllFromFS(t *testing.T) { + t.Parallel() + t.Run("file_not_exist", func(t *testing.T) { + mapFS := fstest.MapFS{} + _, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) + check.HasError(t, err) + check.Bool(t, errors.Is(err, os.ErrNotExist), true) + }) + t.Run("empty_file", func(t *testing.T) { + mapFS := fstest.MapFS{ + "001_foo.sql": &fstest.MapFile{}, + } + _, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) + check.HasError(t, err) + check.Contains(t, err.Error(), "failed to parse migration") + check.Contains(t, err.Error(), "must start with '-- +goose Up' annotation") + }) + t.Run("all_statements", func(t *testing.T) { + mapFS := fstest.MapFS{ + "001_foo.sql": newFile(` +-- +goose Up +`), + "002_bar.sql": newFile(` +-- +goose Up +-- +goose Down +`), + "003_baz.sql": newFile(` +-- +goose Up +CREATE TABLE foo (id int); +CREATE TABLE bar (id int); + +-- +goose Down +DROP TABLE bar; +`), + "004_qux.sql": newFile(` +-- +goose NO TRANSACTION +-- +goose Up +CREATE TABLE foo (id int); +-- +goose Down +DROP TABLE foo; +`), + } + parsedSQL, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, true, 0, 0) + parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "002_bar.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, true, 0, 0) + parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "003_baz.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, true, 2, 1) + parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "004_qux.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, false, 1, 1) + }) +} + +func assertParsedSQL(t *testing.T, got *sqlparser.ParsedSQL, useTx bool, up, down int) { + t.Helper() + check.Bool(t, got != nil, true) + check.Equal(t, len(got.Up), up) + check.Equal(t, len(got.Down), down) + check.Equal(t, got.UseTx, useTx) +} + +func newFile(data string) *fstest.MapFile { + return &fstest.MapFile{ + Data: []byte(data), + } +} diff --git a/internal/sqlparser/parser.go b/internal/sqlparser/parser.go index 5e6c67503..a62846026 100644 --- a/internal/sqlparser/parser.go +++ b/internal/sqlparser/parser.go @@ -25,6 +25,14 @@ func FromBool(b bool) Direction { return DirectionDown } +func (d Direction) String() string { + return string(d) +} + +func (d Direction) ToBool() bool { + return d == DirectionUp +} + type parserState int const ( diff --git a/internal/testdb/vertica.go b/internal/testdb/vertica.go index abe292bec..fc9cc1d9a 100644 --- a/internal/testdb/vertica.go +++ b/internal/testdb/vertica.go @@ -15,7 +15,7 @@ import ( const ( // https://hub.docker.com/r/vertica/vertica-ce VERTICA_IMAGE = "vertica/vertica-ce" - VERTICA_VERSION = "12.0.0-0" + VERTICA_VERSION = "23.3.0-0" VERTICA_DB = "testdb" ) diff --git a/lock/postgres.go b/lock/postgres.go new file mode 100644 index 000000000..3583162e2 --- /dev/null +++ b/lock/postgres.go @@ -0,0 +1,110 @@ +package lock + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/sethvargo/go-retry" +) + +// NewPostgresSessionLocker returns a SessionLocker that utilizes PostgreSQL's exclusive +// session-level advisory lock mechanism. +// +// This function creates a SessionLocker that can be used to acquire and release locks for +// synchronization purposes. The lock acquisition is retried until it is successfully acquired or +// until the maximum duration is reached. The default lock duration is set to 60 minutes, and the +// default unlock duration is set to 1 minute. +// +// See [SessionLockerOption] for options that can be used to configure the SessionLocker. +func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error) { + cfg := sessionLockerConfig{ + lockID: DefaultLockID, + lockTimeout: DefaultLockTimeout, + unlockTimeout: DefaultUnlockTimeout, + } + for _, opt := range opts { + if err := opt.apply(&cfg); err != nil { + return nil, err + } + } + return &postgresSessionLocker{ + lockID: cfg.lockID, + retryLock: retry.WithMaxDuration( + cfg.lockTimeout, + retry.NewConstant(2*time.Second), + ), + retryUnlock: retry.WithMaxDuration( + cfg.unlockTimeout, + retry.NewConstant(2*time.Second), + ), + }, nil +} + +type postgresSessionLocker struct { + lockID int64 + retryLock retry.Backoff + retryUnlock retry.Backoff +} + +var _ SessionLocker = (*postgresSessionLocker)(nil) + +func (l *postgresSessionLocker) SessionLock(ctx context.Context, conn *sql.Conn) error { + return retry.Do(ctx, l.retryLock, func(ctx context.Context) error { + row := conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", l.lockID) + var locked bool + if err := row.Scan(&locked); err != nil { + return fmt.Errorf("failed to execute pg_try_advisory_lock: %w", err) + } + if locked { + // A session-level advisory lock was acquired. + return nil + } + // A session-level advisory lock could not be acquired. This is likely because another + // process has already acquired the lock. We will continue retrying until the lock is + // acquired or the maximum number of retries is reached. + return retry.RetryableError(errors.New("failed to acquire lock")) + }) +} + +func (l *postgresSessionLocker) SessionUnlock(ctx context.Context, conn *sql.Conn) error { + return retry.Do(ctx, l.retryUnlock, func(ctx context.Context) error { + var unlocked bool + row := conn.QueryRowContext(ctx, "SELECT pg_advisory_unlock($1)", l.lockID) + if err := row.Scan(&unlocked); err != nil { + return fmt.Errorf("failed to execute pg_advisory_unlock: %w", err) + } + if unlocked { + // A session-level advisory lock was released. + return nil + } + /* + TODO(mf): provide users with some documentation on how they can unlock the session + manually. + + This is probably not an issue for 99.99% of users since pg_advisory_unlock_all() will + release all session level advisory locks held by the current session. This function is + implicitly invoked at session end, even if the client disconnects ungracefully. + + Here is output from a session that has a lock held: + + SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks + WHERE locktype='advisory'; + + | pid | granted | goose_lock_id | + |-----|---------|---------------------| + | 191 | t | 5887940537704921958 | + + A forceful way to unlock the session is to terminate the backend with SIGTERM: + + SELECT pg_terminate_backend(191); + + Subsequent commands on the same connection will fail with: + + Query 1 ERROR: FATAL: terminating connection due to administrator command + */ + return retry.RetryableError(errors.New("failed to unlock session")) + }) +} diff --git a/lock/postgres_test.go b/lock/postgres_test.go new file mode 100644 index 000000000..2622d5cb6 --- /dev/null +++ b/lock/postgres_test.go @@ -0,0 +1,194 @@ +package lock_test + +import ( + "context" + "database/sql" + "errors" + "sync" + "testing" + "time" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/testdb" + "github.com/pressly/goose/v3/lock" +) + +func TestPostgresSessionLocker(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("skip long running test") + } + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + + // Do not run tests in parallel, because they are using the same database. + + t.Run("lock_and_unlock", func(t *testing.T) { + const ( + lockID int64 = 123456789 + ) + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockID(lockID), + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + ctx := context.Background() + conn, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn.Close()) + }) + err = locker.SessionLock(ctx, conn) + check.NoError(t, err) + pgLocks, err := queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + // Check that the lock was acquired. + check.Bool(t, pgLocks[0].granted, true) + // Check that the custom lock ID is the same as the one used by the locker. + check.Equal(t, pgLocks[0].gooseLockID, lockID) + check.NumberNotZero(t, pgLocks[0].pid) + + // Check that the lock is released. + err = locker.SessionUnlock(ctx, conn) + check.NoError(t, err) + pgLocks, err = queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 0) + }) + t.Run("lock_close_conn_unlock", func(t *testing.T) { + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + ctx := context.Background() + conn, err := db.Conn(ctx) + check.NoError(t, err) + + err = locker.SessionLock(ctx, conn) + check.NoError(t, err) + pgLocks, err := queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + check.Bool(t, pgLocks[0].granted, true) + check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID) + // Simulate a connection close. + err = conn.Close() + check.NoError(t, err) + // Check an error is returned when unlocking, because the connection is already closed. + err = locker.SessionUnlock(ctx, conn) + check.HasError(t, err) + check.Bool(t, errors.Is(err, sql.ErrConnDone), true) + }) + t.Run("multiple_connections", func(t *testing.T) { + const ( + workers = 5 + ) + ch := make(chan error) + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + ctx := context.Background() + conn, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn.Close()) + }) + // Exactly one connection should acquire the lock. While the other connections + // should fail to acquire the lock and timeout. + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + ch <- locker.SessionLock(ctx, conn) + }() + } + go func() { + wg.Wait() + close(ch) + }() + var errors []error + for err := range ch { + if err != nil { + errors = append(errors, err) + } + } + check.Equal(t, len(errors), workers-1) // One worker succeeds, the rest fail. + for _, err := range errors { + check.HasError(t, err) + check.Equal(t, err.Error(), "failed to acquire lock") + } + pgLocks, err := queryPgLocks(context.Background(), db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + check.Bool(t, pgLocks[0].granted, true) + check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID) + }) + t.Run("unlock_with_different_connection", func(t *testing.T) { + ctx := context.Background() + const ( + lockID int64 = 999 + ) + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockID(lockID), + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + + conn1, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn1.Close()) + }) + err = locker.SessionLock(ctx, conn1) + check.NoError(t, err) + pgLocks, err := queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + check.Bool(t, pgLocks[0].granted, true) + check.Equal(t, pgLocks[0].gooseLockID, lockID) + // Unlock with a different connection. + conn2, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn2.Close()) + }) + // Check an error is returned when unlocking with a different connection. + err = locker.SessionUnlock(ctx, conn2) + check.HasError(t, err) + }) +} + +type pgLock struct { + pid int + granted bool + gooseLockID int64 +} + +func queryPgLocks(ctx context.Context, db *sql.DB) ([]pgLock, error) { + q := `SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks WHERE locktype='advisory'` + rows, err := db.QueryContext(ctx, q) + if err != nil { + return nil, err + } + var pgLocks []pgLock + for rows.Next() { + var p pgLock + if err = rows.Scan(&p.pid, &p.granted, &p.gooseLockID); err != nil { + return nil, err + } + pgLocks = append(pgLocks, p) + } + if err := rows.Err(); err != nil { + return nil, err + } + return pgLocks, nil +} diff --git a/lock/session_locker.go b/lock/session_locker.go new file mode 100644 index 000000000..b74187829 --- /dev/null +++ b/lock/session_locker.go @@ -0,0 +1,23 @@ +// Package lock defines the Locker interface and implements the locking logic. +package lock + +import ( + "context" + "database/sql" + "errors" +) + +var ( + // ErrLockNotImplemented is returned when the database does not support locking. + ErrLockNotImplemented = errors.New("lock not implemented") + // ErrUnlockNotImplemented is returned when the database does not support unlocking. + ErrUnlockNotImplemented = errors.New("unlock not implemented") +) + +// SessionLocker is the interface to lock and unlock the database for the duration of a session. The +// session is defined as the duration of a single connection and both methods must be called on the +// same connection. +type SessionLocker interface { + SessionLock(ctx context.Context, conn *sql.Conn) error + SessionUnlock(ctx context.Context, conn *sql.Conn) error +} diff --git a/lock/session_locker_options.go b/lock/session_locker_options.go new file mode 100644 index 000000000..c3e42151c --- /dev/null +++ b/lock/session_locker_options.go @@ -0,0 +1,63 @@ +package lock + +import ( + "time" +) + +const ( + // DefaultLockID is the id used to lock the database for migrations. It is a crc64 hash of the + // string "goose". This is used to ensure that the lock is unique to goose. + // + // crc64.Checksum([]byte("goose"), crc64.MakeTable(crc64.ECMA)) + DefaultLockID int64 = 5887940537704921958 + + // Default values for the lock (time to wait for the lock to be acquired) and unlock (time to + // wait for the lock to be released) wait durations. + DefaultLockTimeout time.Duration = 60 * time.Minute + DefaultUnlockTimeout time.Duration = 1 * time.Minute +) + +// SessionLockerOption is used to configure a SessionLocker. +type SessionLockerOption interface { + apply(*sessionLockerConfig) error +} + +// WithLockID sets the lock ID to use when locking the database. +// +// If WithLockID is not called, the DefaultLockID is used. +func WithLockID(lockID int64) SessionLockerOption { + return sessionLockerConfigFunc(func(c *sessionLockerConfig) error { + c.lockID = lockID + return nil + }) +} + +// WithLockTimeout sets the max duration to wait for the lock to be acquired. +func WithLockTimeout(duration time.Duration) SessionLockerOption { + return sessionLockerConfigFunc(func(c *sessionLockerConfig) error { + c.lockTimeout = duration + return nil + }) +} + +// WithUnlockTimeout sets the max duration to wait for the lock to be released. +func WithUnlockTimeout(duration time.Duration) SessionLockerOption { + return sessionLockerConfigFunc(func(c *sessionLockerConfig) error { + c.unlockTimeout = duration + return nil + }) +} + +type sessionLockerConfig struct { + lockID int64 + lockTimeout time.Duration + unlockTimeout time.Duration +} + +var _ SessionLockerOption = (sessionLockerConfigFunc)(nil) + +type sessionLockerConfigFunc func(*sessionLockerConfig) error + +func (f sessionLockerConfigFunc) apply(cfg *sessionLockerConfig) error { + return f(cfg) +} diff --git a/migration.go b/migration.go index dcf0c6118..619e934d0 100644 --- a/migration.go +++ b/migration.go @@ -218,27 +218,27 @@ func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, d return store.DeleteVersionNoTx(ctx, db, TableName(), version) } -// NumericComponent looks for migration scripts with names in the form: -// XXX_descriptivename.ext where XXX specifies the version number -// and ext specifies the type of migration -func NumericComponent(name string) (int64, error) { - base := filepath.Base(name) - +// NumericComponent parses the version from the migration file name. +// +// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of +// migration, either .sql or .go. +func NumericComponent(filename string) (int64, error) { + base := filepath.Base(filename) if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { - return 0, errors.New("not a recognized migration file type") + return 0, errors.New("migration file does not have .sql or .go file extension") } - idx := strings.Index(base, "_") if idx < 0 { return 0, errors.New("no filename separator '_' found") } - - n, e := strconv.ParseInt(base[:idx], 10, 64) - if e == nil && n <= 0 { - return 0, errors.New("migration IDs must be greater than zero") + n, err := strconv.ParseInt(base[:idx], 10, 64) + if err != nil { + return 0, err } - - return n, e + if n < 1 { + return 0, errors.New("migration version must be greater than zero") + } + return n, nil } func truncateDuration(d time.Duration) time.Duration { diff --git a/testdata/migrations/00002_posts_table.sql b/testdata/migrations/00002_posts_table.sql index 25648ed42..be70a2348 100644 --- a/testdata/migrations/00002_posts_table.sql +++ b/testdata/migrations/00002_posts_table.sql @@ -1,4 +1,5 @@ -- +goose Up +-- +goose StatementBegin CREATE TABLE posts ( id INTEGER PRIMARY KEY, title TEXT NOT NULL, @@ -7,6 +8,7 @@ CREATE TABLE posts ( created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (author_id) REFERENCES users(id) ); +-- +goose StatementEnd -- +goose Down DROP TABLE posts;