Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,12 @@
],
"sqlState" : "42K0L"
},
"LABEL_ALREADY_EXISTS" : {
"message" : [
"The label <label> already exists. Choose another name or rename the existing label."
],
"sqlState" : "42K0L"
},
"LOAD_DATA_PATH_NOT_EXISTS" : {
"message" : [
"LOAD DATA input path does not exist: <path>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,18 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = {
visit(ctx.beginEndCompoundBlock()).asInstanceOf[CompoundBody]
val labelCtx = new SqlScriptingLabelContext()
visitBeginEndCompoundBlockImpl(ctx.beginEndCompoundBlock(), labelCtx)
}

private def visitCompoundBodyImpl(
ctx: CompoundBodyContext,
label: Option[String],
allowVarDeclare: Boolean): CompoundBody = {
allowVarDeclare: Boolean,
labelCtx: SqlScriptingLabelContext): CompoundBody = {
val buff = ListBuffer[CompoundPlanStatement]()
ctx.compoundStatements.forEach(compoundStatement => {
buff += visit(compoundStatement).asInstanceOf[CompoundPlanStatement]
})
ctx.compoundStatements.forEach(
compoundStatement => buff += visitCompoundStatementImpl(compoundStatement, labelCtx))

val compoundStatements = buff.toList

Expand Down Expand Up @@ -184,90 +185,104 @@ class AstBuilder extends DataTypeAstBuilder
CompoundBody(buff.toSeq, label)
}


private def generateLabelText(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]): String = {

(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.nonEmpty &&
bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}

beginLabelCtx.map(_.multipartIdentifier().getText)
.getOrElse(java.util.UUID.randomUUID.toString).toLowerCase(Locale.ROOT)
}

override def visitBeginEndCompoundBlock(ctx: BeginEndCompoundBlockContext): CompoundBody = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true)
}

override def visitCompoundBody(ctx: CompoundBodyContext): CompoundBody = {
visitCompoundBodyImpl(ctx, None, allowVarDeclare = false)
private def visitBeginEndCompoundBlockImpl(
ctx: BeginEndCompoundBlockContext,
labelCtx: SqlScriptingLabelContext): CompoundBody = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
Some(labelText),
allowVarDeclare = true,
labelCtx
)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
body
}

override def visitCompoundStatement(ctx: CompoundStatementContext): CompoundPlanStatement =
private def visitCompoundStatementImpl(
ctx: CompoundStatementContext,
labelCtx: SqlScriptingLabelContext): CompoundPlanStatement =
withOrigin(ctx) {
Option(ctx.statement().asInstanceOf[ParserRuleContext])
.orElse(Option(ctx.setStatementWithOptionalVarKeyword().asInstanceOf[ParserRuleContext]))
.map { s =>
SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan])
}.getOrElse {
visitChildren(ctx).asInstanceOf[CompoundPlanStatement]
if (ctx.getChildCount == 1) {
ctx.getChild(0) match {
case compoundBodyContext: BeginEndCompoundBlockContext =>
visitBeginEndCompoundBlockImpl(compoundBodyContext, labelCtx)
case whileStmtContext: WhileStatementContext =>
visitWhileStatementImpl(whileStmtContext, labelCtx)
case repeatStmtContext: RepeatStatementContext =>
visitRepeatStatementImpl(repeatStmtContext, labelCtx)
case loopStatementContext: LoopStatementContext =>
visitLoopStatementImpl(loopStatementContext, labelCtx)
case ifElseStmtContext: IfElseStatementContext =>
visitIfElseStatementImpl(ifElseStmtContext, labelCtx)
case searchedCaseContext: SearchedCaseStatementContext =>
visitSearchedCaseStatementImpl(searchedCaseContext, labelCtx)
case simpleCaseContext: SimpleCaseStatementContext =>
visitSimpleCaseStatementImpl(simpleCaseContext, labelCtx)
case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement]
}
} else {
null
}
}
}

override def visitIfElseStatement(ctx: IfElseStatementContext): IfElseStatement = {
private def visitIfElseStatementImpl(
ctx: IfElseStatementContext,
labelCtx: SqlScriptingLabelContext): IfElseStatement = {
IfElseStatement(
conditions = ctx.booleanExpression().asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))
}),
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body)),
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body))
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
),
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
)
)
}

override def visitWhileStatement(ctx: WhileStatementContext): WhileStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
private def visitWhileStatementImpl(
ctx: WhileStatementContext,
labelCtx: SqlScriptingLabelContext): WhileStatement = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

WhileStatement(condition, body, Some(labelText))
}

override def visitSearchedCaseStatement(ctx: SearchedCaseStatementContext): CaseStatement = {
private def visitSearchedCaseStatementImpl(
ctx: SearchedCaseStatementContext,
labelCtx: SqlScriptingLabelContext): CaseStatement = {
val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
)

if (conditions.length != conditionalBodies.length) {
throw SparkException.internalError(
Expand All @@ -278,10 +293,14 @@ class AstBuilder extends DataTypeAstBuilder
CaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
))
}

override def visitSimpleCaseStatement(ctx: SimpleCaseStatementContext): CaseStatement = {
private def visitSimpleCaseStatementImpl(
ctx: SimpleCaseStatementContext,
labelCtx: SqlScriptingLabelContext): CaseStatement = {
// uses EqualTo to compare the case variable(the main case expression)
// to the WHEN clause expressions
val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) {
Expand All @@ -291,7 +310,9 @@ class AstBuilder extends DataTypeAstBuilder
OneRowRelation()))
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
)

if (conditions.length != conditionalBodies.length) {
throw SparkException.internalError(
Expand All @@ -302,19 +323,25 @@ class AstBuilder extends DataTypeAstBuilder
CaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
))
}

override def visitRepeatStatement(ctx: RepeatStatementContext): RepeatStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
private def visitRepeatStatementImpl(
ctx: RepeatStatementContext,
labelCtx: SqlScriptingLabelContext): RepeatStatement = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

RepeatStatement(condition, body, Some(labelText))
}
Expand Down Expand Up @@ -377,9 +404,13 @@ class AstBuilder extends DataTypeAstBuilder
CurrentOrigin.get, labelText, "ITERATE")
}

override def visitLoopStatement(ctx: LoopStatementContext): LoopStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBody(ctx.compoundBody())
private def visitLoopStatementImpl(
ctx: LoopStatementContext,
labelCtx: SqlScriptingLabelContext): LoopStatement = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

LoopStatement(body, Some(labelText))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ package org.apache.spark.sql.catalyst.parser
import java.util
import java.util.Locale

import scala.collection.mutable.Set

import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.{ParseTree, TerminalNodeImpl}

import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.SparkParserUtils
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.catalyst.util.SparkParserUtils.withOrigin
import org.apache.spark.sql.errors.{QueryParsingErrors, SqlScriptingErrors}

/**
* A collection of utility methods for use during the parsing process.
Expand Down Expand Up @@ -134,3 +139,80 @@ object ParserUtils extends SparkParserUtils {
sb.toString()
}
}

class SqlScriptingLabelContext {
/** Set to keep track of labels seen so far */
private val seenLabels = Set[String]()

/**
* Check if the beginLabelCtx and endLabelCtx match.
* If the labels are defined, they must follow rules:
* - If both labels exist, they must match.
* - Begin label must exist if end label exists.
*/
private def checkLabels(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]) : Unit = {
(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}
}

/** Check if the label is defined. */
private def isLabelDefined(beginLabelCtx: Option[BeginLabelContext]): Boolean = {
beginLabelCtx.map(_.multipartIdentifier().getText).isDefined
}

/**
* Enter a labeled scope and return the label text.
* If the label is defined, it will be returned and added to seenLabels.
* If the label is not defined, a random UUID will be returned.
*/
def enterLabeledScope(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]): String = {

// Check if this label already exists in parent scopes.
checkLabels(beginLabelCtx, endLabelCtx)

// Get label text and add it to seenLabels.
val labelText = if (isLabelDefined(beginLabelCtx)) {
val txt = beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
if (seenLabels.contains(txt)) {
withOrigin(beginLabelCtx.get) {
throw SqlScriptingErrors.duplicateLabels(CurrentOrigin.get, txt)
}
}
seenLabels.add(beginLabelCtx.get.multipartIdentifier().getText)
txt
} else {
// Do not add the label to the seenLabels set if it is not defined.
java.util.UUID.randomUUID.toString.toLowerCase(Locale.ROOT)
}
labelText
}

/**
* Exit a labeled scope.
* If the label is defined, it will be removed from seenLabels.
*/
def exitLabeledScope(beginLabelCtx: Option[BeginLabelContext]): Unit = {
if (isLabelDefined(beginLabelCtx)) {
seenLabels.remove(beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ import org.apache.spark.sql.exceptions.SqlScriptingException
*/
private[sql] object SqlScriptingErrors {

def duplicateLabels(origin: Origin, label: String): Throwable = {
new SqlScriptingException(
origin = origin,
errorClass = "LABEL_ALREADY_EXISTS",
cause = null,
messageParameters = Map("label" -> toSQLId(label)))
}

def labelsMismatch(origin: Origin, beginLabel: String, endLabel: String): Throwable = {
new SqlScriptingException(
origin = origin,
Expand Down
Loading