Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,12 @@
],
"sqlState" : "42K0L"
},
"LABEL_ALREADY_EXISTS" : {
"message" : [
"The label <label> already exists. Choose another name or rename the existing label."
],
"sqlState" : "42K0L"
},
"LOAD_DATA_PATH_NOT_EXISTS" : {
"message" : [
"LOAD DATA input path does not exist: <path>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,18 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = {
visit(ctx.beginEndCompoundBlock()).asInstanceOf[CompoundBody]
val seenLabels = Set[String]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is the implementation I was expecting. Can we do a bit more encapsulation? e.g.

val labelCtx = new ScritpingLabelContext()
visitBeginEndCompoundBlockImpl...
...
class ScritpingLabelContext {
  private val seenLabels = Set[String]()
  
  def newLabel(label: String)...
  ...
}
``

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan done.

visitBeginEndCompoundBlockImpl(ctx.beginEndCompoundBlock(), seenLabels)
}

private def visitCompoundBodyImpl(
ctx: CompoundBodyContext,
label: Option[String],
allowVarDeclare: Boolean): CompoundBody = {
allowVarDeclare: Boolean,
seenLabels: Set[String]): CompoundBody = {
val buff = ListBuffer[CompoundPlanStatement]()
ctx.compoundStatements.forEach(compoundStatement => {
buff += visit(compoundStatement).asInstanceOf[CompoundPlanStatement]
})
ctx.compoundStatements.forEach(
compoundStatement => buff += visitCompoundStatementImpl(compoundStatement, seenLabels))

val compoundStatements = buff.toList

Expand Down Expand Up @@ -184,90 +185,104 @@ class AstBuilder extends DataTypeAstBuilder
CompoundBody(buff.toSeq, label)
}


private def generateLabelText(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]): String = {

(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.nonEmpty &&
bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}

beginLabelCtx.map(_.multipartIdentifier().getText)
.getOrElse(java.util.UUID.randomUUID.toString).toLowerCase(Locale.ROOT)
}

override def visitBeginEndCompoundBlock(ctx: BeginEndCompoundBlockContext): CompoundBody = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true)
}

override def visitCompoundBody(ctx: CompoundBodyContext): CompoundBody = {
visitCompoundBodyImpl(ctx, None, allowVarDeclare = false)
private def visitBeginEndCompoundBlockImpl(
ctx: BeginEndCompoundBlockContext,
seenLabels: Set[String]): CompoundBody = {
val labelText =
LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()), seenLabels)
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
Some(labelText),
allowVarDeclare = true,
seenLabels
)
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()), seenLabels)
body
}

override def visitCompoundStatement(ctx: CompoundStatementContext): CompoundPlanStatement =
private def visitCompoundStatementImpl(
ctx: CompoundStatementContext,
seenLabels: Set[String]): CompoundPlanStatement =
withOrigin(ctx) {
Option(ctx.statement().asInstanceOf[ParserRuleContext])
.orElse(Option(ctx.setStatementWithOptionalVarKeyword().asInstanceOf[ParserRuleContext]))
.map { s =>
SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan])
}.getOrElse {
visitChildren(ctx).asInstanceOf[CompoundPlanStatement]
if (ctx.getChildCount == 1) {
ctx.getChild(0) match {
case compoundBodyContext: BeginEndCompoundBlockContext =>
visitBeginEndCompoundBlockImpl(compoundBodyContext, seenLabels)
case whileStmtContext: WhileStatementContext =>
visitWhileStatementImpl(whileStmtContext, seenLabels)
case repeatStmtContext: RepeatStatementContext =>
visitRepeatStatementImpl(repeatStmtContext, seenLabels)
case loopStatementContext: LoopStatementContext =>
visitLoopStatementImpl(loopStatementContext, seenLabels)
case ifElseStmtContext: IfElseStatementContext =>
visitIfElseStatementImpl(ifElseStmtContext, seenLabels)
case searchedCaseContext: SearchedCaseStatementContext =>
visitSearchedCaseStatementImpl(searchedCaseContext, seenLabels)
case simpleCaseContext: SimpleCaseStatementContext =>
visitSimpleCaseStatementImpl(simpleCaseContext, seenLabels)
case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement]
}
} else {
null
}
}
}

override def visitIfElseStatement(ctx: IfElseStatementContext): IfElseStatement = {
private def visitIfElseStatementImpl(
ctx: IfElseStatementContext,
seenLabels: Set[String]): IfElseStatement = {
IfElseStatement(
conditions = ctx.booleanExpression().asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))
}),
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body)),
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body))
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, seenLabels)
),
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, seenLabels)
)
)
}

override def visitWhileStatement(ctx: WhileStatementContext): WhileStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
private def visitWhileStatementImpl(
ctx: WhileStatementContext,
seenLabels: Set[String]): WhileStatement = {
val labelText =
LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()), seenLabels)
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, seenLabels)
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()), seenLabels)

WhileStatement(condition, body, Some(labelText))
}

override def visitSearchedCaseStatement(ctx: SearchedCaseStatementContext): CaseStatement = {
private def visitSearchedCaseStatementImpl(
ctx: SearchedCaseStatementContext,
seenLabels: Set[String]): CaseStatement = {
val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, seenLabels)
)

if (conditions.length != conditionalBodies.length) {
throw SparkException.internalError(
Expand All @@ -278,10 +293,14 @@ class AstBuilder extends DataTypeAstBuilder
CaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, seenLabels)
))
}

override def visitSimpleCaseStatement(ctx: SimpleCaseStatementContext): CaseStatement = {
private def visitSimpleCaseStatementImpl(
ctx: SimpleCaseStatementContext,
seenLabels: Set[String]): CaseStatement = {
// uses EqualTo to compare the case variable(the main case expression)
// to the WHEN clause expressions
val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) {
Expand All @@ -291,7 +310,9 @@ class AstBuilder extends DataTypeAstBuilder
OneRowRelation()))
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, seenLabels)
)

if (conditions.length != conditionalBodies.length) {
throw SparkException.internalError(
Expand All @@ -302,19 +323,25 @@ class AstBuilder extends DataTypeAstBuilder
CaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, seenLabels)
))
}

override def visitRepeatStatement(ctx: RepeatStatementContext): RepeatStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
private def visitRepeatStatementImpl(
ctx: RepeatStatementContext,
seenLabels: Set[String]): RepeatStatement = {
val labelText =
LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()), seenLabels)
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, seenLabels)
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()), seenLabels)

RepeatStatement(condition, body, Some(labelText))
}
Expand Down Expand Up @@ -377,9 +404,13 @@ class AstBuilder extends DataTypeAstBuilder
CurrentOrigin.get, labelText, "ITERATE")
}

override def visitLoopStatement(ctx: LoopStatementContext): LoopStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBody(ctx.compoundBody())
private def visitLoopStatementImpl(
ctx: LoopStatementContext,
seenLabels: Set[String]): LoopStatement = {
val labelText =
LabelUtils.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()), seenLabels)
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, seenLabels)
LabelUtils.exitLabeledScope(Option(ctx.beginLabel()), seenLabels)

LoopStatement(body, Some(labelText))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ package org.apache.spark.sql.catalyst.parser
import java.util
import java.util.Locale

import scala.collection.mutable

import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.{ParseTree, TerminalNodeImpl}

import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.SparkParserUtils
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.catalyst.util.SparkParserUtils.withOrigin
import org.apache.spark.sql.errors.{QueryParsingErrors, SqlScriptingErrors}

/**
* A collection of utility methods for use during the parsing process.
Expand Down Expand Up @@ -134,3 +139,79 @@ object ParserUtils extends SparkParserUtils {
sb.toString()
}
}

object LabelUtils {
/**
* Check if the beginLabelCtx and endLabelCtx match.
* If the labels are defined, they must follow rules:
* - If both labels exist, they must match.
* - Begin label must exist if end label exists.
*/
private def checkLabels(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]) : Unit = {
(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}
}

/** Check if the label is defined. */
private def isLabelDefined(beginLabelCtx: Option[BeginLabelContext]): Boolean = {
beginLabelCtx.map(_.multipartIdentifier().getText).isDefined
}

/**
* Enter a labeled scope and return the label text.
* If the label is defined, it will be returned and added to seenLabels.
* If the label is not defined, a random UUID will be returned.
*/
def enterLabeledScope(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext],
seenLabels: mutable.Set[String]): String = {

checkLabels(beginLabelCtx, endLabelCtx)

val labelText = if (isLabelDefined(beginLabelCtx)) {
val txt = beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
if (seenLabels.contains(txt)) {
withOrigin(beginLabelCtx.get) {
throw SqlScriptingErrors.duplicateLabels(CurrentOrigin.get, txt)
}
}
seenLabels.add(beginLabelCtx.get.multipartIdentifier().getText)
txt
} else {
// Do not add the label to the seenLabels set if it is not defined.
java.util.UUID.randomUUID.toString.toLowerCase(Locale.ROOT)
}

labelText
}

/**
* Exit a labeled scope.
* If the label is defined, it will be removed from seenLabels.
*/
def exitLabeledScope(
beginLabelCtx: Option[BeginLabelContext],
seenLabels: mutable.Set[String]): Unit = {
if (isLabelDefined(beginLabelCtx)) {
seenLabels.remove(beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ import org.apache.spark.sql.exceptions.SqlScriptingException
*/
private[sql] object SqlScriptingErrors {

def duplicateLabels(origin: Origin, label: String): Throwable = {
new SqlScriptingException(
origin = origin,
errorClass = "LABEL_ALREADY_EXISTS",
cause = null,
messageParameters = Map("label" -> toSQLId(label)))
}

def labelsMismatch(origin: Origin, beginLabel: String, endLabel: String): Throwable = {
new SqlScriptingException(
origin = origin,
Expand Down
Loading