Skip to content

Commit

Permalink
[SPARK-49913][SQL] Add check for unique label names in nested labeled…
Browse files Browse the repository at this point in the history
… scopes

### What changes were proposed in this pull request?
We are introducing checks for unique label names.
New rules for label names:
- Labels can't have the same name as some of the labels in scope surrounding them
- Labels can have the same name as other labels in the same scope

**Valid** code:
```
BEGIN
  lbl: BEGIN
    SELECT 1;
  END;

  lbl: BEGIN
    SELECT 2;
  END;

  BEGIN
    lbl: WHILE 1=1 DO
      LEAVE lbl;
    END WHILE;
  END;
END
```

**Invalid** code:
```
BEGIN
  lbl: BEGIN
    lbl: BEGIN
      SELECT 1;
    END;
  END;
END
```

#### Design explanation:

Even though there are _Listeners_ with `enterRule` and `exitRule` methods to check labels before and remove them from `seenLabels` after visiting node, we favor this approach because minimal changes were needed and code is more compact to avoid dependency issues.

Additionally, generating label text would need to be done in 2 places and we wanted to avoid duplicated logic:
- `enterRule`
- `visitRule`

### Why are the changes needed?
It will be needed in future when we release Local Scoped Variables for SQL Scripting so users can target variables from outer scopes if they are shadowed.

### How was this patch tested?
New unit tests in 'SqlScriptingParserSuite.scala'.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48795 from miland-db/milan-dankovic_data/unique_labels_scripting.

Authored-by: Milan Dankovic <milan.dankovic@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
miland-db authored and cloud-fan committed Nov 14, 2024
1 parent 891f694 commit 2fd4702
Show file tree
Hide file tree
Showing 5 changed files with 430 additions and 61 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3411,6 +3411,12 @@
],
"sqlState" : "42K0L"
},
"LABEL_ALREADY_EXISTS" : {
"message" : [
"The label <label> already exists. Choose another name or rename the existing label."
],
"sqlState" : "42K0L"
},
"LOAD_DATA_PATH_NOT_EXISTS" : {
"message" : [
"LOAD DATA input path does not exist: <path>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,18 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = {
visit(ctx.beginEndCompoundBlock()).asInstanceOf[CompoundBody]
val labelCtx = new SqlScriptingLabelContext()
visitBeginEndCompoundBlockImpl(ctx.beginEndCompoundBlock(), labelCtx)
}

private def visitCompoundBodyImpl(
ctx: CompoundBodyContext,
label: Option[String],
allowVarDeclare: Boolean): CompoundBody = {
allowVarDeclare: Boolean,
labelCtx: SqlScriptingLabelContext): CompoundBody = {
val buff = ListBuffer[CompoundPlanStatement]()
ctx.compoundStatements.forEach(compoundStatement => {
buff += visit(compoundStatement).asInstanceOf[CompoundPlanStatement]
})
ctx.compoundStatements.forEach(
compoundStatement => buff += visitCompoundStatementImpl(compoundStatement, labelCtx))

val compoundStatements = buff.toList

Expand Down Expand Up @@ -184,90 +185,104 @@ class AstBuilder extends DataTypeAstBuilder
CompoundBody(buff.toSeq, label)
}


private def generateLabelText(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]): String = {

(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.nonEmpty &&
bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}

beginLabelCtx.map(_.multipartIdentifier().getText)
.getOrElse(java.util.UUID.randomUUID.toString).toLowerCase(Locale.ROOT)
}

override def visitBeginEndCompoundBlock(ctx: BeginEndCompoundBlockContext): CompoundBody = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true)
}

override def visitCompoundBody(ctx: CompoundBodyContext): CompoundBody = {
visitCompoundBodyImpl(ctx, None, allowVarDeclare = false)
private def visitBeginEndCompoundBlockImpl(
ctx: BeginEndCompoundBlockContext,
labelCtx: SqlScriptingLabelContext): CompoundBody = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
Some(labelText),
allowVarDeclare = true,
labelCtx
)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
body
}

override def visitCompoundStatement(ctx: CompoundStatementContext): CompoundPlanStatement =
private def visitCompoundStatementImpl(
ctx: CompoundStatementContext,
labelCtx: SqlScriptingLabelContext): CompoundPlanStatement =
withOrigin(ctx) {
Option(ctx.statement().asInstanceOf[ParserRuleContext])
.orElse(Option(ctx.setStatementWithOptionalVarKeyword().asInstanceOf[ParserRuleContext]))
.map { s =>
SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan])
}.getOrElse {
visitChildren(ctx).asInstanceOf[CompoundPlanStatement]
if (ctx.getChildCount == 1) {
ctx.getChild(0) match {
case compoundBodyContext: BeginEndCompoundBlockContext =>
visitBeginEndCompoundBlockImpl(compoundBodyContext, labelCtx)
case whileStmtContext: WhileStatementContext =>
visitWhileStatementImpl(whileStmtContext, labelCtx)
case repeatStmtContext: RepeatStatementContext =>
visitRepeatStatementImpl(repeatStmtContext, labelCtx)
case loopStatementContext: LoopStatementContext =>
visitLoopStatementImpl(loopStatementContext, labelCtx)
case ifElseStmtContext: IfElseStatementContext =>
visitIfElseStatementImpl(ifElseStmtContext, labelCtx)
case searchedCaseContext: SearchedCaseStatementContext =>
visitSearchedCaseStatementImpl(searchedCaseContext, labelCtx)
case simpleCaseContext: SimpleCaseStatementContext =>
visitSimpleCaseStatementImpl(simpleCaseContext, labelCtx)
case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement]
}
} else {
null
}
}
}

override def visitIfElseStatement(ctx: IfElseStatementContext): IfElseStatement = {
private def visitIfElseStatementImpl(
ctx: IfElseStatementContext,
labelCtx: SqlScriptingLabelContext): IfElseStatement = {
IfElseStatement(
conditions = ctx.booleanExpression().asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))
}),
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body)),
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body))
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
),
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
)
)
}

override def visitWhileStatement(ctx: WhileStatementContext): WhileStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
private def visitWhileStatementImpl(
ctx: WhileStatementContext,
labelCtx: SqlScriptingLabelContext): WhileStatement = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

WhileStatement(condition, body, Some(labelText))
}

override def visitSearchedCaseStatement(ctx: SearchedCaseStatementContext): CaseStatement = {
private def visitSearchedCaseStatementImpl(
ctx: SearchedCaseStatementContext,
labelCtx: SqlScriptingLabelContext): CaseStatement = {
val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
)

if (conditions.length != conditionalBodies.length) {
throw SparkException.internalError(
Expand All @@ -278,10 +293,14 @@ class AstBuilder extends DataTypeAstBuilder
CaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
))
}

override def visitSimpleCaseStatement(ctx: SimpleCaseStatementContext): CaseStatement = {
private def visitSimpleCaseStatementImpl(
ctx: SimpleCaseStatementContext,
labelCtx: SqlScriptingLabelContext): CaseStatement = {
// uses EqualTo to compare the case variable(the main case expression)
// to the WHEN clause expressions
val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) {
Expand All @@ -291,7 +310,9 @@ class AstBuilder extends DataTypeAstBuilder
OneRowRelation()))
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
)

if (conditions.length != conditionalBodies.length) {
throw SparkException.internalError(
Expand All @@ -302,19 +323,25 @@ class AstBuilder extends DataTypeAstBuilder
CaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
))
}

override def visitRepeatStatement(ctx: RepeatStatementContext): RepeatStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
private def visitRepeatStatementImpl(
ctx: RepeatStatementContext,
labelCtx: SqlScriptingLabelContext): RepeatStatement = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

RepeatStatement(condition, body, Some(labelText))
}
Expand Down Expand Up @@ -377,9 +404,13 @@ class AstBuilder extends DataTypeAstBuilder
CurrentOrigin.get, labelText, "ITERATE")
}

override def visitLoopStatement(ctx: LoopStatementContext): LoopStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBody(ctx.compoundBody())
private def visitLoopStatementImpl(
ctx: LoopStatementContext,
labelCtx: SqlScriptingLabelContext): LoopStatement = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

LoopStatement(body, Some(labelText))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ package org.apache.spark.sql.catalyst.parser
import java.util
import java.util.Locale

import scala.collection.mutable.Set

import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.{ParseTree, TerminalNodeImpl}

import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.SparkParserUtils
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.catalyst.util.SparkParserUtils.withOrigin
import org.apache.spark.sql.errors.{QueryParsingErrors, SqlScriptingErrors}

/**
* A collection of utility methods for use during the parsing process.
Expand Down Expand Up @@ -134,3 +139,80 @@ object ParserUtils extends SparkParserUtils {
sb.toString()
}
}

class SqlScriptingLabelContext {
/** Set to keep track of labels seen so far */
private val seenLabels = Set[String]()

/**
* Check if the beginLabelCtx and endLabelCtx match.
* If the labels are defined, they must follow rules:
* - If both labels exist, they must match.
* - Begin label must exist if end label exists.
*/
private def checkLabels(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]) : Unit = {
(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}
}

/** Check if the label is defined. */
private def isLabelDefined(beginLabelCtx: Option[BeginLabelContext]): Boolean = {
beginLabelCtx.map(_.multipartIdentifier().getText).isDefined
}

/**
* Enter a labeled scope and return the label text.
* If the label is defined, it will be returned and added to seenLabels.
* If the label is not defined, a random UUID will be returned.
*/
def enterLabeledScope(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]): String = {

// Check if this label already exists in parent scopes.
checkLabels(beginLabelCtx, endLabelCtx)

// Get label text and add it to seenLabels.
val labelText = if (isLabelDefined(beginLabelCtx)) {
val txt = beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
if (seenLabels.contains(txt)) {
withOrigin(beginLabelCtx.get) {
throw SqlScriptingErrors.duplicateLabels(CurrentOrigin.get, txt)
}
}
seenLabels.add(beginLabelCtx.get.multipartIdentifier().getText)
txt
} else {
// Do not add the label to the seenLabels set if it is not defined.
java.util.UUID.randomUUID.toString.toLowerCase(Locale.ROOT)
}
labelText
}

/**
* Exit a labeled scope.
* If the label is defined, it will be removed from seenLabels.
*/
def exitLabeledScope(beginLabelCtx: Option[BeginLabelContext]): Unit = {
if (isLabelDefined(beginLabelCtx)) {
seenLabels.remove(beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ import org.apache.spark.sql.exceptions.SqlScriptingException
*/
private[sql] object SqlScriptingErrors {

def duplicateLabels(origin: Origin, label: String): Throwable = {
new SqlScriptingException(
origin = origin,
errorClass = "LABEL_ALREADY_EXISTS",
cause = null,
messageParameters = Map("label" -> toSQLId(label)))
}

def labelsMismatch(origin: Origin, beginLabel: String, endLabel: String): Throwable = {
new SqlScriptingException(
origin = origin,
Expand Down
Loading

0 comments on commit 2fd4702

Please sign in to comment.