diff --git a/internal/migration_acceptance_tests/acceptance_test.go b/internal/migration_acceptance_tests/acceptance_test.go index 18ecdfc..a12ed04 100644 --- a/internal/migration_acceptance_tests/acceptance_test.go +++ b/internal/migration_acceptance_tests/acceptance_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "reflect" "testing" "github.com/google/uuid" @@ -42,7 +43,9 @@ type ( planFactory func(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (diff.Plan, error) acceptanceTestCase struct { - name string + name string + // roles is a list of roles that should be created before the DDL is applied + roles []string oldSchemaDDL []string newSchemaDDL []string @@ -88,12 +91,16 @@ func (suite *acceptanceTestSuite) runTestCases(acceptanceTestCases []acceptanceT suite.Run("vanilla", func() { suite.runSubtest(tc, tc.vanillaExpectations, nil) }) - suite.Run("with data packing (and ignoring column order)", func() { - suite.runSubtest(tc, tc.dataPackingExpectations, []diff.PlanOpt{ - diff.WithDataPackNewTables(), - diff.WithLogger(log.SimpleLogger()), + // Only run the data packing test if there are expectations for it. We should strip out the embedded + // data packing tests and make them their own tests to simplify the test suite. + if !reflect.DeepEqual(tc.dataPackingExpectations, expectations{}) { + suite.Run("with data packing (and ignoring column order)", func() { + suite.runSubtest(tc, tc.dataPackingExpectations, []diff.PlanOpt{ + diff.WithDataPackNewTables(), + diff.WithLogger(log.SimpleLogger()), + }) }) - }) + } }) } } @@ -106,6 +113,19 @@ func (suite *acceptanceTestSuite) runSubtest(tc acceptanceTestCase, expects expe expects.outputState = tc.newSchemaDDL } + // Create roles since they are global + rootDb, err := sql.Open("pgx", suite.pgEngine.GetPostgresDatabaseDSN()) + suite.Require().NoError(err) + defer rootDb.Close() + for _, r := range tc.roles { + _, err := rootDb.Exec(fmt.Sprintf("CREATE ROLE %s", r)) + suite.Require().NoError(err) + } + defer func() { + // This will drop the roles (and attempt to reset other cluster-level state) + suite.Require().NoError(pgengine.ResetInstance(context.Background(), rootDb)) + }() + // Apply old schema DDL to old DB oldDb, err := suite.pgEngine.CreateDatabase() suite.Require().NoError(err) diff --git a/internal/migration_acceptance_tests/backwards_compat_cases_test.go b/internal/migration_acceptance_tests/backwards_compat_cases_test.go index b5d8511..74f4cbd 100644 --- a/internal/migration_acceptance_tests/backwards_compat_cases_test.go +++ b/internal/migration_acceptance_tests/backwards_compat_cases_test.go @@ -177,6 +177,6 @@ var backCompatAcceptanceTestCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestBackCompatAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestBackCompatTestCases() { suite.runTestCases(backCompatAcceptanceTestCases) } diff --git a/internal/migration_acceptance_tests/check_constraint_cases_test.go b/internal/migration_acceptance_tests/check_constraint_cases_test.go index 2a2efdb..01f769b 100644 --- a/internal/migration_acceptance_tests/check_constraint_cases_test.go +++ b/internal/migration_acceptance_tests/check_constraint_cases_test.go @@ -568,6 +568,6 @@ var checkConstraintCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestCheckConstraintAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestCheckConstraintTestCases() { suite.runTestCases(checkConstraintCases) } diff --git a/internal/migration_acceptance_tests/column_cases_test.go b/internal/migration_acceptance_tests/column_cases_test.go index aa5464e..0334ef5 100644 --- a/internal/migration_acceptance_tests/column_cases_test.go +++ b/internal/migration_acceptance_tests/column_cases_test.go @@ -945,6 +945,6 @@ var columnAcceptanceTestCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestColumnAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestColumnTestCases() { suite.runTestCases(columnAcceptanceTestCases) } diff --git a/internal/migration_acceptance_tests/enum_cases_test.go b/internal/migration_acceptance_tests/enum_cases_test.go index 31b38ed..5c04707 100644 --- a/internal/migration_acceptance_tests/enum_cases_test.go +++ b/internal/migration_acceptance_tests/enum_cases_test.go @@ -126,6 +126,6 @@ var enumAcceptanceTestCases = []acceptanceTestCase{ }, } -func (s *acceptanceTestSuite) TestEnumTestCases() { - s.runTestCases(enumAcceptanceTestCases) +func (suite *acceptanceTestSuite) TestEnumTestCases() { + suite.runTestCases(enumAcceptanceTestCases) } diff --git a/internal/migration_acceptance_tests/extensions_cases_test.go b/internal/migration_acceptance_tests/extensions_cases_test.go index ac0d4ed..309a428 100644 --- a/internal/migration_acceptance_tests/extensions_cases_test.go +++ b/internal/migration_acceptance_tests/extensions_cases_test.go @@ -72,6 +72,6 @@ var extensionAcceptanceTestCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestExtensionAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestExtensionTestCases() { suite.runTestCases(extensionAcceptanceTestCases) } diff --git a/internal/migration_acceptance_tests/function_cases_test.go b/internal/migration_acceptance_tests/function_cases_test.go index 1648d20..2617fc2 100644 --- a/internal/migration_acceptance_tests/function_cases_test.go +++ b/internal/migration_acceptance_tests/function_cases_test.go @@ -577,6 +577,6 @@ var functionAcceptanceTestCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestFunctionAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestFunctionTestCases() { suite.runTestCases(functionAcceptanceTestCases) } diff --git a/internal/migration_acceptance_tests/index_cases_test.go b/internal/migration_acceptance_tests/index_cases_test.go index 79b701b..7b6ce3e 100644 --- a/internal/migration_acceptance_tests/index_cases_test.go +++ b/internal/migration_acceptance_tests/index_cases_test.go @@ -803,6 +803,6 @@ var indexAcceptanceTestCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestIndexAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestIndexTestCases() { suite.runTestCases(indexAcceptanceTestCases) } diff --git a/internal/migration_acceptance_tests/local_partition_index_cases_test.go b/internal/migration_acceptance_tests/local_partition_index_cases_test.go index 05ca11e..ef8c999 100644 --- a/internal/migration_acceptance_tests/local_partition_index_cases_test.go +++ b/internal/migration_acceptance_tests/local_partition_index_cases_test.go @@ -605,6 +605,6 @@ var localPartitionIndexAcceptanceTestCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestLocalPartitionIndexAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestLocalPartitionIndexTestCases() { suite.runTestCases(localPartitionIndexAcceptanceTestCases) } diff --git a/internal/migration_acceptance_tests/named_schema_cases_test.go b/internal/migration_acceptance_tests/named_schema_cases_test.go index 0d20bed..958c200 100644 --- a/internal/migration_acceptance_tests/named_schema_cases_test.go +++ b/internal/migration_acceptance_tests/named_schema_cases_test.go @@ -40,6 +40,6 @@ var namedSchemaAcceptanceTestCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestNamedSchemaAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestNamedSchemaTestCases() { suite.runTestCases(namedSchemaAcceptanceTestCases) } diff --git a/internal/migration_acceptance_tests/partitioned_index_cases_test.go b/internal/migration_acceptance_tests/partitioned_index_cases_test.go index a5d94e7..8bf93d8 100644 --- a/internal/migration_acceptance_tests/partitioned_index_cases_test.go +++ b/internal/migration_acceptance_tests/partitioned_index_cases_test.go @@ -1245,6 +1245,6 @@ var partitionedIndexAcceptanceTestCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestPartitionedIndexAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestPartitionedIndexTestCases() { suite.runTestCases(partitionedIndexAcceptanceTestCases) } diff --git a/internal/migration_acceptance_tests/partitioned_table_cases_test.go b/internal/migration_acceptance_tests/partitioned_table_cases_test.go index a873589..ac6f112 100644 --- a/internal/migration_acceptance_tests/partitioned_table_cases_test.go +++ b/internal/migration_acceptance_tests/partitioned_table_cases_test.go @@ -19,9 +19,13 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ UNIQUE (foo, bar) ) PARTITION BY LIST (foo); ALTER TABLE foobar REPLICA IDENTITY FULL; + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); ALTER TABLE foobar_1 REPLICA IDENTITY DEFAULT ; + ALTER TABLE foobar_1 ENABLE ROW LEVEL SECURITY; CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + ALTER TABLE foobar_2 FORCE ROW LEVEL SECURITY; CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); -- partitioned indexes CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar); @@ -59,10 +63,17 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ UNIQUE (foo, bar) ) PARTITION BY LIST (foo); ALTER TABLE foobar REPLICA IDENTITY FULL; + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; -- partitions + ALTER TABLE foobar REPLICA IDENTITY FULL; + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); ALTER TABLE foobar_1 REPLICA IDENTITY DEFAULT ; + ALTER TABLE foobar_1 ENABLE ROW LEVEL SECURITY; CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + ALTER TABLE foobar_2 FORCE ROW LEVEL SECURITY; CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); -- partitioned indexes CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar); @@ -98,7 +109,7 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ }, }, { - name: "Create partitioned table with shared primary key", + name: "Create partitioned table with shared primary key and RLS enabled globally", oldSchemaDDL: nil, newSchemaDDL: []string{ ` @@ -113,6 +124,9 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ UNIQUE (foo, bar) ) PARTITION BY LIST (foo); ALTER TABLE schema_1."Foobar" REPLICA IDENTITY FULL; + + ALTER TABLE schema_1."Foobar" ENABLE ROW LEVEL SECURITY; + ALTER TABLE schema_1."Foobar" FORCE ROW LEVEL SECURITY; -- partitions CREATE SCHEMA schema_2; @@ -145,6 +159,9 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ ) PARTITION BY LIST (foo); ALTER TABLE schema_1."Foobar" REPLICA IDENTITY FULL; + ALTER TABLE schema_1."Foobar" ENABLE ROW LEVEL SECURITY; + ALTER TABLE schema_1."Foobar" FORCE ROW LEVEL SECURITY; + -- partitions CREATE SCHEMA schema_2; CREATE TABLE schema_2."FOOBAR_1" PARTITION OF schema_1."Foobar"( @@ -165,7 +182,7 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ }, }, { - name: "Create partitioned table with local primary keys", + name: "Create partitioned table with local primary keys and RLS enabled locally", oldSchemaDDL: nil, newSchemaDDL: []string{ ` @@ -182,9 +199,11 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ bar NOT NULL, PRIMARY KEY (foo, id) ) FOR VALUES IN ('foo_1'); + ALTER TABLE "FOOBAR_1" ENABLE ROW LEVEL SECURITY; CREATE TABLE foobar_2 PARTITION OF "Foobar"( PRIMARY KEY (foo, bar) ) FOR VALUES IN ('foo_2'); + ALTER TABLE foobar_2 FORCE ROW LEVEL SECURITY; CREATE TABLE foobar_3 PARTITION OF "Foobar"( PRIMARY KEY (foo, fizz), UNIQUE (foo, bar) @@ -334,6 +353,80 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ diff.MigrationHazardTypeCorrectness, }, }, + { + name: "Enable RLS of parent and children", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + + -- partitions + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; + -- partitions + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + ALTER TABLE foobar_1 ENABLE ROW LEVEL SECURITY; + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + ALTER TABLE foobar_2 FORCE ROW LEVEL SECURITY; + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Disable RLS of parent and children", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; + -- partitions + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + ALTER TABLE foobar_1 ENABLE ROW LEVEL SECURITY; + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + ALTER TABLE foobar_2 FORCE ROW LEVEL SECURITY; + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT, + foo VARCHAR(255), + PRIMARY KEY (foo, id) + ) PARTITION BY LIST (foo); + + -- partitions + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('foo_1'); + CREATE TABLE foobar_2 PARTITION OF foobar FOR VALUES IN ('foo_2'); + CREATE TABLE foobar_3 PARTITION OF foobar FOR VALUES IN ('foo_3'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, { name: "Alter table: New primary key, new unique constraint, dropped unique constraint, change column types, delete partitioned index, new partitioned index, delete local index, add local index, validate check constraint, validate FK, delete FK", oldSchemaDDL: []string{ @@ -1248,6 +1341,6 @@ var partitionedTableAcceptanceTestCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestPartitionedTableAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestPartitionedTableTestCases() { suite.runTestCases(partitionedTableAcceptanceTestCases) } diff --git a/internal/migration_acceptance_tests/policy_cases_test.go b/internal/migration_acceptance_tests/policy_cases_test.go new file mode 100644 index 0000000..f9d5884 --- /dev/null +++ b/internal/migration_acceptance_tests/policy_cases_test.go @@ -0,0 +1,671 @@ +package migration_acceptance_tests + +import "github.com/stripe/pg-schema-diff/pkg/diff" + +var policyAcceptanceTestCases = []acceptanceTestCase{ + { + name: "no-op", + roles: []string{ + "role_1", + "role_2", + }, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR ALL + TO role_1, role_2 + USING (true) + WITH CHECK (true); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR ALL + TO role_1, role_2 + USING (true) + WITH CHECK (true); + `, + }, + vanillaExpectations: expectations{ + empty: true, + }, + }, + { + name: "Add permissive ALL policy target on non-public schema", + roles: []string{ + "role_1", + "role_2", + }, + oldSchemaDDL: []string{ + ` + CREATE SCHEMA schema_1; + CREATE TABLE schema_1.foobar(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE SCHEMA schema_1; + CREATE TABLE schema_1.foobar(); + CREATE POLICY foobar_policy ON schema_1.foobar + AS PERMISSIVE + FOR ALL + TO role_1, role_2 + USING (true) + WITH CHECK (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Create SELECT policy", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS RESTRICTIVE + FOR SELECT + TO PUBLIC + USING (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Create INSERT policy", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS RESTRICTIVE + FOR INSERT + TO PUBLIC + WITH CHECK (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Create UPDATE policy", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS RESTRICTIVE + FOR UPDATE + TO PUBLIC + WITH CHECK (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Create DELETE policy", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS RESTRICTIVE + FOR SELECT + TO PUBLIC + USING (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Add policy on new table", + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS RESTRICTIVE + FOR SELECT + TO PUBLIC + USING (true); + `, + }, + }, + { + name: "Add policy then enable RLS", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; + CREATE POLICY foobar_policy ON foobar + AS RESTRICTIVE + FOR SELECT + TO PUBLIC + USING (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + ddl: []string{ + // Ensure that the policy is created before enabling RLS. + "CREATE POLICY \"foobar_policy\" ON \"public\".\"foobar\"\n\tAS RESTRICTIVE\n\tFOR SELECT\n\tTO PUBLIC\n\tUSING (true)", + "ALTER TABLE \"public\".\"foobar\" ENABLE ROW LEVEL SECURITY", + "ALTER TABLE \"public\".\"foobar\" FORCE ROW LEVEL SECURITY", + }, + }, + { + name: "Drop non-public schema policy", + oldSchemaDDL: []string{ + ` + CREATE SCHEMA schema_1; + CREATE TABLE schema_1.foobar(); + CREATE POLICY foobar_policy ON schema_1.foobar + AS RESTRICTIVE + FOR SELECT + TO PUBLIC + USING (true); + `, + }, + newSchemaDDL: []string{ + ` + CREATE SCHEMA schema_1; + CREATE TABLE schema_1.foobar(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Drop policy and table", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS RESTRICTIVE + FOR SELECT + TO PUBLIC + USING (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + }, + { + name: "Disable RLS then drop policy", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS RESTRICTIVE + FOR SELECT + TO PUBLIC + USING (true); + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + ddl: []string{ + "ALTER TABLE \"public\".\"foobar\" DISABLE ROW LEVEL SECURITY", + "ALTER TABLE \"public\".\"foobar\" NO FORCE ROW LEVEL SECURITY", + "DROP POLICY \"foobar_policy\" ON \"public\".\"foobar\"", + }, + }, + { + name: "Drop policy and columns", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT, + val TEXT + ); + CREATE POLICY foobar_policy ON foobar + AS RESTRICTIVE + FOR SELECT + TO PUBLIC + USING (category = 'category' AND val = 'value'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Restrictive to permissive policy", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR SELECT + TO PUBLIC + USING (true); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS RESTRICTIVE + FOR SELECT + TO PUBLIC + USING (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Alter policy target", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR SELECT + TO PUBLIC + USING (true); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR DELETE + TO PUBLIC + USING (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Alter policy applies to", + roles: []string{"role_1", "role_2"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR SELECT + TO PUBLIC + USING (true); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR SELECT + TO role_1, role_2 + USING (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Alter policy using", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR SELECT + TO PUBLIC + USING (true); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR SELECT + TO PUBLIC + USING (false); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Alter policy check", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR INSERT + TO PUBLIC + WITH CHECK (true); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR INSERT + TO PUBLIC + WITH CHECK (false); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Remove using check for ALL policy", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR ALL + TO PUBLIC + USING (true) + WITH CHECK (true); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR ALL + TO PUBLIC + WITH CHECK (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Remove check for ALL policy", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR ALL + TO PUBLIC + USING (true) + WITH CHECK (true); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR ALL + TO PUBLIC + USING (true); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Alter all alterable attributes on non-public schema", + roles: []string{ + "role_1", + "role_2", + }, + oldSchemaDDL: []string{ + ` + CREATE SCHEMA schema_1; + CREATE TABLE schema_1.foobar(); + CREATE POLICY foobar_policy ON schema_1.foobar + AS PERMISSIVE + FOR ALL + TO role_1, role_2 + USING (true) + WITH CHECK (true); + `, + }, + newSchemaDDL: []string{ + ` + CREATE SCHEMA schema_1; + CREATE TABLE schema_1.foobar(); + CREATE POLICY foobar_policy ON schema_1.foobar + AS PERMISSIVE + FOR ALL + TO PUBLIC + USING (false) + WITH CHECK (false); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Alter policy that references deleted columns and new columns (non-public schema)", + oldSchemaDDL: []string{ + ` + CREATE SCHEMA schema_1; + CREATE TABLE schema_1.foobar( + category TEXT, + val TEXT + ); + CREATE POLICY foobar_policy ON schema_1.foobar + AS PERMISSIVE + FOR ALL + TO PUBLIC + USING (val = 'value' AND category = 'category') + WITH CHECK (val = 'value'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE SCHEMA schema_1; + CREATE TABLE schema_1.foobar( + category TEXT NOT NULL, + new_val TEXT + ); + CREATE POLICY foobar_policy ON schema_1.foobar + AS PERMISSIVE + FOR ALL + TO PUBLIC + USING (new_val = 'value' AND category = 'category') + WITH CHECK (new_val = 'value'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + diff.MigrationHazardTypeDeletesData, + }, + }, + { + name: "Re-create policy that references deleted columns and new columns (non-public schema)", + oldSchemaDDL: []string{ + ` + CREATE SCHEMA schema_1; + CREATE TABLE schema_1.foobar( + category TEXT, + val TEXT + ); + CREATE POLICY foobar_policy ON schema_1.foobar + AS PERMISSIVE + FOR ALL + TO PUBLIC + USING (val = 'value' AND category = 'category') + WITH CHECK (val = 'value'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE SCHEMA schema_1; + CREATE TABLE schema_1.foobar( + category TEXT NOT NULL, + new_val TEXT + ); + CREATE POLICY foobar_policy ON schema_1.foobar + AS RESTRICTIVE -- force-recreate the policy + FOR ALL + TO PUBLIC + USING (new_val = 'value' AND category = 'category') + WITH CHECK (new_val = 'value'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + diff.MigrationHazardTypeDeletesData, + }, + }, + { + name: "Alter policy (table is re-created)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT + ) partition by list (category); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR INSERT + TO PUBLIC + WITH CHECK (true); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT, + some_new_column TEXT + ); -- Re-create by removing partitioning + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR INSERT + TO PUBLIC + WITH CHECK (category = 'category' AND some_new_column = 'value'); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeDeletesData, + }, + }, + { + name: "Policy on new partition (not implemented)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT + ) partition by list (category); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT + ) partition by list (category); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('category'); + CREATE POLICY foobar_1_policy ON foobar_1 + AS PERMISSIVE + FOR INSERT + TO PUBLIC + WITH CHECK (true); + `, + }, + vanillaExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + dataPackingExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + }, + { + name: "Add policy on existing partition (not implemented)", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT + ) partition by list (category); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('category'); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + category TEXT + ) partition by list (category); + CREATE TABLE foobar_1 PARTITION OF foobar FOR VALUES IN ('category'); + CREATE POLICY foobar_1_policy ON foobar_1 + AS PERMISSIVE + FOR INSERT + TO PUBLIC + WITH CHECK (true); + `, + }, + vanillaExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + dataPackingExpectations: expectations{ + planErrorIs: diff.ErrNotImplemented, + }, + }, +} + +func (suite *acceptanceTestSuite) TestPolicyCases() { + suite.runTestCases(policyAcceptanceTestCases) +} diff --git a/internal/migration_acceptance_tests/schema_cases_test.go b/internal/migration_acceptance_tests/schema_cases_test.go index 7f5b4e5..efc275d 100644 --- a/internal/migration_acceptance_tests/schema_cases_test.go +++ b/internal/migration_acceptance_tests/schema_cases_test.go @@ -52,6 +52,8 @@ var schemaAcceptanceTests = []acceptanceTestCase{ PRIMARY KEY (foo, id), UNIQUE (foo, bar) ) PARTITION BY LIST(foo); + ALTER TABLE schema_1.foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE schema_1.foobar FORCE ROW LEVEL SECURITY; CREATE TABLE foobar_1 PARTITION of schema_1.foobar( fizz NOT NULL @@ -64,6 +66,8 @@ var schemaAcceptanceTests = []acceptanceTestCase{ -- local indexes CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz); + CREATE POLICY foobar_foo_policy ON schema_1.foobar FOR SELECT TO PUBLIC USING (foo = current_user); + CREATE table bar( id INT PRIMARY KEY, foo VARCHAR(255), @@ -122,6 +126,8 @@ var schemaAcceptanceTests = []acceptanceTestCase{ PRIMARY KEY (foo, id), UNIQUE (foo, bar) ) PARTITION BY LIST(foo); + ALTER TABLE schema_1.foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE schema_1.foobar FORCE ROW LEVEL SECURITY; CREATE TABLE foobar_1 PARTITION of schema_1.foobar( fizz NOT NULL @@ -134,6 +140,8 @@ var schemaAcceptanceTests = []acceptanceTestCase{ -- local indexes CREATE INDEX foobar_1_local_idx ON foobar_1(foo, fizz); + CREATE POLICY foobar_foo_policy ON schema_1.foobar FOR SELECT TO PUBLIC USING (foo = current_user); + CREATE table bar( id INT PRIMARY KEY, foo VARCHAR(255), @@ -156,7 +164,8 @@ var schemaAcceptanceTests = []acceptanceTestCase{ }, }, { - name: "Add schema, drop schema, Add enum, Drop enum, Drop table, Add Table, Drop Seq, Add Seq, Drop Funcs, Add Funcs, Drop Triggers, Add Triggers, Create Extension, Drop Extension, Create Index Using Extension", + name: "Add schema, drop schema, Add enum, Drop enum, Drop table, Add Table, Drop Seq, Add Seq, Drop Funcs, Add Funcs, Drop Triggers, Add Triggers, Create Extension, Drop Extension, Create Index Using Extension, Add policies, Drop policies", + roles: []string{"role_1"}, oldSchemaDDL: []string{ ` CREATE SCHEMA schema_1; @@ -212,6 +221,8 @@ var schemaAcceptanceTests = []acceptanceTestCase{ CREATE INDEX foobar_normal_idx ON foobar USING hash (fizz); CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, fizz DESC); + CREATE POLICY foobar_foo_policy ON foobar FOR SELECT TO PUBLIC USING (foo = current_user); + CREATE TRIGGER "some trigger" BEFORE UPDATE ON foobar FOR EACH ROW @@ -230,6 +241,9 @@ var schemaAcceptanceTests = []acceptanceTestCase{ CREATE INDEX bar_normal_idx ON schema_2.bar(bar); CREATE INDEX bar_another_normal_id ON schema_2.bar(bar DESC, fizz DESC); CREATE UNIQUE INDEX bar_unique_idx on schema_2.bar(fizz, buzz); + + CREATE POLICY bar_bar_policy ON schema_2.bar FOR INSERT TO role_1 WITH CHECK (bar > 5.1); + CREATE POLICY bar_foo_policy ON schema_2.bar FOR SELECT TO PUBLIC USING (foo = 'some_foo'); `, }, newSchemaDDL: []string{ @@ -288,6 +302,8 @@ var schemaAcceptanceTests = []acceptanceTestCase{ ALTER TABLE "New_table" ADD CONSTRAINT "new_fzz_check" CHECK ( new_fizz < CURRENT_TIMESTAMP - interval '1 month' ) NO INHERIT NOT VALID; CREATE UNIQUE INDEX foobar_unique_idx ON "New_table"(new_foo, new_fizz); + CREATE POLICY "New_table_foo_policy" ON "New_table" FOR DELETE TO PUBLIC USING (version > 0); + CREATE TRIGGER "some trigger" BEFORE UPDATE ON "New_table" FOR EACH ROW @@ -309,6 +325,8 @@ var schemaAcceptanceTests = []acceptanceTestCase{ CREATE UNIQUE INDEX bar_unique_idx ON schema_2.bar(fizz, buzz); CREATE INDEX gin_index ON schema_2.bar USING gin (quux gin_trgm_ops); + CREATE POLICY bar_foo_policy ON schema_2.bar FOR SELECT TO role_1 USING (foo = 'some_foo' AND quux = 'some_quux'); + CREATE FUNCTION check_content() RETURNS TRIGGER AS $$ BEGIN IF LENGTH(NEW.id) == 0 THEN @@ -325,6 +343,7 @@ var schemaAcceptanceTests = []acceptanceTestCase{ }, expectedHazardTypes: []diff.MigrationHazardType{ diff.MigrationHazardTypeAcquiresShareRowExclusiveLock, + diff.MigrationHazardTypeAuthzUpdate, diff.MigrationHazardTypeDeletesData, diff.MigrationHazardTypeHasUntrackableDependencies, diff.MigrationHazardTypeIndexBuild, @@ -414,6 +433,6 @@ var schemaAcceptanceTests = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestSchemaAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestSchemaTestCases() { suite.runTestCases(schemaAcceptanceTests) } diff --git a/internal/migration_acceptance_tests/sequence_cases_test.go b/internal/migration_acceptance_tests/sequence_cases_test.go index 6b7826a..9c099ef 100644 --- a/internal/migration_acceptance_tests/sequence_cases_test.go +++ b/internal/migration_acceptance_tests/sequence_cases_test.go @@ -760,6 +760,6 @@ var sequenceAcceptanceTests = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestSequenceAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestSequenceTestCases() { suite.runTestCases(sequenceAcceptanceTests) } diff --git a/internal/migration_acceptance_tests/table_cases_test.go b/internal/migration_acceptance_tests/table_cases_test.go index 726bdfb..8537472 100644 --- a/internal/migration_acceptance_tests/table_cases_test.go +++ b/internal/migration_acceptance_tests/table_cases_test.go @@ -17,6 +17,8 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ buzz REAL CHECK (buzz IS NOT NULL) ); ALTER TABLE foobar REPLICA IDENTITY FULL; + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; CREATE INDEX normal_idx ON foobar(fizz); CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar); @@ -40,6 +42,8 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ buzz REAL CHECK (buzz IS NOT NULL) ); ALTER TABLE foobar REPLICA IDENTITY FULL; + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; CREATE INDEX normal_idx ON foobar(fizz); CREATE UNIQUE INDEX unique_idx ON foobar(foo, bar); @@ -73,6 +77,8 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ buzz REAL CHECK (buzz IS NOT NULL) ); ALTER TABLE foobar REPLICA IDENTITY FULL; + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; CREATE INDEX normal_idx ON foobar(fizz); CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar); @@ -97,6 +103,8 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ foo VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL ); ALTER TABLE foobar REPLICA IDENTITY FULL; + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; CREATE INDEX normal_idx ON foobar(fizz); CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar); @@ -114,33 +122,37 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ }, }, { - name: "Create table with quoted names", + name: "Create table with RLS enabled", oldSchemaDDL: nil, newSchemaDDL: []string{ ` CREATE TABLE "Foobar"( - id INT PRIMARY KEY, - "Foo" VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL CHECK (LENGTH("Foo") > 0), - bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - fizz SERIAL NOT NULL + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + id INT PRIMARY KEY, + fizz SERIAL NOT NULL, + "Foo" VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL CHECK (LENGTH("Foo") > 0) ); + ALTER TABLE "Foobar" ENABLE ROW LEVEL SECURITY; CREATE INDEX normal_idx ON "Foobar" USING hash (fizz); CREATE UNIQUE INDEX unique_idx ON "Foobar"("Foo" DESC, bar); `, }, - dataPackingExpectations: expectations{ - outputState: []string{ - ` - CREATE TABLE "Foobar"( - bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - id INT PRIMARY KEY, - fizz SERIAL NOT NULL, - "Foo" VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL CHECK (LENGTH("Foo") > 0) - ); - CREATE INDEX normal_idx ON "Foobar" USING hash (fizz); - CREATE UNIQUE INDEX unique_idx ON "Foobar"("Foo" DESC, bar); - `, - }, + }, + { + name: "Create table with force RLS enabled", + oldSchemaDDL: nil, + newSchemaDDL: []string{ + ` + CREATE TABLE "Foobar"( + bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + id INT PRIMARY KEY, + fizz SERIAL NOT NULL, + "Foo" VARCHAR(255) COLLATE "POSIX" DEFAULT '' NOT NULL CHECK (LENGTH("Foo") > 0) + ); + ALTER TABLE "Foobar" FORCE ROW LEVEL SECURITY; + CREATE INDEX normal_idx ON "Foobar" USING hash (fizz); + CREATE UNIQUE INDEX unique_idx ON "Foobar"("Foo" DESC, bar); + `, }, }, { @@ -288,6 +300,90 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ planErrorIs: diff.ErrNotImplemented, }, }, + { + name: "Enable RLS", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Disable RLS", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + ALTER TABLE foobar ENABLE ROW LEVEL SECURITY; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Force RLS", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Unforce RLS", + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + ALTER TABLE foobar FORCE ROW LEVEL SECURITY; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar( + id INT PRIMARY KEY + ); + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, { name: "Alter table: New primary key, drop unique constraint, new unique constraint, change column types, delete unique index, delete FK's, new index, validate check constraint", oldSchemaDDL: []string{ @@ -342,7 +438,7 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ }, }, { - name: "Alter table: New column, new primary key, new FK, drop FK, alter column to nullable, alter column types, drop column, drop index, drop check constraints", + name: "Alter table: New column, new primary key, new FK, drop FK, alter column to nullable, alter column types, drop column, drop index, drop check constraints, alter policies", oldSchemaDDL: []string{ ` CREATE TABLE foobar( @@ -352,9 +448,23 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ fizz SERIAL NOT NULL, buzz REAL CHECK (buzz IS NOT NULL) ); + CREATE INDEX normal_idx ON foobar(fizz); CREATE UNIQUE INDEX unique_idx ON foobar(foo DESC, bar); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR INSERT + TO PUBLIC + WITH CHECK (fizz > 0); + + CREATE POLICY some_policy_to_drop ON foobar + AS RESTRICTIVE + FOR SELECT + TO PUBLIC + USING (bar = CURRENT_TIMESTAMP AND fizz * 2 > 0); + + CREATE TABLE foobar_fk( bar TIMESTAMP, foo VARCHAR(255) @@ -371,8 +481,15 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ bar TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, new_fizz DECIMAL(65, 10) DEFAULT 5.25 NOT NULL PRIMARY KEY UNIQUE ); + CREATE INDEX other_idx ON foobar(bar); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR INSERT + TO PUBLIC + WITH CHECK (new_fizz = 5.25); + CREATE TABLE foobar_fk( bar TIMESTAMP, foo CHAR @@ -385,6 +502,7 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ expectedHazardTypes: []diff.MigrationHazardType{ diff.MigrationHazardTypeAcquiresAccessExclusiveLock, diff.MigrationHazardTypeAcquiresShareRowExclusiveLock, + diff.MigrationHazardTypeAuthzUpdate, diff.MigrationHazardTypeImpactsDatabasePerformance, diff.MigrationHazardTypeDeletesData, diff.MigrationHazardTypeIndexDropped, @@ -405,6 +523,12 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ CREATE INDEX normal_idx ON foobar USING hash (fizz); CREATE UNIQUE INDEX foobar_unique_idx ON foobar(foo, bar); + CREATE POLICY foobar_policy ON foobar + AS PERMISSIVE + FOR INSERT + TO PUBLIC + WITH CHECK (id > 1 AND foo = 'value'); + CREATE TABLE foobar_fk( bar TIMESTAMP, foo VARCHAR(255) @@ -427,6 +551,13 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ CREATE INDEX normal_idx ON foobar USING hash (new_fizz); CREATE UNIQUE INDEX foobar_unique_idx ON foobar(new_foo, new_bar); + + CREATE POLICY foobar_policy ON foobar + AS RESTRICTIVE + FOR INSERT + TO PUBLIC + WITH CHECK (new_id > 0 AND new_foo = 'some_new_value'); + CREATE TABLE foobar_fk( bar TIMESTAMP, foo VARCHAR(255) @@ -440,6 +571,7 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ expectedHazardTypes: []diff.MigrationHazardType{ diff.MigrationHazardTypeAcquiresAccessExclusiveLock, diff.MigrationHazardTypeAcquiresShareRowExclusiveLock, + diff.MigrationHazardTypeAuthzUpdate, diff.MigrationHazardTypeDeletesData, diff.MigrationHazardTypeIndexDropped, diff.MigrationHazardTypeIndexBuild, @@ -449,7 +581,7 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ name: "Alter table: translate BIGINT type to TIMESTAMP, set to not null, set default", oldSchemaDDL: []string{ ` - CREATE TABLE alexrhee_testing( + CREATE TABLE foobar( id INT PRIMARY KEY, obj_attr__c_time BIGINT, obj_attr__m_time BIGINT @@ -458,7 +590,7 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ }, newSchemaDDL: []string{ ` - CREATE TABLE alexrhee_testing( + CREATE TABLE foobar( id INT PRIMARY KEY, obj_attr__c_time TIMESTAMP NOT NULL, obj_attr__m_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL @@ -472,6 +604,6 @@ var tableAcceptanceTestCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestTableAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestTableTestCases() { suite.runTestCases(tableAcceptanceTestCases) } diff --git a/internal/migration_acceptance_tests/trigger_cases_test.go b/internal/migration_acceptance_tests/trigger_cases_test.go index 83eb2f2..f7ab43c 100644 --- a/internal/migration_acceptance_tests/trigger_cases_test.go +++ b/internal/migration_acceptance_tests/trigger_cases_test.go @@ -795,6 +795,6 @@ var triggerAcceptanceTestCases = []acceptanceTestCase{ }, } -func (suite *acceptanceTestSuite) TestTriggerAcceptanceTestCases() { +func (suite *acceptanceTestSuite) TestTriggerTestCases() { suite.runTestCases(triggerAcceptanceTestCases) } diff --git a/internal/pgengine/actions.go b/internal/pgengine/actions.go new file mode 100644 index 0000000..2022097 --- /dev/null +++ b/internal/pgengine/actions.go @@ -0,0 +1,50 @@ +package pgengine + +import ( + "context" + "database/sql" + "fmt" +) + +// ResetInstance attempts to reset the cluster to a clean state. +// It deletes all cluster level objects, i.e., roles, which are not deleted +// by dropping database(s). This can be useful for re-using a cluster for multiple tests. +func ResetInstance(ctx context.Context, db *sql.DB) error { + // Drop all roles except the current user and postgres internal roles + if err := dropRoles(ctx, db); err != nil { + return fmt.Errorf("dropping roles: %w", err) + } + + return nil +} + +// DropRoles drops all roles except the current user and postgres internal roles +func dropRoles(ctx context.Context, db *sql.DB) error { + rows, err := db.QueryContext(ctx, ` + SELECT rolname + FROM pg_catalog.pg_roles + WHERE rolname NOT LIKE 'pg_%' + AND rolname != current_user; + `, + ) + if err != nil { + return fmt.Errorf("querying roles: %w", err) + } + defer rows.Close() + + for rows.Next() { + var roleName string + if err := rows.Scan(&roleName); err != nil { + return fmt.Errorf("scanning role: %w", err) + } + if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP ROLE %s", roleName)); err != nil { + return fmt.Errorf("dropping role %q: %w", roleName, err) + } + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("iterating over rows: %w", err) + } + + return nil +} diff --git a/internal/pgengine/db.go b/internal/pgengine/db.go index 2b24105..447a538 100644 --- a/internal/pgengine/db.go +++ b/internal/pgengine/db.go @@ -44,7 +44,7 @@ func (d *DB) DropDB() error { return err } - // Drop existing connections, so that we can drop the table + // Drop existing connections, so that we can drop the database _, err = db.Exec("SELECT PG_TERMINATE_BACKEND(pid) FROM pg_stat_activity WHERE datname = $1", d.GetName()) if err != nil { return err diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql index 7686077..f67e39f 100644 --- a/internal/queries/queries.sql +++ b/internal/queries/queries.sql @@ -21,6 +21,8 @@ SELECT c.relname::TEXT AS table_name, table_namespace.nspname::TEXT AS table_schema_name, c.relreplident::TEXT AS replica_identity, + c.relrowsecurity AS rls_enabled, + c.relforcerowsecurity AS rls_forced, COALESCE(parent_c.relname, '')::TEXT AS parent_table_name, COALESCE(parent_namespace.nspname, '')::TEXT AS parent_table_schema_name, (CASE @@ -352,3 +354,56 @@ WHERE AND ext_depend.objid = pg_type.oid AND ext_depend.deptype = 'e' ); + + +-- name: GetPolicies :many +WITH roles AS ( + SELECT + oid, + rolname + FROM pg_catalog.pg_roles + UNION + ( + SELECT + 0 AS ois, + 'PUBLIC' AS role_name + ) +) + +SELECT + pol.polname::TEXT AS policy_name, + table_c.relname::TEXT AS owning_table_name, + table_namespace.nspname::TEXT AS owning_table_schema_name, + pol.polpermissive AS is_permissive, + ( + SELECT ARRAY_AGG(rolname) + FROM roles + WHERE roles.oid = ANY(pol.polroles) + )::TEXT [] AS applies_to, + pol.polcmd::TEXT AS cmd, + COALESCE(pg_catalog.pg_get_expr( + pol.polwithcheck, pol.polrelid + ), '')::TEXT AS check_expression, + COALESCE( + pg_catalog.pg_get_expr(pol.polqual, pol.polrelid), '' + )::TEXT AS using_expression, + ( + SELECT ARRAY_AGG(a.attname) + FROM pg_catalog.pg_attribute AS a + INNER JOIN pg_catalog.pg_depend AS d ON a.attnum = d.refobjsubid + WHERE + d.objid = pol.oid + AND d.refobjid = table_c.oid + AND d.refclassid = 'pg_class'::REGCLASS + AND a.attrelid = table_c.oid + AND NOT a.attisdropped + )::TEXT [] AS column_names +FROM pg_catalog.pg_policy AS pol +INNER JOIN pg_catalog.pg_class AS table_c ON pol.polrelid = table_c.oid +INNER JOIN + pg_catalog.pg_namespace AS table_namespace + ON table_c.relnamespace = table_namespace.oid +WHERE + table_namespace.nspname NOT IN ('pg_catalog', 'information_schema') + AND table_namespace.nspname !~ '^pg_toast' + AND table_namespace.nspname !~ '^pg_temp'; diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go index 46714be..773d4fb 100644 --- a/internal/queries/queries.sql.go +++ b/internal/queries/queries.sql.go @@ -589,6 +589,104 @@ func (q *Queries) GetIndexes(ctx context.Context) ([]GetIndexesRow, error) { return items, nil } +const getPolicies = `-- name: GetPolicies :many +WITH roles AS ( + SELECT + oid, + rolname + FROM pg_catalog.pg_roles + UNION + ( + SELECT + 0 AS ois, + 'PUBLIC' AS role_name + ) +) + +SELECT + pol.polname::TEXT AS policy_name, + table_c.relname::TEXT AS owning_table_name, + table_namespace.nspname::TEXT AS owning_table_schema_name, + pol.polpermissive AS is_permissive, + ( + SELECT ARRAY_AGG(rolname) + FROM roles + WHERE roles.oid = ANY(pol.polroles) + )::TEXT [] AS applies_to, + pol.polcmd::TEXT AS cmd, + COALESCE(pg_catalog.pg_get_expr( + pol.polwithcheck, pol.polrelid + ), '')::TEXT AS check_expression, + COALESCE( + pg_catalog.pg_get_expr(pol.polqual, pol.polrelid), '' + )::TEXT AS using_expression, + ( + SELECT ARRAY_AGG(a.attname) + FROM pg_catalog.pg_attribute AS a + INNER JOIN pg_catalog.pg_depend AS d ON a.attnum = d.refobjsubid + WHERE + d.objid = pol.oid + AND d.refobjid = table_c.oid + AND d.refclassid = 'pg_class'::REGCLASS + AND a.attrelid = table_c.oid + AND NOT a.attisdropped + )::TEXT [] AS column_names +FROM pg_catalog.pg_policy AS pol +INNER JOIN pg_catalog.pg_class AS table_c ON pol.polrelid = table_c.oid +INNER JOIN + pg_catalog.pg_namespace AS table_namespace + ON table_c.relnamespace = table_namespace.oid +WHERE + table_namespace.nspname NOT IN ('pg_catalog', 'information_schema') + AND table_namespace.nspname !~ '^pg_toast' + AND table_namespace.nspname !~ '^pg_temp' +` + +type GetPoliciesRow struct { + PolicyName string + OwningTableName string + OwningTableSchemaName string + IsPermissive bool + AppliesTo []string + Cmd string + CheckExpression string + UsingExpression string + ColumnNames []string +} + +func (q *Queries) GetPolicies(ctx context.Context) ([]GetPoliciesRow, error) { + rows, err := q.db.QueryContext(ctx, getPolicies) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetPoliciesRow + for rows.Next() { + var i GetPoliciesRow + if err := rows.Scan( + &i.PolicyName, + &i.OwningTableName, + &i.OwningTableSchemaName, + &i.IsPermissive, + pq.Array(&i.AppliesTo), + &i.Cmd, + &i.CheckExpression, + &i.UsingExpression, + pq.Array(&i.ColumnNames), + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getSchemas = `-- name: GetSchemas :many SELECT nspname::TEXT AS schema_name FROM pg_catalog.pg_namespace @@ -733,6 +831,8 @@ SELECT c.relname::TEXT AS table_name, table_namespace.nspname::TEXT AS table_schema_name, c.relreplident::TEXT AS replica_identity, + c.relrowsecurity AS rls_enabled, + c.relforcerowsecurity AS rls_forced, COALESCE(parent_c.relname, '')::TEXT AS parent_table_name, COALESCE(parent_namespace.nspname, '')::TEXT AS parent_table_schema_name, (CASE @@ -769,6 +869,8 @@ type GetTablesRow struct { TableName string TableSchemaName string ReplicaIdentity string + RlsEnabled bool + RlsForced bool ParentTableName string ParentTableSchemaName string PartitionKeyDef string @@ -789,6 +891,8 @@ func (q *Queries) GetTables(ctx context.Context) ([]GetTablesRow, error) { &i.TableName, &i.TableSchemaName, &i.ReplicaIdentity, + &i.RlsEnabled, + &i.RlsForced, &i.ParentTableName, &i.ParentTableSchemaName, &i.PartitionKeyDef, diff --git a/internal/schema/schema.go b/internal/schema/schema.go index c4e2b55..879a33c 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -68,19 +68,8 @@ func (s Schema) Normalize() Schema { s.Enums = sortSchemaObjectsByName(s.Enums) var normTables []Table - for _, table := range sortSchemaObjectsByName(s.Tables) { - // Don't normalize columns order. their order is derived from the postgres catalogs - // (relevant to data packing) - var normCheckConstraints []CheckConstraint - for _, checkConstraint := range sortSchemaObjectsByName(table.CheckConstraints) { - checkConstraint.DependsOnFunctions = sortSchemaObjectsByName(checkConstraint.DependsOnFunctions) - checkConstraint.KeyColumns = sortByKey(checkConstraint.KeyColumns, func(s string) string { - return s - }) - normCheckConstraints = append(normCheckConstraints, checkConstraint) - } - table.CheckConstraints = normCheckConstraints - normTables = append(normTables, table) + for _, t := range sortSchemaObjectsByName(s.Tables) { + normTables = append(normTables, normalizeTable(t)) } s.Tables = normTables @@ -100,6 +89,33 @@ func (s Schema) Normalize() Schema { return s } +func normalizeTable(t Table) Table { + // Don't normalize columns order. their order is derived from the postgres catalogs + // (relevant to data packing) + var normCheckConstraints []CheckConstraint + for _, checkConstraint := range sortSchemaObjectsByName(t.CheckConstraints) { + checkConstraint.DependsOnFunctions = sortSchemaObjectsByName(checkConstraint.DependsOnFunctions) + checkConstraint.KeyColumns = sortByKey(checkConstraint.KeyColumns, func(s string) string { + return s + }) + normCheckConstraints = append(normCheckConstraints, checkConstraint) + } + t.CheckConstraints = normCheckConstraints + + var normPolicies []Policy + for _, p := range sortSchemaObjectsByName(t.Policies) { + p.AppliesTo = sortByKey(p.AppliesTo, func(s string) string { + return s + }) + p.Columns = sortByKey(p.Columns, func(s string) string { + return s + }) + normPolicies = append(normPolicies, p) + } + t.Policies = normPolicies + return t +} + // sortSchemaObjectsByName returns a (copied) sorted list of schema objects. func sortSchemaObjectsByName[S Object](vals []S) []S { return sortByKey(vals, func(v S) string { @@ -158,7 +174,10 @@ type Table struct { SchemaQualifiedName Columns []Column CheckConstraints []CheckConstraint + Policies []Policy ReplicaIdentity ReplicaIdentity + RLSEnabled bool + RLSForced bool // PartitionKeyDef is the output of Pg function pg_get_partkeydef: // PARTITION BY $PartitionKeyDef @@ -299,7 +318,7 @@ type ForeignKeyConstraint struct { } func (f ForeignKeyConstraint) GetName() string { - return f.OwningTable.GetFQEscapedName() + "_" + f.EscapedName + return f.OwningTable.GetFQEscapedName() + "-" + f.EscapedName } type ( @@ -350,6 +369,33 @@ func (g GetTriggerDefStatement) ToCreateOrReplace() (string, error) { return triggerToOrReplaceRegex.ReplaceAllString(string(g), "${1}OR REPLACE ${2}"), nil } +// PolicyCmd represents the polcmd value in the pg_policy system catalog. +// See docs for possible values: https://www.postgresql.org/docs/current/catalog-pg-policy.html#CATALOG-PG-POLICY +type PolicyCmd string + +const ( + SelectPolicyCmd PolicyCmd = "r" + InsertPolicyCmd PolicyCmd = "a" + UpdatePolicyCmd PolicyCmd = "w" + DeletePolicyCmd PolicyCmd = "d" + AllPolicyCmd PolicyCmd = "*" +) + +type Policy struct { + EscapedName string + IsPermissive bool + AppliesTo []string + Cmd PolicyCmd + CheckExpression string + UsingExpression string + // Columns are the columns that the policy applies to. + Columns []string +} + +func (p Policy) GetName() string { + return p.EscapedName +} + type Trigger struct { EscapedName string OwningTable SchemaQualifiedName @@ -360,7 +406,7 @@ type Trigger struct { } func (t Trigger) GetName() string { - return t.OwningTable.GetFQEscapedName() + "_" + t.EscapedName + return t.OwningTable.GetFQEscapedName() + "-" + t.EscapedName } type ( @@ -710,9 +756,22 @@ func (s *schemaFetcher) fetchTables(ctx context.Context) ([]Table, error) { return nil, fmt.Errorf("GetTables(): %w", err) } - checkConsByTable, err := s.fetchCheckConsAndBuildTableToCheckConsMap(ctx) + checkCons, err := s.fetchCheckCons(ctx) if err != nil { - return nil, fmt.Errorf("fetchCheckConsAndBuildTableToCheckConsMap: %w", err) + return nil, fmt.Errorf("fetchCheckCons(): %w", err) + } + checkConsByTable := make(map[string][]CheckConstraint) + for _, cc := range checkCons { + checkConsByTable[cc.table.GetFQEscapedName()] = append(checkConsByTable[cc.table.GetFQEscapedName()], cc.checkConstraint) + } + + policies, err := s.fetchPolicies(ctx) + if err != nil { + return nil, fmt.Errorf("fetchPolicies(): %w", err) + } + policiesByTable := make(map[string][]Policy) + for _, p := range policies { + policiesByTable[p.table.GetFQEscapedName()] = append(policiesByTable[p.table.GetFQEscapedName()], p.policy) } goroutineRunner := s.goroutineRunnerFactory() @@ -720,7 +779,7 @@ func (s *schemaFetcher) fetchTables(ctx context.Context) ([]Table, error) { for _, _rawTable := range rawTables { rawTable := _rawTable // Capture loop variables for go routine tableFuture, err := concurrent.SubmitFuture(ctx, goroutineRunner, func() (Table, error) { - return s.buildTable(ctx, rawTable, checkConsByTable) + return s.buildTable(ctx, rawTable, checkConsByTable, policiesByTable) }) if err != nil { return nil, fmt.Errorf("starting table future: %w", err) @@ -743,7 +802,12 @@ func (s *schemaFetcher) fetchTables(ctx context.Context) ([]Table, error) { return tables, nil } -func (s *schemaFetcher) buildTable(ctx context.Context, table queries.GetTablesRow, checkConsByTable map[string][]CheckConstraint) (Table, error) { +func (s *schemaFetcher) buildTable( + ctx context.Context, + table queries.GetTablesRow, + checkConsByTable map[string][]CheckConstraint, + policiesByTable map[string][]Policy, +) (Table, error) { rawColumns, err := s.q.GetColumnsForTable(ctx, table.Oid) if err != nil { return Table{}, fmt.Errorf("GetColumnsForTable(%s): %w", table.Oid, err) @@ -788,7 +852,10 @@ func (s *schemaFetcher) buildTable(ctx context.Context, table queries.GetTablesR SchemaQualifiedName: schemaQualifiedName, Columns: columns, CheckConstraints: checkConsByTable[schemaQualifiedName.GetFQEscapedName()], + Policies: policiesByTable[schemaQualifiedName.GetFQEscapedName()], ReplicaIdentity: ReplicaIdentity(table.ReplicaIdentity), + RLSEnabled: table.RlsEnabled, + RLSForced: table.RlsForced, PartitionKeyDef: table.PartitionKeyDef, @@ -797,19 +864,18 @@ func (s *schemaFetcher) buildTable(ctx context.Context, table queries.GetTablesR }, nil } -// fetchCheckConsAndBuildTableToCheckConsMap fetches the check constraints and builds a map of table name to the check -// constraints within the table -func (s *schemaFetcher) fetchCheckConsAndBuildTableToCheckConsMap(ctx context.Context) (map[string][]CheckConstraint, error) { +type checkConstraintAndTable struct { + checkConstraint CheckConstraint + table SchemaQualifiedName +} + +// fetchCheckCons fetches the check constraints +func (s *schemaFetcher) fetchCheckCons(ctx context.Context) ([]checkConstraintAndTable, error) { rawCheckCons, err := s.q.GetCheckConstraints(ctx) if err != nil { return nil, fmt.Errorf("GetCheckConstraints: %w", err) } - type checkConstraintAndTable struct { - checkConstraint CheckConstraint - table SchemaQualifiedName - } - goroutineRunner := s.goroutineRunnerFactory() var ccFutures []concurrent.Future[checkConstraintAndTable] for _, _rawCC := range rawCheckCons { @@ -821,10 +887,7 @@ func (s *schemaFetcher) fetchCheckConsAndBuildTableToCheckConsMap(ctx context.Co } return checkConstraintAndTable{ checkConstraint: cc, - table: SchemaQualifiedName{ - SchemaName: rawCC.TableSchemaName, - EscapedName: EscapeIdentifier(rawCC.TableName), - }, + table: buildNameFromUnescaped(rawCC.TableName, rawCC.TableSchemaName), }, nil }) if err != nil { @@ -850,13 +913,7 @@ func (s *schemaFetcher) fetchCheckConsAndBuildTableToCheckConsMap(ctx context.Co s.nameFilter, ) - // Build a map of table name to check constraints - tablesToCheckConsMap := make(map[string][]CheckConstraint) - for _, cc := range ccs { - tablesToCheckConsMap[cc.table.GetFQEscapedName()] = append(tablesToCheckConsMap[cc.table.GetFQEscapedName()], cc.checkConstraint) - } - - return tablesToCheckConsMap, nil + return ccs, nil } func (s *schemaFetcher) buildCheckConstraint(ctx context.Context, cc queries.GetCheckConstraintsRow) (CheckConstraint, error) { @@ -1104,6 +1161,47 @@ func (s *schemaFetcher) fetchDependsOnFunctions(ctx context.Context, systemCatal return functionNames, nil } +type policyAndTable struct { + policy Policy + table SchemaQualifiedName +} + +func (s *schemaFetcher) fetchPolicies(ctx context.Context) ([]policyAndTable, error) { + rawPolicies, err := s.q.GetPolicies(ctx) + if err != nil { + return nil, fmt.Errorf("GetPolicies: %w", err) + } + + var policies []policyAndTable + for _, rp := range rawPolicies { + policies = append(policies, policyAndTable{ + policy: Policy{ + EscapedName: EscapeIdentifier(rp.PolicyName), + IsPermissive: rp.IsPermissive, + AppliesTo: rp.AppliesTo, + Cmd: PolicyCmd(rp.Cmd), + CheckExpression: rp.CheckExpression, + UsingExpression: rp.UsingExpression, + Columns: rp.ColumnNames, + }, + table: buildNameFromUnescaped(rp.OwningTableName, rp.OwningTableSchemaName), + }) + } + + policies = filterSliceByName( + policies, + func(p policyAndTable) SchemaQualifiedName { + return SchemaQualifiedName{ + SchemaName: p.table.SchemaName, + EscapedName: p.policy.EscapedName, + } + }, + s.nameFilter, + ) + + return policies, nil +} + func (s *schemaFetcher) fetchTriggers(ctx context.Context) ([]Trigger, error) { rawTriggers, err := s.q.GetTriggers(ctx) if err != nil { diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index 5d8db56..cd8280c 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -99,12 +99,28 @@ var ( PRIMARY KEY(id, version), CHECK ( function_with_dependencies(id, id) > 0) ); + ALTER TABLE schema_2.foo ENABLE ROW LEVEL SECURITY; ALTER TABLE schema_2.foo ADD CONSTRAINT author_content_check CHECK ( LENGTH(content) > 0 AND LENGTH(author) > 0 ) NO INHERIT NOT VALID; CREATE INDEX some_idx ON schema_2.foo (created_at DESC, author ASC); CREATE UNIQUE INDEX some_unique_idx ON schema_2.foo (content); CREATE INDEX some_gin_idx ON schema_2.foo USING GIN (author schema_1.gin_trgm_ops); ALTER TABLE schema_2.foo REPLICA IDENTITY USING INDEX some_unique_idx; + CREATE POLICY foo_policy_1 ON schema_2.foo + AS PERMISSIVE + FOR ALL + TO PUBLIC + USING (author = current_user) + WITH CHECK (version > 0); + CREATE ROLE some_role_1; + CREATE ROLE some_role_2; + CREATE POLICY foo_policy_2 ON schema_2.foo + AS RESTRICTIVE + FOR INSERT + TO some_role_2, some_role_1 + WITH CHECK (version > 0); + + CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ BEGIN NEW.version = OLD.version + 1; @@ -131,6 +147,13 @@ var ( id INT NOT NULL, CHECK (id > 0) ); + ALTER TABLE schema_1.foo ENABLE ROW LEVEL SECURITY; + ALTER TABLE schema_1.foo FORCE ROW LEVEL SECURITY; + CREATE POLICY foo_policy_1 ON schema_1.foo + AS RESTRICTIVE + FOR UPDATE + TO PUBLIC + WITH CHECK (id > 0); CREATE TABLE schema_1.foo_fk( id INT, @@ -161,8 +184,14 @@ var ( WHEN (OLD.* IS DISTINCT FROM NEW.*) -- Reference a function in a filtered out schema. The trigger should still be included. EXECUTE PROCEDURE public.increment_version(); + -- Validate policies are filtered out + CREATE POLICY foo_policy_1 ON schema_filtered_1.foo_fk + AS PERMISSIVE + FOR SELECT + TO PUBLIC + USING (version > 0); `}, - expectedHash: "44052eb962385897", + expectedHash: "ffcf26204e89f536", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, @@ -219,7 +248,26 @@ var ( KeyColumns: []string{"id"}, }, }, + Policies: []Policy{ + { + EscapedName: "\"foo_policy_1\"", + IsPermissive: true, + AppliesTo: []string{"PUBLIC"}, + Cmd: AllPolicyCmd, + UsingExpression: "(author = CURRENT_USER)", + CheckExpression: "(version > 0)", + Columns: []string{"author", "version"}, + }, + { + EscapedName: "\"foo_policy_2\"", + AppliesTo: []string{"some_role_1", "some_role_2"}, + Cmd: InsertPolicyCmd, + CheckExpression: "(version > 0)", + Columns: []string{"version"}, + }, + }, ReplicaIdentity: ReplicaIdentityIndex, + RLSEnabled: true, }, { SchemaQualifiedName: SchemaQualifiedName{SchemaName: "schema_1", EscapedName: "\"foo\""}, @@ -229,7 +277,19 @@ var ( CheckConstraints: []CheckConstraint{ {Name: "foo_id_check", Expression: "(id > 0)", IsValid: true, IsInheritable: true, KeyColumns: []string{"id"}}, }, + Policies: []Policy{ + { + EscapedName: "\"foo_policy_1\"", + IsPermissive: false, + AppliesTo: []string{"PUBLIC"}, + Cmd: UpdatePolicyCmd, + CheckExpression: "(id > 0)", + Columns: []string{"id"}, + }, + }, ReplicaIdentity: ReplicaIdentityDefault, + RLSEnabled: true, + RLSForced: true, }, { SchemaQualifiedName: SchemaQualifiedName{SchemaName: "schema_1", EscapedName: "\"foo_fk\""}, @@ -434,7 +494,7 @@ var ( ALTER TABLE foo_fk_1 ADD CONSTRAINT foo_fk_1_fk FOREIGN KEY (author, content) REFERENCES foo_1 (author, content) NOT VALID; `}, - expectedHash: "1609376865697f2d", + expectedHash: "481b62a68155716d", expectedSchema: Schema{ NamedSchemas: []NamedSchema{ {Name: "public"}, @@ -1066,10 +1126,24 @@ func TestSchemaTestCases(t *testing.T) { } func runTestCase(t *testing.T, engine *pgengine.Engine, testCase *testCase, getDBTX func(db *sql.DB) (queries.DBTX, io.Closer)) { + defer func() { + db, err := sql.Open("pgx", engine.GetPostgresDatabaseDSN()) + require.NoError(t, err) + defer db.Close() + require.NoError(t, pgengine.ResetInstance(context.Background(), db)) + }() + db, err := engine.CreateDatabase() require.NoError(t, err) + defer func() { + require.NoError(t, db.DropDB()) + }() + connPool, err := sql.Open("pgx", db.GetDSN()) require.NoError(t, err) + defer func() { + require.NoError(t, connPool.Close()) + }() for _, stmt := range testCase.ddl { _, err := connPool.Exec(stmt) @@ -1110,9 +1184,6 @@ func runTestCase(t *testing.T, engine *pgengine.Engine, testCase *testCase, getD // Optionally assert that the hash matches the expected hash assert.Equal(t, testCase.expectedHash, fetchedSchemaHash) } - - require.NoError(t, connPool.Close()) - require.NoError(t, db.DropDB()) } func TestIdxDefStmtToCreateIdxConcurrently(t *testing.T) { diff --git a/pkg/diff/diff.go b/pkg/diff/diff.go index 2958552..0f51f65 100644 --- a/pkg/diff/diff.go +++ b/pkg/diff/diff.go @@ -123,7 +123,7 @@ type ( ) func (ld listDiff[S, D]) isEmpty() bool { - return len(ld.adds) == 0 || len(ld.alters) == 0 || len(ld.deletes) == 0 + return len(ld.adds) == 0 && len(ld.alters) == 0 && len(ld.deletes) == 0 } func (ld listDiff[S, D]) resolveToSQLGroupedByEffect(sqlGenerator sqlGenerator[S, D]) (sqlGroupedByEffect[S, D], error) { diff --git a/pkg/diff/plan.go b/pkg/diff/plan.go index f3b8b9b..ebc215a 100644 --- a/pkg/diff/plan.go +++ b/pkg/diff/plan.go @@ -20,6 +20,7 @@ const ( MigrationHazardTypeImpactsDatabasePerformance MigrationHazardType = "IMPACTS_DATABASE_PERFORMANCE" MigrationHazardTypeIsUserGenerated MigrationHazardType = "IS_USER_GENERATED" MigrationHazardTypeExtensionVersionUpgrade MigrationHazardType = "UPGRADING_EXTENSION_VERSION" + MigrationHazardTypeAuthzUpdate MigrationHazardType = "AUTHZ_UPDATE" ) // MigrationHazard represents a hazard that a statement poses to a database diff --git a/pkg/diff/plan_generator.go b/pkg/diff/plan_generator.go index 510b536..67372fa 100644 --- a/pkg/diff/plan_generator.go +++ b/pkg/diff/plan_generator.go @@ -282,7 +282,7 @@ func assertMigratedSchemaMatchesTarget(migratedSchema, targetSchema schema.Schem for _, stmt := range toTargetSchemaStmts { stmtsStrs = append(stmtsStrs, stmt.DDL) } - return fmt.Errorf("diff detected:\n%s", strings.Join(stmtsStrs, "\n")) + return fmt.Errorf("validating plan failed. diff detected:\n%s", strings.Join(stmtsStrs, "\n")) } return nil diff --git a/pkg/diff/policy_sql_generator.go b/pkg/diff/policy_sql_generator.go new file mode 100644 index 0000000..b99e2e5 --- /dev/null +++ b/pkg/diff/policy_sql_generator.go @@ -0,0 +1,319 @@ +package diff + +import ( + "errors" + "fmt" + "strings" + + "github.com/google/go-cmp/cmp" + "github.com/stripe/pg-schema-diff/internal/schema" +) + +var ( + migrationHazardRLSEnabled = MigrationHazard{ + Type: MigrationHazardTypeAuthzUpdate, + Message: "Enabling RLS on a table could cause queries to fail if not correctly configured.", + } + migrationHazardRLSDisabled = MigrationHazard{ + Type: MigrationHazardTypeAuthzUpdate, + Message: "Disabling RLS on a table could allow unauthorized access to data.", + } + migrationHazardRLSForced = MigrationHazard{ + Type: MigrationHazardTypeAuthzUpdate, + Message: "Forcing RLS on a table could cause queries to fail if not correctly configured.", + } + migrationHazardRLSUnforced = MigrationHazard{ + Type: MigrationHazardTypeAuthzUpdate, + Message: "Disabling forcing RLS on a table could allow unauthorized access to data.", + } + + migrationHazardPermissivePolicyAdded = MigrationHazard{ + Type: MigrationHazardTypeAuthzUpdate, + Message: "Adding a permissive policy could allow unauthorized access to data.", + } + migrationHazardPermissivePolicyRemoved = MigrationHazard{ + Type: MigrationHazardTypeAuthzUpdate, + Message: "Removing a permissive policy could cause queries to fail if not correctly configured.", + } + + migrationHazardRestrictivePolicyAdded = MigrationHazard{ + Type: MigrationHazardTypeAuthzUpdate, + Message: "Adding a restrictive policy could cause queries to fail if not correctly configured.", + } + migrationHazardRestrictivePolicyRemoved = MigrationHazard{ + Type: MigrationHazardTypeAuthzUpdate, + Message: "Removing a restrictive policy could allow unauthorized access to data.", + } + + migrationHazardPolicyAltered = MigrationHazard{ + Type: MigrationHazardTypeAuthzUpdate, + Message: "Altering a policy could cause queries to fail if not correctly configured or allow unauthorized access to data.", + } +) + +// When building/altering RLS policies, we must maintain the following order: +// 1. Create/alter table such that all necessary columns exist +// 2. Create/alter policies +// 3. Enable RLS -- This MUST be done last +// +// If not done in this order, we may create an outtage for a user's queries where RLS rejects their queries because +// the policy allowing them hasn't been created yet. The same is true for disabling RLS, but in the reverse order. RLS +// must be disabled before policies are dropped. +// +// Another quirk of policies: Policies on partitions must be dropped before the base table is altered, otherwise +// the SQL could fail because, e.g., the policy references a column that no longer exists. + +func enableRLSForTable(t schema.Table) Statement { + return Statement{ + DDL: fmt.Sprintf("%s ENABLE ROW LEVEL SECURITY", alterTablePrefix(t.SchemaQualifiedName)), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardRLSEnabled}, + } +} + +func disableRLSForTable(t schema.Table) Statement { + return Statement{ + DDL: fmt.Sprintf("%s DISABLE ROW LEVEL SECURITY", alterTablePrefix(t.SchemaQualifiedName)), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardRLSDisabled}, + } +} + +func forceRLSForTable(t schema.Table) Statement { + return Statement{ + DDL: fmt.Sprintf("%s FORCE ROW LEVEL SECURITY", alterTablePrefix(t.SchemaQualifiedName)), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardRLSForced}, + } + +} + +func unforceRLSForTable(t schema.Table) Statement { + return Statement{ + DDL: fmt.Sprintf("%s NO FORCE ROW LEVEL SECURITY", alterTablePrefix(t.SchemaQualifiedName)), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardRLSUnforced}, + } +} + +type policyDiff struct { + oldAndNew[schema.Policy] +} + +func buildPolicyDiffs(psg *policySQLVertexGenerator, old, new []schema.Policy) (listDiff[schema.Policy, policyDiff], error) { + return diffLists(old, new, func(old, new schema.Policy, _, _ int) (_ policyDiff, requiresRecreate bool, _ error) { + diff := policyDiff{ + oldAndNew: oldAndNew[schema.Policy]{ + old: old, new: new, + }, + } + + if _, err := psg.Alter(diff); err != nil { + if errors.Is(err, ErrNotImplemented) { + // If we can't generate the alter SQL, we'll have to recreate the policy. + return diff, true, nil + } + return policyDiff{}, false, fmt.Errorf("generating alter SQL: %w", err) + } + + return diff, false, nil + }) +} + +type policySQLVertexGenerator struct { + table schema.Table + oldTable *schema.Table + newSchemaColumnsByName map[string]schema.Column + oldSchemaColumnsByName map[string]schema.Column +} + +func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (*policySQLVertexGenerator, error) { + var oldSchemaColumnsByName map[string]schema.Column + if oldTable != nil { + if oldTable.SchemaQualifiedName != table.SchemaQualifiedName { + return nil, fmt.Errorf("old and new tables must have the same schema-qualified name. new=%s, old=%s", table.SchemaQualifiedName.GetFQEscapedName(), oldTable.SchemaQualifiedName.GetFQEscapedName()) + } + oldSchemaColumnsByName = buildSchemaObjByNameMap(oldTable.Columns) + } + + return &policySQLVertexGenerator{ + table: table, + newSchemaColumnsByName: buildSchemaObjByNameMap(table.Columns), + oldTable: oldTable, + oldSchemaColumnsByName: oldSchemaColumnsByName, + }, nil +} + +func (psg *policySQLVertexGenerator) Add(p schema.Policy) ([]Statement, error) { + sb := strings.Builder{} + sb.WriteString(fmt.Sprintf("CREATE POLICY %s ON %s", p.EscapedName, psg.table.GetFQEscapedName())) + + typeModifier := "RESTRICTIVE" + if p.IsPermissive { + typeModifier = "PERMISSIVE" + } + sb.WriteString(fmt.Sprintf("\n\tAS %s", typeModifier)) + + cmdSQL, err := policyCharToSQL(p.Cmd) + if err != nil { + return nil, err + } + sb.WriteString(fmt.Sprintf("\n\tFOR %s", cmdSQL)) + + sb.WriteString(fmt.Sprintf("\n\tTO %s", strings.Join(p.AppliesTo, ", "))) + + if p.UsingExpression != "" { + sb.WriteString(fmt.Sprintf("\n\tUSING (%s)", p.UsingExpression)) + } + if p.CheckExpression != "" { + sb.WriteString(fmt.Sprintf("\n\tWITH CHECK (%s)", p.CheckExpression)) + } + + hazard := migrationHazardRestrictivePolicyAdded + if p.IsPermissive { + hazard = migrationHazardPermissivePolicyAdded + } + + return []Statement{{ + DDL: sb.String(), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{hazard}, + }}, nil +} + +func policyCharToSQL(c schema.PolicyCmd) (string, error) { + switch c { + case schema.SelectPolicyCmd: + return "SELECT", nil + case schema.InsertPolicyCmd: + return "INSERT", nil + case schema.UpdatePolicyCmd: + return "UPDATE", nil + case schema.DeletePolicyCmd: + return "DELETE", nil + case schema.AllPolicyCmd: + return "ALL", nil + default: + return "", fmt.Errorf("unknown policy command: %v", c) + } +} + +func (psg *policySQLVertexGenerator) Delete(p schema.Policy) ([]Statement, error) { + hazard := migrationHazardRestrictivePolicyRemoved + if p.IsPermissive { + hazard = migrationHazardPermissivePolicyRemoved + } + return []Statement{{ + DDL: fmt.Sprintf("DROP POLICY %s ON %s", p.EscapedName, psg.table.GetFQEscapedName()), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{hazard}, + }}, nil +} + +func (psg *policySQLVertexGenerator) Alter(diff policyDiff) ([]Statement, error) { + oldCopy := diff.old + + // alterPolicyParts represents the set of strings to include in the ALTER POLICY ... ON TABLE ... statement + var alterPolicyParts []string + + if !cmp.Equal(oldCopy.AppliesTo, diff.new.AppliesTo) { + alterPolicyParts = append(alterPolicyParts, fmt.Sprintf("TO %s", strings.Join(diff.new.AppliesTo, ", "))) + oldCopy.AppliesTo = diff.new.AppliesTo + } + + if oldCopy.UsingExpression != diff.new.UsingExpression && diff.new.UsingExpression != "" { + // Weirdly, you can't actually drop a "USING EXPRESSION" clause from an ALL policy even though you + // can have an ALL policy with only a check expression. + alterPolicyParts = append(alterPolicyParts, fmt.Sprintf("USING (%s)", diff.new.UsingExpression)) + oldCopy.UsingExpression = diff.new.UsingExpression + } + + if oldCopy.CheckExpression != diff.new.CheckExpression && diff.new.CheckExpression != "" { + // Same quirk as above with ALL policies. + alterPolicyParts = append(alterPolicyParts, fmt.Sprintf("WITH CHECK (%s)", diff.new.CheckExpression)) + oldCopy.CheckExpression = diff.new.CheckExpression + } + oldCopy.Columns = diff.new.Columns + + if diff := cmp.Diff(oldCopy, diff.new); diff != "" { + return nil, fmt.Errorf("unsupported diff %s: %w", diff, ErrNotImplemented) + } + + if len(alterPolicyParts) == 0 { + // There is no diff + return nil, nil + } + + sb := strings.Builder{} + sb.WriteString(fmt.Sprintf("ALTER POLICY %s ON %s\n\t", diff.new.EscapedName, psg.table.GetFQEscapedName())) + sb.WriteString(strings.Join(alterPolicyParts, "\n\t")) + + return []Statement{{ + DDL: sb.String(), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardPolicyAltered}, + }}, nil +} + +func (psg *policySQLVertexGenerator) GetSQLVertexId(p schema.Policy) string { + return buildPolicyVertexId(psg.table.SchemaQualifiedName, p.EscapedName) +} + +func buildPolicyVertexId(owningTable schema.SchemaQualifiedName, policyEscapedName string) string { + return buildVertexId("policy", fmt.Sprintf("%s.%s", owningTable.GetFQEscapedName(), policyEscapedName)) +} + +func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolicy schema.Policy) ([]dependency, error) { + deps := []dependency{ + mustRun(psg.GetSQLVertexId(newPolicy), diffTypeDelete).before(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter), + } + + newTargetColumns, err := getTargetColumns(newPolicy.Columns, psg.newSchemaColumnsByName) + if err != nil { + return nil, fmt.Errorf("getting target columns: %w", err) + } + + // Run after the new columns are added/altered + for _, tc := range newTargetColumns { + deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).after(buildColumnVertexId(tc.Name), diffTypeAddAlter)) + } + + if !cmp.Equal(oldPolicy, schema.Policy{}) { + // Run before the old columns are deleted (if they are deleted) + oldTargetColumns, err := getTargetColumns(oldPolicy.Columns, psg.oldSchemaColumnsByName) + if err != nil { + return nil, fmt.Errorf("getting target columns: %w", err) + } + for _, tc := range oldTargetColumns { + // It only needs to run before the delete if the column is actually being deleted + if _, stillExists := psg.newSchemaColumnsByName[tc.GetName()]; !stillExists { + deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).before(buildColumnVertexId(tc.Name), diffTypeDelete)) + } + } + } + + return deps, nil +} + +func (psg *policySQLVertexGenerator) GetDeleteDependencies(pol schema.Policy) ([]dependency, error) { + var deps []dependency + + columns, err := getTargetColumns(pol.Columns, psg.oldSchemaColumnsByName) + if err != nil { + return nil, fmt.Errorf("getting target columns: %w", err) + } + // The policy needs to be deleted before all the columns it references are deleted or add/altered + for _, c := range columns { + deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).before(buildColumnVertexId(c.Name), diffTypeDelete)) + deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).before(buildColumnVertexId(c.Name), diffTypeAddAlter)) + } + + return deps, nil +} diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go index 3c49ba9..eeda9ad 100644 --- a/pkg/diff/sql_generator.go +++ b/pkg/diff/sql_generator.go @@ -109,6 +109,7 @@ type ( oldAndNew[schema.Table] columnsDiff listDiff[schema.Column, columnDiff] checkConstraintDiff listDiff[schema.CheckConstraint, checkConstraintDiff] + policiesDiff listDiff[schema.Policy, policyDiff] } indexDiff struct { @@ -244,9 +245,9 @@ func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { if err != nil { return schemaDiff{}, false, fmt.Errorf("diffing indexes: %w", err) } - foreignKeyConstraintDiffs, err := diffLists(old.ForeignKeyConstraints, new.ForeignKeyConstraints, func(old, new schema.ForeignKeyConstraint, _, _ int) (foreignKeyConstraintDiff, bool, error) { if _, isOnNewTable := addedTablesByName[new.OwningTable.GetName()]; isOnNewTable { + // If the owning table is new, then it must be re-created (this occurs if the base table has been // re-created). In other words, a foreign key constraint must be re-created if the owning table or referenced // table is re-created @@ -400,7 +401,21 @@ func buildTableDiff(oldTable, newTable schema.Table, _, _ int) (diff tableDiff, }, ) if err != nil { - return tableDiff{}, false, fmt.Errorf("diffing lists: %w", err) + return tableDiff{}, false, fmt.Errorf("diffing check cons: %w", err) + } + + var nilableOldTable *schema.Table + if !cmp.Equal(oldTable, schema.Table{}) { + nilableOldTable = &oldTable + } + psg, err := newPolicySQLVertexGenerator(nilableOldTable, newTable) + if err != nil { + return tableDiff{}, false, fmt.Errorf("creating policy sql vertex generator: %w", err) + } + policiesDiff, err := buildPolicyDiffs(psg, oldTable.Policies, newTable.Policies) + if err != nil { + return tableDiff{}, false, fmt.Errorf("diffing policies: %w", err) + } return tableDiff{ @@ -410,6 +425,7 @@ func buildTableDiff(oldTable, newTable schema.Table, _, _ int) (diff tableDiff, }, columnsDiff: columnsDiff, checkConstraintDiff: checkConsDiff, + policiesDiff: policiesDiff, }, false, nil } @@ -677,8 +693,6 @@ type tableSQLVertexGenerator struct { tableDiffsByName map[string]tableDiff } -var _ sqlVertexGenerator[schema.Table, tableDiff] = &tableSQLVertexGenerator{} - func (t *tableSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) { if table.IsPartition() { if table.IsPartitioned() { @@ -687,6 +701,9 @@ func (t *tableSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) { if len(table.CheckConstraints) > 0 { return nil, fmt.Errorf("check constraints on partitions: %w", ErrNotImplemented) } + if len(table.Policies) > 0 { + return nil, fmt.Errorf("policies on partitions: %w", ErrNotImplemented) + } // We attach the partitions separately. So the partition must have all the same check constraints // as the original table table.CheckConstraints = append(table.CheckConstraints, t.tablesInNewSchemaByName[table.ParentTable.GetName()].CheckConstraints...) @@ -722,7 +739,7 @@ func (t *tableSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) { return nil, fmt.Errorf("generating add check constraint statements for check constraint %s: %w", checkCon.Name, err) } // Remove hazards from statements since the table is brand new - stmts = append(stmts, stripMigrationHazards(addConStmts)...) + stmts = append(stmts, stripMigrationHazards(addConStmts...)...) } if table.ReplicaIdentity != schema.ReplicaIdentityDefault { @@ -736,6 +753,26 @@ func (t *tableSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) { stmts = append(stmts, alterReplicaIdentityStmt) } + psg, err := newPolicySQLVertexGenerator(nil, table) + if err != nil { + return nil, fmt.Errorf("creating policy sql vertex generator: %w", err) + } + for _, policy := range table.Policies { + addPolicyStmts, err := psg.Add(policy) + if err != nil { + return nil, fmt.Errorf("generating add policy statements for policy %s: %w", policy.EscapedName, err) + } + // Remove hazards from statements since the table is brand new + stmts = append(stmts, stripMigrationHazards(addPolicyStmts...)...) + } + + if table.RLSEnabled { + stmts = append(stmts, stripMigrationHazards(enableRLSForTable(table))...) + } + if table.RLSForced { + stmts = append(stmts, stripMigrationHazards(forceRLSForTable(table))...) + } + return stmts, nil } @@ -771,18 +808,28 @@ func (t *tableSQLVertexGenerator) Alter(diff tableDiff) ([]Statement, error) { } var stmts []Statement + // Only handle disabling RLS if it was previously enabled. + // We want to disable RLS before we do any other operations on the table, e.g., delete policies, to avoid creating an + // outage while RLS is being disabled + if !diff.new.RLSEnabled && diff.old.RLSEnabled { + stmts = append(stmts, disableRLSForTable(diff.new)) + } + if !diff.new.RLSForced && diff.old.RLSForced { + stmts = append(stmts, unforceRLSForTable(diff.new)) + } + if diff.new.IsPartition() { alterPartitionStmts, err := t.alterPartition(diff) if err != nil { return nil, fmt.Errorf("altering partition: %w", err) } - stmts = alterPartitionStmts + stmts = append(stmts, alterPartitionStmts...) } else { alterBaseTableStmts, err := t.alterBaseTable(diff) if err != nil { return nil, fmt.Errorf("altering base table: %w", err) } - stmts = alterBaseTableStmts + stmts = append(stmts, alterBaseTableStmts...) } if diff.old.ReplicaIdentity != diff.new.ReplicaIdentity { @@ -793,6 +840,15 @@ func (t *tableSQLVertexGenerator) Alter(diff tableDiff) ([]Statement, error) { stmts = append(stmts, alterReplicaIdentityStmt) } + // We want to enable RLS after we do any other operations on the table, i.e., create policies, to avoid creating an + // outtage while RLS is being enabled + if diff.new.RLSEnabled && !diff.old.RLSEnabled { + stmts = append(stmts, enableRLSForTable(diff.new)) + } + if diff.new.RLSForced && !diff.old.RLSForced { + stmts = append(stmts, forceRLSForTable(diff.new)) + } + return stmts, nil } @@ -812,7 +868,7 @@ func (t *tableSQLVertexGenerator) alterBaseTable(diff tableDiff) ([]Statement, e } columnSQLVertexGenerator := columnSQLVertexGenerator{tableName: diff.new.SchemaQualifiedName} - columnGraphs, err := diff.columnsDiff.resolveToSQLGraph(&columnSQLVertexGenerator) + columnGraph, err := diff.columnsDiff.resolveToSQLGraph(&columnSQLVertexGenerator) if err != nil { return nil, fmt.Errorf("resolving index diff: %w", err) } @@ -838,13 +894,33 @@ func (t *tableSQLVertexGenerator) alterBaseTable(diff tableDiff) ([]Statement, e dropTempCCs = append(dropTempCCs, stmt...) } - if err := columnGraphs.union(checkConGraphs); err != nil { + var nilableOldTable *schema.Table + if !cmp.Equal(diff.old, schema.Table{}) { + nilableOldTable = &diff.old + } + psg, err := newPolicySQLVertexGenerator(nilableOldTable, diff.new) + if err != nil { + return nil, fmt.Errorf("creating policy sql vertex generator: %w", err) + } + policyGraph, err := diff.policiesDiff.resolveToSQLGraph(psg) + if err != nil { + return nil, fmt.Errorf("resolving policy diff: %w", err) + } + + if err := columnGraph.union(checkConGraphs); err != nil { return nil, fmt.Errorf("unioning column and check constraint graphs: %w", err) } - stmts, err := columnGraphs.toOrderedStatements() + if err := columnGraph.union(policyGraph); err != nil { + return nil, fmt.Errorf("unioning column and policy graphs: %w", err) + } + + graphStmts, err := columnGraph.toOrderedStatements() if err != nil { return nil, fmt.Errorf("getting ordered statements from columnGraphs: %w", err) } + + var stmts []Statement + stmts = append(stmts, graphStmts...) // Drop the temporary check constraints that were added to make changing columns to "NOT NULL" not require an // extended table lock stmts = append(stmts, dropTempCCs...) @@ -859,6 +935,11 @@ func (t *tableSQLVertexGenerator) alterPartition(diff tableDiff) ([]Statement, e if !diff.checkConstraintDiff.isEmpty() { return nil, fmt.Errorf("check constraints on partitions: %w", ErrNotImplemented) } + if !diff.policiesDiff.isEmpty() { + // Policy diffing on individual partitions cannot be supported until where a SQL statement is generated is + // _independent_ of how it is ordered. + return nil, fmt.Errorf("policies on partitions: %w", ErrNotImplemented) + } var alteredParentColumnsByName map[string]columnDiff if parentDiff, ok := t.tableDiffsByName[diff.new.ParentTable.GetName()]; ok { @@ -1323,7 +1404,7 @@ func (isg *indexSQLVertexGenerator) Add(index schema.Index) ([]Statement, error) } if _, isNewTable := isg.addedTablesByName[index.OwningTable.GetName()]; isNewTable { - stmts = stripMigrationHazards(stmts) + stmts = stripMigrationHazards(stmts...) } return stmts, nil } @@ -1729,7 +1810,7 @@ func (csg *checkConstraintSQLVertexGenerator) GetAddAlterDependencies(con, _ sch mustRun(csg.GetSQLVertexId(con), diffTypeDelete).before(csg.GetSQLVertexId(con), diffTypeAddAlter), } - targetColumns, err := getTargetColumns(con, csg.newSchemaColumnsByName) + targetColumns, err := getTargetColumns(con.KeyColumns, csg.newSchemaColumnsByName) if err != nil { return nil, fmt.Errorf("getting target columns: %w", err) } @@ -1757,7 +1838,7 @@ func (csg *checkConstraintSQLVertexGenerator) GetAddAlterDependencies(con, _ sch func (csg *checkConstraintSQLVertexGenerator) GetDeleteDependencies(con schema.CheckConstraint) ([]dependency, error) { var deps []dependency - targetColumns, err := getTargetColumns(con, csg.oldSchemaColumnsByName) + targetColumns, err := getTargetColumns(con.KeyColumns, csg.oldSchemaColumnsByName) if err != nil { return nil, fmt.Errorf("getting target columns: %w", err) } @@ -1799,9 +1880,9 @@ func (csg *checkConstraintSQLVertexGenerator) GetDeleteDependencies(con schema.C return deps, nil } -func getTargetColumns(con schema.CheckConstraint, columnsByName map[string]schema.Column) ([]schema.Column, error) { +func getTargetColumns(targetColumnNames []string, columnsByName map[string]schema.Column) ([]schema.Column, error) { var targetColumns []schema.Column - for _, name := range con.KeyColumns { + for _, name := range targetColumnNames { targetColumn, ok := columnsByName[name] if !ok { return nil, fmt.Errorf("could not find column with name %s", name) @@ -2428,7 +2509,7 @@ func buildVertexId(objType string, id string) string { return fmt.Sprintf("%s_%s", objType, id) } -func stripMigrationHazards(stmts []Statement) []Statement { +func stripMigrationHazards(stmts ...Statement) []Statement { var noHazardsStmts []Statement for _, stmt := range stmts { stmt.Hazards = nil