From 6fd60ccfcba0d5246a6048941fb0b8eb3fdcf113 Mon Sep 17 00:00:00 2001 From: Michael Butler Date: Tue, 24 Dec 2024 08:37:11 -0500 Subject: [PATCH] crosscluster/logical: properly handle incorrect targets in CREATE stmt Previously, if the schema or database were missing from a destination target table in CREATE LOGICALLY REPLICATED TABLE, the planHook would panic. This patch also beefs up testing around user specified replication targets. Fixes #137745 Release note: none --- pkg/crosscluster/logical/BUILD.bazel | 3 + .../create_logical_replication_stmt.go | 26 ++- .../create_logical_replication_stmt_test.go | 198 ++++++++++++++++++ 3 files changed, 216 insertions(+), 11 deletions(-) create mode 100644 pkg/crosscluster/logical/create_logical_replication_stmt_test.go diff --git a/pkg/crosscluster/logical/BUILD.bazel b/pkg/crosscluster/logical/BUILD.bazel index 732832a9946b..e37c064d3b2c 100644 --- a/pkg/crosscluster/logical/BUILD.bazel +++ b/pkg/crosscluster/logical/BUILD.bazel @@ -100,6 +100,7 @@ go_library( go_test( name = "logical_test", srcs = [ + "create_logical_replication_stmt_test.go", "dead_letter_queue_test.go", "logical_replication_job_test.go", "lww_kv_processor_test.go", @@ -138,6 +139,7 @@ go_test( "//pkg/sql/catalog/descpb", "//pkg/sql/catalog/descs", "//pkg/sql/catalog/desctestutils", + "//pkg/sql/catalog/resolver", "//pkg/sql/execinfra", "//pkg/sql/execinfrapb", "//pkg/sql/isql", @@ -166,6 +168,7 @@ go_test( "//pkg/util/uuid", "@com_github_cockroachdb_cockroach_go_v2//crdb", "@com_github_cockroachdb_errors//:errors", + "@com_github_cockroachdb_redact//:redact", "@com_github_lib_pq//:pq", "@com_github_stretchr_testify//require", ], diff --git a/pkg/crosscluster/logical/create_logical_replication_stmt.go b/pkg/crosscluster/logical/create_logical_replication_stmt.go index 1236ae335843..a3846999a0aa 100644 --- a/pkg/crosscluster/logical/create_logical_replication_stmt.go +++ b/pkg/crosscluster/logical/create_logical_replication_stmt.go @@ -38,6 +38,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/asof" "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/sql/syntheticprivilege" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/buildutil" @@ -135,7 +136,7 @@ func createLogicalReplicationStreamPlanHook( return pgerror.Newf(pgcode.InvalidParameterValue, "unknown discard option %q", m) } } - resolvedDestObjects, err := resolveDestinationObjects(ctx, p, stmt.Into, stmt.CreateTable) + resolvedDestObjects, err := resolveDestinationObjects(ctx, p, p.SessionData(), stmt.Into, stmt.CreateTable) if err != nil { return err } @@ -305,7 +306,8 @@ func (r *ResolvedDestObjects) TargetDescription() string { func resolveDestinationObjects( ctx context.Context, - p sql.PlanHookState, + r resolver.SchemaResolver, + sessionData *sessiondata.SessionData, destResources tree.LogicalReplicationResources, createTable bool, ) (ResolvedDestObjects, error) { @@ -315,15 +317,16 @@ func resolveDestinationObjects( if err != nil { return resolved, err } - dstObjName.HasExplicitSchema() - + dstTableName := dstObjName.ToTableName() if createTable { - _, _, resPrefix, err := resolver.ResolveTarget(ctx, - &dstObjName, p, p.SessionData().Database, p.SessionData().SearchPath) + found, _, resPrefix, err := resolver.ResolveTarget(ctx, + &dstObjName, r, sessionData.Database, sessionData.SearchPath) if err != nil { return resolved, errors.Newf("resolving target import name") } - + if !found { + return resolved, errors.Newf("database or schema not found for destination table %s", destResources.Tables[i]) + } if resolved.ParentDatabaseID == 0 { resolved.ParentDatabaseID = resPrefix.Database.GetID() resolved.ParentSchemaID = resPrefix.Schema.GetID() @@ -332,7 +335,9 @@ func resolveDestinationObjects( } else if resolved.ParentSchemaID != resPrefix.Schema.GetID() { return resolved, errors.Newf("destination tables must all be in the same schema") } - + if _, _, err := resolver.ResolveMutableExistingTableObject(ctx, r, &dstTableName, true, tree.ResolveRequireTableDesc); err == nil { + return resolved, errors.Newf("destination table %s already exists", destResources.Tables[i]) + } tbNameWithSchema := tree.MakeTableNameWithSchema( tree.Name(resPrefix.Database.GetName()), tree.Name(resPrefix.Schema.GetName()), @@ -340,10 +345,9 @@ func resolveDestinationObjects( ) resolved.TableNames = append(resolved.TableNames, tbNameWithSchema) } else { - dstTableName := dstObjName.ToTableName() - prefix, td, err := resolver.ResolveMutableExistingTableObject(ctx, p, &dstTableName, true, tree.ResolveRequireTableDesc) + prefix, td, err := resolver.ResolveMutableExistingTableObject(ctx, r, &dstTableName, true, tree.ResolveRequireTableDesc) if err != nil { - return resolved, err + return resolved, errors.Wrapf(err, "failed to find existing destination table %s", destResources.Tables[i]) } tbNameWithSchema := tree.MakeTableNameWithSchema( diff --git a/pkg/crosscluster/logical/create_logical_replication_stmt_test.go b/pkg/crosscluster/logical/create_logical_replication_stmt_test.go new file mode 100644 index 000000000000..9fc3f8e58080 --- /dev/null +++ b/pkg/crosscluster/logical/create_logical_replication_stmt_test.go @@ -0,0 +1,198 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package logical + +import ( + "context" + "strings" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descs" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/resolver" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/redact" + "github.com/stretchr/testify/require" +) + +func TestResolveDestinationObjects(t *testing.T) { + defer leaktest.AfterTest(t)() + skip.UnderDeadlock(t) + defer log.Scope(t).Close(t) + + ctx := context.Background() + + srv, conn, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + s := srv.ApplicationLayer() + + execCfg := s.ExecutorConfig().(sql.ExecutorConfig) + + sqlDB := sqlutils.MakeSQLRunner(conn) + sqlDB.Exec(t, "CREATE DATABASE db1") + sqlDB.Exec(t, "CREATE SCHEMA db1.sc1") + sqlDB.Exec(t, "CREATE SCHEMA db1.sc2") + sqlDB.Exec(t, "CREATE TABLE db1.sc1.t1 (a INT PRIMARY KEY)") + sqlDB.Exec(t, "CREATE TABLE db1.sc1.t2 (a INT PRIMARY KEY)") + sqlDB.Exec(t, "CREATE TABLE t3 (a INT PRIMARY KEY)") + + sqlUser, err := username.MakeSQLUsernameFromUserInput("root", username.PurposeValidation) + require.NoError(t, err) + + resolveObjects := func(destResources tree.LogicalReplicationResources, createTables bool) (resolved ResolvedDestObjects, err error) { + err = sql.TestingDescsTxn(ctx, s, func(ctx context.Context, txn isql.Txn, _ *descs.Collection) error { + opName := redact.SafeString("resolve") + sessionData := sql.NewInternalSessionData(ctx, execCfg.Settings, opName) + sessionData.Database = "defaultdb" + planner, close := sql.NewInternalPlanner( + opName, + txn.KV(), + sqlUser, + &sql.MemoryMetrics{}, + &execCfg, + sessionData, + ) + defer close() + resolved, err = resolveDestinationObjects(ctx, planner.(resolver.SchemaResolver), sessionData, destResources, createTables) + return err + }) + return resolved, err + } + + res := func(db string, tables ...string) tree.LogicalReplicationResources { + resources := tree.LogicalReplicationResources{ + Database: tree.Name(db), + } + for _, table := range tables { + t := strings.Split(table, ".") + resources.Tables = append(resources.Tables, tree.NewUnresolvedName(t...)) + } + return resources + } + + type testCase struct { + name string + resources tree.LogicalReplicationResources + create bool + expectedDesc string + expectedErr string + } + + for _, tc := range []testCase{ + { + name: "single", + resources: res("", "db1.sc1.t1"), + expectedDesc: "db1.sc1.t1", + }, + { + name: "single/create", + resources: res("", "db1.sc1.t1_c"), + create: true, + expectedDesc: "db1.sc1.t1_c", + }, + { + name: "implicit_schema", + resources: res("", "defaultdb.t3"), + expectedDesc: "defaultdb.public.t3", + }, + { + name: "implicit_schema/create", + resources: res("", "defaultdb.t3_c"), + create: true, + expectedDesc: "defaultdb.public.t3_c", + }, + { + name: "implicit_db", + resources: res("", "t3"), + expectedDesc: "defaultdb.public.t3", + }, + { + name: "implicit_db/create", + resources: res("", "t3_c"), + create: true, + expectedDesc: "defaultdb.public.t3_c", + }, + { + name: "multi", + resources: res("", "db1.sc1.t1", "db1.sc1.t2"), + expectedDesc: "db1.sc1.t1, db1.sc1.t2", + }, + { + name: "multi/create", + resources: res("", "db1.sc1.t1_c", "db1.sc1.t2_c"), + create: true, + expectedDesc: "db1.sc1.t1_c, db1.sc1.t2_c", + }, + { + name: "missing_schema", + resources: res("", "db1.s2.t1"), + expectedErr: "failed to find existing destination table db1.s2.t1", + }, + { + name: "missing_schema/create", + resources: res("", "db1.s2.t1"), + create: true, + expectedErr: "database or schema not found for destination table db1.s2.t1", + }, + { + name: "missing_db/create", + resources: res("", "db2.t1"), + create: true, + expectedErr: "database or schema not found for destination table db2.t1", + }, + { + name: "multiple_db_target", + resources: res("", "t3", "db1.sc1.t1"), + expectedDesc: "defaultdb.public.t3, db1.sc1.t1", + }, + { + name: "multiple_db_target/create", + resources: res("", "t3_c", "db1.sc1.t1_c"), + create: true, + expectedErr: "destination tables must all be in the same database", + }, + { + name: "multiple_schema_target/create", + resources: res("", "db1.sc2.t2_c", "db1.sc1.t1_c"), + create: true, + expectedErr: "destination tables must all be in the same schema", + }, + { + name: "existing_table/create", + resources: res("", "db1.sc1.t1"), + create: true, + expectedErr: "destination table db1.sc1.t1 already exists", + }, + } { + t.Run(tc.name, func(t *testing.T) { + resolved, err := resolveObjects(tc.resources, tc.create) + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) + return + } + require.NoError(t, err) + require.Equal(t, tc.expectedDesc, resolved.TargetDescription()) + if tc.create { + require.NotZero(t, resolved.ParentDatabaseID) + require.NotZero(t, resolved.ParentSchemaID) + } else { + require.Zero(t, resolved.ParentDatabaseID) + require.Zero(t, resolved.ParentSchemaID) + + } + }) + + } + +}