Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add check for unique label names in nested labeled scopes
  • Loading branch information
miland-db committed Nov 7, 2024
commit 67bab8440b3c8f1b85bafb95dc3b5dcc9e548451
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,12 @@
],
"sqlState" : "23505"
},
"DUPLICATE_LABEL" : {
"message" : [
"Label <label> already exists."
],
"sqlState" : "42K0L"
},
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT" : {
"message" : [
"Call to routine <routineName> is invalid because it includes multiple argument assignments to the same parameter name <parameterName>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class AstBuilder extends DataTypeAstBuilder
override def visitCompoundOrSingleStatement(
ctx: CompoundOrSingleStatementContext): CompoundBody = withOrigin(ctx) {
Option(ctx.singleCompoundStatement()).map { s =>
LabelUtils.init()
visit(s).asInstanceOf[CompoundBody]
}.getOrElse {
val logicalPlan = visitSingleStatement(ctx.singleStatement())
Expand Down Expand Up @@ -184,37 +185,11 @@ class AstBuilder extends DataTypeAstBuilder
CompoundBody(buff.toSeq, label)
}


private def generateLabelText(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]): String = {

(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.nonEmpty &&
bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}

beginLabelCtx.map(_.multipartIdentifier().getText)
.getOrElse(java.util.UUID.randomUUID.toString).toLowerCase(Locale.ROOT)
}

override def visitBeginEndCompoundBlock(ctx: BeginEndCompoundBlockContext): CompoundBody = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true)
val labelText = LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true)
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()))
body
}

override def visitCompoundBody(ctx: CompoundBodyContext): CompoundBody = {
Expand Down Expand Up @@ -246,7 +221,7 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitWhileStatement(ctx: WhileStatementContext): WhileStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val labelText = LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
Expand All @@ -255,6 +230,7 @@ class AstBuilder extends DataTypeAstBuilder
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()))

WhileStatement(condition, body, Some(labelText))
}
Expand Down Expand Up @@ -306,7 +282,7 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitRepeatStatement(ctx: RepeatStatementContext): RepeatStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val labelText = LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
Expand All @@ -315,6 +291,7 @@ class AstBuilder extends DataTypeAstBuilder
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()))

RepeatStatement(condition, body, Some(labelText))
}
Expand Down Expand Up @@ -378,8 +355,9 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitLoopStatement(ctx: LoopStatementContext): LoopStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val labelText = LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBody(ctx.compoundBody())
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()))

LoopStatement(body, Some(labelText))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@ import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.{ParseTree, TerminalNodeImpl}

import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.SparkParserUtils
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.catalyst.util.SparkParserUtils.withOrigin
import org.apache.spark.sql.errors.{QueryParsingErrors, SqlScriptingErrors}

import scala.collection.mutable

/**
* A collection of utility methods for use during the parsing process.
Expand Down Expand Up @@ -134,3 +139,70 @@ object ParserUtils extends SparkParserUtils {
sb.toString()
}
}

object LabelUtils {
private val seenLabels: mutable.Set[String] = mutable.Set.empty

def init(): Unit = {
seenLabels.clear()
}

private def checkLabels(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]) : Unit = {
(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.nonEmpty &&
bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
seenLabels.clear()
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
seenLabels.clear()
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}
}

private def isLabelDefined(beginLabelCtx: Option[BeginLabelContext]): Boolean = {
beginLabelCtx.map(_.multipartIdentifier().getText).isDefined
}

def enterLabeledScope(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]): String = {

checkLabels(beginLabelCtx, endLabelCtx)

val labelText = if (isLabelDefined(beginLabelCtx)) {
val txt = beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
if (seenLabels.contains(txt)) {
withOrigin(beginLabelCtx.get) {
seenLabels.clear()
throw SqlScriptingErrors.duplicateLabels(CurrentOrigin.get, txt)
}
}
seenLabels.add(beginLabelCtx.get.multipartIdentifier().getText)
txt
} else {
// Do not add the label to the seenLabels set if it is not defined.
java.util.UUID.randomUUID.toString.toLowerCase(Locale.ROOT)
}

labelText
}

def exitLabeledScope(beginLabelCtx: Option[BeginLabelContext]): Unit = {
if (isLabelDefined(beginLabelCtx)) {
seenLabels.remove(beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ import org.apache.spark.sql.exceptions.SqlScriptingException
*/
private[sql] object SqlScriptingErrors {

def duplicateLabels(origin: Origin, label: String): Throwable = {
new SqlScriptingException(
origin = origin,
errorClass = "DUPLICATE_LABEL",
cause = null,
messageParameters = Map("label" -> toSQLId(label)))
}

def labelsMismatch(origin: Origin, beginLabel: String, endLabel: String): Throwable = {
new SqlScriptingException(
origin = origin,
Expand Down
Loading