From 473a70da34ca4ca6460362ec84cf57dc6d2087d3 Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Thu, 5 Oct 2023 08:59:48 -0400 Subject: [PATCH 01/10] build: bump deps and GH actions (#602) --- .github/workflows/release.yaml | 2 +- go.mod | 42 ++++++++--------- go.sum | 86 +++++++++++++++++----------------- internal/testdb/vertica.go | 2 +- 4 files changed, 66 insertions(+), 66 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index de76f3840..f99d4a6a3 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -20,7 +20,7 @@ jobs: continue-on-error: true run: ./scripts/release-notes.sh ${{github.ref_name}} > ${{runner.temp}}/release_notes.txt - name: Run GoReleaser - uses: goreleaser/goreleaser-action@v4 + uses: goreleaser/goreleaser-action@v5 with: distribution: goreleaser version: latest diff --git a/go.mod b/go.mod index b8aaaa724..931d9a985 100644 --- a/go.mod +++ b/go.mod @@ -3,30 +3,30 @@ module github.com/pressly/goose/v3 go 1.19 require ( - github.com/ClickHouse/clickhouse-go/v2 v2.13.0 + github.com/ClickHouse/clickhouse-go/v2 v2.14.2 github.com/go-sql-driver/mysql v1.7.1 github.com/jackc/pgx/v5 v5.4.3 github.com/microsoft/go-mssqldb v1.6.0 github.com/ory/dockertest/v3 v3.10.0 github.com/vertica/vertica-sql-go v1.3.3 github.com/ziutek/mymysql v1.5.4 - modernc.org/sqlite v1.25.0 + modernc.org/sqlite v1.26.0 ) require ( github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect - github.com/ClickHouse/ch-go v0.57.0 // indirect + github.com/ClickHouse/ch-go v0.58.2 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect - github.com/containerd/continuity v0.4.1 // indirect - github.com/docker/cli v24.0.2+incompatible // indirect - github.com/docker/docker v24.0.2+incompatible // indirect + github.com/containerd/continuity v0.4.2 // indirect + github.com/docker/cli v24.0.6+incompatible // indirect + github.com/docker/docker v24.0.6+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/elastic/go-sysinfo v1.11.0 // indirect + github.com/elastic/go-sysinfo v1.11.1 // indirect github.com/elastic/go-windows v1.0.1 // indirect github.com/go-faster/city v1.0.1 // indirect github.com/go-faster/errors v0.6.1 // indirect @@ -34,24 +34,24 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/google/uuid v1.3.0 // indirect + github.com/google/uuid v1.3.1 // indirect github.com/imdario/mergo v0.3.16 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/joeshaw/multierror v0.0.0-20140124173710-69b34d4ec901 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/klauspost/compress v1.16.7 // indirect + github.com/klauspost/compress v1.17.0 // indirect github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.0-rc4 // indirect - github.com/opencontainers/runc v1.1.7 // indirect + github.com/opencontainers/image-spec v1.1.0-rc5 // indirect + github.com/opencontainers/runc v1.1.9 // indirect github.com/paulmach/orb v0.10.0 // indirect github.com/pierrec/lz4/v4 v4.1.18 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/prometheus/procfs v0.11.0 // indirect + github.com/prometheus/procfs v0.12.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/shopspring/decimal v1.3.1 // indirect @@ -59,24 +59,24 @@ require ( github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect - go.opentelemetry.io/otel v1.16.0 // indirect - go.opentelemetry.io/otel/trace v1.16.0 // indirect - golang.org/x/crypto v0.12.0 // indirect + go.opentelemetry.io/otel v1.19.0 // indirect + go.opentelemetry.io/otel/trace v1.19.0 // indirect + golang.org/x/crypto v0.13.0 // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/sys v0.11.0 // indirect - golang.org/x/text v0.12.0 // indirect - golang.org/x/tools v0.10.0 // indirect + golang.org/x/sys v0.13.0 // indirect + golang.org/x/text v0.13.0 // indirect + golang.org/x/tools v0.13.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect howett.net/plist v1.0.0 // indirect lukechampine.com/uint128 v1.3.0 // indirect modernc.org/cc/v3 v3.41.0 // indirect - modernc.org/ccgo/v3 v3.16.14 // indirect + modernc.org/ccgo/v3 v3.16.15 // indirect modernc.org/libc v1.24.1 // indirect modernc.org/mathutil v1.6.0 // indirect - modernc.org/memory v1.6.0 // indirect + modernc.org/memory v1.7.2 // indirect modernc.org/opt v0.1.3 // indirect - modernc.org/strutil v1.1.3 // indirect + modernc.org/strutil v1.2.0 // indirect modernc.org/token v1.1.0 // indirect ) diff --git a/go.sum b/go.sum index 659c7bb7a..0c34e9ce7 100644 --- a/go.sum +++ b/go.sum @@ -6,10 +6,10 @@ github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0 h1:T028g github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0 h1:HCc0+LpPfpCKs6LGGLAhwBARt9632unrVcI6i8s/8os= -github.com/ClickHouse/ch-go v0.57.0 h1:X/QmUmFhpUvLgPSQb7fWOSi1wvqGn6tJ7w2a59c4xsg= -github.com/ClickHouse/ch-go v0.57.0/go.mod h1:DR3iBn7OrrDj+KeUp1LbdxLEUDbW+5Qwdl/qkc+PQ+Y= -github.com/ClickHouse/clickhouse-go/v2 v2.13.0 h1:oP1OlTQIbQKKLnqLzyDhiyNFvN3pbOtM+e/3qdexG9k= -github.com/ClickHouse/clickhouse-go/v2 v2.13.0/go.mod h1:xyL0De2K54/n+HGsdtPuyYJq76wefafaHfGUXTDEq/0= +github.com/ClickHouse/ch-go v0.58.2 h1:jSm2szHbT9MCAB1rJ3WuCJqmGLi5UTjlNu+f530UTS0= +github.com/ClickHouse/ch-go v0.58.2/go.mod h1:Ap/0bEmiLa14gYjCiRkYGbXvbe8vwdrfTYWhsuQ99aw= +github.com/ClickHouse/clickhouse-go/v2 v2.14.2 h1:iYGP2bgPYJ33y6rCfZTQAPrnHt8wmsM3Ut4cLoYhWY0= +github.com/ClickHouse/clickhouse-go/v2 v2.14.2/go.mod h1:ZLn63wODwGxVdnGB0EIYmFL5tjtlLcLBuwQUH6B2sYk= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= @@ -18,17 +18,17 @@ github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/ github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/containerd/continuity v0.4.1 h1:wQnVrjIyQ8vhU2sgOiL5T07jo+ouqc2bnKsv5/EqGhU= -github.com/containerd/continuity v0.4.1/go.mod h1:F6PTNCKepoxEaXLQp3wDAjygEnImnZ/7o4JzpodfroQ= +github.com/containerd/continuity v0.4.2 h1:v3y/4Yz5jwnvqPKJJ+7Wf93fyWoCB3F5EclWG023MDM= +github.com/containerd/continuity v0.4.2/go.mod h1:F6PTNCKepoxEaXLQp3wDAjygEnImnZ/7o4JzpodfroQ= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/docker/cli v24.0.2+incompatible h1:QdqR7znue1mtkXIJ+ruQMGQhpw2JzMJLRXp6zpzF6tM= -github.com/docker/cli v24.0.2+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/cli v24.0.6+incompatible h1:fF+XCQCgJjjQNIMjzaSmiKJSCcfcXb3TWTcc7GAneOY= +github.com/docker/cli v24.0.6+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= -github.com/docker/docker v24.0.2+incompatible h1:eATx+oLz9WdNVkQrr0qjQ8HvRJ4bOOxfzEo8R+dA3cg= -github.com/docker/docker v24.0.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v24.0.6+incompatible h1:hceabKCtUgDqPu+qm0NgsaXf28Ljf4/pWFL7xjWWDgE= +github.com/docker/docker v24.0.6+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= @@ -36,8 +36,8 @@ github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDD github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/elastic/go-sysinfo v1.8.1/go.mod h1:JfllUnzoQV/JRYymbH3dO1yggI3mV2oTKSXsDHM+uIM= -github.com/elastic/go-sysinfo v1.11.0 h1:QW+6BF1oxBoAprH3w2yephF7xLkrrSXj7gl2xC2BM4w= -github.com/elastic/go-sysinfo v1.11.0/go.mod h1:6KQb31j0QeWBDF88jIdWSxE8cwoOB9tO4Y4osN7Q70E= +github.com/elastic/go-sysinfo v1.11.1 h1:g9mwl05njS4r69TisC+vwHWTSKywZFYYUu3so3T/Lao= +github.com/elastic/go-sysinfo v1.11.1/go.mod h1:6KQb31j0QeWBDF88jIdWSxE8cwoOB9tO4Y4osN7Q70E= github.com/elastic/go-windows v1.0.0/go.mod h1:TsU0Nrp7/y3+VwE82FoZF8gC/XFg/Elz6CcloAxnPgU= github.com/elastic/go-windows v1.0.1 h1:AlYZOldA+UJ0/2nBuqWdo90GFCgG9xuyw9SYzGUtJm0= github.com/elastic/go-windows v1.0.1/go.mod h1:FoVvqWSun28vaDQPbj2Elfc0JahhPB7WQEGa3c814Ss= @@ -62,8 +62,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= +github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -80,8 +80,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:C github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= -github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= +github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -102,10 +102,10 @@ github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3 github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= -github.com/opencontainers/image-spec v1.1.0-rc4 h1:oOxKUJWnFC4YGHCCMNql1x4YaDfYBTS5Y4x/Cgeo1E0= -github.com/opencontainers/image-spec v1.1.0-rc4/go.mod h1:X4pATf0uXsnn3g5aiGIsVnJBR4mxhKzfwmvK/B2NTm8= -github.com/opencontainers/runc v1.1.7 h1:y2EZDS8sNng4Ksf0GUYNhKbTShZJPJg1FiXJNH/uoCk= -github.com/opencontainers/runc v1.1.7/go.mod h1:CbUumNnWCuTGFukNXahoo/RFBZvDAgRh/smNYNOhA50= +github.com/opencontainers/image-spec v1.1.0-rc5 h1:Ygwkfw9bpDvs+c9E34SdgGOj41dX/cbdlwvlWt0pnFI= +github.com/opencontainers/image-spec v1.1.0-rc5/go.mod h1:X4pATf0uXsnn3g5aiGIsVnJBR4mxhKzfwmvK/B2NTm8= +github.com/opencontainers/runc v1.1.9 h1:XR0VIHTGce5eWPkaPesqTBrhW2yAcaraWfsEalNwQLM= +github.com/opencontainers/runc v1.1.9/go.mod h1:CbUumNnWCuTGFukNXahoo/RFBZvDAgRh/smNYNOhA50= github.com/ory/dockertest/v3 v3.10.0 h1:4K3z2VMe8Woe++invjaTB7VRyQXQy5UY+loujO4aNE4= github.com/ory/dockertest/v3 v3.10.0/go.mod h1:nr57ZbRWMqfsdGdFNLHz5jjNdDb7VVFnzAeW1n5N1Lg= github.com/paulmach/orb v0.10.0 h1:guVYVqzxHE/CQ1KpfGO077TR0ATHSNjp4s6XGLn3W9s= @@ -120,8 +120,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/procfs v0.0.0-20190425082905-87a4384529e0/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.11.0 h1:5EAgkfkMl659uZPbe9AS2N68a7Cc1TJbPEuGzFuRbyk= -github.com/prometheus/procfs v0.11.0/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= +github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= +github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= @@ -155,16 +155,16 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs= github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= go.mongodb.org/mongo-driver v1.11.4/go.mod h1:PTSz5yu21bkT/wXpkS7WR5f0ddqw5quethTUn9WM+2g= -go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s= -go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4= -go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs= -go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0= +go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs= +go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY= +go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg= +go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= -golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= @@ -174,7 +174,7 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -192,21 +192,21 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= -golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= -golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM= +golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -231,22 +231,22 @@ lukechampine.com/uint128 v1.3.0 h1:cDdUVfRwDUDovz610ABgFD17nXD4/uDgVHl2sC3+sbo= lukechampine.com/uint128 v1.3.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= modernc.org/cc/v3 v3.41.0 h1:QoR1Sn3YWlmA1T4vLaKZfawdVtSiGx8H+cEojbC7v1Q= modernc.org/cc/v3 v3.41.0/go.mod h1:Ni4zjJYJ04CDOhG7dn640WGfwBzfE0ecX8TyMB0Fv0Y= -modernc.org/ccgo/v3 v3.16.14 h1:af6KNtFgsVmnDYrWk3PQCS9XT6BXe7o3ZFJKkIKvXNQ= -modernc.org/ccgo/v3 v3.16.14/go.mod h1:mPDSujUIaTNWQSG4eqKw+atqLOEbma6Ncsa94WbC9zo= +modernc.org/ccgo/v3 v3.16.15 h1:KbDR3ZAVU+wiLyMESPtbtE/Add4elztFyfsWoNTgxS0= +modernc.org/ccgo/v3 v3.16.15/go.mod h1:yT7B+/E2m43tmMOT51GMoM98/MtHIcQQSleGnddkUNI= modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk= modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= modernc.org/libc v1.24.1 h1:uvJSeCKL/AgzBo2yYIPPTy82v21KgGnizcGYfBHaNuM= modernc.org/libc v1.24.1/go.mod h1:FmfO1RLrU3MHJfyi9eYYmZBfi/R+tqZ6+hQ3yQQUkak= modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= -modernc.org/memory v1.6.0 h1:i6mzavxrE9a30whzMfwf7XWVODx2r5OYXvU46cirX7o= -modernc.org/memory v1.6.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= +modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E= +modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E= modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= -modernc.org/sqlite v1.25.0 h1:AFweiwPNd/b3BoKnBOfFm+Y260guGMF+0UFk0savqeA= -modernc.org/sqlite v1.25.0/go.mod h1:FL3pVXie73rg3Rii6V/u5BoHlSoyeZeIgKZEgHARyCU= -modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY= -modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw= +modernc.org/sqlite v1.26.0 h1:SocQdLRSYlA8W99V8YH0NES75thx19d9sB/aFc4R8Lw= +modernc.org/sqlite v1.26.0/go.mod h1:FL3pVXie73rg3Rii6V/u5BoHlSoyeZeIgKZEgHARyCU= +modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= +modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= modernc.org/tcl v1.15.2 h1:C4ybAYCGJw968e+Me18oW55kD/FexcHbqH2xak1ROSY= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/testdb/vertica.go b/internal/testdb/vertica.go index abe292bec..fc9cc1d9a 100644 --- a/internal/testdb/vertica.go +++ b/internal/testdb/vertica.go @@ -15,7 +15,7 @@ import ( const ( // https://hub.docker.com/r/vertica/vertica-ce VERTICA_IMAGE = "vertica/vertica-ce" - VERTICA_VERSION = "12.0.0-0" + VERTICA_VERSION = "23.3.0-0" VERTICA_DB = "testdb" ) From ccfb885423604a30b60ddce79439def8a7900e7d Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Fri, 6 Oct 2023 23:53:26 -0400 Subject: [PATCH 02/10] feat(experimental): goose provider with unimplemented methods (#596) --- dialect.go | 14 ++ go.mod | 1 + go.sum | 2 + internal/sqladapter/sqladapter.go | 49 +++++++ internal/sqladapter/store.go | 111 ++++++++++++++ internal/sqladapter/store_test.go | 218 ++++++++++++++++++++++++++++ internal/sqlextended/sqlextended.go | 23 +++ internal/sqlparser/parser.go | 8 + provider.go | 196 +++++++++++++++++++++++++ provider_options.go | 50 +++++++ provider_options_test.go | 100 +++++++++++++ 11 files changed, 772 insertions(+) create mode 100644 internal/sqladapter/sqladapter.go create mode 100644 internal/sqladapter/store.go create mode 100644 internal/sqladapter/store_test.go create mode 100644 internal/sqlextended/sqlextended.go create mode 100644 provider.go create mode 100644 provider_options.go create mode 100644 provider_options_test.go diff --git a/dialect.go b/dialect.go index a14248002..83c81c4dd 100644 --- a/dialect.go +++ b/dialect.go @@ -6,6 +6,20 @@ import ( "github.com/pressly/goose/v3/internal/dialect" ) +// Dialect is the type of database dialect. +type Dialect string + +const ( + DialectClickHouse Dialect = "clickhouse" + DialectMSSQL Dialect = "mssql" + DialectMySQL Dialect = "mysql" + DialectPostgres Dialect = "postgres" + DialectRedshift Dialect = "redshift" + DialectSQLite3 Dialect = "sqlite3" + DialectTiDB Dialect = "tidb" + DialectVertica Dialect = "vertica" +) + func init() { store, _ = dialect.NewStore(dialect.Postgres) } diff --git a/go.mod b/go.mod index 931d9a985..9beb962cf 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/ory/dockertest/v3 v3.10.0 github.com/vertica/vertica-sql-go v1.3.3 github.com/ziutek/mymysql v1.5.4 + go.uber.org/multierr v1.11.0 modernc.org/sqlite v1.26.0 ) diff --git a/go.sum b/go.sum index 0c34e9ce7..59b8881ef 100644 --- a/go.sum +++ b/go.sum @@ -159,6 +159,8 @@ go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs= go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY= go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg= go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/internal/sqladapter/sqladapter.go b/internal/sqladapter/sqladapter.go new file mode 100644 index 000000000..f6c975dc4 --- /dev/null +++ b/internal/sqladapter/sqladapter.go @@ -0,0 +1,49 @@ +// Package sqladapter provides an interface for interacting with a SQL database. +// +// All supported database dialects must implement the Store interface. +package sqladapter + +import ( + "context" + "time" + + "github.com/pressly/goose/v3/internal/sqlextended" +) + +// Store is the interface that wraps the basic methods for a database dialect. +// +// A dialect is a set of SQL statements that are specific to a database. +// +// By defining a store interface, we can support multiple databases with a single codebase. +// +// The underlying implementation does not modify the error. It is the callers responsibility to +// assert for the correct error, such as [sql.ErrNoRows]. +type Store interface { + // CreateVersionTable creates the version table within a transaction. This table is used to + // record applied migrations. + CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error + + // InsertOrDelete inserts or deletes a version id from the version table. + InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error + + // GetMigration retrieves a single migration by version id. + // + // Returns the raw sql error if the query fails. It is the callers responsibility to assert for + // the correct error, such as [sql.ErrNoRows]. + GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) + + // ListMigrations retrieves all migrations sorted in descending order by id. + // + // If there are no migrations, an empty slice is returned with no error. + ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) +} + +type GetMigrationResult struct { + IsApplied bool + Timestamp time.Time +} + +type ListMigrationsResult struct { + Version int64 + IsApplied bool +} diff --git a/internal/sqladapter/store.go b/internal/sqladapter/store.go new file mode 100644 index 000000000..0ee90ca49 --- /dev/null +++ b/internal/sqladapter/store.go @@ -0,0 +1,111 @@ +package sqladapter + +import ( + "context" + "errors" + "fmt" + + "github.com/pressly/goose/v3/internal/dialect/dialectquery" + "github.com/pressly/goose/v3/internal/sqlextended" +) + +var _ Store = (*store)(nil) + +type store struct { + tablename string + querier dialectquery.Querier +} + +// NewStore returns a new [Store] backed by the given dialect. +// +// The dialect must match one of the supported dialects defined in dialect.go. +func NewStore(dialect string, table string) (Store, error) { + if table == "" { + return nil, errors.New("table must not be empty") + } + if dialect == "" { + return nil, errors.New("dialect must not be empty") + } + var querier dialectquery.Querier + switch dialect { + case "clickhouse": + querier = &dialectquery.Clickhouse{} + case "mssql": + querier = &dialectquery.Sqlserver{} + case "mysql": + querier = &dialectquery.Mysql{} + case "postgres": + querier = &dialectquery.Postgres{} + case "redshift": + querier = &dialectquery.Redshift{} + case "sqlite3": + querier = &dialectquery.Sqlite3{} + case "tidb": + querier = &dialectquery.Tidb{} + case "vertica": + querier = &dialectquery.Vertica{} + default: + return nil, fmt.Errorf("unknown dialect: %q", dialect) + } + return &store{ + tablename: table, + querier: querier, + }, nil +} + +func (s *store) CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error { + q := s.querier.CreateTable(s.tablename) + if _, err := db.ExecContext(ctx, q); err != nil { + return fmt.Errorf("failed to create version table %q: %w", s.tablename, err) + } + return nil +} + +func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error { + if direction { + q := s.querier.InsertVersion(s.tablename) + if _, err := db.ExecContext(ctx, q, version, true); err != nil { + return fmt.Errorf("failed to insert version %d: %w", version, err) + } + return nil + } + q := s.querier.DeleteVersion(s.tablename) + if _, err := db.ExecContext(ctx, q, version); err != nil { + return fmt.Errorf("failed to delete version %d: %w", version, err) + } + return nil +} + +func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) { + q := s.querier.GetMigrationByVersion(s.tablename) + var result GetMigrationResult + if err := db.QueryRowContext(ctx, q, version).Scan( + &result.Timestamp, + &result.IsApplied, + ); err != nil { + return nil, fmt.Errorf("failed to get migration %d: %w", version, err) + } + return &result, nil +} + +func (s *store) ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) { + q := s.querier.ListMigrations(s.tablename) + rows, err := db.QueryContext(ctx, q) + if err != nil { + return nil, fmt.Errorf("failed to list migrations: %w", err) + } + defer rows.Close() + + var migrations []*ListMigrationsResult + for rows.Next() { + var result ListMigrationsResult + if err := rows.Scan(&result.Version, &result.IsApplied); err != nil { + return nil, fmt.Errorf("failed to scan list migrations result: %w", err) + } + migrations = append(migrations, &result) + } + if err := rows.Err(); err != nil { + return nil, err + } + return migrations, nil +} diff --git a/internal/sqladapter/store_test.go b/internal/sqladapter/store_test.go new file mode 100644 index 000000000..1d0189598 --- /dev/null +++ b/internal/sqladapter/store_test.go @@ -0,0 +1,218 @@ +package sqladapter_test + +import ( + "context" + "database/sql" + "errors" + "path/filepath" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/internal/testdb" + "go.uber.org/multierr" + "modernc.org/sqlite" +) + +// The goal of this test is to verify the sqladapter package works as expected. This test is not +// meant to be exhaustive or test every possible database dialect. It is meant to verify the Store +// interface works against a real database. + +func TestStore(t *testing.T) { + t.Parallel() + t.Run("invalid", func(t *testing.T) { + // Test empty table name. + _, err := sqladapter.NewStore("sqlite3", "") + check.HasError(t, err) + // Test unknown dialect. + _, err = sqladapter.NewStore("unknown-dialect", "foo") + check.HasError(t, err) + // Test empty dialect. + _, err = sqladapter.NewStore("", "foo") + check.HasError(t, err) + }) + t.Run("postgres", func(t *testing.T) { + if testing.Short() { + t.Skip("skip long-running test") + } + // Test postgres specific behavior. + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + testStore(context.Background(), t, goose.DialectPostgres, db, func(t *testing.T, err error) { + var pgErr *pgconn.PgError + ok := errors.As(err, &pgErr) + check.Bool(t, ok, true) + check.Equal(t, pgErr.Code, "42P07") // duplicate_table + }) + }) + // Test generic behavior. + t.Run("sqlite3", func(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + testStore(context.Background(), t, goose.DialectSQLite3, db, func(t *testing.T, err error) { + var sqliteErr *sqlite.Error + ok := errors.As(err, &sqliteErr) + check.Bool(t, ok, true) + check.Equal(t, sqliteErr.Code(), 1) // Generic error (SQLITE_ERROR) + check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists") + }) + }) +} + +// testStore tests various store operations. +// +// If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable +// when the version table already exists. +func testStore(ctx context.Context, t *testing.T, dialect goose.Dialect, db *sql.DB, alreadyExists func(t *testing.T, err error)) { + const ( + tablename = "test_goose_db_version" + ) + store, err := sqladapter.NewStore(string(dialect), tablename) + check.NoError(t, err) + // Create the version table. + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.CreateVersionTable(ctx, tx) + }) + check.NoError(t, err) + // Create the version table again. This should fail. + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.CreateVersionTable(ctx, tx) + }) + check.HasError(t, err) + if alreadyExists != nil { + alreadyExists(t, err) + } + + // List migrations. There should be none. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 0) + return nil + }) + check.NoError(t, err) + + // Insert 5 migrations in addition to the zero migration. + for i := 0; i < 6; i++ { + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, true, int64(i)) + }) + check.NoError(t, err) + } + + // List migrations. There should be 6. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 6) + // Check versions are in descending order. + for i := 0; i < 6; i++ { + check.Number(t, res[i].Version, 5-i) + } + return nil + }) + check.NoError(t, err) + + // Delete 3 migrations backwards + for i := 5; i >= 3; i-- { + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, false, int64(i)) + }) + check.NoError(t, err) + } + + // List migrations. There should be 3. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 3) + // Check that the remaining versions are in descending order. + for i := 0; i < 3; i++ { + check.Number(t, res[i].Version, 2-i) + } + return nil + }) + check.NoError(t, err) + + // Get remaining migrations one by one. + for i := 0; i < 3; i++ { + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.GetMigration(ctx, conn, int64(i)) + check.NoError(t, err) + check.Equal(t, res.IsApplied, true) + check.Equal(t, res.Timestamp.IsZero(), false) + return nil + }) + check.NoError(t, err) + } + + // Delete remaining migrations one by one and use all 3 connection types: + + // 1. *sql.Tx + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.InsertOrDelete(ctx, tx, false, 2) + }) + check.NoError(t, err) + // 2. *sql.Conn + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, false, 1) + }) + check.NoError(t, err) + // 3. *sql.DB + err = store.InsertOrDelete(ctx, db, false, 0) + check.NoError(t, err) + + // List migrations. There should be none. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 0) + return nil + }) + check.NoError(t, err) + + // Try to get a migration that does not exist. + err = runConn(ctx, db, func(conn *sql.Conn) error { + _, err := store.GetMigration(ctx, conn, 0) + check.HasError(t, err) + check.Bool(t, errors.Is(err, sql.ErrNoRows), true) + return nil + }) + check.NoError(t, err) +} + +func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, tx.Rollback()) + } + }() + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func runConn(ctx context.Context, db *sql.DB, fn func(*sql.Conn) error) (retErr error) { + conn, err := db.Conn(ctx) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, conn.Close()) + } + }() + if err := fn(conn); err != nil { + return err + } + return conn.Close() +} diff --git a/internal/sqlextended/sqlextended.go b/internal/sqlextended/sqlextended.go new file mode 100644 index 000000000..e3e763abf --- /dev/null +++ b/internal/sqlextended/sqlextended.go @@ -0,0 +1,23 @@ +package sqlextended + +import ( + "context" + "database/sql" +) + +// DBTxConn is a thin interface for common method that is satisfied by *sql.DB, *sql.Tx and +// *sql.Conn. +// +// There is a long outstanding issue to formalize a std lib interface, but alas... See: +// https://github.com/golang/go/issues/14468 +type DBTxConn interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +var ( + _ DBTxConn = (*sql.DB)(nil) + _ DBTxConn = (*sql.Tx)(nil) + _ DBTxConn = (*sql.Conn)(nil) +) diff --git a/internal/sqlparser/parser.go b/internal/sqlparser/parser.go index 5e6c67503..a62846026 100644 --- a/internal/sqlparser/parser.go +++ b/internal/sqlparser/parser.go @@ -25,6 +25,14 @@ func FromBool(b bool) Direction { return DirectionDown } +func (d Direction) String() string { + return string(d) +} + +func (d Direction) ToBool() bool { + return d == DirectionUp +} + type parserState int const ( diff --git a/provider.go b/provider.go new file mode 100644 index 000000000..c12d4ea8b --- /dev/null +++ b/provider.go @@ -0,0 +1,196 @@ +package goose + +import ( + "context" + "database/sql" + "errors" + "io/fs" + "time" + + "github.com/pressly/goose/v3/internal/sqladapter" +) + +// NewProvider returns a new goose Provider. +// +// The caller is responsible for matching the database dialect with the database/sql driver. For +// example, if the database dialect is "postgres", the database/sql driver could be +// github.com/lib/pq or github.com/jackc/pgx. +// +// fsys is the filesystem used to read the migration files. Most users will want to use +// os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is +// possible to use a different filesystem, such as embed.FS. +// +// Functional options are used to configure the Provider. See [ProviderOption] for more information. +// +// Unless otherwise specified, all methods on Provider are safe for concurrent use. +// +// Experimental: This API is experimental and may change in the future. +func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { + if db == nil { + return nil, errors.New("db must not be nil") + } + if dialect == "" { + return nil, errors.New("dialect must not be empty") + } + if fsys == nil { + return nil, errors.New("fsys must not be nil") + } + var cfg config + for _, opt := range opts { + if err := opt.apply(&cfg); err != nil { + return nil, err + } + } + // Set defaults + if cfg.tableName == "" { + cfg.tableName = defaultTablename + } + store, err := sqladapter.NewStore(string(dialect), cfg.tableName) + if err != nil { + return nil, err + } + // TODO(mf): implement the rest of this function - collect sources - merge sources into + // migrations + return &Provider{ + db: db, + fsys: fsys, + cfg: cfg, + store: store, + }, nil +} + +// Provider is a goose migration provider. +// Experimental: This API is experimental and may change in the future. +type Provider struct { + db *sql.DB + fsys fs.FS + cfg config + store sqladapter.Store +} + +// MigrationStatus represents the status of a single migration. +type MigrationStatus struct { + // State represents the state of the migration. One of "untracked", "pending", "applied". + // - untracked: in the database, but not on the filesystem. + // - pending: on the filesystem, but not in the database. + // - applied: in both the database and on the filesystem. + State string + // AppliedAt is the time the migration was applied. Only set if state is applied or untracked. + AppliedAt time.Time + // Source is the migration source. Only set if the state is pending or applied. + Source Source +} + +// Status returns the status of all migrations, merging the list of migrations from the database and +// filesystem. The returned items are ordered by version, in ascending order. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { + return nil, errors.New("not implemented") +} + +// GetDBVersion returns the max version from the database, regardless of the applied order. For +// example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been +// applied, it returns 0. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { + return 0, errors.New("not implemented") +} + +// SourceType represents the type of migration source. +type SourceType string + +const ( + // SourceTypeSQL represents a SQL migration. + SourceTypeSQL SourceType = "sql" + // SourceTypeGo represents a Go migration. + SourceTypeGo SourceType = "go" +) + +// Source represents a single migration source. +// +// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if +// the migration has a corresponding file on disk. It will be empty if the migration was registered +// manually. +// Experimental: This API is experimental and may change in the future. +type Source struct { + // Type is the type of migration. + Type SourceType + // Full path to the migration file. + // + // Example: /path/to/migrations/001_create_users_table.sql + Fullpath string + // Version is the version of the migration. + Version int64 +} + +// ListSources returns a list of all available migration sources the provider is aware of. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) ListSources() []*Source { + return nil +} + +// Ping attempts to ping the database to verify a connection is available. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) Ping(ctx context.Context) error { + return errors.New("not implemented") +} + +// Close closes the database connection. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) Close() error { + return errors.New("not implemented") +} + +// MigrationResult represents the result of a single migration. +type MigrationResult struct{} + +// ApplyVersion applies exactly one migration at the specified version. If there is no source for +// the specified version, this method returns [ErrNoCurrentVersion]. If the migration has been +// applied already, this method returns [ErrAlreadyApplied]. +// +// When direction is true, the up migration is executed, and when direction is false, the down +// migration is executed. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) { + return nil, errors.New("not implemented") +} + +// Up applies all pending migrations. If there are no new migrations to apply, this method returns +// empty list and nil error. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { + return nil, errors.New("not implemented") +} + +// UpByOne applies the next available migration. If there are no migrations to apply, this method +// returns [ErrNoNextVersion]. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { + return nil, errors.New("not implemented") +} + +// UpTo applies all available migrations up to and including the specified version. If there are no +// migrations to apply, this method returns empty list and nil error. +// +// For instance, if there are three new migrations (9,10,11) and the current database version is 8 +// with a requested version of 10, only versions 9 and 10 will be applied. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { + return nil, errors.New("not implemented") +} + +// Down rolls back the most recently applied migration. If there are no migrations to apply, this +// method returns [ErrNoNextVersion]. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { + return nil, errors.New("not implemented") +} + +// DownTo rolls back all migrations down to but not including the specified version. +// +// For instance, if the current database version is 11, and the requested version is 9, only +// migrations 11 and 10 will be rolled back. +// Experimental: This API is experimental and may change in the future. +func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { + return nil, errors.New("not implemented") +} diff --git a/provider_options.go b/provider_options.go new file mode 100644 index 000000000..904b3ed34 --- /dev/null +++ b/provider_options.go @@ -0,0 +1,50 @@ +package goose + +import ( + "errors" + "fmt" +) + +const ( + defaultTablename = "goose_db_version" +) + +// ProviderOption is a configuration option for a goose provider. +type ProviderOption interface { + apply(*config) error +} + +// WithTableName sets the name of the database table used to track history of applied migrations. +// +// If WithTableName is not called, the default value is "goose_db_version". +func WithTableName(name string) ProviderOption { + return configFunc(func(c *config) error { + if c.tableName != "" { + return fmt.Errorf("table already set to %q", c.tableName) + } + if name == "" { + return errors.New("table must not be empty") + } + c.tableName = name + return nil + }) +} + +// WithVerbose enables verbose logging. +func WithVerbose() ProviderOption { + return configFunc(func(c *config) error { + c.verbose = true + return nil + }) +} + +type config struct { + tableName string + verbose bool +} + +type configFunc func(*config) error + +func (o configFunc) apply(cfg *config) error { + return o(cfg) +} diff --git a/provider_options_test.go b/provider_options_test.go new file mode 100644 index 000000000..629c6efaa --- /dev/null +++ b/provider_options_test.go @@ -0,0 +1,100 @@ +package goose_test + +import ( + "database/sql" + "io/fs" + "path/filepath" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/check" +) + +func TestNewProvider(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + fsys := newFsys() + t.Run("invalid", func(t *testing.T) { + // Empty dialect not allowed + _, err = goose.NewProvider("", db, fsys) + check.HasError(t, err) + // Invalid dialect not allowed + _, err = goose.NewProvider("unknown-dialect", db, fsys) + check.HasError(t, err) + // Nil db not allowed + _, err = goose.NewProvider("sqlite3", nil, fsys) + check.HasError(t, err) + // Nil fsys not allowed + _, err = goose.NewProvider("sqlite3", db, nil) + check.HasError(t, err) + // Duplicate table name not allowed + _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithTableName("foo"), goose.WithTableName("bar")) + check.HasError(t, err) + check.Equal(t, `table already set to "foo"`, err.Error()) + // Empty table name not allowed + _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithTableName("")) + check.HasError(t, err) + check.Equal(t, "table must not be empty", err.Error()) + }) + t.Run("valid", func(t *testing.T) { + // Valid dialect, db, and fsys allowed + _, err = goose.NewProvider("sqlite3", db, fsys) + check.NoError(t, err) + // Valid dialect, db, fsys, and table name allowed + _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithTableName("foo")) + check.NoError(t, err) + // Valid dialect, db, fsys, and verbose allowed + _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithVerbose()) + check.NoError(t, err) + }) +} + +func newFsys() fs.FS { + return fstest.MapFS{ + "1_foo.sql": {Data: []byte(migration1)}, + "2_bar.sql": {Data: []byte(migration2)}, + "3_baz.sql": {Data: []byte(migration3)}, + "4_qux.sql": {Data: []byte(migration4)}, + } +} + +var ( + migration1 = ` +-- +goose Up +CREATE TABLE foo (id INTEGER PRIMARY KEY); +-- +goose Down +DROP TABLE foo; +` + migration2 = ` +-- +goose Up +ALTER TABLE foo ADD COLUMN name TEXT; +-- +goose Down +ALTER TABLE foo DROP COLUMN name; +` + migration3 = ` +-- +goose Up +CREATE TABLE bar ( + id INTEGER PRIMARY KEY, + description TEXT +); +-- +goose Down +DROP TABLE bar; +` + migration4 = ` +-- +goose Up +-- Rename the 'foo' table to 'my_foo' +ALTER TABLE foo RENAME TO my_foo; + +-- Add a new column 'timestamp' to 'my_foo' +ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP; + +-- +goose Down +-- Remove the 'timestamp' column from 'my_foo' +ALTER TABLE my_foo DROP COLUMN timestamp; + +-- Rename the 'my_foo' table back to 'foo' +ALTER TABLE my_foo RENAME TO foo; +` +) From c590380f39cdea66e5000b3836f146df836a5c74 Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Mon, 9 Oct 2023 15:08:51 -0400 Subject: [PATCH 03/10] feat(experimental): add internal migrate package and SessionLocker interface (#606) --- go.mod | 1 + go.sum | 2 + internal/migrate/doc.go | 9 ++ internal/migrate/migration.go | 166 ++++++++++++++++++++++++ internal/migrate/parse.go | 75 +++++++++++ internal/migrate/run.go | 53 ++++++++ internal/sqlextended/sqlextended.go | 2 +- lock/postgres.go | 110 ++++++++++++++++ lock/postgres_test.go | 193 ++++++++++++++++++++++++++++ lock/session_locker.go | 23 ++++ lock/session_locker_options.go | 63 +++++++++ provider_options.go | 29 ++++- 12 files changed, 723 insertions(+), 3 deletions(-) create mode 100644 internal/migrate/doc.go create mode 100644 internal/migrate/migration.go create mode 100644 internal/migrate/parse.go create mode 100644 internal/migrate/run.go create mode 100644 lock/postgres.go create mode 100644 lock/postgres_test.go create mode 100644 lock/session_locker.go create mode 100644 lock/session_locker_options.go diff --git a/go.mod b/go.mod index 9beb962cf..0230e2c97 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/jackc/pgx/v5 v5.4.3 github.com/microsoft/go-mssqldb v1.6.0 github.com/ory/dockertest/v3 v3.10.0 + github.com/sethvargo/go-retry v0.2.4 github.com/vertica/vertica-sql-go v1.3.3 github.com/ziutek/mymysql v1.5.4 go.uber.org/multierr v1.11.0 diff --git a/go.sum b/go.sum index 59b8881ef..e2c6ea03a 100644 --- a/go.sum +++ b/go.sum @@ -127,6 +127,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qq github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/sethvargo/go-retry v0.2.4 h1:T+jHEQy/zKJf5s95UkguisicE0zuF9y7+/vgz08Ocec= +github.com/sethvargo/go-retry v0.2.4/go.mod h1:1afjQuvh7s4gflMObvjLPaWgluLLyhA1wmVZ6KLpICw= github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/internal/migrate/doc.go b/internal/migrate/doc.go new file mode 100644 index 000000000..5fbee1582 --- /dev/null +++ b/internal/migrate/doc.go @@ -0,0 +1,9 @@ +// Package migrate defines a Migration struct and implements the migration logic for executing Go +// and SQL migrations. +// +// - For Go migrations, only *sql.Tx and *sql.DB are supported. *sql.Conn is not supported. +// - For SQL migrations, all three are supported. +// +// Lastly, SQL migrations are lazily parsed. This means that the SQL migration is parsed the first +// time it is executed. +package migrate diff --git a/internal/migrate/migration.go b/internal/migrate/migration.go new file mode 100644 index 000000000..23a0514cf --- /dev/null +++ b/internal/migrate/migration.go @@ -0,0 +1,166 @@ +package migrate + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/pressly/goose/v3/internal/sqlextended" +) + +type Migration struct { + // Fullpath is the full path to the migration file. + // + // Example: /path/to/migrations/123_create_users_table.go + Fullpath string + // Version is the version of the migration. + Version int64 + // Type is the type of migration. + Type MigrationType + // A migration is either a Go migration or a SQL migration, but never both. + // + // Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is + // an optimization to avoid parsing the SQL migration if it is never required. Also, the + // majority of the time migrations are incremental, so it is likely that the user will only want + // to run the last few migrations, and there is no need to parse ALL prior migrations. + // + // Exactly one of these fields will be set: + Go *Go + // -- or -- + SQLParsed bool + SQL *SQL +} + +type MigrationType int + +const ( + TypeGo MigrationType = iota + 1 + TypeSQL +) + +func (t MigrationType) String() string { + switch t { + case TypeGo: + return "go" + case TypeSQL: + return "sql" + default: + // This should never happen. + return "unknown" + } +} + +func (m *Migration) UseTx() bool { + switch m.Type { + case TypeGo: + return m.Go.UseTx + case TypeSQL: + return m.SQL.UseTx + default: + // This should never happen. + panic("unknown migration type: use tx") + } +} + +func (m *Migration) IsEmpty(direction bool) bool { + switch m.Type { + case TypeGo: + return m.Go.IsEmpty(direction) + case TypeSQL: + return m.SQL.IsEmpty(direction) + default: + // This should never happen. + panic("unknown migration type: is empty") + } +} + +func (m *Migration) GetSQLStatements(direction bool) ([]string, error) { + if m.Type != TypeSQL { + return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Type) + } + if m.SQL == nil { + return nil, errors.New("sql migration has not been initialized") + } + if !m.SQLParsed { + return nil, errors.New("sql migration has not been parsed") + } + if direction { + return m.SQL.UpStatements, nil + } + return m.SQL.DownStatements, nil +} + +type Go struct { + // We used an explicit bool instead of relying on a pointer because registered funcs may be nil. + // These are still valid Go and versioned migrations, but they are just empty. + // + // For example: goose.AddMigration(nil, nil) + UseTx bool + + // Only one of these func pairs will be set: + UpFn, DownFn func(context.Context, *sql.Tx) error + // -- or -- + UpFnNoTx, DownFnNoTx func(context.Context, *sql.DB) error +} + +func (g *Go) IsEmpty(direction bool) bool { + if direction { + return g.UpFn == nil && g.UpFnNoTx == nil + } + return g.DownFn == nil && g.DownFnNoTx == nil +} + +func (g *Go) run(ctx context.Context, tx *sql.Tx, direction bool) error { + var fn func(context.Context, *sql.Tx) error + if direction { + fn = g.UpFn + } else { + fn = g.DownFn + } + if fn != nil { + return fn(ctx, tx) + } + return nil +} + +func (g *Go) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { + var fn func(context.Context, *sql.DB) error + if direction { + fn = g.UpFnNoTx + } else { + fn = g.DownFnNoTx + } + if fn != nil { + return fn(ctx, db) + } + return nil +} + +type SQL struct { + UseTx bool + UpStatements []string + DownStatements []string +} + +func (s *SQL) IsEmpty(direction bool) bool { + if direction { + return len(s.UpStatements) == 0 + } + return len(s.DownStatements) == 0 +} + +func (s *SQL) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error { + var statements []string + if direction { + statements = s.UpStatements + } else { + statements = s.DownStatements + } + for _, stmt := range statements { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return err + } + } + return nil +} diff --git a/internal/migrate/parse.go b/internal/migrate/parse.go new file mode 100644 index 000000000..18a66b499 --- /dev/null +++ b/internal/migrate/parse.go @@ -0,0 +1,75 @@ +package migrate + +import ( + "bytes" + "io" + "io/fs" + + "github.com/pressly/goose/v3/internal/sqlparser" +) + +// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it +// will not be parsed again. +// +// Important: This function will mutate SQL migrations. +func ParseSQL(fsys fs.FS, debug bool, migrations []*Migration) error { + for _, m := range migrations { + if m.Type == TypeSQL && !m.SQLParsed { + parsedSQLMigration, err := parseSQL(fsys, m.Fullpath, parseAll, debug) + if err != nil { + return err + } + m.SQLParsed = true + m.SQL = parsedSQLMigration + } + } + return nil +} + +// parse is used to determine which direction to parse the SQL migration. +type parse int + +const ( + // parseAll parses all SQL statements in BOTH directions. + parseAll parse = iota + 1 + // parseUp parses all SQL statements in the UP direction. + parseUp + // parseDown parses all SQL statements in the DOWN direction. + parseDown +) + +func parseSQL(fsys fs.FS, filename string, p parse, debug bool) (*SQL, error) { + r, err := fsys.Open(filename) + if err != nil { + return nil, err + } + by, err := io.ReadAll(r) + if err != nil { + return nil, err + } + if err := r.Close(); err != nil { + return nil, err + } + s := new(SQL) + if p == parseAll || p == parseUp { + s.UpStatements, s.UseTx, err = sqlparser.ParseSQLMigration( + bytes.NewReader(by), + sqlparser.DirectionUp, + debug, + ) + if err != nil { + return nil, err + } + } + if p == parseAll || p == parseDown { + s.DownStatements, s.UseTx, err = sqlparser.ParseSQLMigration( + bytes.NewReader(by), + sqlparser.DirectionDown, + debug, + ) + if err != nil { + return nil, err + } + } + return s, nil +} diff --git a/internal/migrate/run.go b/internal/migrate/run.go new file mode 100644 index 000000000..7b7a883d8 --- /dev/null +++ b/internal/migrate/run.go @@ -0,0 +1,53 @@ +package migrate + +import ( + "context" + "database/sql" + "fmt" + "path/filepath" +) + +// Run runs the migration inside of a transaction. +func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error { + switch m.Type { + case TypeSQL: + if m.SQL == nil || !m.SQLParsed { + return fmt.Errorf("tx: sql migration has not been parsed") + } + return m.SQL.run(ctx, tx, direction) + case TypeGo: + return m.Go.run(ctx, tx, direction) + } + // This should never happen. + return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) +} + +// RunNoTx runs the migration without a transaction. +func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error { + switch m.Type { + case TypeSQL: + if m.SQL == nil || !m.SQLParsed { + return fmt.Errorf("db: sql migration has not been parsed") + } + return m.SQL.run(ctx, db, direction) + case TypeGo: + return m.Go.runNoTx(ctx, db, direction) + } + // This should never happen. + return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) +} + +// RunConn runs the migration without a transaction using the provided connection. +func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error { + switch m.Type { + case TypeSQL: + if m.SQL == nil || !m.SQLParsed { + return fmt.Errorf("conn: sql migration has not been parsed") + } + return m.SQL.run(ctx, conn, direction) + case TypeGo: + return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") + } + // This should never happen. + return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) +} diff --git a/internal/sqlextended/sqlextended.go b/internal/sqlextended/sqlextended.go index e3e763abf..83ca7ae8b 100644 --- a/internal/sqlextended/sqlextended.go +++ b/internal/sqlextended/sqlextended.go @@ -11,7 +11,7 @@ import ( // There is a long outstanding issue to formalize a std lib interface, but alas... See: // https://github.com/golang/go/issues/14468 type DBTxConn interface { - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row } diff --git a/lock/postgres.go b/lock/postgres.go new file mode 100644 index 000000000..3583162e2 --- /dev/null +++ b/lock/postgres.go @@ -0,0 +1,110 @@ +package lock + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/sethvargo/go-retry" +) + +// NewPostgresSessionLocker returns a SessionLocker that utilizes PostgreSQL's exclusive +// session-level advisory lock mechanism. +// +// This function creates a SessionLocker that can be used to acquire and release locks for +// synchronization purposes. The lock acquisition is retried until it is successfully acquired or +// until the maximum duration is reached. The default lock duration is set to 60 minutes, and the +// default unlock duration is set to 1 minute. +// +// See [SessionLockerOption] for options that can be used to configure the SessionLocker. +func NewPostgresSessionLocker(opts ...SessionLockerOption) (SessionLocker, error) { + cfg := sessionLockerConfig{ + lockID: DefaultLockID, + lockTimeout: DefaultLockTimeout, + unlockTimeout: DefaultUnlockTimeout, + } + for _, opt := range opts { + if err := opt.apply(&cfg); err != nil { + return nil, err + } + } + return &postgresSessionLocker{ + lockID: cfg.lockID, + retryLock: retry.WithMaxDuration( + cfg.lockTimeout, + retry.NewConstant(2*time.Second), + ), + retryUnlock: retry.WithMaxDuration( + cfg.unlockTimeout, + retry.NewConstant(2*time.Second), + ), + }, nil +} + +type postgresSessionLocker struct { + lockID int64 + retryLock retry.Backoff + retryUnlock retry.Backoff +} + +var _ SessionLocker = (*postgresSessionLocker)(nil) + +func (l *postgresSessionLocker) SessionLock(ctx context.Context, conn *sql.Conn) error { + return retry.Do(ctx, l.retryLock, func(ctx context.Context) error { + row := conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", l.lockID) + var locked bool + if err := row.Scan(&locked); err != nil { + return fmt.Errorf("failed to execute pg_try_advisory_lock: %w", err) + } + if locked { + // A session-level advisory lock was acquired. + return nil + } + // A session-level advisory lock could not be acquired. This is likely because another + // process has already acquired the lock. We will continue retrying until the lock is + // acquired or the maximum number of retries is reached. + return retry.RetryableError(errors.New("failed to acquire lock")) + }) +} + +func (l *postgresSessionLocker) SessionUnlock(ctx context.Context, conn *sql.Conn) error { + return retry.Do(ctx, l.retryUnlock, func(ctx context.Context) error { + var unlocked bool + row := conn.QueryRowContext(ctx, "SELECT pg_advisory_unlock($1)", l.lockID) + if err := row.Scan(&unlocked); err != nil { + return fmt.Errorf("failed to execute pg_advisory_unlock: %w", err) + } + if unlocked { + // A session-level advisory lock was released. + return nil + } + /* + TODO(mf): provide users with some documentation on how they can unlock the session + manually. + + This is probably not an issue for 99.99% of users since pg_advisory_unlock_all() will + release all session level advisory locks held by the current session. This function is + implicitly invoked at session end, even if the client disconnects ungracefully. + + Here is output from a session that has a lock held: + + SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks + WHERE locktype='advisory'; + + | pid | granted | goose_lock_id | + |-----|---------|---------------------| + | 191 | t | 5887940537704921958 | + + A forceful way to unlock the session is to terminate the backend with SIGTERM: + + SELECT pg_terminate_backend(191); + + Subsequent commands on the same connection will fail with: + + Query 1 ERROR: FATAL: terminating connection due to administrator command + */ + return retry.RetryableError(errors.New("failed to unlock session")) + }) +} diff --git a/lock/postgres_test.go b/lock/postgres_test.go new file mode 100644 index 000000000..bfb1a0d99 --- /dev/null +++ b/lock/postgres_test.go @@ -0,0 +1,193 @@ +package lock_test + +import ( + "context" + "database/sql" + "errors" + "sync" + "testing" + "time" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/testdb" + "github.com/pressly/goose/v3/lock" +) + +func TestPostgresSessionLocker(t *testing.T) { + if testing.Short() { + t.Skip("skip long running test") + } + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + const ( + lockID int64 = 123456789 + ) + + // Do not run tests in parallel, because they are using the same database. + + t.Run("lock_and_unlock", func(t *testing.T) { + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockID(lockID), + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + ctx := context.Background() + conn, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn.Close()) + }) + err = locker.SessionLock(ctx, conn) + check.NoError(t, err) + pgLocks, err := queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + // Check that the lock was acquired. + check.Bool(t, pgLocks[0].granted, true) + // Check that the custom lock ID is the same as the one used by the locker. + check.Equal(t, pgLocks[0].gooseLockID, lockID) + check.NumberNotZero(t, pgLocks[0].pid) + + // Check that the lock is released. + err = locker.SessionUnlock(ctx, conn) + check.NoError(t, err) + pgLocks, err = queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 0) + }) + t.Run("lock_close_conn_unlock", func(t *testing.T) { + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + ctx := context.Background() + conn, err := db.Conn(ctx) + check.NoError(t, err) + + err = locker.SessionLock(ctx, conn) + check.NoError(t, err) + pgLocks, err := queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + check.Bool(t, pgLocks[0].granted, true) + check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID) + // Simulate a connection close. + err = conn.Close() + check.NoError(t, err) + // Check an error is returned when unlocking, because the connection is already closed. + err = locker.SessionUnlock(ctx, conn) + check.HasError(t, err) + check.Bool(t, errors.Is(err, sql.ErrConnDone), true) + }) + t.Run("multiple_connections", func(t *testing.T) { + const ( + workers = 5 + ) + ch := make(chan error) + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + ctx := context.Background() + conn, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn.Close()) + }) + // Exactly one connection should acquire the lock. While the other connections + // should fail to acquire the lock and timeout. + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + ch <- locker.SessionLock(ctx, conn) + }() + } + go func() { + wg.Wait() + close(ch) + }() + var errors []error + for err := range ch { + if err != nil { + errors = append(errors, err) + } + } + check.Equal(t, len(errors), workers-1) // One worker succeeds, the rest fail. + for _, err := range errors { + check.HasError(t, err) + check.Equal(t, err.Error(), "failed to acquire lock") + } + pgLocks, err := queryPgLocks(context.Background(), db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + check.Bool(t, pgLocks[0].granted, true) + check.Equal(t, pgLocks[0].gooseLockID, lock.DefaultLockID) + }) + t.Run("unlock_with_different_connection", func(t *testing.T) { + ctx := context.Background() + const ( + lockID int64 = 999 + ) + locker, err := lock.NewPostgresSessionLocker( + lock.WithLockID(lockID), + lock.WithLockTimeout(4*time.Second), + lock.WithUnlockTimeout(4*time.Second), + ) + check.NoError(t, err) + + conn1, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn1.Close()) + }) + err = locker.SessionLock(ctx, conn1) + check.NoError(t, err) + pgLocks, err := queryPgLocks(ctx, db) + check.NoError(t, err) + check.Number(t, len(pgLocks), 1) + check.Bool(t, pgLocks[0].granted, true) + check.Equal(t, pgLocks[0].gooseLockID, lockID) + // Unlock with a different connection. + conn2, err := db.Conn(ctx) + check.NoError(t, err) + t.Cleanup(func() { + check.NoError(t, conn2.Close()) + }) + // Check an error is returned when unlocking with a different connection. + err = locker.SessionUnlock(ctx, conn2) + check.HasError(t, err) + }) +} + +type pgLock struct { + pid int + granted bool + gooseLockID int64 +} + +func queryPgLocks(ctx context.Context, db *sql.DB) ([]pgLock, error) { + q := `SELECT pid,granted,((classid::bigint<<32)|objid::bigint)AS goose_lock_id FROM pg_locks WHERE locktype='advisory'` + rows, err := db.QueryContext(ctx, q) + if err != nil { + return nil, err + } + var pgLocks []pgLock + for rows.Next() { + var p pgLock + if err = rows.Scan(&p.pid, &p.granted, &p.gooseLockID); err != nil { + return nil, err + } + pgLocks = append(pgLocks, p) + } + if err := rows.Err(); err != nil { + return nil, err + } + return pgLocks, nil +} diff --git a/lock/session_locker.go b/lock/session_locker.go new file mode 100644 index 000000000..b74187829 --- /dev/null +++ b/lock/session_locker.go @@ -0,0 +1,23 @@ +// Package lock defines the Locker interface and implements the locking logic. +package lock + +import ( + "context" + "database/sql" + "errors" +) + +var ( + // ErrLockNotImplemented is returned when the database does not support locking. + ErrLockNotImplemented = errors.New("lock not implemented") + // ErrUnlockNotImplemented is returned when the database does not support unlocking. + ErrUnlockNotImplemented = errors.New("unlock not implemented") +) + +// SessionLocker is the interface to lock and unlock the database for the duration of a session. The +// session is defined as the duration of a single connection and both methods must be called on the +// same connection. +type SessionLocker interface { + SessionLock(ctx context.Context, conn *sql.Conn) error + SessionUnlock(ctx context.Context, conn *sql.Conn) error +} diff --git a/lock/session_locker_options.go b/lock/session_locker_options.go new file mode 100644 index 000000000..c3e42151c --- /dev/null +++ b/lock/session_locker_options.go @@ -0,0 +1,63 @@ +package lock + +import ( + "time" +) + +const ( + // DefaultLockID is the id used to lock the database for migrations. It is a crc64 hash of the + // string "goose". This is used to ensure that the lock is unique to goose. + // + // crc64.Checksum([]byte("goose"), crc64.MakeTable(crc64.ECMA)) + DefaultLockID int64 = 5887940537704921958 + + // Default values for the lock (time to wait for the lock to be acquired) and unlock (time to + // wait for the lock to be released) wait durations. + DefaultLockTimeout time.Duration = 60 * time.Minute + DefaultUnlockTimeout time.Duration = 1 * time.Minute +) + +// SessionLockerOption is used to configure a SessionLocker. +type SessionLockerOption interface { + apply(*sessionLockerConfig) error +} + +// WithLockID sets the lock ID to use when locking the database. +// +// If WithLockID is not called, the DefaultLockID is used. +func WithLockID(lockID int64) SessionLockerOption { + return sessionLockerConfigFunc(func(c *sessionLockerConfig) error { + c.lockID = lockID + return nil + }) +} + +// WithLockTimeout sets the max duration to wait for the lock to be acquired. +func WithLockTimeout(duration time.Duration) SessionLockerOption { + return sessionLockerConfigFunc(func(c *sessionLockerConfig) error { + c.lockTimeout = duration + return nil + }) +} + +// WithUnlockTimeout sets the max duration to wait for the lock to be released. +func WithUnlockTimeout(duration time.Duration) SessionLockerOption { + return sessionLockerConfigFunc(func(c *sessionLockerConfig) error { + c.unlockTimeout = duration + return nil + }) +} + +type sessionLockerConfig struct { + lockID int64 + lockTimeout time.Duration + unlockTimeout time.Duration +} + +var _ SessionLockerOption = (sessionLockerConfigFunc)(nil) + +type sessionLockerConfigFunc func(*sessionLockerConfig) error + +func (f sessionLockerConfigFunc) apply(cfg *sessionLockerConfig) error { + return f(cfg) +} diff --git a/provider_options.go b/provider_options.go index 904b3ed34..2370486f9 100644 --- a/provider_options.go +++ b/provider_options.go @@ -3,6 +3,8 @@ package goose import ( "errors" "fmt" + + "github.com/pressly/goose/v3/lock" ) const ( @@ -38,13 +40,36 @@ func WithVerbose() ProviderOption { }) } +// WithSessionLocker enables locking using the provided SessionLocker. +// +// If WithSessionLocker is not called, locking is disabled. +func WithSessionLocker(locker lock.SessionLocker) ProviderOption { + return configFunc(func(c *config) error { + if c.lockEnabled { + return errors.New("lock already enabled") + } + if c.sessionLocker != nil { + return errors.New("session locker already set") + } + if locker == nil { + return errors.New("session locker must not be nil") + } + c.lockEnabled = true + c.sessionLocker = locker + return nil + }) +} + type config struct { tableName string verbose bool + + lockEnabled bool + sessionLocker lock.SessionLocker } type configFunc func(*config) error -func (o configFunc) apply(cfg *config) error { - return o(cfg) +func (f configFunc) apply(cfg *config) error { + return f(cfg) } From e696fa3ba534f3766f50b2e9789d537316017bc4 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Tue, 10 Oct 2023 08:59:05 -0400 Subject: [PATCH 04/10] feat(experimental): move Provider to an internal package --- provider.go => internal/provider/provider.go | 6 +++--- .../provider/provider_options.go | 2 +- .../provider/provider_options_test.go | 21 +++++++++---------- 3 files changed, 14 insertions(+), 15 deletions(-) rename provider.go => internal/provider/provider.go (97%) rename provider_options.go => internal/provider/provider_options.go (98%) rename provider_options_test.go => internal/provider/provider_options_test.go (76%) diff --git a/provider.go b/internal/provider/provider.go similarity index 97% rename from provider.go rename to internal/provider/provider.go index c12d4ea8b..c8d899511 100644 --- a/provider.go +++ b/internal/provider/provider.go @@ -1,4 +1,4 @@ -package goose +package provider import ( "context" @@ -25,7 +25,7 @@ import ( // Unless otherwise specified, all methods on Provider are safe for concurrent use. // // Experimental: This API is experimental and may change in the future. -func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { +func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { if db == nil { return nil, errors.New("db must not be nil") } @@ -45,7 +45,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption if cfg.tableName == "" { cfg.tableName = defaultTablename } - store, err := sqladapter.NewStore(string(dialect), cfg.tableName) + store, err := sqladapter.NewStore(dialect, cfg.tableName) if err != nil { return nil, err } diff --git a/provider_options.go b/internal/provider/provider_options.go similarity index 98% rename from provider_options.go rename to internal/provider/provider_options.go index 2370486f9..bf7b9f9b2 100644 --- a/provider_options.go +++ b/internal/provider/provider_options.go @@ -1,4 +1,4 @@ -package goose +package provider import ( "errors" diff --git a/provider_options_test.go b/internal/provider/provider_options_test.go similarity index 76% rename from provider_options_test.go rename to internal/provider/provider_options_test.go index 629c6efaa..341735401 100644 --- a/provider_options_test.go +++ b/internal/provider/provider_options_test.go @@ -1,4 +1,4 @@ -package goose_test +package provider import ( "database/sql" @@ -7,7 +7,6 @@ import ( "testing" "testing/fstest" - "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/internal/check" ) @@ -18,35 +17,35 @@ func TestNewProvider(t *testing.T) { fsys := newFsys() t.Run("invalid", func(t *testing.T) { // Empty dialect not allowed - _, err = goose.NewProvider("", db, fsys) + _, err = NewProvider("", db, fsys) check.HasError(t, err) // Invalid dialect not allowed - _, err = goose.NewProvider("unknown-dialect", db, fsys) + _, err = NewProvider("unknown-dialect", db, fsys) check.HasError(t, err) // Nil db not allowed - _, err = goose.NewProvider("sqlite3", nil, fsys) + _, err = NewProvider("sqlite3", nil, fsys) check.HasError(t, err) // Nil fsys not allowed - _, err = goose.NewProvider("sqlite3", db, nil) + _, err = NewProvider("sqlite3", db, nil) check.HasError(t, err) // Duplicate table name not allowed - _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithTableName("foo"), goose.WithTableName("bar")) + _, err = NewProvider("sqlite3", db, fsys, WithTableName("foo"), WithTableName("bar")) check.HasError(t, err) check.Equal(t, `table already set to "foo"`, err.Error()) // Empty table name not allowed - _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithTableName("")) + _, err = NewProvider("sqlite3", db, fsys, WithTableName("")) check.HasError(t, err) check.Equal(t, "table must not be empty", err.Error()) }) t.Run("valid", func(t *testing.T) { // Valid dialect, db, and fsys allowed - _, err = goose.NewProvider("sqlite3", db, fsys) + _, err = NewProvider("sqlite3", db, fsys) check.NoError(t, err) // Valid dialect, db, fsys, and table name allowed - _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithTableName("foo")) + _, err = NewProvider("sqlite3", db, fsys, WithTableName("foo")) check.NoError(t, err) // Valid dialect, db, fsys, and verbose allowed - _, err = goose.NewProvider("sqlite3", db, fsys, goose.WithVerbose()) + _, err = NewProvider("sqlite3", db, fsys, WithVerbose()) check.NoError(t, err) }) } From 091166a0b573e240700eecfd7a9c26b994effebb Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Tue, 10 Oct 2023 09:24:22 -0400 Subject: [PATCH 05/10] Release v3.15.1 --- CHANGELOG.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c5a9e2e7..2f7fc1e87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,18 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [v3.15.1] - 2023-10-10 + +- Fix regression that prevented registering Go migrations that didn't have the corresponding files + available in the filesystem. (#588) + - If Go migrations have been registered globally, but there are no .go files in the filesystem, + **always include** them. + - If Go migrations have been registered, and there are .go files in the filesystem, **only + include** those migrations. This was the original motivation behind #553. + - If there are .go files in the filesystem but not registered, **raise an error**. This is to + prevent accidentally adding valid looking Go migration files without explicitly registering + them. + ## [v3.15.0] - 2023-08-12 - Fix `sqlparser` to avoid skipping the last statement when it's not terminated with a semicolon @@ -49,7 +61,8 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - Add new `context.Context`-aware functions and methods, for both sql and go migrations. - Return error when no migration files found or dir is not a directory. -[Unreleased]: https://github.com/pressly/goose/compare/v3.15.0...HEAD +[Unreleased]: https://github.com/pressly/goose/compare/v3.15.1...HEAD +[v3.15.1]: https://github.com/pressly/goose/compare/v3.15.0...v3.15.1 [v3.15.0]: https://github.com/pressly/goose/compare/v3.14.0...v3.15.0 [v3.14.0]: https://github.com/pressly/goose/compare/v3.13.4...v3.14.0 [v3.13.4]: https://github.com/pressly/goose/compare/v3.13.1...v3.13.4 From 3482c2fe082d15643de9ff31fb904f3da78066cc Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Tue, 10 Oct 2023 09:29:24 -0400 Subject: [PATCH 06/10] fix(test): sql driver import --- internal/provider/provider_options_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/provider/provider_options_test.go b/internal/provider/provider_options_test.go index 341735401..82362bad1 100644 --- a/internal/provider/provider_options_test.go +++ b/internal/provider/provider_options_test.go @@ -8,6 +8,7 @@ import ( "testing/fstest" "github.com/pressly/goose/v3/internal/check" + _ "modernc.org/sqlite" ) func TestNewProvider(t *testing.T) { From fe8fe975d85f26fd50450159d61325f13e9d9b81 Mon Sep 17 00:00:00 2001 From: Grey Date: Fri, 13 Oct 2023 12:29:12 -0400 Subject: [PATCH 07/10] chore(readme): update install code fences (#614) --- README.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 94450778f..81dc24e31 100644 --- a/README.md +++ b/README.md @@ -35,17 +35,23 @@ Goose supports [embedding SQL migrations](#embedded-sql-migrations), which means # Install - $ go install github.com/pressly/goose/v3/cmd/goose@latest +```shell +go install github.com/pressly/goose/v3/cmd/goose@latest +``` This will install the `goose` binary to your `$GOPATH/bin` directory. For a lite version of the binary without DB connection dependent commands, use the exclusive build tags: - $ go build -tags='no_postgres no_mysql no_sqlite3' -o goose ./cmd/goose +```shell +go build -tags='no_postgres no_mysql no_sqlite3' -o goose ./cmd/goose +``` For macOS users `goose` is available as a [Homebrew Formulae](https://formulae.brew.sh/formula/goose#default): - $ brew install goose +```shell +brew install goose +``` See the docs for more [installation instructions](https://pressly.github.io/goose/installation/). From 68853f91ea3f310637c73df806e23e29ee146c13 Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Sat, 14 Oct 2023 09:30:01 -0400 Subject: [PATCH 08/10] feat(experimental): add collect migrations logic and new Provider options (#615) --- Makefile | 3 + create_test.go | 3 + fix_test.go | 3 + .../migrationstats/migrationstats_test.go | 2 + internal/provider/collect.go | 176 +++++++++++++++++ internal/provider/collect_test.go | 185 ++++++++++++++++++ internal/provider/provider.go | 128 ++++++++---- internal/provider/provider_options.go | 14 ++ internal/provider/provider_options_test.go | 86 +++----- internal/provider/provider_test.go | 83 ++++++++ lock/postgres_test.go | 7 +- migration.go | 28 +-- 12 files changed, 600 insertions(+), 118 deletions(-) create mode 100644 internal/provider/collect.go create mode 100644 internal/provider/collect_test.go create mode 100644 internal/provider/provider_test.go diff --git a/Makefile b/Makefile index 46b444fbc..8a3a05fb1 100644 --- a/Makefile +++ b/Makefile @@ -33,6 +33,9 @@ tools: test-packages: go test $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples) +test-packages-short: + go test -test.short $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples) + test-e2e: test-e2e-postgres test-e2e-mysql test-e2e-clickhouse test-e2e-vertica test-e2e-postgres: diff --git a/create_test.go b/create_test.go index fddf48d85..34791cc65 100644 --- a/create_test.go +++ b/create_test.go @@ -11,6 +11,9 @@ import ( func TestSequential(t *testing.T) { t.Parallel() + if testing.Short() { + t.Skip("skip long running test") + } dir := t.TempDir() defer os.Remove("./bin/create-goose") // clean up diff --git a/fix_test.go b/fix_test.go index 5c982dbe8..6a5e0842b 100644 --- a/fix_test.go +++ b/fix_test.go @@ -11,6 +11,9 @@ import ( func TestFix(t *testing.T) { t.Parallel() + if testing.Short() { + t.Skip("skip long running test") + } dir := t.TempDir() defer os.Remove("./bin/fix-goose") // clean up diff --git a/internal/migrationstats/migrationstats_test.go b/internal/migrationstats/migrationstats_test.go index 67a65a3cf..26c49fd38 100644 --- a/internal/migrationstats/migrationstats_test.go +++ b/internal/migrationstats/migrationstats_test.go @@ -8,6 +8,7 @@ import ( ) func TestParsingGoMigrations(t *testing.T) { + t.Parallel() tests := []struct { name string input string @@ -38,6 +39,7 @@ func TestParsingGoMigrations(t *testing.T) { } func TestParsingGoMigrationsError(t *testing.T) { + t.Parallel() _, err := parseGoFile(strings.NewReader(emptyInit)) check.HasError(t, err) check.Contains(t, err.Error(), "no registered goose functions") diff --git a/internal/provider/collect.go b/internal/provider/collect.go new file mode 100644 index 000000000..cf12961fb --- /dev/null +++ b/internal/provider/collect.go @@ -0,0 +1,176 @@ +package provider + +import ( + "errors" + "fmt" + "io/fs" + "path/filepath" + "sort" + "strings" + + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/internal/migrate" +) + +// fileSources represents a collection of migration files on the filesystem. +type fileSources struct { + sqlSources []Source + goSources []Source +} + +// collectFileSources scans the file system for migration files that have a numeric prefix (greater +// than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil, +// in which case an empty fileSources is returned. +// +// If strict is true, then any error parsing the numeric component of the filename will result in an +// error. The file is skipped otherwise. +// +// This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects +// migration sources from the filesystem. +func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) { + if fsys == nil { + return new(fileSources), nil + } + sources := new(fileSources) + versionToBaseLookup := make(map[int64]string) // map[version]filepath.Base(fullpath) + for _, pattern := range []string{ + "*.sql", + "*.go", + } { + files, err := fs.Glob(fsys, pattern) + if err != nil { + return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err) + } + for _, fullpath := range files { + base := filepath.Base(fullpath) + // Skip explicit excludes or Go test files. + if excludes[base] || strings.HasSuffix(base, "_test.go") { + continue + } + // If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use + // that as the version. Otherwise, ignore it. This allows users to have arbitrary + // filenames, but still have versioned migrations within the same directory. For + // example, a user could have a helpers.go file which contains unexported helper + // functions for migrations. + version, err := goose.NumericComponent(base) + if err != nil { + if strict { + return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err) + } + continue + } + // Ensure there are no duplicate versions. + if existing, ok := versionToBaseLookup[version]; ok { + return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", + version, + existing, + base, + ) + } + switch filepath.Ext(base) { + case ".sql": + sources.sqlSources = append(sources.sqlSources, Source{ + Fullpath: fullpath, + Version: version, + }) + case ".go": + sources.goSources = append(sources.goSources, Source{ + Fullpath: fullpath, + Version: version, + }) + default: + // Should never happen since we already filtered out all other file types. + return nil, fmt.Errorf("unknown migration type: %s", base) + } + // Add the version to the lookup map. + versionToBaseLookup[version] = base + } + } + return sources, nil +} + +func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migrate.Migration, error) { + var migrations []*migrate.Migration + migrationLookup := make(map[int64]*migrate.Migration) + // Add all SQL migrations to the list of migrations. + for _, s := range sources.sqlSources { + m := &migrate.Migration{ + Type: migrate.TypeSQL, + Fullpath: s.Fullpath, + Version: s.Version, + SQLParsed: false, + } + migrations = append(migrations, m) + migrationLookup[s.Version] = m + } + // If there are no Go files in the filesystem and no registered Go migrations, return early. + if len(sources.goSources) == 0 && len(registerd) == 0 { + return migrations, nil + } + // Return an error if the given sources contain a versioned Go migration that has not been + // registered. This is a sanity check to ensure users didn't accidentally create a valid looking + // Go migration file on disk and forget to register it. + // + // This is almost always a user error. + var unregistered []string + for _, s := range sources.goSources { + if _, ok := registerd[s.Version]; !ok { + unregistered = append(unregistered, s.Fullpath) + } + } + if len(unregistered) > 0 { + return nil, unregisteredError(unregistered) + } + // Add all registered Go migrations to the list of migrations, checking for duplicate versions. + // + // Important, users can register Go migrations manually via goose.Add_ functions. These + // migrations may not have a corresponding file on disk. Which is fine! We include them + // wholesale as part of migrations. This allows users to build a custom binary that only embeds + // the SQL migration files. + for _, r := range registerd { + // Ensure there are no duplicate versions. + if existing, ok := migrationLookup[r.Version]; ok { + return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", + r.Version, + existing, + filepath.Base(r.Source), + ) + } + m := &migrate.Migration{ + Fullpath: r.Source, // May be empty if the migration was registered manually. + Version: r.Version, + Type: migrate.TypeGo, + Go: &migrate.Go{ + UseTx: r.UseTx, + UpFn: r.UpFnContext, + UpFnNoTx: r.UpFnNoTxContext, + DownFn: r.DownFnContext, + DownFnNoTx: r.DownFnNoTxContext, + }, + } + migrations = append(migrations, m) + migrationLookup[r.Version] = m + } + // Sort migrations by version in ascending order. + sort.Slice(migrations, func(i, j int) bool { + return migrations[i].Version < migrations[j].Version + }) + return migrations, nil +} + +func unregisteredError(unregistered []string) error { + f := "file" + if len(unregistered) > 1 { + f += "s" + } + var b strings.Builder + + b.WriteString(fmt.Sprintf("error: detected %d unregistered Go %s:\n", len(unregistered), f)) + for _, name := range unregistered { + b.WriteString("\t" + name + "\n") + } + b.WriteString("\n") + b.WriteString("go functions must be registered and built into a custom binary see:\nhttps://github.com/pressly/goose/tree/master/examples/go-migrations") + + return errors.New(b.String()) +} diff --git a/internal/provider/collect_test.go b/internal/provider/collect_test.go new file mode 100644 index 000000000..a5ee2d352 --- /dev/null +++ b/internal/provider/collect_test.go @@ -0,0 +1,185 @@ +package provider + +import ( + "io/fs" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" +) + +func TestCollectFileSources(t *testing.T) { + t.Parallel() + t.Run("nil", func(t *testing.T) { + sources, err := collectFileSources(nil, false, nil) + check.NoError(t, err) + check.Bool(t, sources != nil, true) + check.Number(t, len(sources.goSources), 0) + check.Number(t, len(sources.sqlSources), 0) + }) + t.Run("empty", func(t *testing.T) { + sources, err := collectFileSources(fstest.MapFS{}, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.goSources), 0) + check.Number(t, len(sources.sqlSources), 0) + check.Bool(t, sources != nil, true) + }) + t.Run("incorrect_fsys", func(t *testing.T) { + mapFS := fstest.MapFS{ + "00000_foo.sql": sqlMapFile, + } + // strict disable - should not error + sources, err := collectFileSources(mapFS, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.goSources), 0) + check.Number(t, len(sources.sqlSources), 0) + // strict enabled - should error + _, err = collectFileSources(mapFS, true, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "migration version must be greater than zero") + }) + t.Run("collect", func(t *testing.T) { + fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 4) + check.Number(t, len(sources.goSources), 0) + expected := fileSources{ + sqlSources: []Source{ + {Fullpath: "00001_foo.sql", Version: 1}, + {Fullpath: "00002_bar.sql", Version: 2}, + {Fullpath: "00003_baz.sql", Version: 3}, + {Fullpath: "00110_qux.sql", Version: 110}, + }, + } + for i := 0; i < len(sources.sqlSources); i++ { + check.Equal(t, sources.sqlSources[i], expected.sqlSources[i]) + } + }) + t.Run("excludes", func(t *testing.T) { + fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") + check.NoError(t, err) + sources, err := collectFileSources( + fsys, + false, + // exclude 2 files explicitly + map[string]bool{ + "00002_bar.sql": true, + "00110_qux.sql": true, + }, + ) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 2) + check.Number(t, len(sources.goSources), 0) + expected := fileSources{ + sqlSources: []Source{ + {Fullpath: "00001_foo.sql", Version: 1}, + {Fullpath: "00003_baz.sql", Version: 3}, + }, + } + for i := 0; i < len(sources.sqlSources); i++ { + check.Equal(t, sources.sqlSources[i], expected.sqlSources[i]) + } + }) + t.Run("strict", func(t *testing.T) { + mapFS := newSQLOnlyFS() + // Add a file with no version number + mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")} + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + _, err = collectFileSources(fsys, true, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`) + }) + t.Run("skip_go_test_files", func(t *testing.T) { + mapFS := fstest.MapFS{ + "1_foo.sql": sqlMapFile, + "2_bar.sql": sqlMapFile, + "3_baz.sql": sqlMapFile, + "4_qux.sql": sqlMapFile, + "5_foo_test.go": {Data: []byte(`package goose_test`)}, + } + sources, err := collectFileSources(mapFS, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 4) + check.Number(t, len(sources.goSources), 0) + }) + t.Run("skip_random_files", func(t *testing.T) { + mapFS := fstest.MapFS{ + "1_foo.sql": sqlMapFile, + "4_something.go": {Data: []byte(`package goose`)}, + "5_qux.sql": sqlMapFile, + "README.md": {Data: []byte(`# README`)}, + "LICENSE": {Data: []byte(`MIT`)}, + "no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)}, + "some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)}, + } + sources, err := collectFileSources(mapFS, false, nil) + check.NoError(t, err) + check.Number(t, len(sources.sqlSources), 2) + check.Number(t, len(sources.goSources), 1) + // 1 + check.Equal(t, sources.sqlSources[0].Fullpath, "1_foo.sql") + check.Equal(t, sources.sqlSources[0].Version, int64(1)) + // 2 + check.Equal(t, sources.sqlSources[1].Fullpath, "5_qux.sql") + check.Equal(t, sources.sqlSources[1].Version, int64(5)) + // 3 + check.Equal(t, sources.goSources[0].Fullpath, "4_something.go") + check.Equal(t, sources.goSources[0].Version, int64(4)) + }) + t.Run("duplicate_versions", func(t *testing.T) { + mapFS := fstest.MapFS{ + "001_foo.sql": sqlMapFile, + "01_bar.sql": sqlMapFile, + } + _, err := collectFileSources(mapFS, false, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "found duplicate migration version 1") + }) + t.Run("dirpath", func(t *testing.T) { + mapFS := fstest.MapFS{ + "dir1/101_a.sql": sqlMapFile, + "dir1/102_b.sql": sqlMapFile, + "dir1/103_c.sql": sqlMapFile, + "dir2/201_a.sql": sqlMapFile, + "876_a.sql": sqlMapFile, + } + assertDirpath := func(dirpath string, sqlSources []Source) { + t.Helper() + f, err := fs.Sub(mapFS, dirpath) + check.NoError(t, err) + got, err := collectFileSources(f, false, nil) + check.NoError(t, err) + check.Number(t, len(got.sqlSources), len(sqlSources)) + check.Number(t, len(got.goSources), 0) + for i := 0; i < len(got.sqlSources); i++ { + check.Equal(t, got.sqlSources[i], sqlSources[i]) + } + } + assertDirpath(".", []Source{ + {Fullpath: "876_a.sql", Version: 876}, + }) + assertDirpath("dir1", []Source{ + {Fullpath: "101_a.sql", Version: 101}, + {Fullpath: "102_b.sql", Version: 102}, + {Fullpath: "103_c.sql", Version: 103}, + }) + assertDirpath("dir2", []Source{{Fullpath: "201_a.sql", Version: 201}}) + assertDirpath("dir3", nil) + }) +} + +func newSQLOnlyFS() fstest.MapFS { + return fstest.MapFS{ + "migrations/00001_foo.sql": sqlMapFile, + "migrations/00002_bar.sql": sqlMapFile, + "migrations/00003_baz.sql": sqlMapFile, + "migrations/00110_qux.sql": sqlMapFile, + } +} + +var ( + sqlMapFile = &fstest.MapFile{Data: []byte(`-- +goose Up`)} +) diff --git a/internal/provider/provider.go b/internal/provider/provider.go index c8d899511..6702f0731 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -5,22 +5,29 @@ import ( "database/sql" "errors" "io/fs" + "os" "time" + "github.com/pressly/goose/v3/internal/migrate" "github.com/pressly/goose/v3/internal/sqladapter" ) +var ( + // ErrNoMigrations is returned by [NewProvider] when no migrations are found. + ErrNoMigrations = errors.New("no migrations found") +) + // NewProvider returns a new goose Provider. // // The caller is responsible for matching the database dialect with the database/sql driver. For // example, if the database dialect is "postgres", the database/sql driver could be // github.com/lib/pq or github.com/jackc/pgx. // -// fsys is the filesystem used to read the migration files. Most users will want to use -// os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is -// possible to use a different filesystem, such as embed.FS. +// fsys is the filesystem used to read the migration files, but may be nil. Most users will want to +// use os.DirFS("path/to/migrations") to read migrations from the local filesystem. However, it is +// possible to use a different filesystem, such as embed.FS or filter out migrations using fs.Sub. // -// Functional options are used to configure the Provider. See [ProviderOption] for more information. +// See [ProviderOption] for more information on configuring the provider. // // Unless otherwise specified, all methods on Provider are safe for concurrent use. // @@ -33,7 +40,7 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) return nil, errors.New("dialect must not be empty") } if fsys == nil { - return nil, errors.New("fsys must not be nil") + fsys = noopFS{} } var cfg config for _, opt := range opts { @@ -41,7 +48,7 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) return nil, err } } - // Set defaults + // Set defaults after applying user-supplied options so option funcs can check for empty values. if cfg.tableName == "" { cfg.tableName = defaultTablename } @@ -49,41 +56,76 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) if err != nil { return nil, err } - // TODO(mf): implement the rest of this function - collect sources - merge sources into - // migrations + // Collect migrations from the filesystem and merge with registered migrations. + // + // Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed + // lazily. + // + // TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to + // return an error if there are any SQL parsing errors. This adds a bit overhead to startup + // though, so we should make it optional. + sources, err := collectFileSources(fsys, false, cfg.excludes) + if err != nil { + return nil, err + } + migrations, err := merge(sources, nil) + if err != nil { + return nil, err + } + if len(migrations) == 0 { + return nil, ErrNoMigrations + } return &Provider{ - db: db, - fsys: fsys, - cfg: cfg, - store: store, + db: db, + fsys: fsys, + cfg: cfg, + store: store, + migrations: migrations, }, nil } +type noopFS struct{} + +var _ fs.FS = noopFS{} + +func (f noopFS) Open(name string) (fs.File, error) { + return nil, os.ErrNotExist +} + // Provider is a goose migration provider. -// Experimental: This API is experimental and may change in the future. type Provider struct { - db *sql.DB - fsys fs.FS - cfg config - store sqladapter.Store + db *sql.DB + fsys fs.FS + cfg config + store sqladapter.Store + migrations []*migrate.Migration } +// State represents the state of a migration. +type State string + +const ( + // StateUntracked represents a migration that is in the database, but not on the filesystem. + StateUntracked State = "untracked" + // StatePending represents a migration that is on the filesystem, but not in the database. + StatePending State = "pending" + // StateApplied represents a migration that is in BOTH the database and on the filesystem. + StateApplied State = "applied" +) + // MigrationStatus represents the status of a single migration. type MigrationStatus struct { - // State represents the state of the migration. One of "untracked", "pending", "applied". - // - untracked: in the database, but not on the filesystem. - // - pending: on the filesystem, but not in the database. - // - applied: in both the database and on the filesystem. - State string - // AppliedAt is the time the migration was applied. Only set if state is applied or untracked. + // State is the state of the migration. + State State + // AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or + // [StateUntracked]. AppliedAt time.Time - // Source is the migration source. Only set if the state is pending or applied. - Source Source + // Source is the migration source. Only set if the state is [StatePending] or [StateApplied]. + Source *Source } // Status returns the status of all migrations, merging the list of migrations from the database and // filesystem. The returned items are ordered by version, in ascending order. -// Experimental: This API is experimental and may change in the future. func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { return nil, errors.New("not implemented") } @@ -91,7 +133,6 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { // GetDBVersion returns the max version from the database, regardless of the applied order. For // example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been // applied, it returns 0. -// Experimental: This API is experimental and may change in the future. func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { return 0, errors.New("not implemented") } @@ -111,7 +152,6 @@ const ( // For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if // the migration has a corresponding file on disk. It will be empty if the migration was registered // manually. -// Experimental: This API is experimental and may change in the future. type Source struct { // Type is the type of migration. Type SourceType @@ -123,22 +163,34 @@ type Source struct { Version int64 } -// ListSources returns a list of all available migration sources the provider is aware of. -// Experimental: This API is experimental and may change in the future. +// ListSources returns a list of all available migration sources the provider is aware of, sorted in +// ascending order by version. func (p *Provider) ListSources() []*Source { - return nil + sources := make([]*Source, 0, len(p.migrations)) + for _, m := range p.migrations { + s := &Source{ + Fullpath: m.Fullpath, + Version: m.Version, + } + switch m.Type { + case migrate.TypeSQL: + s.Type = SourceTypeSQL + case migrate.TypeGo: + s.Type = SourceTypeGo + } + sources = append(sources, s) + } + return sources } // Ping attempts to ping the database to verify a connection is available. -// Experimental: This API is experimental and may change in the future. func (p *Provider) Ping(ctx context.Context) error { - return errors.New("not implemented") + return p.db.PingContext(ctx) } // Close closes the database connection. -// Experimental: This API is experimental and may change in the future. func (p *Provider) Close() error { - return errors.New("not implemented") + return p.db.Close() } // MigrationResult represents the result of a single migration. @@ -150,21 +202,18 @@ type MigrationResult struct{} // // When direction is true, the up migration is executed, and when direction is false, the down // migration is executed. -// Experimental: This API is experimental and may change in the future. func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) { return nil, errors.New("not implemented") } // Up applies all pending migrations. If there are no new migrations to apply, this method returns // empty list and nil error. -// Experimental: This API is experimental and may change in the future. func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { return nil, errors.New("not implemented") } // UpByOne applies the next available migration. If there are no migrations to apply, this method // returns [ErrNoNextVersion]. -// Experimental: This API is experimental and may change in the future. func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { return nil, errors.New("not implemented") } @@ -174,14 +223,12 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { // // For instance, if there are three new migrations (9,10,11) and the current database version is 8 // with a requested version of 10, only versions 9 and 10 will be applied. -// Experimental: This API is experimental and may change in the future. func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { return nil, errors.New("not implemented") } // Down rolls back the most recently applied migration. If there are no migrations to apply, this // method returns [ErrNoNextVersion]. -// Experimental: This API is experimental and may change in the future. func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { return nil, errors.New("not implemented") } @@ -190,7 +237,6 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { // // For instance, if the current database version is 11, and the requested version is 9, only // migrations 11 and 10 will be rolled back. -// Experimental: This API is experimental and may change in the future. func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { return nil, errors.New("not implemented") } diff --git a/internal/provider/provider_options.go b/internal/provider/provider_options.go index bf7b9f9b2..d8060c458 100644 --- a/internal/provider/provider_options.go +++ b/internal/provider/provider_options.go @@ -60,10 +60,24 @@ func WithSessionLocker(locker lock.SessionLocker) ProviderOption { }) } +// WithExcludes excludes the given file names from the list of migrations. +// +// If WithExcludes is called multiple times, the list of excludes is merged. +func WithExcludes(excludes []string) ProviderOption { + return configFunc(func(c *config) error { + for _, name := range excludes { + c.excludes[name] = true + } + return nil + }) +} + type config struct { tableName string verbose bool + excludes map[string]bool + // Locking options lockEnabled bool sessionLocker lock.SessionLocker } diff --git a/internal/provider/provider_options_test.go b/internal/provider/provider_options_test.go index 82362bad1..89a1cda16 100644 --- a/internal/provider/provider_options_test.go +++ b/internal/provider/provider_options_test.go @@ -1,13 +1,13 @@ -package provider +package provider_test import ( "database/sql" - "io/fs" "path/filepath" "testing" "testing/fstest" "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/provider" _ "modernc.org/sqlite" ) @@ -15,86 +15,52 @@ func TestNewProvider(t *testing.T) { dir := t.TempDir() db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) check.NoError(t, err) - fsys := newFsys() + fsys := fstest.MapFS{ + "1_foo.sql": {Data: []byte(migration1)}, + "2_bar.sql": {Data: []byte(migration2)}, + "3_baz.sql": {Data: []byte(migration3)}, + "4_qux.sql": {Data: []byte(migration4)}, + } t.Run("invalid", func(t *testing.T) { // Empty dialect not allowed - _, err = NewProvider("", db, fsys) + _, err = provider.NewProvider("", db, fsys) check.HasError(t, err) // Invalid dialect not allowed - _, err = NewProvider("unknown-dialect", db, fsys) + _, err = provider.NewProvider("unknown-dialect", db, fsys) check.HasError(t, err) // Nil db not allowed - _, err = NewProvider("sqlite3", nil, fsys) + _, err = provider.NewProvider("sqlite3", nil, fsys) check.HasError(t, err) // Nil fsys not allowed - _, err = NewProvider("sqlite3", db, nil) + _, err = provider.NewProvider("sqlite3", db, nil) check.HasError(t, err) // Duplicate table name not allowed - _, err = NewProvider("sqlite3", db, fsys, WithTableName("foo"), WithTableName("bar")) + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithTableName("foo"), + provider.WithTableName("bar"), + ) check.HasError(t, err) check.Equal(t, `table already set to "foo"`, err.Error()) // Empty table name not allowed - _, err = NewProvider("sqlite3", db, fsys, WithTableName("")) + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithTableName(""), + ) check.HasError(t, err) check.Equal(t, "table must not be empty", err.Error()) }) t.Run("valid", func(t *testing.T) { // Valid dialect, db, and fsys allowed - _, err = NewProvider("sqlite3", db, fsys) + _, err = provider.NewProvider("sqlite3", db, fsys) check.NoError(t, err) // Valid dialect, db, fsys, and table name allowed - _, err = NewProvider("sqlite3", db, fsys, WithTableName("foo")) + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithTableName("foo"), + ) check.NoError(t, err) // Valid dialect, db, fsys, and verbose allowed - _, err = NewProvider("sqlite3", db, fsys, WithVerbose()) + _, err = provider.NewProvider("sqlite3", db, fsys, + provider.WithVerbose(), + ) check.NoError(t, err) }) } - -func newFsys() fs.FS { - return fstest.MapFS{ - "1_foo.sql": {Data: []byte(migration1)}, - "2_bar.sql": {Data: []byte(migration2)}, - "3_baz.sql": {Data: []byte(migration3)}, - "4_qux.sql": {Data: []byte(migration4)}, - } -} - -var ( - migration1 = ` --- +goose Up -CREATE TABLE foo (id INTEGER PRIMARY KEY); --- +goose Down -DROP TABLE foo; -` - migration2 = ` --- +goose Up -ALTER TABLE foo ADD COLUMN name TEXT; --- +goose Down -ALTER TABLE foo DROP COLUMN name; -` - migration3 = ` --- +goose Up -CREATE TABLE bar ( - id INTEGER PRIMARY KEY, - description TEXT -); --- +goose Down -DROP TABLE bar; -` - migration4 = ` --- +goose Up --- Rename the 'foo' table to 'my_foo' -ALTER TABLE foo RENAME TO my_foo; - --- Add a new column 'timestamp' to 'my_foo' -ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP; - --- +goose Down --- Remove the 'timestamp' column from 'my_foo' -ALTER TABLE my_foo DROP COLUMN timestamp; - --- Rename the 'my_foo' table back to 'foo' -ALTER TABLE my_foo RENAME TO foo; -` -) diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go new file mode 100644 index 000000000..10aed48e0 --- /dev/null +++ b/internal/provider/provider_test.go @@ -0,0 +1,83 @@ +package provider_test + +import ( + "database/sql" + "errors" + "io/fs" + "path/filepath" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/provider" + _ "modernc.org/sqlite" +) + +func TestProvider(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + t.Run("empty", func(t *testing.T) { + _, err := provider.NewProvider("sqlite3", db, fstest.MapFS{}) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrNoMigrations), true) + }) + + mapFS := fstest.MapFS{ + "migrations/001_foo.sql": {Data: []byte(`-- +goose Up`)}, + "migrations/002_bar.sql": {Data: []byte(`-- +goose Up`)}, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + p, err := provider.NewProvider("sqlite3", db, fsys) + check.NoError(t, err) + sources := p.ListSources() + check.Equal(t, len(sources), 2) + // 1 + check.Equal(t, sources[0].Version, int64(1)) + check.Equal(t, sources[0].Fullpath, "001_foo.sql") + check.Equal(t, sources[0].Type, provider.SourceTypeSQL) + // 2 + check.Equal(t, sources[1].Version, int64(2)) + check.Equal(t, sources[1].Fullpath, "002_bar.sql") + check.Equal(t, sources[1].Type, provider.SourceTypeSQL) +} + +var ( + migration1 = ` +-- +goose Up +CREATE TABLE foo (id INTEGER PRIMARY KEY); +-- +goose Down +DROP TABLE foo; +` + migration2 = ` +-- +goose Up +ALTER TABLE foo ADD COLUMN name TEXT; +-- +goose Down +ALTER TABLE foo DROP COLUMN name; +` + migration3 = ` +-- +goose Up +CREATE TABLE bar ( + id INTEGER PRIMARY KEY, + description TEXT +); +-- +goose Down +DROP TABLE bar; +` + migration4 = ` +-- +goose Up +-- Rename the 'foo' table to 'my_foo' +ALTER TABLE foo RENAME TO my_foo; + +-- Add a new column 'timestamp' to 'my_foo' +ALTER TABLE my_foo ADD COLUMN timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP; + +-- +goose Down +-- Remove the 'timestamp' column from 'my_foo' +ALTER TABLE my_foo DROP COLUMN timestamp; + +-- Rename the 'my_foo' table back to 'foo' +ALTER TABLE my_foo RENAME TO foo; +` +) diff --git a/lock/postgres_test.go b/lock/postgres_test.go index bfb1a0d99..2622d5cb6 100644 --- a/lock/postgres_test.go +++ b/lock/postgres_test.go @@ -14,19 +14,20 @@ import ( ) func TestPostgresSessionLocker(t *testing.T) { + t.Parallel() if testing.Short() { t.Skip("skip long running test") } db, cleanup, err := testdb.NewPostgres() check.NoError(t, err) t.Cleanup(cleanup) - const ( - lockID int64 = 123456789 - ) // Do not run tests in parallel, because they are using the same database. t.Run("lock_and_unlock", func(t *testing.T) { + const ( + lockID int64 = 123456789 + ) locker, err := lock.NewPostgresSessionLocker( lock.WithLockID(lockID), lock.WithLockTimeout(4*time.Second), diff --git a/migration.go b/migration.go index dcf0c6118..619e934d0 100644 --- a/migration.go +++ b/migration.go @@ -218,27 +218,27 @@ func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, d return store.DeleteVersionNoTx(ctx, db, TableName(), version) } -// NumericComponent looks for migration scripts with names in the form: -// XXX_descriptivename.ext where XXX specifies the version number -// and ext specifies the type of migration -func NumericComponent(name string) (int64, error) { - base := filepath.Base(name) - +// NumericComponent parses the version from the migration file name. +// +// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of +// migration, either .sql or .go. +func NumericComponent(filename string) (int64, error) { + base := filepath.Base(filename) if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { - return 0, errors.New("not a recognized migration file type") + return 0, errors.New("migration file does not have .sql or .go file extension") } - idx := strings.Index(base, "_") if idx < 0 { return 0, errors.New("no filename separator '_' found") } - - n, e := strconv.ParseInt(base[:idx], 10, 64) - if e == nil && n <= 0 { - return 0, errors.New("migration IDs must be greater than zero") + n, err := strconv.ParseInt(base[:idx], 10, 64) + if err != nil { + return 0, err } - - return n, e + if n < 1 { + return 0, errors.New("migration version must be greater than zero") + } + return n, nil } func truncateDuration(d time.Duration) time.Duration { From 58f85346104a06d1a6eaebc30e13745bdc88252a Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Sat, 14 Oct 2023 23:04:07 -0400 Subject: [PATCH 09/10] feat(experimental): shuffle packages & add explicit provider Go func registration (#616) --- go.mod | 1 + go.sum | 3 +- internal/migrate/doc.go | 9 -- internal/migrate/migration.go | 166 -------------------------- internal/migrate/parse.go | 75 ------------ internal/provider/collect.go | 113 ++++++++++++------ internal/provider/collect_test.go | 154 ++++++++++++++++++++++-- internal/provider/migration.go | 119 ++++++++++++++++++ internal/provider/provider.go | 122 ++++++++++++------- internal/provider/provider_options.go | 58 +++++++++ internal/provider/provider_test.go | 4 +- internal/{migrate => provider}/run.go | 26 ++-- internal/sqlparser/parse.go | 54 +++++++++ internal/sqlparser/parse_test.go | 82 +++++++++++++ 14 files changed, 629 insertions(+), 357 deletions(-) delete mode 100644 internal/migrate/doc.go delete mode 100644 internal/migrate/migration.go delete mode 100644 internal/migrate/parse.go create mode 100644 internal/provider/migration.go rename internal/{migrate => provider}/run.go (71%) create mode 100644 internal/sqlparser/parse.go create mode 100644 internal/sqlparser/parse_test.go diff --git a/go.mod b/go.mod index 0230e2c97..8276042e8 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/vertica/vertica-sql-go v1.3.3 github.com/ziutek/mymysql v1.5.4 go.uber.org/multierr v1.11.0 + golang.org/x/sync v0.4.0 modernc.org/sqlite v1.26.0 ) diff --git a/go.sum b/go.sum index e2c6ea03a..537f3434b 100644 --- a/go.sum +++ b/go.sum @@ -184,7 +184,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= +golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/migrate/doc.go b/internal/migrate/doc.go deleted file mode 100644 index 5fbee1582..000000000 --- a/internal/migrate/doc.go +++ /dev/null @@ -1,9 +0,0 @@ -// Package migrate defines a Migration struct and implements the migration logic for executing Go -// and SQL migrations. -// -// - For Go migrations, only *sql.Tx and *sql.DB are supported. *sql.Conn is not supported. -// - For SQL migrations, all three are supported. -// -// Lastly, SQL migrations are lazily parsed. This means that the SQL migration is parsed the first -// time it is executed. -package migrate diff --git a/internal/migrate/migration.go b/internal/migrate/migration.go deleted file mode 100644 index 23a0514cf..000000000 --- a/internal/migrate/migration.go +++ /dev/null @@ -1,166 +0,0 @@ -package migrate - -import ( - "context" - "database/sql" - "errors" - "fmt" - - "github.com/pressly/goose/v3/internal/sqlextended" -) - -type Migration struct { - // Fullpath is the full path to the migration file. - // - // Example: /path/to/migrations/123_create_users_table.go - Fullpath string - // Version is the version of the migration. - Version int64 - // Type is the type of migration. - Type MigrationType - // A migration is either a Go migration or a SQL migration, but never both. - // - // Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is - // an optimization to avoid parsing the SQL migration if it is never required. Also, the - // majority of the time migrations are incremental, so it is likely that the user will only want - // to run the last few migrations, and there is no need to parse ALL prior migrations. - // - // Exactly one of these fields will be set: - Go *Go - // -- or -- - SQLParsed bool - SQL *SQL -} - -type MigrationType int - -const ( - TypeGo MigrationType = iota + 1 - TypeSQL -) - -func (t MigrationType) String() string { - switch t { - case TypeGo: - return "go" - case TypeSQL: - return "sql" - default: - // This should never happen. - return "unknown" - } -} - -func (m *Migration) UseTx() bool { - switch m.Type { - case TypeGo: - return m.Go.UseTx - case TypeSQL: - return m.SQL.UseTx - default: - // This should never happen. - panic("unknown migration type: use tx") - } -} - -func (m *Migration) IsEmpty(direction bool) bool { - switch m.Type { - case TypeGo: - return m.Go.IsEmpty(direction) - case TypeSQL: - return m.SQL.IsEmpty(direction) - default: - // This should never happen. - panic("unknown migration type: is empty") - } -} - -func (m *Migration) GetSQLStatements(direction bool) ([]string, error) { - if m.Type != TypeSQL { - return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Type) - } - if m.SQL == nil { - return nil, errors.New("sql migration has not been initialized") - } - if !m.SQLParsed { - return nil, errors.New("sql migration has not been parsed") - } - if direction { - return m.SQL.UpStatements, nil - } - return m.SQL.DownStatements, nil -} - -type Go struct { - // We used an explicit bool instead of relying on a pointer because registered funcs may be nil. - // These are still valid Go and versioned migrations, but they are just empty. - // - // For example: goose.AddMigration(nil, nil) - UseTx bool - - // Only one of these func pairs will be set: - UpFn, DownFn func(context.Context, *sql.Tx) error - // -- or -- - UpFnNoTx, DownFnNoTx func(context.Context, *sql.DB) error -} - -func (g *Go) IsEmpty(direction bool) bool { - if direction { - return g.UpFn == nil && g.UpFnNoTx == nil - } - return g.DownFn == nil && g.DownFnNoTx == nil -} - -func (g *Go) run(ctx context.Context, tx *sql.Tx, direction bool) error { - var fn func(context.Context, *sql.Tx) error - if direction { - fn = g.UpFn - } else { - fn = g.DownFn - } - if fn != nil { - return fn(ctx, tx) - } - return nil -} - -func (g *Go) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { - var fn func(context.Context, *sql.DB) error - if direction { - fn = g.UpFnNoTx - } else { - fn = g.DownFnNoTx - } - if fn != nil { - return fn(ctx, db) - } - return nil -} - -type SQL struct { - UseTx bool - UpStatements []string - DownStatements []string -} - -func (s *SQL) IsEmpty(direction bool) bool { - if direction { - return len(s.UpStatements) == 0 - } - return len(s.DownStatements) == 0 -} - -func (s *SQL) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error { - var statements []string - if direction { - statements = s.UpStatements - } else { - statements = s.DownStatements - } - for _, stmt := range statements { - if _, err := db.ExecContext(ctx, stmt); err != nil { - return err - } - } - return nil -} diff --git a/internal/migrate/parse.go b/internal/migrate/parse.go deleted file mode 100644 index 18a66b499..000000000 --- a/internal/migrate/parse.go +++ /dev/null @@ -1,75 +0,0 @@ -package migrate - -import ( - "bytes" - "io" - "io/fs" - - "github.com/pressly/goose/v3/internal/sqlparser" -) - -// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it -// will not be parsed again. -// -// Important: This function will mutate SQL migrations. -func ParseSQL(fsys fs.FS, debug bool, migrations []*Migration) error { - for _, m := range migrations { - if m.Type == TypeSQL && !m.SQLParsed { - parsedSQLMigration, err := parseSQL(fsys, m.Fullpath, parseAll, debug) - if err != nil { - return err - } - m.SQLParsed = true - m.SQL = parsedSQLMigration - } - } - return nil -} - -// parse is used to determine which direction to parse the SQL migration. -type parse int - -const ( - // parseAll parses all SQL statements in BOTH directions. - parseAll parse = iota + 1 - // parseUp parses all SQL statements in the UP direction. - parseUp - // parseDown parses all SQL statements in the DOWN direction. - parseDown -) - -func parseSQL(fsys fs.FS, filename string, p parse, debug bool) (*SQL, error) { - r, err := fsys.Open(filename) - if err != nil { - return nil, err - } - by, err := io.ReadAll(r) - if err != nil { - return nil, err - } - if err := r.Close(); err != nil { - return nil, err - } - s := new(SQL) - if p == parseAll || p == parseUp { - s.UpStatements, s.UseTx, err = sqlparser.ParseSQLMigration( - bytes.NewReader(by), - sqlparser.DirectionUp, - debug, - ) - if err != nil { - return nil, err - } - } - if p == parseAll || p == parseDown { - s.DownStatements, s.UseTx, err = sqlparser.ParseSQLMigration( - bytes.NewReader(by), - sqlparser.DirectionDown, - debug, - ) - if err != nil { - return nil, err - } - } - return s, nil -} diff --git a/internal/provider/collect.go b/internal/provider/collect.go index cf12961fb..6658c8067 100644 --- a/internal/provider/collect.go +++ b/internal/provider/collect.go @@ -9,15 +9,56 @@ import ( "strings" "github.com/pressly/goose/v3" - "github.com/pressly/goose/v3/internal/migrate" ) +// Source represents a single migration source. +// +// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if +// the migration has a corresponding file on disk. It will be empty if the migration was registered +// manually. +type Source struct { + // Type is the type of migration. + Type MigrationType + // Full path to the migration file. + // + // Example: /path/to/migrations/001_create_users_table.sql + Fullpath string + // Version is the version of the migration. + Version int64 +} + +func newSource(t MigrationType, fullpath string, version int64) Source { + return Source{ + Type: t, + Fullpath: fullpath, + Version: version, + } +} + // fileSources represents a collection of migration files on the filesystem. type fileSources struct { sqlSources []Source goSources []Source } +func (s *fileSources) lookup(t MigrationType, version int64) *Source { + switch t { + case TypeGo: + for _, source := range s.goSources { + if source.Version == version { + return &source + } + } + case TypeSQL: + for _, source := range s.sqlSources { + if source.Version == version { + return &source + } + } + } + return nil +} + // collectFileSources scans the file system for migration files that have a numeric prefix (greater // than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil, // in which case an empty fileSources is returned. @@ -69,15 +110,9 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil } switch filepath.Ext(base) { case ".sql": - sources.sqlSources = append(sources.sqlSources, Source{ - Fullpath: fullpath, - Version: version, - }) + sources.sqlSources = append(sources.sqlSources, newSource(TypeSQL, fullpath, version)) case ".go": - sources.goSources = append(sources.goSources, Source{ - Fullpath: fullpath, - Version: version, - }) + sources.goSources = append(sources.goSources, newSource(TypeGo, fullpath, version)) default: // Should never happen since we already filtered out all other file types. return nil, fmt.Errorf("unknown migration type: %s", base) @@ -89,19 +124,17 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil return sources, nil } -func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migrate.Migration, error) { - var migrations []*migrate.Migration - migrationLookup := make(map[int64]*migrate.Migration) +func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration, error) { + var migrations []*migration + migrationLookup := make(map[int64]*migration) // Add all SQL migrations to the list of migrations. - for _, s := range sources.sqlSources { - m := &migrate.Migration{ - Type: migrate.TypeSQL, - Fullpath: s.Fullpath, - Version: s.Version, - SQLParsed: false, + for _, source := range sources.sqlSources { + m := &migration{ + Source: source, + SQL: nil, // SQL migrations are parsed lazily. } migrations = append(migrations, m) - migrationLookup[s.Version] = m + migrationLookup[source.Version] = m } // If there are no Go files in the filesystem and no registered Go migrations, return early. if len(sources.goSources) == 0 && len(registerd) == 0 { @@ -127,38 +160,41 @@ func merge(sources *fileSources, registerd map[int64]*goose.Migration) ([]*migra // migrations may not have a corresponding file on disk. Which is fine! We include them // wholesale as part of migrations. This allows users to build a custom binary that only embeds // the SQL migration files. - for _, r := range registerd { + for version, r := range registerd { + var fullpath string + if s := sources.lookup(TypeGo, version); s != nil { + fullpath = s.Fullpath + } // Ensure there are no duplicate versions. - if existing, ok := migrationLookup[r.Version]; ok { + if existing, ok := migrationLookup[version]; ok { + if fullpath == "" { + fullpath = "manually registered (no source)" + } return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", - r.Version, - existing, - filepath.Base(r.Source), + version, + existing.Source.Fullpath, + fullpath, ) } - m := &migrate.Migration{ - Fullpath: r.Source, // May be empty if the migration was registered manually. - Version: r.Version, - Type: migrate.TypeGo, - Go: &migrate.Go{ - UseTx: r.UseTx, - UpFn: r.UpFnContext, - UpFnNoTx: r.UpFnNoTxContext, - DownFn: r.DownFnContext, - DownFnNoTx: r.DownFnNoTxContext, - }, + m := &migration{ + // Note, the fullpath may be empty if the migration was registered manually. + Source: newSource(TypeGo, fullpath, version), + Go: r, } migrations = append(migrations, m) - migrationLookup[r.Version] = m + migrationLookup[version] = m } // Sort migrations by version in ascending order. sort.Slice(migrations, func(i, j int) bool { - return migrations[i].Version < migrations[j].Version + return migrations[i].Source.Version < migrations[j].Source.Version }) return migrations, nil } func unregisteredError(unregistered []string) error { + const ( + hintURL = "https://github.com/pressly/goose/tree/master/examples/go-migrations" + ) f := "file" if len(unregistered) > 1 { f += "s" @@ -169,8 +205,9 @@ func unregisteredError(unregistered []string) error { for _, name := range unregistered { b.WriteString("\t" + name + "\n") } + hint := fmt.Sprintf("hint: go functions must be registered and built into a custom binary see:\n%s", hintURL) + b.WriteString(hint) b.WriteString("\n") - b.WriteString("go functions must be registered and built into a custom binary see:\nhttps://github.com/pressly/goose/tree/master/examples/go-migrations") return errors.New(b.String()) } diff --git a/internal/provider/collect_test.go b/internal/provider/collect_test.go index a5ee2d352..401a1ce40 100644 --- a/internal/provider/collect_test.go +++ b/internal/provider/collect_test.go @@ -10,14 +10,14 @@ import ( func TestCollectFileSources(t *testing.T) { t.Parallel() - t.Run("nil", func(t *testing.T) { + t.Run("nil_fsys", func(t *testing.T) { sources, err := collectFileSources(nil, false, nil) check.NoError(t, err) check.Bool(t, sources != nil, true) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) }) - t.Run("empty", func(t *testing.T) { + t.Run("empty_fsys", func(t *testing.T) { sources, err := collectFileSources(fstest.MapFS{}, false, nil) check.NoError(t, err) check.Number(t, len(sources.goSources), 0) @@ -47,10 +47,10 @@ func TestCollectFileSources(t *testing.T) { check.Number(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ - {Fullpath: "00001_foo.sql", Version: 1}, - {Fullpath: "00002_bar.sql", Version: 2}, - {Fullpath: "00003_baz.sql", Version: 3}, - {Fullpath: "00110_qux.sql", Version: 110}, + newSource(TypeSQL, "00001_foo.sql", 1), + newSource(TypeSQL, "00002_bar.sql", 2), + newSource(TypeSQL, "00003_baz.sql", 3), + newSource(TypeSQL, "00110_qux.sql", 110), }, } for i := 0; i < len(sources.sqlSources); i++ { @@ -74,8 +74,8 @@ func TestCollectFileSources(t *testing.T) { check.Number(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ - {Fullpath: "00001_foo.sql", Version: 1}, - {Fullpath: "00003_baz.sql", Version: 3}, + newSource(TypeSQL, "00001_foo.sql", 1), + newSource(TypeSQL, "00003_baz.sql", 3), }, } for i := 0; i < len(sources.sqlSources); i++ { @@ -159,18 +159,146 @@ func TestCollectFileSources(t *testing.T) { } } assertDirpath(".", []Source{ - {Fullpath: "876_a.sql", Version: 876}, + newSource(TypeSQL, "876_a.sql", 876), }) assertDirpath("dir1", []Source{ - {Fullpath: "101_a.sql", Version: 101}, - {Fullpath: "102_b.sql", Version: 102}, - {Fullpath: "103_c.sql", Version: 103}, + newSource(TypeSQL, "101_a.sql", 101), + newSource(TypeSQL, "102_b.sql", 102), + newSource(TypeSQL, "103_c.sql", 103), + }) + assertDirpath("dir2", []Source{ + newSource(TypeSQL, "201_a.sql", 201), }) - assertDirpath("dir2", []Source{{Fullpath: "201_a.sql", Version: 201}}) assertDirpath("dir3", nil) }) } +func TestMerge(t *testing.T) { + t.Parallel() + + t.Run("with_go_files_on_disk", func(t *testing.T) { + mapFS := fstest.MapFS{ + // SQL + "migrations/00001_foo.sql": sqlMapFile, + // Go + "migrations/00002_bar.go": {Data: []byte(`package migrations`)}, + "migrations/00003_baz.go": {Data: []byte(`package migrations`)}, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + check.Equal(t, len(sources.sqlSources), 1) + check.Equal(t, len(sources.goSources), 2) + src1 := sources.lookup(TypeSQL, 1) + check.Bool(t, src1 != nil, true) + src2 := sources.lookup(TypeGo, 2) + check.Bool(t, src2 != nil, true) + src3 := sources.lookup(TypeGo, 3) + check.Bool(t, src3 != nil, true) + + t.Run("valid", func(t *testing.T) { + migrations, err := merge(sources, map[int64]*goMigration{ + 2: {version: 2}, + 3: {version: 3}, + }) + check.NoError(t, err) + check.Number(t, len(migrations), 3) + assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3)) + }) + t.Run("unregistered_all", func(t *testing.T) { + _, err := merge(sources, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "error: detected 2 unregistered Go files:") + check.Contains(t, err.Error(), "00002_bar.go") + check.Contains(t, err.Error(), "00003_baz.go") + }) + t.Run("unregistered_some", func(t *testing.T) { + _, err := merge(sources, map[int64]*goMigration{ + 2: {version: 2}, + }) + check.HasError(t, err) + check.Contains(t, err.Error(), "error: detected 1 unregistered Go file") + check.Contains(t, err.Error(), "00003_baz.go") + }) + t.Run("duplicate_sql", func(t *testing.T) { + _, err := merge(sources, map[int64]*goMigration{ + 1: {version: 1}, // duplicate. SQL already exists. + 2: {version: 2}, + 3: {version: 3}, + }) + check.HasError(t, err) + check.Contains(t, err.Error(), "found duplicate migration version 1") + }) + }) + t.Run("no_go_files_on_disk", func(t *testing.T) { + mapFS := fstest.MapFS{ + // SQL + "migrations/00001_foo.sql": sqlMapFile, + "migrations/00002_bar.sql": sqlMapFile, + "migrations/00005_baz.sql": sqlMapFile, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + t.Run("unregistered_all", func(t *testing.T) { + migrations, err := merge(sources, map[int64]*goMigration{ + 3: {version: 3}, + // 4 is missing + 6: {version: 6}, + }) + check.NoError(t, err) + check.Number(t, len(migrations), 5) + assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2)) + assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], newSource(TypeSQL, "00005_baz.sql", 5)) + assertMigration(t, migrations[4], newSource(TypeGo, "", 6)) + }) + }) + t.Run("partial_go_files_on_disk", func(t *testing.T) { + mapFS := fstest.MapFS{ + "migrations/00001_foo.sql": sqlMapFile, + "migrations/00002_bar.go": &fstest.MapFile{Data: []byte(`package migrations`)}, + } + fsys, err := fs.Sub(mapFS, "migrations") + check.NoError(t, err) + sources, err := collectFileSources(fsys, false, nil) + check.NoError(t, err) + t.Run("unregistered_all", func(t *testing.T) { + migrations, err := merge(sources, map[int64]*goMigration{ + // This is the only Go file on disk. + 2: {version: 2}, + // These are not on disk. Explicitly registered. + 3: {version: 3}, + 6: {version: 6}, + }) + check.NoError(t, err) + check.Number(t, len(migrations), 4) + assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], newSource(TypeGo, "", 6)) + }) + }) +} + +func assertMigration(t *testing.T, got *migration, want Source) { + t.Helper() + check.Equal(t, got.Source, want) + switch got.Source.Type { + case TypeGo: + check.Equal(t, got.Go.version, want.Version) + case TypeSQL: + check.Bool(t, got.SQL == nil, true) + default: + t.Fatalf("unknown migration type: %s", got.Source.Type) + } +} + func newSQLOnlyFS() fstest.MapFS { return fstest.MapFS{ "migrations/00001_foo.sql": sqlMapFile, diff --git a/internal/provider/migration.go b/internal/provider/migration.go new file mode 100644 index 000000000..cf98abc3e --- /dev/null +++ b/internal/provider/migration.go @@ -0,0 +1,119 @@ +package provider + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/pressly/goose/v3/internal/sqlextended" +) + +type migration struct { + Source Source + // A migration is either a Go migration or a SQL migration, but never both. + // + // Note, the SQLParsed field is used to determine if the SQL migration has been parsed. This is + // an optimization to avoid parsing the SQL migration if it is never required. Also, the + // majority of the time migrations are incremental, so it is likely that the user will only want + // to run the last few migrations, and there is no need to parse ALL prior migrations. + // + // Exactly one of these fields will be set: + Go *goMigration + // -- OR -- + SQL *sqlMigration +} + +type MigrationType int + +const ( + TypeGo MigrationType = iota + 1 + TypeSQL +) + +func (t MigrationType) String() string { + switch t { + case TypeGo: + return "go" + case TypeSQL: + return "sql" + default: + // This should never happen. + return fmt.Sprintf("unknown (%d)", t) + } +} + +func (m *migration) GetSQLStatements(direction bool) ([]string, error) { + if m.Source.Type != TypeSQL { + return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Source.Type) + } + if m.SQL == nil { + return nil, errors.New("sql migration has not been parsed") + } + if direction { + return m.SQL.UpStatements, nil + } + return m.SQL.DownStatements, nil +} + +func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error { + if g == nil { + return nil + } + var fn func(context.Context, *sql.Tx) error + if direction && g.up != nil { + fn = g.up.Run + } + if !direction && g.down != nil { + fn = g.down.Run + } + if fn != nil { + return fn(ctx, tx) + } + return nil +} + +func (g *goMigration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { + if g == nil { + return nil + } + var fn func(context.Context, *sql.DB) error + if direction && g.up != nil { + fn = g.up.RunNoTx + } + if !direction && g.down != nil { + fn = g.down.RunNoTx + } + if fn != nil { + return fn(ctx, db) + } + return nil +} + +type sqlMigration struct { + UseTx bool + UpStatements []string + DownStatements []string +} + +func (s *sqlMigration) IsEmpty(direction bool) bool { + if direction { + return len(s.UpStatements) == 0 + } + return len(s.DownStatements) == 0 +} + +func (s *sqlMigration) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error { + var statements []string + if direction { + statements = s.UpStatements + } else { + statements = s.DownStatements + } + for _, stmt := range statements { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return err + } + } + return nil +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 6702f0731..7d5085069 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -4,12 +4,14 @@ import ( "context" "database/sql" "errors" + "fmt" "io/fs" "os" "time" - "github.com/pressly/goose/v3/internal/migrate" + "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/internal/sqlparser" ) var ( @@ -17,6 +19,8 @@ var ( ErrNoMigrations = errors.New("no migrations found") ) +var registeredGoMigrations = make(map[int64]*goose.Migration) + // NewProvider returns a new goose Provider. // // The caller is responsible for matching the database dialect with the database/sql driver. For @@ -68,7 +72,59 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) if err != nil { return nil, err } - migrations, err := merge(sources, nil) + // + // TODO(mf): move the merging of Go migrations into a separate function. + // + registered := make(map[int64]*goMigration) + // Add user-registered Go migrations. + for version, m := range cfg.registered { + registered[version] = &goMigration{ + version: version, + up: m.up, + down: m.down, + } + } + // Add init() functions. This is a bit ugly because we need to convert from the old Migration + // struct to the new goMigration struct. + for version, m := range registeredGoMigrations { + if _, ok := registered[version]; ok { + return nil, fmt.Errorf("go migration with version %d already registered", version) + } + g := &goMigration{ + version: version, + } + if m == nil { + return nil, errors.New("registered migration with nil init function") + } + if m.UpFnContext != nil && m.UpFnNoTxContext != nil { + return nil, errors.New("registered migration with both UpFnContext and UpFnNoTxContext") + } + if m.DownFnContext != nil && m.DownFnNoTxContext != nil { + return nil, errors.New("registered migration with both DownFnContext and DownFnNoTxContext") + } + // Up + if m.UpFnContext != nil { + g.up = &GoMigration{ + Run: m.UpFnContext, + } + } else if m.UpFnNoTxContext != nil { + g.up = &GoMigration{ + RunNoTx: m.UpFnNoTxContext, + } + } + // Down + if m.DownFnContext != nil { + g.down = &GoMigration{ + Run: m.DownFnContext, + } + } else if m.DownFnNoTxContext != nil { + g.down = &GoMigration{ + RunNoTx: m.DownFnNoTxContext, + } + } + registered[version] = g + } + migrations, err := merge(sources, registered) if err != nil { return nil, err } @@ -98,7 +154,7 @@ type Provider struct { fsys fs.FS cfg config store sqladapter.Store - migrations []*migrate.Migration + migrations []*migration } // State represents the state of a migration. @@ -137,48 +193,12 @@ func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { return 0, errors.New("not implemented") } -// SourceType represents the type of migration source. -type SourceType string - -const ( - // SourceTypeSQL represents a SQL migration. - SourceTypeSQL SourceType = "sql" - // SourceTypeGo represents a Go migration. - SourceTypeGo SourceType = "go" -) - -// Source represents a single migration source. -// -// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if -// the migration has a corresponding file on disk. It will be empty if the migration was registered -// manually. -type Source struct { - // Type is the type of migration. - Type SourceType - // Full path to the migration file. - // - // Example: /path/to/migrations/001_create_users_table.sql - Fullpath string - // Version is the version of the migration. - Version int64 -} - // ListSources returns a list of all available migration sources the provider is aware of, sorted in // ascending order by version. func (p *Provider) ListSources() []*Source { sources := make([]*Source, 0, len(p.migrations)) for _, m := range p.migrations { - s := &Source{ - Fullpath: m.Fullpath, - Version: m.Version, - } - switch m.Type { - case migrate.TypeSQL: - s.Type = SourceTypeSQL - case migrate.TypeGo: - s.Type = SourceTypeGo - } - sources = append(sources, s) + sources = append(sources, &m.Source) } return sources } @@ -240,3 +260,25 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { return nil, errors.New("not implemented") } + +// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it +// will not be parsed again. +// +// Important: This function will mutate SQL migrations and is not safe for concurrent use. +func ParseSQL(fsys fs.FS, debug bool, migrations []*migration) error { + for _, m := range migrations { + // If the migration is a SQL migration, and it has not been parsed, parse it. + if m.Source.Type == TypeSQL && m.SQL == nil { + parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug) + if err != nil { + return err + } + m.SQL = &sqlMigration{ + UseTx: parsed.UseTx, + UpStatements: parsed.Up, + DownStatements: parsed.Down, + } + } + } + return nil +} diff --git a/internal/provider/provider_options.go b/internal/provider/provider_options.go index d8060c458..f3ed15b28 100644 --- a/internal/provider/provider_options.go +++ b/internal/provider/provider_options.go @@ -1,6 +1,8 @@ package provider import ( + "context" + "database/sql" "errors" "fmt" @@ -72,11 +74,67 @@ func WithExcludes(excludes []string) ProviderOption { }) } +// GoMigration is a user-defined Go migration, registered using the option [WithGoMigration]. +type GoMigration struct { + // One of the following must be set: + Run func(context.Context, *sql.Tx) error + // -- OR -- + RunNoTx func(context.Context, *sql.DB) error +} + +// WithGoMigration registers a Go migration with the given version. +// +// If WithGoMigration is called multiple times with the same version, an error is returned. Both up +// and down functions may be nil. But if set, exactly one of Run or RunNoTx functions must be set. +func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { + return configFunc(func(c *config) error { + if version < 1 { + return fmt.Errorf("go migration version must be greater than 0") + } + if _, ok := c.registered[version]; ok { + return fmt.Errorf("go migration with version %d already registered", version) + } + // Allow nil up/down functions. This enables users to apply "no-op" migrations, while + // versioning them. + if up != nil { + if up.Run == nil && up.RunNoTx == nil { + return fmt.Errorf("go migration with version %d must have an up function", version) + } + if up.Run != nil && up.RunNoTx != nil { + return fmt.Errorf("go migration with version %d must not have both an up and upNoTx function", version) + } + } + if down != nil { + if down.Run == nil && down.RunNoTx == nil { + return fmt.Errorf("go migration with version %d must have a down function", version) + } + if down.Run != nil && down.RunNoTx != nil { + return fmt.Errorf("go migration with version %d must not have both a down and downNoTx function", version) + } + } + c.registered[version] = &goMigration{ + version: version, + up: up, + down: down, + } + return nil + }) +} + +type goMigration struct { + version int64 + up, down *GoMigration +} + type config struct { tableName string verbose bool excludes map[string]bool + // Go migrations registered by the user. These will be merged/resolved with migrations from the + // filesystem and init() functions. + registered map[int64]*goMigration + // Locking options lockEnabled bool sessionLocker lock.SessionLocker diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index 10aed48e0..c8b5effe3 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -36,11 +36,11 @@ func TestProvider(t *testing.T) { // 1 check.Equal(t, sources[0].Version, int64(1)) check.Equal(t, sources[0].Fullpath, "001_foo.sql") - check.Equal(t, sources[0].Type, provider.SourceTypeSQL) + check.Equal(t, sources[0].Type, provider.TypeSQL) // 2 check.Equal(t, sources[1].Version, int64(2)) check.Equal(t, sources[1].Fullpath, "002_bar.sql") - check.Equal(t, sources[1].Type, provider.SourceTypeSQL) + check.Equal(t, sources[1].Type, provider.TypeSQL) } var ( diff --git a/internal/migrate/run.go b/internal/provider/run.go similarity index 71% rename from internal/migrate/run.go rename to internal/provider/run.go index 7b7a883d8..f5ca25038 100644 --- a/internal/migrate/run.go +++ b/internal/provider/run.go @@ -1,4 +1,4 @@ -package migrate +package provider import ( "context" @@ -8,10 +8,10 @@ import ( ) // Run runs the migration inside of a transaction. -func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error { - switch m.Type { +func (m *migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error { + switch m.Source.Type { case TypeSQL: - if m.SQL == nil || !m.SQLParsed { + if m.SQL == nil { return fmt.Errorf("tx: sql migration has not been parsed") } return m.SQL.run(ctx, tx, direction) @@ -19,14 +19,14 @@ func (m *Migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error { return m.Go.run(ctx, tx, direction) } // This should never happen. - return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) + return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) } // RunNoTx runs the migration without a transaction. -func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error { - switch m.Type { +func (m *migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error { + switch m.Source.Type { case TypeSQL: - if m.SQL == nil || !m.SQLParsed { + if m.SQL == nil { return fmt.Errorf("db: sql migration has not been parsed") } return m.SQL.run(ctx, db, direction) @@ -34,14 +34,14 @@ func (m *Migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) err return m.Go.runNoTx(ctx, db, direction) } // This should never happen. - return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) + return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) } // RunConn runs the migration without a transaction using the provided connection. -func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error { - switch m.Type { +func (m *migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error { + switch m.Source.Type { case TypeSQL: - if m.SQL == nil || !m.SQLParsed { + if m.SQL == nil { return fmt.Errorf("conn: sql migration has not been parsed") } return m.SQL.run(ctx, conn, direction) @@ -49,5 +49,5 @@ func (m *Migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") } // This should never happen. - return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Fullpath)) + return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) } diff --git a/internal/sqlparser/parse.go b/internal/sqlparser/parse.go new file mode 100644 index 000000000..e993587a6 --- /dev/null +++ b/internal/sqlparser/parse.go @@ -0,0 +1,54 @@ +package sqlparser + +import ( + "io/fs" + + "go.uber.org/multierr" + "golang.org/x/sync/errgroup" +) + +type ParsedSQL struct { + UseTx bool + Up, Down []string +} + +func ParseAllFromFS(fsys fs.FS, filename string, debug bool) (*ParsedSQL, error) { + parsedSQL := new(ParsedSQL) + // TODO(mf): parse is called twice, once for up and once for down. This is inefficient. It + // should be possible to parse both directions in one pass. Also, UseTx is set once (but + // returned twice), which is unnecessary and potentially error-prone if the two calls to + // parseSQL disagree based on direction. + var g errgroup.Group + g.Go(func() error { + up, useTx, err := parse(fsys, filename, DirectionUp, debug) + if err != nil { + return err + } + parsedSQL.Up = up + parsedSQL.UseTx = useTx + return nil + }) + g.Go(func() error { + down, _, err := parse(fsys, filename, DirectionDown, debug) + if err != nil { + return err + } + parsedSQL.Down = down + return nil + }) + if err := g.Wait(); err != nil { + return nil, err + } + return parsedSQL, nil +} + +func parse(fsys fs.FS, filename string, direction Direction, debug bool) (_ []string, _ bool, retErr error) { + r, err := fsys.Open(filename) + if err != nil { + return nil, false, err + } + defer func() { + retErr = multierr.Append(retErr, r.Close()) + }() + return ParseSQLMigration(r, direction, debug) +} diff --git a/internal/sqlparser/parse_test.go b/internal/sqlparser/parse_test.go new file mode 100644 index 000000000..632bbe13b --- /dev/null +++ b/internal/sqlparser/parse_test.go @@ -0,0 +1,82 @@ +package sqlparser_test + +import ( + "errors" + "os" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/sqlparser" +) + +func TestParseAllFromFS(t *testing.T) { + t.Parallel() + t.Run("file_not_exist", func(t *testing.T) { + mapFS := fstest.MapFS{} + _, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) + check.HasError(t, err) + check.Bool(t, errors.Is(err, os.ErrNotExist), true) + }) + t.Run("empty_file", func(t *testing.T) { + mapFS := fstest.MapFS{ + "001_foo.sql": &fstest.MapFile{}, + } + _, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) + check.HasError(t, err) + check.Contains(t, err.Error(), "failed to parse migration") + check.Contains(t, err.Error(), "must start with '-- +goose Up' annotation") + }) + t.Run("all_statements", func(t *testing.T) { + mapFS := fstest.MapFS{ + "001_foo.sql": newFile(` +-- +goose Up +`), + "002_bar.sql": newFile(` +-- +goose Up +-- +goose Down +`), + "003_baz.sql": newFile(` +-- +goose Up +CREATE TABLE foo (id int); +CREATE TABLE bar (id int); + +-- +goose Down +DROP TABLE bar; +`), + "004_qux.sql": newFile(` +-- +goose NO TRANSACTION +-- +goose Up +CREATE TABLE foo (id int); +-- +goose Down +DROP TABLE foo; +`), + } + parsedSQL, err := sqlparser.ParseAllFromFS(mapFS, "001_foo.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, true, 0, 0) + parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "002_bar.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, true, 0, 0) + parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "003_baz.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, true, 2, 1) + parsedSQL, err = sqlparser.ParseAllFromFS(mapFS, "004_qux.sql", false) + check.NoError(t, err) + assertParsedSQL(t, parsedSQL, false, 1, 1) + }) +} + +func assertParsedSQL(t *testing.T, got *sqlparser.ParsedSQL, useTx bool, up, down int) { + t.Helper() + check.Bool(t, got != nil, true) + check.Equal(t, len(got.Up), up) + check.Equal(t, len(got.Down), down) + check.Equal(t, got.UseTx, useTx) +} + +func newFile(data string) *fstest.MapFile { + return &fstest.MapFile{ + Data: []byte(data), + } +} From 257b523dffbebe543f949a5feed7cdeda63a201e Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Mon, 16 Oct 2023 22:08:38 -0400 Subject: [PATCH 10/10] feat(experimental): add migration logic with tests (#617) --- internal/provider/collect.go | 71 +- internal/provider/collect_test.go | 70 +- internal/provider/errors.go | 39 + internal/provider/migration.go | 91 +- internal/provider/misc.go | 39 + internal/provider/provider.go | 138 +- internal/provider/provider_options.go | 40 +- internal/provider/provider_options_test.go | 2 +- internal/provider/provider_test.go | 71 +- internal/provider/run.go | 386 ++++- internal/provider/run_down.go | 53 + internal/provider/run_test.go | 1282 +++++++++++++++++ internal/provider/run_up.go | 96 ++ .../no-versioning/migrations/00001_a.sql | 8 + .../no-versioning/migrations/00002_b.sql | 9 + .../no-versioning/migrations/00003_c.sql | 9 + .../testdata/no-versioning/seed/00001_a.sql | 17 + .../testdata/no-versioning/seed/00002_b.sql | 15 + internal/provider/types.go | 99 ++ internal/sqlparser/parse.go | 7 +- testdata/migrations/00002_posts_table.sql | 2 + 21 files changed, 2317 insertions(+), 227 deletions(-) create mode 100644 internal/provider/errors.go create mode 100644 internal/provider/misc.go create mode 100644 internal/provider/run_down.go create mode 100644 internal/provider/run_test.go create mode 100644 internal/provider/run_up.go create mode 100644 internal/provider/testdata/no-versioning/migrations/00001_a.sql create mode 100644 internal/provider/testdata/no-versioning/migrations/00002_b.sql create mode 100644 internal/provider/testdata/no-versioning/migrations/00003_c.sql create mode 100644 internal/provider/testdata/no-versioning/seed/00001_a.sql create mode 100644 internal/provider/testdata/no-versioning/seed/00002_b.sql create mode 100644 internal/provider/types.go diff --git a/internal/provider/collect.go b/internal/provider/collect.go index 6658c8067..fd7d63e75 100644 --- a/internal/provider/collect.go +++ b/internal/provider/collect.go @@ -4,30 +4,14 @@ import ( "errors" "fmt" "io/fs" + "os" "path/filepath" "sort" + "strconv" "strings" - - "github.com/pressly/goose/v3" ) -// Source represents a single migration source. -// -// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if -// the migration has a corresponding file on disk. It will be empty if the migration was registered -// manually. -type Source struct { - // Type is the type of migration. - Type MigrationType - // Full path to the migration file. - // - // Example: /path/to/migrations/001_create_users_table.sql - Fullpath string - // Version is the version of the migration. - Version int64 -} - -func newSource(t MigrationType, fullpath string, version int64) Source { +func NewSource(t MigrationType, fullpath string, version int64) Source { return Source{ Type: t, Fullpath: fullpath, @@ -41,6 +25,7 @@ type fileSources struct { goSources []Source } +// TODO(mf): remove? func (s *fileSources) lookup(t MigrationType, version int64) *Source { switch t { case TypeGo: @@ -93,7 +78,7 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil // filenames, but still have versioned migrations within the same directory. For // example, a user could have a helpers.go file which contains unexported helper // functions for migrations. - version, err := goose.NumericComponent(base) + version, err := NumericComponent(base) if err != nil { if strict { return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err) @@ -110,9 +95,9 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil } switch filepath.Ext(base) { case ".sql": - sources.sqlSources = append(sources.sqlSources, newSource(TypeSQL, fullpath, version)) + sources.sqlSources = append(sources.sqlSources, NewSource(TypeSQL, fullpath, version)) case ".go": - sources.goSources = append(sources.goSources, newSource(TypeGo, fullpath, version)) + sources.goSources = append(sources.goSources, NewSource(TypeGo, fullpath, version)) default: // Should never happen since we already filtered out all other file types. return nil, fmt.Errorf("unknown migration type: %s", base) @@ -161,12 +146,15 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration // wholesale as part of migrations. This allows users to build a custom binary that only embeds // the SQL migration files. for version, r := range registerd { - var fullpath string - if s := sources.lookup(TypeGo, version); s != nil { - fullpath = s.Fullpath + fullpath := r.fullpath + if fullpath == "" { + if s := sources.lookup(TypeGo, version); s != nil { + fullpath = s.Fullpath + } } // Ensure there are no duplicate versions. if existing, ok := migrationLookup[version]; ok { + fullpath := r.fullpath if fullpath == "" { fullpath = "manually registered (no source)" } @@ -178,7 +166,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration } m := &migration{ // Note, the fullpath may be empty if the migration was registered manually. - Source: newSource(TypeGo, fullpath, version), + Source: NewSource(TypeGo, fullpath, version), Go: r, } migrations = append(migrations, m) @@ -211,3 +199,34 @@ func unregisteredError(unregistered []string) error { return errors.New(b.String()) } + +type noopFS struct{} + +var _ fs.FS = noopFS{} + +func (f noopFS) Open(name string) (fs.File, error) { + return nil, os.ErrNotExist +} + +// NumericComponent parses the version from the migration file name. +// +// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of +// migration, either .sql or .go. +func NumericComponent(filename string) (int64, error) { + base := filepath.Base(filename) + if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { + return 0, errors.New("migration file does not have .sql or .go file extension") + } + idx := strings.Index(base, "_") + if idx < 0 { + return 0, errors.New("no filename separator '_' found") + } + n, err := strconv.ParseInt(base[:idx], 10, 64) + if err != nil { + return 0, err + } + if n < 1 { + return 0, errors.New("migration version must be greater than zero") + } + return n, nil +} diff --git a/internal/provider/collect_test.go b/internal/provider/collect_test.go index 401a1ce40..73b2642c5 100644 --- a/internal/provider/collect_test.go +++ b/internal/provider/collect_test.go @@ -47,10 +47,10 @@ func TestCollectFileSources(t *testing.T) { check.Number(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ - newSource(TypeSQL, "00001_foo.sql", 1), - newSource(TypeSQL, "00002_bar.sql", 2), - newSource(TypeSQL, "00003_baz.sql", 3), - newSource(TypeSQL, "00110_qux.sql", 110), + NewSource(TypeSQL, "00001_foo.sql", 1), + NewSource(TypeSQL, "00002_bar.sql", 2), + NewSource(TypeSQL, "00003_baz.sql", 3), + NewSource(TypeSQL, "00110_qux.sql", 110), }, } for i := 0; i < len(sources.sqlSources); i++ { @@ -74,8 +74,8 @@ func TestCollectFileSources(t *testing.T) { check.Number(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ - newSource(TypeSQL, "00001_foo.sql", 1), - newSource(TypeSQL, "00003_baz.sql", 3), + NewSource(TypeSQL, "00001_foo.sql", 1), + NewSource(TypeSQL, "00003_baz.sql", 3), }, } for i := 0; i < len(sources.sqlSources); i++ { @@ -159,15 +159,15 @@ func TestCollectFileSources(t *testing.T) { } } assertDirpath(".", []Source{ - newSource(TypeSQL, "876_a.sql", 876), + NewSource(TypeSQL, "876_a.sql", 876), }) assertDirpath("dir1", []Source{ - newSource(TypeSQL, "101_a.sql", 101), - newSource(TypeSQL, "102_b.sql", 102), - newSource(TypeSQL, "103_c.sql", 103), + NewSource(TypeSQL, "101_a.sql", 101), + NewSource(TypeSQL, "102_b.sql", 102), + NewSource(TypeSQL, "103_c.sql", 103), }) assertDirpath("dir2", []Source{ - newSource(TypeSQL, "201_a.sql", 201), + NewSource(TypeSQL, "201_a.sql", 201), }) assertDirpath("dir3", nil) }) @@ -199,14 +199,14 @@ func TestMerge(t *testing.T) { t.Run("valid", func(t *testing.T) { migrations, err := merge(sources, map[int64]*goMigration{ - 2: {version: 2}, - 3: {version: 3}, + 2: newGoMigration("", nil, nil), + 3: newGoMigration("", nil, nil), }) check.NoError(t, err) check.Number(t, len(migrations), 3) - assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) - assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) - assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3)) + assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], NewSource(TypeGo, "00003_baz.go", 3)) }) t.Run("unregistered_all", func(t *testing.T) { _, err := merge(sources, nil) @@ -217,7 +217,7 @@ func TestMerge(t *testing.T) { }) t.Run("unregistered_some", func(t *testing.T) { _, err := merge(sources, map[int64]*goMigration{ - 2: {version: 2}, + 2: newGoMigration("", nil, nil), }) check.HasError(t, err) check.Contains(t, err.Error(), "error: detected 1 unregistered Go file") @@ -225,9 +225,9 @@ func TestMerge(t *testing.T) { }) t.Run("duplicate_sql", func(t *testing.T) { _, err := merge(sources, map[int64]*goMigration{ - 1: {version: 1}, // duplicate. SQL already exists. - 2: {version: 2}, - 3: {version: 3}, + 1: newGoMigration("", nil, nil), // duplicate. SQL already exists. + 2: newGoMigration("", nil, nil), + 3: newGoMigration("", nil, nil), }) check.HasError(t, err) check.Contains(t, err.Error(), "found duplicate migration version 1") @@ -246,17 +246,17 @@ func TestMerge(t *testing.T) { check.NoError(t, err) t.Run("unregistered_all", func(t *testing.T) { migrations, err := merge(sources, map[int64]*goMigration{ - 3: {version: 3}, + 3: newGoMigration("", nil, nil), // 4 is missing - 6: {version: 6}, + 6: newGoMigration("", nil, nil), }) check.NoError(t, err) check.Number(t, len(migrations), 5) - assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) - assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2)) - assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) - assertMigration(t, migrations[3], newSource(TypeSQL, "00005_baz.sql", 5)) - assertMigration(t, migrations[4], newSource(TypeGo, "", 6)) + assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], NewSource(TypeSQL, "00002_bar.sql", 2)) + assertMigration(t, migrations[2], NewSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], NewSource(TypeSQL, "00005_baz.sql", 5)) + assertMigration(t, migrations[4], NewSource(TypeGo, "", 6)) }) }) t.Run("partial_go_files_on_disk", func(t *testing.T) { @@ -271,17 +271,17 @@ func TestMerge(t *testing.T) { t.Run("unregistered_all", func(t *testing.T) { migrations, err := merge(sources, map[int64]*goMigration{ // This is the only Go file on disk. - 2: {version: 2}, + 2: newGoMigration("", nil, nil), // These are not on disk. Explicitly registered. - 3: {version: 3}, - 6: {version: 6}, + 3: newGoMigration("", nil, nil), + 6: newGoMigration("", nil, nil), }) check.NoError(t, err) check.Number(t, len(migrations), 4) - assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) - assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) - assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) - assertMigration(t, migrations[3], newSource(TypeGo, "", 6)) + assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], NewSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], NewSource(TypeGo, "", 6)) }) }) } @@ -291,7 +291,7 @@ func assertMigration(t *testing.T, got *migration, want Source) { check.Equal(t, got.Source, want) switch got.Source.Type { case TypeGo: - check.Equal(t, got.Go.version, want.Version) + check.Bool(t, got.Go != nil, true) case TypeSQL: check.Bool(t, got.SQL == nil, true) default: diff --git a/internal/provider/errors.go b/internal/provider/errors.go new file mode 100644 index 000000000..e8ece3871 --- /dev/null +++ b/internal/provider/errors.go @@ -0,0 +1,39 @@ +package provider + +import ( + "errors" + "fmt" + "path/filepath" +) + +var ( + // ErrVersionNotFound when a migration version is not found. + ErrVersionNotFound = errors.New("version not found") + + // ErrAlreadyApplied when a migration has already been applied. + ErrAlreadyApplied = errors.New("already applied") + + // ErrNoMigrations is returned by [NewProvider] when no migrations are found. + ErrNoMigrations = errors.New("no migrations found") + + // ErrNoNextVersion when the next migration version is not found. + ErrNoNextVersion = errors.New("no next version found") +) + +// PartialError is returned when a migration fails, but some migrations already got applied. +type PartialError struct { + // Applied are migrations that were applied successfully before the error occurred. + Applied []*MigrationResult + // Failed contains the result of the migration that failed. + Failed *MigrationResult + // Err is the error that occurred while running the migration. + Err error +} + +func (e *PartialError) Error() string { + filename := "(file unknown)" + if e.Failed != nil && e.Failed.Source.Fullpath != "" { + filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Fullpath)) + } + return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err) +} diff --git a/internal/provider/migration.go b/internal/provider/migration.go index cf98abc3e..87098cf22 100644 --- a/internal/provider/migration.go +++ b/internal/provider/migration.go @@ -3,8 +3,8 @@ package provider import ( "context" "database/sql" - "errors" "fmt" + "path/filepath" "github.com/pressly/goose/v3/internal/sqlextended" ) @@ -24,36 +24,83 @@ type migration struct { SQL *sqlMigration } -type MigrationType int +func (m *migration) useTx(direction bool) bool { + switch m.Source.Type { + case TypeSQL: + return m.SQL.UseTx + case TypeGo: + if m.Go == nil { + return false + } + if direction { + return m.Go.up.Run != nil + } + return m.Go.down.Run != nil + } + // This should never happen. + return false +} -const ( - TypeGo MigrationType = iota + 1 - TypeSQL -) +func (m *migration) filename() string { + return filepath.Base(m.Source.Fullpath) +} -func (t MigrationType) String() string { - switch t { - case TypeGo: - return "go" +// run runs the migration inside of a transaction. +func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error { + switch m.Source.Type { case TypeSQL: - return "sql" - default: - // This should never happen. - return fmt.Sprintf("unknown (%d)", t) + if m.SQL == nil { + return fmt.Errorf("tx: sql migration has not been parsed") + } + return m.SQL.run(ctx, tx, direction) + case TypeGo: + return m.Go.run(ctx, tx, direction) } + // This should never happen. + return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) } -func (m *migration) GetSQLStatements(direction bool) ([]string, error) { - if m.Source.Type != TypeSQL { - return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Source.Type) +// runNoTx runs the migration without a transaction. +func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { + switch m.Source.Type { + case TypeSQL: + if m.SQL == nil { + return fmt.Errorf("db: sql migration has not been parsed") + } + return m.SQL.run(ctx, db, direction) + case TypeGo: + return m.Go.runNoTx(ctx, db, direction) } - if m.SQL == nil { - return nil, errors.New("sql migration has not been parsed") + // This should never happen. + return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) +} + +// runConn runs the migration without a transaction using the provided connection. +func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool) error { + switch m.Source.Type { + case TypeSQL: + if m.SQL == nil { + return fmt.Errorf("conn: sql migration has not been parsed") + } + return m.SQL.run(ctx, conn, direction) + case TypeGo: + return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") } - if direction { - return m.SQL.UpStatements, nil + // This should never happen. + return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) +} + +type goMigration struct { + fullpath string + up, down *GoMigration +} + +func newGoMigration(fullpath string, up, down *GoMigration) *goMigration { + return &goMigration{ + fullpath: fullpath, + up: up, + down: down, } - return m.SQL.DownStatements, nil } func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error { diff --git a/internal/provider/misc.go b/internal/provider/misc.go new file mode 100644 index 000000000..be84b4622 --- /dev/null +++ b/internal/provider/misc.go @@ -0,0 +1,39 @@ +package provider + +import ( + "context" + "database/sql" + "errors" + "fmt" +) + +type Migration struct { + Version int64 + Source string // path to .sql script or go file + Registered bool + UseTx bool + UpFnContext func(context.Context, *sql.Tx) error + DownFnContext func(context.Context, *sql.Tx) error + + UpFnNoTxContext func(context.Context, *sql.DB) error + DownFnNoTxContext func(context.Context, *sql.DB) error +} + +var registeredGoMigrations = make(map[int64]*Migration) + +func SetGlobalGoMigrations(migrations []*Migration) error { + for _, m := range migrations { + if m == nil { + return errors.New("cannot register nil go migration") + } + if _, ok := registeredGoMigrations[m.Version]; ok { + return fmt.Errorf("go migration with version %d already registered", m.Version) + } + registeredGoMigrations[m.Version] = m + } + return nil +} + +func ResetGlobalGoMigrations() { + registeredGoMigrations = make(map[int64]*Migration) +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 7d5085069..3982ac37b 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -6,21 +6,12 @@ import ( "errors" "fmt" "io/fs" - "os" - "time" + "math" + "sync" - "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/internal/sqladapter" - "github.com/pressly/goose/v3/internal/sqlparser" ) -var ( - // ErrNoMigrations is returned by [NewProvider] when no migrations are found. - ErrNoMigrations = errors.New("no migrations found") -) - -var registeredGoMigrations = make(map[int64]*goose.Migration) - // NewProvider returns a new goose Provider. // // The caller is responsible for matching the database dialect with the database/sql driver. For @@ -36,7 +27,7 @@ var registeredGoMigrations = make(map[int64]*goose.Migration) // Unless otherwise specified, all methods on Provider are safe for concurrent use. // // Experimental: This API is experimental and may change in the future. -func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { +func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { if db == nil { return nil, errors.New("db must not be nil") } @@ -46,7 +37,9 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) if fsys == nil { fsys = noopFS{} } - var cfg config + cfg := config{ + registered: make(map[int64]*goMigration), + } for _, opt := range opts { if err := opt.apply(&cfg); err != nil { return nil, err @@ -54,9 +47,9 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) } // Set defaults after applying user-supplied options so option funcs can check for empty values. if cfg.tableName == "" { - cfg.tableName = defaultTablename + cfg.tableName = DefaultTablename } - store, err := sqladapter.NewStore(dialect, cfg.tableName) + store, err := sqladapter.NewStore(string(dialect), cfg.tableName) if err != nil { return nil, err } @@ -78,11 +71,7 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) registered := make(map[int64]*goMigration) // Add user-registered Go migrations. for version, m := range cfg.registered { - registered[version] = &goMigration{ - version: version, - up: m.up, - down: m.down, - } + registered[version] = newGoMigration("", m.up, m.down) } // Add init() functions. This is a bit ugly because we need to convert from the old Migration // struct to the new goMigration struct. @@ -90,12 +79,10 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) if _, ok := registered[version]; ok { return nil, fmt.Errorf("go migration with version %d already registered", version) } - g := &goMigration{ - version: version, - } if m == nil { return nil, errors.New("registered migration with nil init function") } + g := newGoMigration(m.Source, nil, nil) if m.UpFnContext != nil && m.UpFnNoTxContext != nil { return nil, errors.New("registered migration with both UpFnContext and UpFnNoTxContext") } @@ -140,16 +127,12 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) }, nil } -type noopFS struct{} - -var _ fs.FS = noopFS{} - -func (f noopFS) Open(name string) (fs.File, error) { - return nil, os.ErrNotExist -} - // Provider is a goose migration provider. type Provider struct { + // mu protects all accesses to the provider and must be held when calling operations on the + // database. + mu sync.Mutex + db *sql.DB fsys fs.FS cfg config @@ -157,48 +140,27 @@ type Provider struct { migrations []*migration } -// State represents the state of a migration. -type State string - -const ( - // StateUntracked represents a migration that is in the database, but not on the filesystem. - StateUntracked State = "untracked" - // StatePending represents a migration that is on the filesystem, but not in the database. - StatePending State = "pending" - // StateApplied represents a migration that is in BOTH the database and on the filesystem. - StateApplied State = "applied" -) - -// MigrationStatus represents the status of a single migration. -type MigrationStatus struct { - // State is the state of the migration. - State State - // AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or - // [StateUntracked]. - AppliedAt time.Time - // Source is the migration source. Only set if the state is [StatePending] or [StateApplied]. - Source *Source -} - // Status returns the status of all migrations, merging the list of migrations from the database and // filesystem. The returned items are ordered by version, in ascending order. func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { - return nil, errors.New("not implemented") + return p.status(ctx) } // GetDBVersion returns the max version from the database, regardless of the applied order. For // example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been // applied, it returns 0. +// +// TODO(mf): this is not true? func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { - return 0, errors.New("not implemented") + return p.getDBVersion(ctx) } // ListSources returns a list of all available migration sources the provider is aware of, sorted in // ascending order by version. -func (p *Provider) ListSources() []*Source { - sources := make([]*Source, 0, len(p.migrations)) +func (p *Provider) ListSources() []Source { + sources := make([]Source, 0, len(p.migrations)) for _, m := range p.migrations { - sources = append(sources, &m.Source) + sources = append(sources, m.Source) } return sources } @@ -213,9 +175,6 @@ func (p *Provider) Close() error { return p.db.Close() } -// MigrationResult represents the result of a single migration. -type MigrationResult struct{} - // ApplyVersion applies exactly one migration at the specified version. If there is no source for // the specified version, this method returns [ErrNoCurrentVersion]. If the migration has been // applied already, this method returns [ErrAlreadyApplied]. @@ -223,19 +182,26 @@ type MigrationResult struct{} // When direction is true, the up migration is executed, and when direction is false, the down // migration is executed. func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) { - return nil, errors.New("not implemented") + return p.apply(ctx, version, direction) } // Up applies all pending migrations. If there are no new migrations to apply, this method returns // empty list and nil error. func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { - return nil, errors.New("not implemented") + return p.up(ctx, false, math.MaxInt64) } // UpByOne applies the next available migration. If there are no migrations to apply, this method -// returns [ErrNoNextVersion]. -func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { - return nil, errors.New("not implemented") +// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result. +func (p *Provider) UpByOne(ctx context.Context) ([]*MigrationResult, error) { + res, err := p.up(ctx, true, math.MaxInt64) + if err != nil { + return nil, err + } + if len(res) == 0 { + return nil, ErrNoNextVersion + } + return res, nil } // UpTo applies all available migrations up to and including the specified version. If there are no @@ -244,13 +210,20 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { // For instance, if there are three new migrations (9,10,11) and the current database version is 8 // with a requested version of 10, only versions 9 and 10 will be applied. func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { - return nil, errors.New("not implemented") + return p.up(ctx, false, version) } // Down rolls back the most recently applied migration. If there are no migrations to apply, this // method returns [ErrNoNextVersion]. -func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { - return nil, errors.New("not implemented") +func (p *Provider) Down(ctx context.Context) ([]*MigrationResult, error) { + res, err := p.down(ctx, true, 0) + if err != nil { + return nil, err + } + if len(res) == 0 { + return nil, ErrNoNextVersion + } + return res, nil } // DownTo rolls back all migrations down to but not including the specified version. @@ -258,27 +231,8 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { // For instance, if the current database version is 11, and the requested version is 9, only // migrations 11 and 10 will be rolled back. func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { - return nil, errors.New("not implemented") -} - -// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it -// will not be parsed again. -// -// Important: This function will mutate SQL migrations and is not safe for concurrent use. -func ParseSQL(fsys fs.FS, debug bool, migrations []*migration) error { - for _, m := range migrations { - // If the migration is a SQL migration, and it has not been parsed, parse it. - if m.Source.Type == TypeSQL && m.SQL == nil { - parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug) - if err != nil { - return err - } - m.SQL = &sqlMigration{ - UseTx: parsed.UseTx, - UpStatements: parsed.Up, - DownStatements: parsed.Down, - } - } + if version < 0 { + return nil, fmt.Errorf("version must be a number greater than or equal zero: %d", version) } - return nil + return p.down(ctx, false, version) } diff --git a/internal/provider/provider_options.go b/internal/provider/provider_options.go index f3ed15b28..0b7cd7ad6 100644 --- a/internal/provider/provider_options.go +++ b/internal/provider/provider_options.go @@ -10,7 +10,7 @@ import ( ) const ( - defaultTablename = "goose_db_version" + DefaultTablename = "goose_db_version" ) // ProviderOption is a configuration option for a goose provider. @@ -35,9 +35,9 @@ func WithTableName(name string) ProviderOption { } // WithVerbose enables verbose logging. -func WithVerbose() ProviderOption { +func WithVerbose(b bool) ProviderOption { return configFunc(func(c *config) error { - c.verbose = true + c.verbose = b return nil }) } @@ -89,7 +89,7 @@ type GoMigration struct { func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { return configFunc(func(c *config) error { if version < 1 { - return fmt.Errorf("go migration version must be greater than 0") + return errors.New("version must be greater than zero") } if _, ok := c.registered[version]; ok { return fmt.Errorf("go migration with version %d already registered", version) @@ -113,17 +113,33 @@ func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { } } c.registered[version] = &goMigration{ - version: version, - up: up, - down: down, + up: up, + down: down, } return nil }) } -type goMigration struct { - version int64 - up, down *GoMigration +// WithAllowMissing allows the provider to apply missing (out-of-order) migrations. +// +// Example: migrations 1,6 are applied and then version 2,3,5 are introduced. If this option is +// true, then goose will apply 2,3,5 instead of raising an error. The final order of applied +// migrations will be: 1,6,2,3,5. +func WithAllowMissing(b bool) ProviderOption { + return configFunc(func(c *config) error { + c.allowMissing = b + return nil + }) +} + +// WithNoVersioning disables versioning. Disabling versioning allows the ability to apply migrations +// without tracking the versions in the database schema table. Useful for tests, seeding a database +// or running ad-hoc queries. +func WithNoVersioning(b bool) ProviderOption { + return configFunc(func(c *config) error { + c.noVersioning = b + return nil + }) } type config struct { @@ -138,6 +154,10 @@ type config struct { // Locking options lockEnabled bool sessionLocker lock.SessionLocker + + // Feature + noVersioning bool + allowMissing bool } type configFunc func(*config) error diff --git a/internal/provider/provider_options_test.go b/internal/provider/provider_options_test.go index 89a1cda16..2271111ba 100644 --- a/internal/provider/provider_options_test.go +++ b/internal/provider/provider_options_test.go @@ -59,7 +59,7 @@ func TestNewProvider(t *testing.T) { check.NoError(t, err) // Valid dialect, db, fsys, and verbose allowed _, err = provider.NewProvider("sqlite3", db, fsys, - provider.WithVerbose(), + provider.WithVerbose(testing.Verbose()), ) check.NoError(t, err) }) diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index c8b5effe3..ac4ec7e0e 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -1,6 +1,7 @@ package provider_test import ( + "context" "database/sql" "errors" "io/fs" @@ -33,14 +34,68 @@ func TestProvider(t *testing.T) { check.NoError(t, err) sources := p.ListSources() check.Equal(t, len(sources), 2) - // 1 - check.Equal(t, sources[0].Version, int64(1)) - check.Equal(t, sources[0].Fullpath, "001_foo.sql") - check.Equal(t, sources[0].Type, provider.TypeSQL) - // 2 - check.Equal(t, sources[1].Version, int64(2)) - check.Equal(t, sources[1].Fullpath, "002_bar.sql") - check.Equal(t, sources[1].Type, provider.TypeSQL) + check.Equal(t, sources[0], provider.NewSource(provider.TypeSQL, "001_foo.sql", 1)) + check.Equal(t, sources[1], provider.NewSource(provider.TypeSQL, "002_bar.sql", 2)) + + t.Run("duplicate_go", func(t *testing.T) { + // Not parallel because it modifies global state. + register := []*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnContext: nil, + DownFnContext: nil, + }, + } + err := provider.SetGlobalGoMigrations(register) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + + db := newDB(t) + _, err = provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration(1, nil, nil), + ) + check.HasError(t, err) + check.Equal(t, err.Error(), "go migration with version 1 already registered") + }) + t.Run("empty_go", func(t *testing.T) { + db := newDB(t) + // explicit + _, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration(1, &provider.GoMigration{Run: nil}, &provider.GoMigration{Run: nil}), + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "go migration with version 1 must have an up function") + }) + t.Run("duplicate_up", func(t *testing.T) { + err := provider.SetGlobalGoMigrations([]*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnContext: func(context.Context, *sql.Tx) error { return nil }, + UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, + }, + }) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + db := newDB(t) + _, err = provider.NewProvider(provider.DialectSQLite3, db, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "registered migration with both UpFnContext and UpFnNoTxContext") + }) + t.Run("duplicate_down", func(t *testing.T) { + err := provider.SetGlobalGoMigrations([]*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + DownFnContext: func(context.Context, *sql.Tx) error { return nil }, + DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, + }, + }) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + db := newDB(t) + _, err = provider.NewProvider(provider.DialectSQLite3, db, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "registered migration with both DownFnContext and DownFnNoTxContext") + }) } var ( diff --git a/internal/provider/run.go b/internal/provider/run.go index f5ca25038..55bef9f32 100644 --- a/internal/provider/run.go +++ b/internal/provider/run.go @@ -3,51 +3,373 @@ package provider import ( "context" "database/sql" + "errors" "fmt" - "path/filepath" + "io/fs" + "sort" + "strings" + "time" + + "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" ) -// Run runs the migration inside of a transaction. -func (m *migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error { - switch m.Source.Type { - case TypeSQL: - if m.SQL == nil { - return fmt.Errorf("tx: sql migration has not been parsed") +// runMigrations runs migrations sequentially in the given direction. +// +// If the migrations slice is empty, this function returns nil with no error. +func (p *Provider) runMigrations( + ctx context.Context, + conn *sql.Conn, + migrations []*migration, + direction sqlparser.Direction, + byOne bool, +) ([]*MigrationResult, error) { + if len(migrations) == 0 { + return nil, nil + } + var apply []*migration + if byOne { + apply = []*migration{migrations[0]} + } else { + apply = migrations + } + // Lazily parse SQL migrations (if any) in both directions. We do this before running any + // migrations so that we can fail fast if there are any errors and avoid leaving the database in + // a partially migrated state. + + if err := parseSQL(p.fsys, false, apply); err != nil { + return nil, err + } + + // TODO(mf): If we decide to add support for advisory locks at the transaction level, this may + // be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe + // to run in a transaction. + + // + // + // + + // bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but + // are locking the database with *sql.Conn. If the caller sets max open connections to 1, then + // this will deadlock because the Go migration will try to acquire a connection from the pool, + // but the pool is locked. + // + // A potential solution is to expose a third Go register function *sql.Conn. Or continue to use + // *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of + // an edge case. if p.opt.LockMode != LockModeNone && p.db.Stats().MaxOpenConnections == 1 { + // for _, m := range apply { + // if m.IsGo() && !m.Go.UseTx { + // return nil, errors.New("potential deadlock detected: cannot run GoMigrationNoTx with max open connections set to 1") + // } + // } + // } + + // Run migrations individually, opening a new transaction for each migration if the migration is + // safe to run in a transaction. + + // Avoid allocating a slice because we may have a partial migration error. 1. Avoid giving the + // impression that N migrations were applied when in fact some were not 2. Avoid the caller + // having to check for nil results + var results []*MigrationResult + for _, m := range apply { + current := &MigrationResult{ + Source: m.Source, + Direction: strings.ToLower(direction.String()), + // TODO(mf): empty set here } - return m.SQL.run(ctx, tx, direction) - case TypeGo: - return m.Go.run(ctx, tx, direction) + + start := time.Now() + if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil { + // TODO(mf): we should also return the pending migrations here. + current.Error = err + current.Duration = time.Since(start) + return nil, &PartialError{ + Applied: results, + Failed: current, + Err: err, + } + } + + current.Duration = time.Since(start) + results = append(results, current) } - // This should never happen. - return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) + return results, nil } -// RunNoTx runs the migration without a transaction. -func (m *migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error { +// runIndividually runs an individual migration, opening a new transaction if the migration is safe +// to run in a transaction. Otherwise, it runs the migration outside of a transaction with the +// supplied connection. +func (p *Provider) runIndividually( + ctx context.Context, + conn *sql.Conn, + direction bool, + m *migration, +) error { + if m.useTx(direction) { + // Run the migration in a transaction. + return p.beginTx(ctx, conn, func(tx *sql.Tx) error { + if err := m.run(ctx, tx, direction); err != nil { + return err + } + if p.cfg.noVersioning { + return nil + } + return p.store.InsertOrDelete(ctx, tx, direction, m.Source.Version) + }) + } + // Run the migration outside of a transaction. switch m.Source.Type { + case TypeGo: + // Note, we're using *sql.DB instead of *sql.Conn because it's the contract of the + // GoMigrationNoTx function. This may be a deadlock scenario if the caller sets max open + // connections to 1. See the comment in runMigrations for more details. + if err := m.runNoTx(ctx, p.db, direction); err != nil { + return err + } case TypeSQL: - if m.SQL == nil { - return fmt.Errorf("db: sql migration has not been parsed") + if err := m.runConn(ctx, conn, direction); err != nil { + return err } - return m.SQL.run(ctx, db, direction) - case TypeGo: - return m.Go.runNoTx(ctx, db, direction) } - // This should never happen. - return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) + if p.cfg.noVersioning { + return nil + } + return p.store.InsertOrDelete(ctx, conn, direction, m.Source.Version) } -// RunConn runs the migration without a transaction using the provided connection. -func (m *migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error { - switch m.Source.Type { - case TypeSQL: - if m.SQL == nil { - return fmt.Errorf("conn: sql migration has not been parsed") +// beginTx begins a transaction and runs the given function. If the function returns an error, the +// transaction is rolled back. Otherwise, the transaction is committed. +// +// If the provider is configured to use versioning, this function also inserts or deletes the +// migration version. +func (p *Provider) beginTx( + ctx context.Context, + conn *sql.Conn, + fn func(tx *sql.Tx) error, +) (retErr error) { + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, tx.Rollback()) } - return m.SQL.run(ctx, conn, direction) - case TypeGo: - return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") + }() + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) { + p.mu.Lock() + conn, err := p.db.Conn(ctx) + if err != nil { + p.mu.Unlock() + return nil, nil, err + } + // cleanup is a function that cleans up the connection, and optionally, the session lock. + cleanup := func() error { + p.mu.Unlock() + return conn.Close() + } + if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled { + if err := l.SessionLock(ctx, conn); err != nil { + return nil, nil, multierr.Append(err, cleanup()) + } + cleanup = func() error { + p.mu.Unlock() + // Use a detached context to unlock the session. This is because the context passed to + // SessionLock may have been canceled, and we don't want to cancel the unlock. + // TODO(mf): use [context.WithoutCancel] added in go1.21 + detachedCtx := context.Background() + return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close()) + } + } + // If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't + // need the version table because there is no versioning. + if !p.cfg.noVersioning { + if err := p.ensureVersionTable(ctx, conn); err != nil { + return nil, nil, multierr.Append(err, cleanup()) + } + } + return conn, cleanup, nil +} + +// parseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it +// will not be parsed again. +// +// Important: This function will mutate SQL migrations and is not safe for concurrent use. +func parseSQL(fsys fs.FS, debug bool, migrations []*migration) error { + for _, m := range migrations { + // If the migration is a SQL migration, and it has not been parsed, parse it. + if m.Source.Type == TypeSQL && m.SQL == nil { + parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug) + if err != nil { + return err + } + m.SQL = &sqlMigration{ + UseTx: parsed.UseTx, + UpStatements: parsed.Up, + DownStatements: parsed.Down, + } + } + } + return nil +} + +func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) { + // feat(mf): this is where we can check if the version table exists instead of trying to fetch + // from a table that may not exist. https://github.com/pressly/goose/issues/461 + res, err := p.store.GetMigration(ctx, conn, 0) + if err == nil && res != nil { + return nil + } + return p.beginTx(ctx, conn, func(tx *sql.Tx) error { + if err := p.store.CreateVersionTable(ctx, tx); err != nil { + return err + } + if p.cfg.noVersioning { + return nil + } + return p.store.InsertOrDelete(ctx, tx, true, 0) + }) +} + +type missingMigration struct { + versionID int64 + filename string +} + +// findMissingMigrations returns a list of migrations that are missing from the database. A missing +// migration is one that has a version less than the max version in the database. +func findMissingMigrations( + dbMigrations []*sqladapter.ListMigrationsResult, + fsMigrations []*migration, + dbMaxVersion int64, +) []missingMigration { + existing := make(map[int64]bool) + for _, m := range dbMigrations { + existing[m.Version] = true + } + var missing []missingMigration + for _, m := range fsMigrations { + version := m.Source.Version + if !existing[version] && version < dbMaxVersion { + missing = append(missing, missingMigration{ + versionID: version, + filename: m.filename(), + }) + } + } + sort.Slice(missing, func(i, j int) bool { + return missing[i].versionID < missing[j].versionID + }) + return missing +} + +// getMigration returns the migration with the given version. If no migration is found, then +// ErrVersionNotFound is returned. +func (p *Provider) getMigration(version int64) (*migration, error) { + for _, m := range p.migrations { + if m.Source.Version == version { + return m, nil + } + } + return nil, ErrVersionNotFound +} + +func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) { + if version < 1 { + return nil, errors.New("version must be greater than zero") + } + + m, err := p.getMigration(version) + if err != nil { + return nil, err + } + + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + result, err := p.store.GetMigration(ctx, conn, version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + // If the migration has already been applied, return an error, unless the migration is being + // applied in the opposite direction. In that case, we allow the migration to be applied again. + if result != nil && direction { + return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) + } + + d := sqlparser.DirectionDown + if direction { + d = sqlparser.DirectionUp + } + results, err := p.runMigrations(ctx, conn, []*migration{m}, d, true) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) + } + return results[0], nil +} + +func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + // TODO(mf): add support for limit and order. Also would be nice to refactor the list query to + // support limiting the set. + + status := make([]*MigrationStatus, 0, len(p.migrations)) + for _, m := range p.migrations { + migrationStatus := &MigrationStatus{ + Source: m.Source, + State: StatePending, + } + dbResult, err := p.store.GetMigration(ctx, conn, m.Source.Version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + if dbResult != nil { + migrationStatus.State = StateApplied + migrationStatus.AppliedAt = dbResult.Timestamp + } + status = append(status, migrationStatus) + } + + return status, nil +} + +func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return 0, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + res, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return 0, err + } + if len(res) == 0 { + return 0, nil } - // This should never happen. - return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) + return res[0].Version, nil } diff --git a/internal/provider/run_down.go b/internal/provider/run_down.go new file mode 100644 index 000000000..011ba7990 --- /dev/null +++ b/internal/provider/run_down.go @@ -0,0 +1,53 @@ +package provider + +import ( + "context" + + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" +) + +func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + if len(p.migrations) == 0 { + return nil, nil + } + + if p.cfg.noVersioning { + var downMigrations []*migration + if downByOne { + downMigrations = append(downMigrations, p.migrations[len(p.migrations)-1]) + } else { + downMigrations = p.migrations + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) + } + + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + if dbMigrations[0].Version == 0 { + return nil, nil + } + + var downMigrations []*migration + for _, dbMigration := range dbMigrations { + if dbMigration.Version <= version { + break + } + m, err := p.getMigration(dbMigration.Version) + if err != nil { + return nil, err + } + downMigrations = append(downMigrations, m) + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) +} diff --git a/internal/provider/run_test.go b/internal/provider/run_test.go new file mode 100644 index 000000000..97e86ed21 --- /dev/null +++ b/internal/provider/run_test.go @@ -0,0 +1,1282 @@ +package provider_test + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io/fs" + "math" + "math/rand" + "os" + "path/filepath" + "reflect" + "sort" + "sync" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/provider" + "github.com/pressly/goose/v3/internal/testdb" + "github.com/pressly/goose/v3/lock" + "golang.org/x/sync/errgroup" +) + +func TestProviderRun(t *testing.T) { + t.Parallel() + + t.Run("closed_db", func(t *testing.T) { + p, db := newProviderWithDB(t) + check.NoError(t, db.Close()) + _, err := p.Up(context.Background()) + check.HasError(t, err) + check.Equal(t, err.Error(), "sql: database is closed") + }) + t.Run("ping_and_close", func(t *testing.T) { + p, _ := newProviderWithDB(t) + t.Cleanup(func() { + check.NoError(t, p.Close()) + }) + check.NoError(t, p.Ping(context.Background())) + }) + t.Run("apply_unknown_version", func(t *testing.T) { + p, _ := newProviderWithDB(t) + _, err := p.ApplyVersion(context.Background(), 999, true) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true) + _, err = p.ApplyVersion(context.Background(), 999, false) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true) + }) + t.Run("run_zero", func(t *testing.T) { + p, _ := newProviderWithDB(t) + _, err := p.UpTo(context.Background(), 0) + check.HasError(t, err) + check.Equal(t, err.Error(), "version must be greater than zero") + _, err = p.DownTo(context.Background(), -1) + check.HasError(t, err) + check.Equal(t, err.Error(), "version must be a number greater than or equal zero: -1") + _, err = p.ApplyVersion(context.Background(), 0, true) + check.HasError(t, err) + check.Equal(t, err.Error(), "version must be greater than zero") + }) + t.Run("up_and_down_all", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + const ( + numCount = 7 + ) + sources := p.ListSources() + check.Number(t, len(sources), numCount) + // Ensure only SQL migrations are returned + for _, s := range sources { + check.Equal(t, s.Type, provider.TypeSQL) + } + // Test Up + res, err := p.Up(ctx) + check.NoError(t, err) + check.Number(t, len(res), numCount) + assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") + assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up") + assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up") + assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up") + assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up") + assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up") + assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up") + // Test Down + res, err = p.DownTo(ctx, 0) + check.NoError(t, err) + check.Number(t, len(res), numCount) + assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down") + assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down") + assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down") + assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down") + assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down") + assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down") + assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down") + }) + t.Run("up_and_down_by_one", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + maxVersion := len(p.ListSources()) + // Apply all migrations one-by-one. + var counter int + for { + res, err := p.UpByOne(ctx) + counter++ + if counter > maxVersion { + if !errors.Is(err, provider.ErrNoNextVersion) { + t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion) + } + break + } + check.NoError(t, err) + check.Number(t, len(res), 1) + check.Number(t, res[0].Source.Version, int64(counter)) + } + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, int64(maxVersion)) + // Reset counter + counter = 0 + // Rollback all migrations one-by-one. + for { + res, err := p.Down(ctx) + counter++ + if counter > maxVersion { + if !errors.Is(err, provider.ErrNoNextVersion) { + t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion) + } + break + } + check.NoError(t, err) + check.Number(t, len(res), 1) + check.Number(t, res[0].Source.Version, int64(maxVersion-counter+1)) + } + // Once everything is tested the version should match the highest testdata version + currentVersion, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + }) + t.Run("up_to", func(t *testing.T) { + ctx := context.Background() + p, db := newProviderWithDB(t) + const ( + upToVersion int64 = 2 + ) + results, err := p.UpTo(ctx, upToVersion) + check.NoError(t, err) + check.Number(t, len(results), upToVersion) + assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") + assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up") + // Fetch the goose version from DB + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, upToVersion) + // Validate the version actually matches what goose claims it is + gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, gotVersion, upToVersion) + }) + t.Run("sql_connections", func(t *testing.T) { + tt := []struct { + name string + maxOpenConns int + maxIdleConns int + useDefaults bool + }{ + // Single connection ensures goose is able to function correctly when multiple + // connections are not available. + {name: "single_conn", maxOpenConns: 1, maxIdleConns: 1}, + {name: "defaults", useDefaults: true}, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + // Start a new database for each test case. + p, db := newProviderWithDB(t) + if !tc.useDefaults { + db.SetMaxOpenConns(tc.maxOpenConns) + db.SetMaxIdleConns(tc.maxIdleConns) + } + sources := p.ListSources() + check.NumberNotZero(t, len(sources)) + + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + + { + // Apply all up migrations + upResult, err := p.Up(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), len(sources)) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, p.ListSources()[len(sources)-1].Version) + // Validate the db migration version actually matches what goose claims it is + gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, gotVersion, currentVersion) + tables, err := getTableNames(db) + check.NoError(t, err) + if !reflect.DeepEqual(tables, knownTables) { + t.Logf("got tables: %v", tables) + t.Logf("known tables: %v", knownTables) + t.Fatal("failed to match tables") + } + } + { + // Apply all down migrations + downResult, err := p.DownTo(ctx, 0) + check.NoError(t, err) + check.Number(t, len(downResult), len(sources)) + gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, gotVersion, 0) + // Should only be left with a single table, the default goose table + tables, err := getTableNames(db) + check.NoError(t, err) + knownTables := []string{provider.DefaultTablename, "sqlite_sequence"} + if !reflect.DeepEqual(tables, knownTables) { + t.Logf("got tables: %v", tables) + t.Logf("known tables: %v", knownTables) + t.Fatal("failed to match tables") + } + } + }) + } + }) + t.Run("apply", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + sources := p.ListSources() + // Apply all migrations in the up direction. + for _, s := range sources { + res, err := p.ApplyVersion(ctx, s.Version, true) + check.NoError(t, err) + // Round-trip the migration result through the database to ensure it's valid. + assertResult(t, res, s, "up") + } + // Apply all migrations in the down direction. + for i := len(sources) - 1; i >= 0; i-- { + s := sources[i] + res, err := p.ApplyVersion(ctx, s.Version, false) + check.NoError(t, err) + // Round-trip the migration result through the database to ensure it's valid. + assertResult(t, res, s, "down") + } + // Try apply version 1 multiple times + _, err := p.ApplyVersion(ctx, 1, true) + check.NoError(t, err) + _, err = p.ApplyVersion(ctx, 1, true) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrAlreadyApplied), true) + check.Contains(t, err.Error(), "version 1: already applied") + }) + t.Run("status", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + numCount := len(p.ListSources()) + // Before any migrations are applied, the status should be empty. + status, err := p.Status(ctx) + check.NoError(t, err) + check.Number(t, len(status), numCount) + assertStatus(t, status[0], provider.StatePending, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), true) + assertStatus(t, status[1], provider.StatePending, provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), true) + assertStatus(t, status[2], provider.StatePending, provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), true) + assertStatus(t, status[3], provider.StatePending, provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), true) + assertStatus(t, status[4], provider.StatePending, provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), true) + assertStatus(t, status[5], provider.StatePending, provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), true) + assertStatus(t, status[6], provider.StatePending, provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), true) + // Apply all migrations + _, err = p.Up(ctx) + check.NoError(t, err) + status, err = p.Status(ctx) + check.NoError(t, err) + check.Number(t, len(status), numCount) + assertStatus(t, status[0], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), false) + assertStatus(t, status[1], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), false) + assertStatus(t, status[2], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), false) + assertStatus(t, status[3], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), false) + assertStatus(t, status[4], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), false) + assertStatus(t, status[5], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), false) + assertStatus(t, status[6], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), false) + }) + t.Run("tx_partial_errors", func(t *testing.T) { + countOwners := func(db *sql.DB) (int, error) { + q := `SELECT count(*)FROM owners` + var count int + if err := db.QueryRow(q).Scan(&count); err != nil { + return 0, err + } + return count, nil + } + + ctx := context.Background() + db := newDB(t) + mapFS := fstest.MapFS{ + "00001_users_table.sql": newMapFile(` +-- +goose Up +CREATE TABLE owners ( owner_name TEXT NOT NULL ); +`), + "00002_partial_error.sql": newMapFile(` +-- +goose Up +INSERT INTO invalid_table (invalid_table) VALUES ('invalid_value'); +`), + "00003_insert_data.sql": newMapFile(` +-- +goose Up +INSERT INTO owners (owner_name) VALUES ('seed-user-1'); +INSERT INTO owners (owner_name) VALUES ('seed-user-2'); +INSERT INTO owners (owner_name) VALUES ('seed-user-3'); +`), + } + p, err := provider.NewProvider(provider.DialectSQLite3, db, mapFS) + check.NoError(t, err) + _, err = p.Up(ctx) + check.HasError(t, err) + check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)") + var expected *provider.PartialError + check.Bool(t, errors.As(err, &expected), true) + // Check Err field + check.Bool(t, expected.Err != nil, true) + check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)") + // Check Results field + check.Number(t, len(expected.Applied), 1) + assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") + // Check Failed field + check.Bool(t, expected.Failed != nil, true) + assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2) + check.Bool(t, expected.Failed.Empty, false) + check.Bool(t, expected.Failed.Error != nil, true) + check.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)") + check.Equal(t, expected.Failed.Direction, "up") + check.Bool(t, expected.Failed.Duration > 0, true) + + // Ensure the partial error did not affect the database. + count, err := countOwners(db) + check.NoError(t, err) + check.Number(t, count, 0) + + status, err := p.Status(ctx) + check.NoError(t, err) + check.Number(t, len(status), 3) + assertStatus(t, status[0], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), false) + assertStatus(t, status[1], provider.StatePending, provider.NewSource(provider.TypeSQL, "00002_partial_error.sql", 2), true) + assertStatus(t, status[2], provider.StatePending, provider.NewSource(provider.TypeSQL, "00003_insert_data.sql", 3), true) + }) +} + +func TestConcurrentProvider(t *testing.T) { + t.Parallel() + + t.Run("up", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + maxVersion := len(p.ListSources()) + + ch := make(chan int64) + var wg sync.WaitGroup + for i := 0; i < maxVersion; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + res, err := p.UpByOne(ctx) + if err != nil { + t.Error(err) + return + } + if len(res) != 1 { + t.Errorf("expected 1 result, got %d", len(res)) + return + } + ch <- res[0].Source.Version + }() + } + go func() { + wg.Wait() + close(ch) + }() + var versions []int64 + for version := range ch { + versions = append(versions, version) + } + // Fail early if any of the goroutines failed. + if t.Failed() { + return + } + check.Number(t, len(versions), maxVersion) + for i := 0; i < maxVersion; i++ { + check.Number(t, versions[i], int64(i+1)) + } + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + }) + t.Run("down", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + maxVersion := len(p.ListSources()) + // Apply all migrations + _, err := p.Up(ctx) + check.NoError(t, err) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + + ch := make(chan []*provider.MigrationResult) + var wg sync.WaitGroup + for i := 0; i < maxVersion; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + res, err := p.DownTo(ctx, 0) + if err != nil { + t.Error(err) + return + } + ch <- res + }() + } + go func() { + wg.Wait() + close(ch) + }() + var ( + valid [][]*provider.MigrationResult + empty [][]*provider.MigrationResult + ) + for results := range ch { + if len(results) == 0 { + empty = append(empty, results) + continue + } + valid = append(valid, results) + } + // Fail early if any of the goroutines failed. + if t.Failed() { + return + } + check.Equal(t, len(valid), 1) + check.Equal(t, len(empty), maxVersion-1) + // Ensure the valid result is correct. + check.Number(t, len(valid[0]), maxVersion) + }) +} + +func TestNoVersioning(t *testing.T) { + t.Parallel() + + countSeedOwners := func(db *sql.DB) (int, error) { + q := `SELECT count(*)FROM owners WHERE owner_name LIKE'seed-user-%'` + var count int + if err := db.QueryRow(q).Scan(&count); err != nil { + return 0, err + } + return count, nil + } + countOwners := func(db *sql.DB) (int, error) { + q := `SELECT count(*)FROM owners` + var count int + if err := db.QueryRow(q).Scan(&count); err != nil { + return 0, err + } + return count, nil + } + ctx := context.Background() + dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8)) + db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName)) + check.NoError(t, err) + fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "migrations")) + const ( + // Total owners created by the seed files. + wantSeedOwnerCount = 250 + // These are owners created by migration files. + wantOwnerCount = 4 + ) + p, err := provider.NewProvider(provider.DialectSQLite3, db, fsys, + provider.WithVerbose(testing.Verbose()), + provider.WithNoVersioning(false), // This is the default. + ) + check.Number(t, len(p.ListSources()), 3) + check.NoError(t, err) + _, err = p.Up(ctx) + check.NoError(t, err) + baseVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, baseVersion, 3) + t.Run("seed-up-down-to-zero", func(t *testing.T) { + fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed")) + p, err := provider.NewProvider(provider.DialectSQLite3, db, fsys, + provider.WithVerbose(testing.Verbose()), + provider.WithNoVersioning(true), // Provider with no versioning. + ) + check.NoError(t, err) + check.Number(t, len(p.ListSources()), 2) + + // Run (all) up migrations from the seed dir + { + upResult, err := p.Up(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), 2) + // Confirm no changes to the versioned schema in the DB + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, baseVersion, currentVersion) + seedOwnerCount, err := countSeedOwners(db) + check.NoError(t, err) + check.Number(t, seedOwnerCount, wantSeedOwnerCount) + } + // Run (all) down migrations from the seed dir + { + downResult, err := p.DownTo(ctx, 0) + check.NoError(t, err) + check.Number(t, len(downResult), 2) + // Confirm no changes to the versioned schema in the DB + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, baseVersion, currentVersion) + seedOwnerCount, err := countSeedOwners(db) + check.NoError(t, err) + check.Number(t, seedOwnerCount, 0) + } + // The migrations added 4 non-seed owners, they must remain in the database afterwards + ownerCount, err := countOwners(db) + check.NoError(t, err) + check.Number(t, ownerCount, wantOwnerCount) + }) +} + +func TestAllowMissing(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Developer A and B check out the "main" branch which is currently on version 3. Developer A + // mistakenly creates migration 5 and commits. Developer B did not pull the latest changes and + // commits migration 4. Oops -- now the migrations are out of order. + // + // When goose is set to allow missing migrations, then 5 is applied after 4 with no error. + // Otherwise it's expected to be an error. + + t.Run("missing_now_allowed", func(t *testing.T) { + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), + provider.WithAllowMissing(false), + ) + check.NoError(t, err) + + // Create and apply first 3 migrations. + _, err = p.UpTo(ctx, 3) + check.NoError(t, err) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 3) + + // Developer A - migration 5 (mistakenly applied) + result, err := p.ApplyVersion(ctx, 5, true) + check.NoError(t, err) + check.Number(t, result.Source.Version, 5) + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + + // The database has migrations 1,2,3,5 applied. + + // Developer B is on version 3 (e.g., never pulled the latest changes). Adds migration 4. By + // default goose does not allow missing (out-of-order) migrations, which means halt if a + // missing migration is detected. + _, err = p.Up(ctx) + check.HasError(t, err) + // found 1 missing (out-of-order) migration: [00004_insert_data.sql] + check.Contains(t, err.Error(), "missing (out-of-order) migration") + // Confirm db version is unchanged. + current, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + + _, err = p.UpByOne(ctx) + check.HasError(t, err) + // found 1 missing (out-of-order) migration: [00004_insert_data.sql] + check.Contains(t, err.Error(), "missing (out-of-order) migration") + // Confirm db version is unchanged. + current, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + + _, err = p.UpTo(ctx, math.MaxInt64) + check.HasError(t, err) + // found 1 missing (out-of-order) migration: [00004_insert_data.sql] + check.Contains(t, err.Error(), "missing (out-of-order) migration") + // Confirm db version is unchanged. + current, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + }) + + t.Run("missing_allowed", func(t *testing.T) { + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), + provider.WithAllowMissing(true), + ) + check.NoError(t, err) + + // Create and apply first 3 migrations. + _, err = p.UpTo(ctx, 3) + check.NoError(t, err) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 3) + + // Developer A - migration 5 (mistakenly applied) + { + _, err = p.ApplyVersion(ctx, 5, true) + check.NoError(t, err) + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + } + // Developer B - migration 4 (missing) and 6 (new) + { + // 4 + upResult, err := p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), 1) + check.Number(t, upResult[0].Source.Version, 4) + // 6 + upResult, err = p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), 1) + check.Number(t, upResult[0].Source.Version, 6) + + count, err := getGooseVersionCount(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, count, 6) + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + // Expecting max(version_id) to be 8 + check.Number(t, current, 6) + } + + // The applied order in the database is expected to be: + // 1,2,3,5,4,6 + // So migrating down should be the reverse of the applied order: + // 6,4,5,3,2,1 + + expected := []int64{6, 4, 5, 3, 2, 1} + for i, v := range expected { + // TODO(mf): this is returning it by the order it was applied. + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, v) + downResult, err := p.Down(ctx) + if i == len(expected)-1 { + check.HasError(t, provider.ErrVersionNotFound) + } else { + check.NoError(t, err) + check.Number(t, len(downResult), 1) + check.Number(t, downResult[0].Source.Version, v) + } + } + }) +} + +func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) { + var gotVersion int64 + if err := db.QueryRow( + fmt.Sprintf("SELECT count(*) FROM %s WHERE version_id > 0", gooseTable), + ).Scan(&gotVersion); err != nil { + return 0, err + } + return gotVersion, nil +} + +func TestGoOnly(t *testing.T) { + // Not parallel because it modifies global state. + + countUser := func(db *sql.DB) int { + q := `SELECT count(*)FROM users` + var count int + err := db.QueryRow(q).Scan(&count) + check.NoError(t, err) + return count + } + + t.Run("with_tx", func(t *testing.T) { + ctx := context.Background() + register := []*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), + DownFnContext: newTxFn("DROP TABLE users"), + }, + } + err := provider.SetGlobalGoMigrations(register) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration( + 2, + &provider.GoMigration{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, + &provider.GoMigration{Run: newTxFn("DELETE FROM users")}, + ), + ) + check.NoError(t, err) + sources := p.ListSources() + check.Number(t, len(p.ListSources()), 2) + assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1) + assertSource(t, sources[1], provider.TypeGo, "", 2) + // Apply migration 1 + res, err := p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up") + check.Number(t, countUser(db), 0) + check.Bool(t, tableExists(t, db, "users"), true) + // Apply migration 2 + res, err = p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up") + check.Number(t, countUser(db), 3) + // Rollback migration 2 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down") + check.Number(t, countUser(db), 0) + // Rollback migration 1 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down") + // Check table does not exist + check.Bool(t, tableExists(t, db, "users"), false) + }) + t.Run("with_db", func(t *testing.T) { + ctx := context.Background() + register := []*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), + DownFnNoTxContext: newDBFn("DROP TABLE users"), + }, + } + err := provider.SetGlobalGoMigrations(register) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration( + 2, + &provider.GoMigration{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, + &provider.GoMigration{RunNoTx: newDBFn("DELETE FROM users")}, + ), + ) + check.NoError(t, err) + sources := p.ListSources() + check.Number(t, len(p.ListSources()), 2) + assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1) + assertSource(t, sources[1], provider.TypeGo, "", 2) + // Apply migration 1 + res, err := p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up") + check.Number(t, countUser(db), 0) + check.Bool(t, tableExists(t, db, "users"), true) + // Apply migration 2 + res, err = p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up") + check.Number(t, countUser(db), 3) + // Rollback migration 2 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down") + check.Number(t, countUser(db), 0) + // Rollback migration 1 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down") + // Check table does not exist + check.Bool(t, tableExists(t, db, "users"), false) + }) +} + +func TestLockModeAdvisorySession(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("skip long running test") + } + + // The migrations are written in such a way that they cannot be applied concurrently, they will + // fail 99.9999% of the time. This test ensures that the advisory session lock mode works as + // expected. + + // TODO(mf): small improvement here is to use the SAME postgres instance but different databases + // created from a template. This will speed up the test. + + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + + newProvider := func() *provider.Provider { + sessionLocker, err := lock.NewPostgresSessionLocker() + check.NoError(t, err) + p, err := provider.NewProvider(provider.DialectPostgres, db, os.DirFS("../../testdata/migrations"), + provider.WithSessionLocker(sessionLocker), // Use advisory session lock mode. + provider.WithVerbose(testing.Verbose()), + ) + check.NoError(t, err) + return p + } + provider1 := newProvider() + provider2 := newProvider() + + sources := provider1.ListSources() + maxVersion := sources[len(sources)-1].Version + + // Since the lock mode is advisory session, only one of these providers is expected to apply ALL + // the migrations. The other provider should apply NO migrations. The test MUST fail if both + // providers apply migrations. + + t.Run("up", func(t *testing.T) { + var g errgroup.Group + var res1, res2 int + g.Go(func() error { + ctx := context.Background() + results, err := provider1.Up(ctx) + check.NoError(t, err) + res1 = len(results) + currentVersion, err := provider1.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + return nil + }) + g.Go(func() error { + ctx := context.Background() + results, err := provider2.Up(ctx) + check.NoError(t, err) + res2 = len(results) + currentVersion, err := provider2.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + return nil + }) + check.NoError(t, g.Wait()) + // One of the providers should have applied all migrations and the other should have applied + // no migrations, but with no error. + if res1 == 0 && res2 == 0 { + t.Fatal("both providers applied no migrations") + } + if res1 > 0 && res2 > 0 { + t.Fatal("both providers applied migrations") + } + }) + + // Reset the database and run the same test with the advisory lock mode, but apply migrations + // one-by-one. + { + _, err := provider1.DownTo(context.Background(), 0) + check.NoError(t, err) + currentVersion, err := provider1.GetDBVersion(context.Background()) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + } + t.Run("up_by_one", func(t *testing.T) { + var g errgroup.Group + var ( + mu sync.Mutex + applied []int64 + ) + g.Go(func() error { + for { + results, err := provider1.UpByOne(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + check.NoError(t, err) + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + g.Go(func() error { + for { + results, err := provider2.UpByOne(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + check.NoError(t, err) + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + check.NoError(t, g.Wait()) + check.Number(t, len(applied), len(sources)) + sort.Slice(applied, func(i, j int) bool { + return applied[i] < applied[j] + }) + // Each migration should have been applied up exactly once. + for i := 0; i < len(sources); i++ { + check.Number(t, applied[i], sources[i].Version) + } + }) + + // Restore the database state by applying all migrations and run the same test with the advisory + // lock mode, but apply down migrations in parallel. + { + _, err := provider1.Up(context.Background()) + check.NoError(t, err) + currentVersion, err := provider1.GetDBVersion(context.Background()) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + } + + t.Run("down_to", func(t *testing.T) { + var g errgroup.Group + var res1, res2 int + g.Go(func() error { + ctx := context.Background() + results, err := provider1.DownTo(ctx, 0) + check.NoError(t, err) + res1 = len(results) + currentVersion, err := provider1.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + return nil + }) + g.Go(func() error { + ctx := context.Background() + results, err := provider2.DownTo(ctx, 0) + check.NoError(t, err) + res2 = len(results) + currentVersion, err := provider2.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + return nil + }) + check.NoError(t, g.Wait()) + + if res1 == 0 && res2 == 0 { + t.Fatal("both providers applied no migrations") + } + if res1 > 0 && res2 > 0 { + t.Fatal("both providers applied migrations") + } + }) + + // Restore the database state by applying all migrations and run the same test with the advisory + // lock mode, but apply down migrations one-by-one. + { + _, err := provider1.Up(context.Background()) + check.NoError(t, err) + currentVersion, err := provider1.GetDBVersion(context.Background()) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + } + + t.Run("down_by_one", func(t *testing.T) { + var g errgroup.Group + var ( + mu sync.Mutex + applied []int64 + ) + g.Go(func() error { + for { + results, err := provider1.Down(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + check.NoError(t, err) + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + g.Go(func() error { + for { + results, err := provider2.Down(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + check.NoError(t, err) + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + check.NoError(t, g.Wait()) + check.Number(t, len(applied), len(sources)) + sort.Slice(applied, func(i, j int) bool { + return applied[i] < applied[j] + }) + // Each migration should have been applied down exactly once. Since this is sequential the + // applied down migrations should be in reverse order. + for i := len(sources) - 1; i >= 0; i-- { + check.Number(t, applied[i], sources[i].Version) + } + }) +} + +func newDBFn(query string) func(context.Context, *sql.DB) error { + return func(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, query) + return err + } +} + +func newTxFn(query string) func(context.Context, *sql.Tx) error { + return func(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, query) + return err + } +} + +func tableExists(t *testing.T, db *sql.DB, table string) bool { + q := fmt.Sprintf(`SELECT CASE WHEN COUNT(*) > 0 THEN 1 ELSE 0 END AS table_exists FROM sqlite_master WHERE type = 'table' AND name = '%s'`, table) + var b string + err := db.QueryRow(q).Scan(&b) + check.NoError(t, err) + return b == "1" +} + +const ( + charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +) + +func randomAlphaNumeric(length int) string { + b := make([]byte, length) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] + } + return string(b) +} + +func newProviderWithDB(t *testing.T, opts ...provider.ProviderOption) (*provider.Provider, *sql.DB) { + t.Helper() + db := newDB(t) + opts = append( + opts, + provider.WithVerbose(testing.Verbose()), + ) + p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), opts...) + check.NoError(t, err) + return p, db +} + +func newDB(t *testing.T) *sql.DB { + t.Helper() + dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8)) + db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName)) + check.NoError(t, err) + return db +} + +func getMaxVersionID(db *sql.DB, gooseTable string) (int64, error) { + var gotVersion int64 + if err := db.QueryRow( + fmt.Sprintf("select max(version_id) from %s", gooseTable), + ).Scan(&gotVersion); err != nil { + return 0, err + } + return gotVersion, nil +} + +func getTableNames(db *sql.DB) ([]string, error) { + rows, err := db.Query(`SELECT name FROM sqlite_master WHERE type='table' ORDER BY name`) + if err != nil { + return nil, err + } + defer rows.Close() + var tables []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + tables = append(tables, name) + } + if err := rows.Err(); err != nil { + return nil, err + } + return tables, nil +} + +func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.State, source provider.Source, appliedIsZero bool) { + t.Helper() + check.Equal(t, got.State, state) + check.Equal(t, got.Source, source) + check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero) +} + +func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string) { + t.Helper() + check.Equal(t, got.Source, source) + check.Equal(t, got.Direction, direction) + check.Equal(t, got.Empty, false) + check.Bool(t, got.Error == nil, true) + check.Bool(t, got.Duration > 0, true) +} + +func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) { + t.Helper() + check.Equal(t, got.Type, typ) + check.Equal(t, got.Fullpath, name) + check.Equal(t, got.Version, version) + switch got.Type { + case provider.TypeGo: + check.Equal(t, got.Type.String(), "go") + case provider.TypeSQL: + check.Equal(t, got.Type.String(), "sql") + } +} + +func newMapFile(data string) *fstest.MapFile { + return &fstest.MapFile{ + Data: []byte(data), + } +} + +func newFsys() fs.FS { + return fstest.MapFS{ + "00001_users_table.sql": newMapFile(runMigration1), + "00002_posts_table.sql": newMapFile(runMigration2), + "00003_comments_table.sql": newMapFile(runMigration3), + "00004_insert_data.sql": newMapFile(runMigration4), + "00005_posts_view.sql": newMapFile(runMigration5), + "00006_empty_up.sql": newMapFile(runMigration6), + "00007_empty_up_down.sql": newMapFile(runMigration7), + } +} + +var ( + + // known tables are the tables (including goose table) created by running all migration files. + // If you add a table, make sure to add to this list and keep it in order. + knownTables = []string{ + "comments", + "goose_db_version", + "posts", + "sqlite_sequence", + "users", + } + + runMigration1 = ` +-- +goose Up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL, + email TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- +goose Down +DROP TABLE users; +` + + runMigration2 = ` +-- +goose Up +-- +goose StatementBegin +CREATE TABLE posts ( + id INTEGER PRIMARY KEY, + title TEXT NOT NULL, + content TEXT NOT NULL, + author_id INTEGER NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (author_id) REFERENCES users(id) +); +-- +goose StatementEnd +SELECT 1; +SELECT 2; + +-- +goose Down +DROP TABLE posts; +` + + runMigration3 = ` +-- +goose Up +CREATE TABLE comments ( + id INTEGER PRIMARY KEY, + post_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (post_id) REFERENCES posts(id), + FOREIGN KEY (user_id) REFERENCES users(id) +); + +-- +goose Down +DROP TABLE comments; +SELECT 1; +SELECT 2; +SELECT 3; +` + + runMigration4 = ` +-- +goose Up +INSERT INTO users (id, username, email) +VALUES + (1, 'john_doe', 'john@example.com'), + (2, 'jane_smith', 'jane@example.com'), + (3, 'alice_wonderland', 'alice@example.com'); + +INSERT INTO posts (id, title, content, author_id) +VALUES + (1, 'Introduction to SQL', 'SQL is a powerful language for managing databases...', 1), + (2, 'Data Modeling Techniques', 'Choosing the right data model is crucial...', 2), + (3, 'Advanced Query Optimization', 'Optimizing queries can greatly improve...', 1); + +INSERT INTO comments (id, post_id, user_id, content) +VALUES + (1, 1, 3, 'Great introduction! Looking forward to more.'), + (2, 1, 2, 'SQL can be a bit tricky at first, but practice helps.'), + (3, 2, 1, 'You covered normalization really well in this post.'); + +-- +goose Down +DELETE FROM comments; +DELETE FROM posts; +DELETE FROM users; +` + + runMigration5 = ` +-- +goose NO TRANSACTION + +-- +goose Up +CREATE VIEW posts_view AS + SELECT + p.id, + p.title, + p.content, + p.created_at, + u.username AS author + FROM posts p + JOIN users u ON p.author_id = u.id; + +-- +goose Down +DROP VIEW posts_view; +` + + runMigration6 = ` +-- +goose Up +` + + runMigration7 = ` +-- +goose Up +-- +goose Down +` +) diff --git a/internal/provider/run_up.go b/internal/provider/run_up.go new file mode 100644 index 000000000..7ee9c6c4f --- /dev/null +++ b/internal/provider/run_up.go @@ -0,0 +1,96 @@ +package provider + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" +) + +func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) { + if version < 1 { + return nil, errors.New("version must be greater than zero") + } + + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + if len(p.migrations) == 0 { + return nil, nil + } + if p.cfg.noVersioning { + // Short circuit if versioning is disabled and apply all migrations. + return p.runMigrations(ctx, conn, p.migrations, sqlparser.DirectionUp, upByOne) + } + + // optimize(mf): Listing all migrations from the database isn't great. This is only required to + // support the out-of-order (allow missing) feature. For users who don't use this feature, we + // could just query the database for the current version and then apply migrations that are + // greater than that version. + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + dbMaxVersion := dbMigrations[0].Version + // lookupAppliedInDB is a map of all applied migrations in the database. + lookupAppliedInDB := make(map[int64]bool) + for _, m := range dbMigrations { + lookupAppliedInDB[m.Version] = true + } + + missingMigrations := findMissingMigrations(dbMigrations, p.migrations, dbMaxVersion) + + // feature(mf): It is very possible someone may want to apply ONLY new migrations and skip + // missing migrations entirely. At the moment this is not supported, but leaving this comment + // because that's where that logic will be handled. + if len(missingMigrations) > 0 && !p.cfg.allowMissing { + var collected []string + for _, v := range missingMigrations { + collected = append(collected, v.filename) + } + msg := "migration" + if len(collected) > 1 { + msg += "s" + } + return nil, fmt.Errorf("found %d missing (out-of-order) %s: [%s]", + len(missingMigrations), msg, strings.Join(collected, ",")) + } + + var migrationsToApply []*migration + if p.cfg.allowMissing { + for _, v := range missingMigrations { + m, err := p.getMigration(v.versionID) + if err != nil { + return nil, err + } + migrationsToApply = append(migrationsToApply, m) + } + } + // filter all migrations with a version greater than the supplied version (min) and less than or + // equal to the requested version (max). + for _, m := range p.migrations { + if lookupAppliedInDB[m.Source.Version] { + continue + } + if m.Source.Version > dbMaxVersion && m.Source.Version <= version { + migrationsToApply = append(migrationsToApply, m) + } + } + + // feat(mf): this is where can (optionally) group multiple migrations to be run in a single + // transaction. The default is to apply each migration sequentially on its own. + // https://github.com/pressly/goose/issues/222 + // + // Note, we can't use a single transaction for all migrations because some may have to be run in + // their own transaction. + + return p.runMigrations(ctx, conn, migrationsToApply, sqlparser.DirectionUp, upByOne) +} diff --git a/internal/provider/testdata/no-versioning/migrations/00001_a.sql b/internal/provider/testdata/no-versioning/migrations/00001_a.sql new file mode 100644 index 000000000..839cb7a7b --- /dev/null +++ b/internal/provider/testdata/no-versioning/migrations/00001_a.sql @@ -0,0 +1,8 @@ +-- +goose Up +CREATE TABLE owners ( + owner_id INTEGER PRIMARY KEY AUTOINCREMENT, + owner_name TEXT NOT NULL +); + +-- +goose Down +DROP TABLE IF EXISTS owners; diff --git a/internal/provider/testdata/no-versioning/migrations/00002_b.sql b/internal/provider/testdata/no-versioning/migrations/00002_b.sql new file mode 100644 index 000000000..bd15ef51c --- /dev/null +++ b/internal/provider/testdata/no-versioning/migrations/00002_b.sql @@ -0,0 +1,9 @@ +-- +goose Up +-- +goose StatementBegin +INSERT INTO owners(owner_name) VALUES ('lucas'), ('ocean'); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DELETE FROM owners; +-- +goose StatementEnd diff --git a/internal/provider/testdata/no-versioning/migrations/00003_c.sql b/internal/provider/testdata/no-versioning/migrations/00003_c.sql new file mode 100644 index 000000000..422fb3068 --- /dev/null +++ b/internal/provider/testdata/no-versioning/migrations/00003_c.sql @@ -0,0 +1,9 @@ +-- +goose Up +-- +goose StatementBegin +INSERT INTO owners(owner_name) VALUES ('james'), ('space'); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DELETE FROM owners WHERE owner_name IN ('james', 'space'); +-- +goose StatementEnd diff --git a/internal/provider/testdata/no-versioning/seed/00001_a.sql b/internal/provider/testdata/no-versioning/seed/00001_a.sql new file mode 100644 index 000000000..64f9ff03c --- /dev/null +++ b/internal/provider/testdata/no-versioning/seed/00001_a.sql @@ -0,0 +1,17 @@ +-- +goose Up +-- +goose StatementBegin +-- Insert 100 owners. +INSERT INTO owners (owner_name) +WITH numbers AS ( + SELECT 1 AS n + UNION ALL + SELECT n + 1 FROM numbers WHERE n < 100 +) +SELECT 'seed-user-' || n FROM numbers; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +-- Delete the previously inserted data. +DELETE FROM owners WHERE owner_name LIKE 'seed-user-%'; +-- +goose StatementEnd diff --git a/internal/provider/testdata/no-versioning/seed/00002_b.sql b/internal/provider/testdata/no-versioning/seed/00002_b.sql new file mode 100644 index 000000000..aafe82752 --- /dev/null +++ b/internal/provider/testdata/no-versioning/seed/00002_b.sql @@ -0,0 +1,15 @@ +-- +goose Up + +-- Insert 150 more owners. +INSERT INTO owners (owner_name) +WITH numbers AS ( + SELECT 101 AS n + UNION ALL + SELECT n + 1 FROM numbers WHERE n < 250 +) +SELECT 'seed-user-' || n FROM numbers; + +-- +goose Down + +-- NOTE: there are 4 migration owners and 100 seed owners, that's why owner_id starts at 105 +DELETE FROM owners WHERE owner_name LIKE 'seed-user-%' AND owner_id BETWEEN 105 AND 254; diff --git a/internal/provider/types.go b/internal/provider/types.go new file mode 100644 index 000000000..21bb18beb --- /dev/null +++ b/internal/provider/types.go @@ -0,0 +1,99 @@ +package provider + +import ( + "fmt" + "time" +) + +// Dialect is the type of database dialect. +type Dialect string + +const ( + DialectClickHouse Dialect = "clickhouse" + DialectMSSQL Dialect = "mssql" + DialectMySQL Dialect = "mysql" + DialectPostgres Dialect = "postgres" + DialectRedshift Dialect = "redshift" + DialectSQLite3 Dialect = "sqlite3" + DialectTiDB Dialect = "tidb" + DialectVertica Dialect = "vertica" +) + +// MigrationType is the type of migration. +type MigrationType int + +const ( + TypeGo MigrationType = iota + 1 + TypeSQL +) + +func (t MigrationType) String() string { + switch t { + case TypeGo: + return "go" + case TypeSQL: + return "sql" + default: + // This should never happen. + return fmt.Sprintf("unknown (%d)", t) + } +} + +// Source represents a single migration source. +// +// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if +// the migration has a corresponding file on disk. It will be empty if the migration was registered +// manually. +type Source struct { + // Type is the type of migration. + Type MigrationType + // Full path to the migration file. + // + // Example: /path/to/migrations/001_create_users_table.sql + Fullpath string + // Version is the version of the migration. + Version int64 +} + +// MigrationResult is the result of a single migration operation. +// +// Note, the caller is responsible for checking the Error field for any errors that occurred while +// running the migration. If the Error field is not nil, the migration failed. +type MigrationResult struct { + Source Source + Duration time.Duration + Direction string + // Empty is true if the file was valid, but no statements to apply. These are still versioned + // migrations, but typically have no effect on the database. + // + // For SQL migrations, this means there was a valid .sql file but contained no statements. For + // Go migrations, this means the function was nil. + Empty bool + + // Error is any error that occurred while running the migration. + Error error +} + +// State represents the state of a migration. +type State string + +const ( + // StatePending represents a migration that is on the filesystem, but not in the database. + StatePending State = "pending" + // StateApplied represents a migration that is in BOTH the database and on the filesystem. + StateApplied State = "applied" + + // StateUntracked represents a migration that is in the database, but not on the filesystem. + // StateUntracked State = "untracked" +) + +// MigrationStatus represents the status of a single migration. +type MigrationStatus struct { + // State is the state of the migration. + State State + // AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or + // [StateUntracked]. + AppliedAt time.Time + // Source is the migration source. Only set if the state is [StatePending] or [StateApplied]. + Source Source +} diff --git a/internal/sqlparser/parse.go b/internal/sqlparser/parse.go index e993587a6..b42fdde14 100644 --- a/internal/sqlparser/parse.go +++ b/internal/sqlparser/parse.go @@ -1,6 +1,7 @@ package sqlparser import ( + "fmt" "io/fs" "go.uber.org/multierr" @@ -50,5 +51,9 @@ func parse(fsys fs.FS, filename string, direction Direction, debug bool) (_ []st defer func() { retErr = multierr.Append(retErr, r.Close()) }() - return ParseSQLMigration(r, direction, debug) + stmts, useTx, err := ParseSQLMigration(r, direction, debug) + if err != nil { + return nil, false, fmt.Errorf("failed to parse %s: %w", filename, err) + } + return stmts, useTx, nil } diff --git a/testdata/migrations/00002_posts_table.sql b/testdata/migrations/00002_posts_table.sql index 25648ed42..be70a2348 100644 --- a/testdata/migrations/00002_posts_table.sql +++ b/testdata/migrations/00002_posts_table.sql @@ -1,4 +1,5 @@ -- +goose Up +-- +goose StatementBegin CREATE TABLE posts ( id INTEGER PRIMARY KEY, title TEXT NOT NULL, @@ -7,6 +8,7 @@ CREATE TABLE posts ( created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (author_id) REFERENCES users(id) ); +-- +goose StatementEnd -- +goose Down DROP TABLE posts;