Skip to content

Commit

Permalink
crosscluster/logical: properly handle incorrect targets in CREATE stmt
Browse files Browse the repository at this point in the history
Previously, if the schema or database were missing from a destination target
table in CREATE LOGICALLY REPLICATED TABLE, the planHook would panic.

This patch also beefs up testing around user specified replication targets.

Fixes #137745

Release note: none
  • Loading branch information
msbutler committed Dec 25, 2024
1 parent 6eb6d49 commit 6fd60cc
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 11 deletions.
3 changes: 3 additions & 0 deletions pkg/crosscluster/logical/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ go_library(
go_test(
name = "logical_test",
srcs = [
"create_logical_replication_stmt_test.go",
"dead_letter_queue_test.go",
"logical_replication_job_test.go",
"lww_kv_processor_test.go",
Expand Down Expand Up @@ -138,6 +139,7 @@ go_test(
"//pkg/sql/catalog/descpb",
"//pkg/sql/catalog/descs",
"//pkg/sql/catalog/desctestutils",
"//pkg/sql/catalog/resolver",
"//pkg/sql/execinfra",
"//pkg/sql/execinfrapb",
"//pkg/sql/isql",
Expand Down Expand Up @@ -166,6 +168,7 @@ go_test(
"//pkg/util/uuid",
"@com_github_cockroachdb_cockroach_go_v2//crdb",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_redact//:redact",
"@com_github_lib_pq//:pq",
"@com_github_stretchr_testify//require",
],
Expand Down
26 changes: 15 additions & 11 deletions pkg/crosscluster/logical/create_logical_replication_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/sem/asof"
"github.com/cockroachdb/cockroach/pkg/sql/sem/catid"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sessiondata"
"github.com/cockroachdb/cockroach/pkg/sql/syntheticprivilege"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/buildutil"
Expand Down Expand Up @@ -135,7 +136,7 @@ func createLogicalReplicationStreamPlanHook(
return pgerror.Newf(pgcode.InvalidParameterValue, "unknown discard option %q", m)
}
}
resolvedDestObjects, err := resolveDestinationObjects(ctx, p, stmt.Into, stmt.CreateTable)
resolvedDestObjects, err := resolveDestinationObjects(ctx, p, p.SessionData(), stmt.Into, stmt.CreateTable)
if err != nil {
return err
}
Expand Down Expand Up @@ -305,7 +306,8 @@ func (r *ResolvedDestObjects) TargetDescription() string {

func resolveDestinationObjects(
ctx context.Context,
p sql.PlanHookState,
r resolver.SchemaResolver,
sessionData *sessiondata.SessionData,
destResources tree.LogicalReplicationResources,
createTable bool,
) (ResolvedDestObjects, error) {
Expand All @@ -315,15 +317,16 @@ func resolveDestinationObjects(
if err != nil {
return resolved, err
}
dstObjName.HasExplicitSchema()

dstTableName := dstObjName.ToTableName()
if createTable {
_, _, resPrefix, err := resolver.ResolveTarget(ctx,
&dstObjName, p, p.SessionData().Database, p.SessionData().SearchPath)
found, _, resPrefix, err := resolver.ResolveTarget(ctx,
&dstObjName, r, sessionData.Database, sessionData.SearchPath)
if err != nil {
return resolved, errors.Newf("resolving target import name")
}

if !found {
return resolved, errors.Newf("database or schema not found for destination table %s", destResources.Tables[i])
}
if resolved.ParentDatabaseID == 0 {
resolved.ParentDatabaseID = resPrefix.Database.GetID()
resolved.ParentSchemaID = resPrefix.Schema.GetID()
Expand All @@ -332,18 +335,19 @@ func resolveDestinationObjects(
} else if resolved.ParentSchemaID != resPrefix.Schema.GetID() {
return resolved, errors.Newf("destination tables must all be in the same schema")
}

if _, _, err := resolver.ResolveMutableExistingTableObject(ctx, r, &dstTableName, true, tree.ResolveRequireTableDesc); err == nil {
return resolved, errors.Newf("destination table %s already exists", destResources.Tables[i])
}
tbNameWithSchema := tree.MakeTableNameWithSchema(
tree.Name(resPrefix.Database.GetName()),
tree.Name(resPrefix.Schema.GetName()),
tree.Name(dstObjName.Object()),
)
resolved.TableNames = append(resolved.TableNames, tbNameWithSchema)
} else {
dstTableName := dstObjName.ToTableName()
prefix, td, err := resolver.ResolveMutableExistingTableObject(ctx, p, &dstTableName, true, tree.ResolveRequireTableDesc)
prefix, td, err := resolver.ResolveMutableExistingTableObject(ctx, r, &dstTableName, true, tree.ResolveRequireTableDesc)
if err != nil {
return resolved, err
return resolved, errors.Wrapf(err, "failed to find existing destination table %s", destResources.Tables[i])
}

tbNameWithSchema := tree.MakeTableNameWithSchema(
Expand Down
198 changes: 198 additions & 0 deletions pkg/crosscluster/logical/create_logical_replication_stmt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

package logical

import (
"context"
"strings"
"testing"

"github.com/cockroachdb/cockroach/pkg/base"
"github.com/cockroachdb/cockroach/pkg/security/username"
"github.com/cockroachdb/cockroach/pkg/sql"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/descs"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/resolver"
"github.com/cockroachdb/cockroach/pkg/sql/isql"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
"github.com/cockroachdb/cockroach/pkg/testutils/skip"
"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/redact"
"github.com/stretchr/testify/require"
)

func TestResolveDestinationObjects(t *testing.T) {
defer leaktest.AfterTest(t)()
skip.UnderDeadlock(t)
defer log.Scope(t).Close(t)

ctx := context.Background()

srv, conn, _ := serverutils.StartServer(t, base.TestServerArgs{})
defer srv.Stopper().Stop(ctx)
s := srv.ApplicationLayer()

execCfg := s.ExecutorConfig().(sql.ExecutorConfig)

sqlDB := sqlutils.MakeSQLRunner(conn)
sqlDB.Exec(t, "CREATE DATABASE db1")
sqlDB.Exec(t, "CREATE SCHEMA db1.sc1")
sqlDB.Exec(t, "CREATE SCHEMA db1.sc2")
sqlDB.Exec(t, "CREATE TABLE db1.sc1.t1 (a INT PRIMARY KEY)")
sqlDB.Exec(t, "CREATE TABLE db1.sc1.t2 (a INT PRIMARY KEY)")
sqlDB.Exec(t, "CREATE TABLE t3 (a INT PRIMARY KEY)")

sqlUser, err := username.MakeSQLUsernameFromUserInput("root", username.PurposeValidation)
require.NoError(t, err)

resolveObjects := func(destResources tree.LogicalReplicationResources, createTables bool) (resolved ResolvedDestObjects, err error) {
err = sql.TestingDescsTxn(ctx, s, func(ctx context.Context, txn isql.Txn, _ *descs.Collection) error {
opName := redact.SafeString("resolve")
sessionData := sql.NewInternalSessionData(ctx, execCfg.Settings, opName)
sessionData.Database = "defaultdb"
planner, close := sql.NewInternalPlanner(
opName,
txn.KV(),
sqlUser,
&sql.MemoryMetrics{},
&execCfg,
sessionData,
)
defer close()
resolved, err = resolveDestinationObjects(ctx, planner.(resolver.SchemaResolver), sessionData, destResources, createTables)
return err
})
return resolved, err
}

res := func(db string, tables ...string) tree.LogicalReplicationResources {
resources := tree.LogicalReplicationResources{
Database: tree.Name(db),
}
for _, table := range tables {
t := strings.Split(table, ".")
resources.Tables = append(resources.Tables, tree.NewUnresolvedName(t...))
}
return resources
}

type testCase struct {
name string
resources tree.LogicalReplicationResources
create bool
expectedDesc string
expectedErr string
}

for _, tc := range []testCase{
{
name: "single",
resources: res("", "db1.sc1.t1"),
expectedDesc: "db1.sc1.t1",
},
{
name: "single/create",
resources: res("", "db1.sc1.t1_c"),
create: true,
expectedDesc: "db1.sc1.t1_c",
},
{
name: "implicit_schema",
resources: res("", "defaultdb.t3"),
expectedDesc: "defaultdb.public.t3",
},
{
name: "implicit_schema/create",
resources: res("", "defaultdb.t3_c"),
create: true,
expectedDesc: "defaultdb.public.t3_c",
},
{
name: "implicit_db",
resources: res("", "t3"),
expectedDesc: "defaultdb.public.t3",
},
{
name: "implicit_db/create",
resources: res("", "t3_c"),
create: true,
expectedDesc: "defaultdb.public.t3_c",
},
{
name: "multi",
resources: res("", "db1.sc1.t1", "db1.sc1.t2"),
expectedDesc: "db1.sc1.t1, db1.sc1.t2",
},
{
name: "multi/create",
resources: res("", "db1.sc1.t1_c", "db1.sc1.t2_c"),
create: true,
expectedDesc: "db1.sc1.t1_c, db1.sc1.t2_c",
},
{
name: "missing_schema",
resources: res("", "db1.s2.t1"),
expectedErr: "failed to find existing destination table db1.s2.t1",
},
{
name: "missing_schema/create",
resources: res("", "db1.s2.t1"),
create: true,
expectedErr: "database or schema not found for destination table db1.s2.t1",
},
{
name: "missing_db/create",
resources: res("", "db2.t1"),
create: true,
expectedErr: "database or schema not found for destination table db2.t1",
},
{
name: "multiple_db_target",
resources: res("", "t3", "db1.sc1.t1"),
expectedDesc: "defaultdb.public.t3, db1.sc1.t1",
},
{
name: "multiple_db_target/create",
resources: res("", "t3_c", "db1.sc1.t1_c"),
create: true,
expectedErr: "destination tables must all be in the same database",
},
{
name: "multiple_schema_target/create",
resources: res("", "db1.sc2.t2_c", "db1.sc1.t1_c"),
create: true,
expectedErr: "destination tables must all be in the same schema",
},
{
name: "existing_table/create",
resources: res("", "db1.sc1.t1"),
create: true,
expectedErr: "destination table db1.sc1.t1 already exists",
},
} {
t.Run(tc.name, func(t *testing.T) {
resolved, err := resolveObjects(tc.resources, tc.create)
if tc.expectedErr != "" {
require.ErrorContains(t, err, tc.expectedErr)
return
}
require.NoError(t, err)
require.Equal(t, tc.expectedDesc, resolved.TargetDescription())
if tc.create {
require.NotZero(t, resolved.ParentDatabaseID)
require.NotZero(t, resolved.ParentSchemaID)
} else {
require.Zero(t, resolved.ParentDatabaseID)
require.Zero(t, resolved.ParentSchemaID)

}
})

}

}

0 comments on commit 6fd60cc

Please sign in to comment.