Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,12 @@
],
"sqlState" : "42K0L"
},
"LABEL_ALREADY_EXISTS" : {
"message" : [
"The label <label> already exists. Choose another name or rename the existing label."
],
"sqlState" : "42K0L"
},
"LOAD_DATA_PATH_NOT_EXISTS" : {
"message" : [
"LOAD DATA input path does not exist: <path>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class AstBuilder extends DataTypeAstBuilder
override def visitCompoundOrSingleStatement(
ctx: CompoundOrSingleStatementContext): CompoundBody = withOrigin(ctx) {
Option(ctx.singleCompoundStatement()).map { s =>
LabelUtils.init()
visit(s).asInstanceOf[CompoundBody]
}.getOrElse {
val logicalPlan = visitSingleStatement(ctx.singleStatement())
Expand Down Expand Up @@ -184,37 +185,11 @@ class AstBuilder extends DataTypeAstBuilder
CompoundBody(buff.toSeq, label)
}


private def generateLabelText(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]): String = {

(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.nonEmpty &&
bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}

beginLabelCtx.map(_.multipartIdentifier().getText)
.getOrElse(java.util.UUID.randomUUID.toString).toLowerCase(Locale.ROOT)
}

override def visitBeginEndCompoundBlock(ctx: BeginEndCompoundBlockContext): CompoundBody = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true)
val labelText = LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true)
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()))
body
}

override def visitCompoundBody(ctx: CompoundBodyContext): CompoundBody = {
Expand Down Expand Up @@ -246,7 +221,7 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitWhileStatement(ctx: WhileStatementContext): WhileStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val labelText = LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
Expand All @@ -255,6 +230,7 @@ class AstBuilder extends DataTypeAstBuilder
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()))

WhileStatement(condition, body, Some(labelText))
}
Expand Down Expand Up @@ -306,7 +282,7 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitRepeatStatement(ctx: RepeatStatementContext): RepeatStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val labelText = LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
Expand All @@ -315,6 +291,7 @@ class AstBuilder extends DataTypeAstBuilder
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()))

RepeatStatement(condition, body, Some(labelText))
}
Expand Down Expand Up @@ -378,8 +355,9 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitLoopStatement(ctx: LoopStatementContext): LoopStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val labelText = LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBody(ctx.compoundBody())
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()))

LoopStatement(body, Some(labelText))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ package org.apache.spark.sql.catalyst.parser
import java.util
import java.util.Locale

import scala.collection.mutable

import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.{ParseTree, TerminalNodeImpl}

import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.SparkParserUtils
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.catalyst.util.SparkParserUtils.withOrigin
import org.apache.spark.sql.errors.{QueryParsingErrors, SqlScriptingErrors}

/**
* A collection of utility methods for use during the parsing process.
Expand Down Expand Up @@ -134,3 +139,87 @@ object ParserUtils extends SparkParserUtils {
sb.toString()
}
}

object LabelUtils {
/** A set to keep track of seen labels. */
private val seenLabels: mutable.Set[String] = mutable.Set.empty

/** Clear the seenLabels set to start fresh analysis. */
def init(): Unit = {
seenLabels.clear()
}

/**
* Check if the beginLabelCtx and endLabelCtx match.
* If the labels are defined, they must follow rules:
* - If both labels exist, they must match.
* - Begin label must exist if end label exists.
*/
private def checkLabels(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]) : Unit = {
(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
seenLabels.clear()
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
seenLabels.clear()
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}
}

/** Check if the label is defined. */
private def isLabelDefined(beginLabelCtx: Option[BeginLabelContext]): Boolean = {
beginLabelCtx.map(_.multipartIdentifier().getText).isDefined
}

/**
* Enter a labeled scope and return the label text.
* If the label is defined, it will be returned and added to seenLabels.
* If the label is not defined, a random UUID will be returned.
*/
def enterLabeledScope(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]): String = {

checkLabels(beginLabelCtx, endLabelCtx)

val labelText = if (isLabelDefined(beginLabelCtx)) {
val txt = beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
if (seenLabels.contains(txt)) {
withOrigin(beginLabelCtx.get) {
seenLabels.clear()
throw SqlScriptingErrors.duplicateLabels(CurrentOrigin.get, txt)
}
}
seenLabels.add(beginLabelCtx.get.multipartIdentifier().getText)
txt
} else {
// Do not add the label to the seenLabels set if it is not defined.
java.util.UUID.randomUUID.toString.toLowerCase(Locale.ROOT)
}

labelText
}

/**
* Exit a labeled scope.
* If the label is defined, it will be removed from seenLabels.
*/
def exitLabeledScope(beginLabelCtx: Option[BeginLabelContext]): Unit = {
if (isLabelDefined(beginLabelCtx)) {
seenLabels.remove(beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ import org.apache.spark.sql.exceptions.SqlScriptingException
*/
private[sql] object SqlScriptingErrors {

def duplicateLabels(origin: Origin, label: String): Throwable = {
new SqlScriptingException(
origin = origin,
errorClass = "LABEL_ALREADY_EXISTS",
cause = null,
messageParameters = Map("label" -> toSQLId(label)))
}

def labelsMismatch(origin: Origin, beginLabel: String, endLabel: String): Throwable = {
new SqlScriptingException(
origin = origin,
Expand Down
Loading