Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
40ffdef
[SPARK-50250][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 13, 2024
ede05fa
[SPARK-50248][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 13, 2024
6fb1d43
[SPARK-50246][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 13, 2024
898bff2
[SPARK-50245][SQL][TESTS] Extended CollationSuite and added tests whe…
vladanvasi-db Nov 13, 2024
bd94419
[SPARK-50226][SQL] Correct MakeDTInterval and MakeYMInterval to catch…
gotocoding-DB Nov 13, 2024
bc9b259
[SPARK-50066][SQL] Codegen Support for `SchemaOfXml` (by Invoke & Run…
panbingkun Nov 13, 2024
558fc89
[SPARK-49611][SQL][FOLLOW-UP] Make collations TVF consistent and retu…
mihailomilosevic2001 Nov 13, 2024
7b1b450
Revert [SPARK-50215][SQL] Refactored StringType pattern matching in j…
vladanvasi-db Nov 13, 2024
87ad4b4
[SPARK-50139][INFRA][SS][PYTHON] Introduce scripts to re-generate and…
LuciferYang Nov 13, 2024
05508cf
[SPARK-42838][SQL] Assign a name to the error class _LEGACY_ERROR_TEM…
mihailomilosevic2001 Nov 13, 2024
5cc60f4
[SPARK-50300][BUILD] Use mirror host instead of `archive.apache.org`
dongjoon-hyun Nov 13, 2024
33378a6
[SPARK-50304][INFRA] Remove `(any|empty).proto` from RAT exclusion
dongjoon-hyun Nov 14, 2024
891f694
[SPARK-50306][PYTHON][CONNECT] Support Python 3.13 in Spark Connect
HyukjinKwon Nov 14, 2024
2fd4702
[SPARK-49913][SQL] Add check for unique label names in nested labeled…
miland-db Nov 14, 2024
6bee268
[SPARK-50299][BUILD] Upgrade jupiter-interface to 0.13.1 and Junit5 t…
LuciferYang Nov 14, 2024
09d6b32
[SPARK-48755][DOCS][PYTHON][FOLLOWUP] Add PySpark doc for `transformW…
itholic Nov 14, 2024
0b1b676
[SPARK-50092][SQL] Fix PostgreSQL connector behaviour for multidimens…
PetarVasiljevic-DB Nov 14, 2024
aea9e87
[SPARK-50291][PYTHON] Standardize verifySchema parameter of createDat…
xinrong-meng Nov 14, 2024
c1968a1
[SPARK-50216][SQL][TESTS] Update `CollationBenchmark` to invoke `coll…
stevomitric Nov 14, 2024
0aee601
[SPARK-50153][SQL] Add `name` to `RuleExecutor` to make printing `Que…
panbingkun Nov 14, 2024
c2343f7
[SPARK-45265][SQL] Support Hive 4.0 metastore
yaooqinn Nov 14, 2024
e0a83f6
[SPARK-50317][BUILD] Upgrade ORC to 2.0.3
dongjoon-hyun Nov 14, 2024
c90efae
[SPARK-50318][SQL] Add IntervalUtils.makeYearMonthInterval to dedupli…
gotocoding-DB Nov 15, 2024
3237885
[SPARK-50312][SQL] SparkThriftServer createServer parameter passing e…
CuiYanxiang Nov 15, 2024
e615e3f
[SPARK-50049][SQL] Support custom driver metrics in writing to v2 table
cloud-fan Nov 15, 2024
3f5e846
[SPARK-50237][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 15, 2024
cf90271
[MINOR] Fix code style for if/for/while statements
exmy Nov 15, 2024
cc81ed0
[SPARK-50325][SQL] Factor out alias resolution to be reused in the si…
vladimirg-db Nov 15, 2024
d317002
[SPARK-50322][SQL] Fix parameterized identifier in a sub-query
MaxGekk Nov 15, 2024
77e006f
[SPARK-50327][SQL] Factor out function resolution to be reused in the…
vladimirg-db Nov 15, 2024
11e4706
[SPARK-50320][CORE] Make `--remote` an official option by removing `e…
dongjoon-hyun Nov 15, 2024
007c31d
[SPARK-50236][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 15, 2024
281a8e1
[SPARK-50309][DOCS] Document `SQL Pipe` Syntax
dtenedor Nov 15, 2024
b626528
[SPARK-50313][SQL][TESTS] Enable ANSI in SQL *SQLQueryTestSuite by de…
yaooqinn Nov 18, 2024
a01856d
[SPARK-50330][SQL] Add hints to Sort and Window nodes
agubichev Nov 18, 2024
8b2d032
[SPARK-45265][SQL][BUILD][FOLLOWUP] Add `-Xss64m` for Maven testing o…
LuciferYang Nov 18, 2024
05750de
[MINOR][PYTHON][DOCS] Fix the type hint of `histogram_numeric`
zhengruifeng Nov 18, 2024
400a8d3
Revert "[SPARK-49787][SQL] Cast between UDT and other types"
cloud-fan Nov 18, 2024
fa36e8b
[SPARK-50335][PYTHON][DOCS] Refine docstrings for window/aggregation …
zhengruifeng Nov 19, 2024
b61411d
[SPARK-50328][INFRA] Add a separate docker file for SparkR
zhengruifeng Nov 19, 2024
e1477a3
[SPARK-50298][PYTHON][CONNECT] Implement verifySchema parameter of cr…
xinrong-meng Nov 19, 2024
6d47981
[SPARK-50331][INFRA] Add a daily test for PySpark on MacOS
LuciferYang Nov 19, 2024
5a57efd
[SPARK-50313][SQL][TESTS][FOLLOWUP] Restore some tests in *SQLQueryTe…
yaooqinn Nov 19, 2024
b74aa8c
[SPARK-50340][SQL] Unwrap UDT in INSERT input query
cloud-fan Nov 19, 2024
87a5b37
[SPARK-50313][SQL][TESTS][FOLLOWUP] Regenerate golden files for Java 21
LuciferYang Nov 19, 2024
f1b68d8
[SPARK-50315][SQL] Support custom metrics for V1Fallback writes
olaky Nov 19, 2024
19509d0
Revert "[SPARK-49002][SQL] Consistently handle invalid locations in W…
cloud-fan Nov 19, 2024
37497e6
[SPARK-50335][PYTHON][DOCS][FOLLOW-UP] Make percentile doctests more …
zhengruifeng Nov 20, 2024
c149dcb
[SPARK-50352][PYTHON][DOCS] Refine docstrings for window/aggregation …
zhengruifeng Nov 20, 2024
8791767
[SPARK-48344][SQL] Prepare SQL Scripting for addition of Execution Fr…
miland-db Nov 20, 2024
b7cf448
[SPARK-49550][FOLLOWUP][SQL][DOC] Switch Hadoop to 3.4.1 in IsolatedC…
pan3793 Nov 20, 2024
2185f3c
[SPARK-50359][PYTHON] Upgrade PyArrow to 18.0
zhengruifeng Nov 20, 2024
0157778
[SPARK-50358][SQL][TESTS] Update postgres docker image to 17.1
panbingkun Nov 20, 2024
b582dac
[MINOR][DOCS] Fix a HTML/Markdown syntax error in sql-migration-guide.md
yaooqinn Nov 20, 2024
19b8250
[SPARK-50331][INFRA][FOLLOW-UP] Skip Torch/DeepSpeed tests in MacOS P…
zhengruifeng Nov 20, 2024
7a4f3c4
[SPARK-50345][BUILD] Upgrade Kafka to 3.9.0
panbingkun Nov 20, 2024
3151d97
[SPARK-49801][INFRA][FOLLOWUP] Sync pandas version in release environ…
yaooqinn Nov 20, 2024
23f276f
[SPARK-50353][SQL] Refactor ResolveSQLOnFile
mihailoale-db Nov 20, 2024
533b8ca
[SPARK-50363][PYTHON][DOCS] Refine the docstring for datetime functio…
zhengruifeng Nov 20, 2024
81a56df
[SPARK-50362][PYTHON][ML] Skip `CrossValidatorTests` if `torch/torche…
LuciferYang Nov 20, 2024
6ee53da
[SPARK-50258][SQL] Fix output column order changed issue after AQE op…
wangyum Nov 20, 2024
30d0b01
[SPARK-50364][SQL] Implement serialization for LocalDateTime type in …
krm95 Nov 20, 2024
ad46db4
[SPARK-50130][SQL][FOLLOWUP] Make Encoder generation lazy
ueshin Nov 20, 2024
a409199
[SPARK-50376][PYTHON][ML][TESTS] Centralize the dependency check in M…
zhengruifeng Nov 21, 2024
3bc374d
[SPARK-50333][SQL] Codegen Support for `CsvToStructs` (by Invoke & Ru…
panbingkun Nov 21, 2024
95faa02
[SPARK-49490][SQL] Add benchmarks for initCap
mrk-andreev Nov 21, 2024
ee21e6b
[SPARK-50113][CONNECT][PYTHON][TESTS] Add `@remote_only` to check the…
itholic Nov 21, 2024
0f1e410
[SPARK-50016][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 21, 2024
b05ef45
[SPARK-50175][SQL] Change collation precedence calculation
stefankandic Nov 21, 2024
fbf255e
[SPARK-50379][SQL] Fix DayTimeIntevalType handling in WindowExecBase
mihailomilosevic2001 Nov 21, 2024
cbb16b9
[MINOR][DOCS] Fix miss semicolon on create table example sql
camilesing Nov 21, 2024
f2de888
[MINOR][DOCS] Remove wrong and ambiguous default statement in datetim…
yaooqinn Nov 21, 2024
229b1b8
[SPARK-50375][BUILD] Upgrade `commons-io` to 2.18.0
panbingkun Nov 21, 2024
136c722
[SPARK-50334][SQL] Extract common logic for reading the descriptor of…
panbingkun Nov 21, 2024
2e1c3dc
[SPARK-50087] Robust handling of boolean expressions in CASE WHEN for…
cloud-fan Nov 21, 2024
2d09ef2
[SPARK-50381][CORE] Support `spark.master.rest.maxThreads`
dongjoon-hyun Nov 21, 2024
69324bd
Merge branch 'master' into pr48820
ueshin Nov 21, 2024
349df78
Fix.
ueshin Nov 21, 2024
1079339
Fix.
ueshin Nov 21, 2024
c6b0651
Fix.
ueshin Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[SPARK-49913][SQL] Add check for unique label names in nested labeled…
… scopes

### What changes were proposed in this pull request?
We are introducing checks for unique label names.
New rules for label names:
- Labels can't have the same name as some of the labels in scope surrounding them
- Labels can have the same name as other labels in the same scope

**Valid** code:
```
BEGIN
  lbl: BEGIN
    SELECT 1;
  END;

  lbl: BEGIN
    SELECT 2;
  END;

  BEGIN
    lbl: WHILE 1=1 DO
      LEAVE lbl;
    END WHILE;
  END;
END
```

**Invalid** code:
```
BEGIN
  lbl: BEGIN
    lbl: BEGIN
      SELECT 1;
    END;
  END;
END
```

#### Design explanation:

Even though there are _Listeners_ with `enterRule` and `exitRule` methods to check labels before and remove them from `seenLabels` after visiting node, we favor this approach because minimal changes were needed and code is more compact to avoid dependency issues.

Additionally, generating label text would need to be done in 2 places and we wanted to avoid duplicated logic:
- `enterRule`
- `visitRule`

### Why are the changes needed?
It will be needed in future when we release Local Scoped Variables for SQL Scripting so users can target variables from outer scopes if they are shadowed.

### How was this patch tested?
New unit tests in 'SqlScriptingParserSuite.scala'.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48795 from miland-db/milan-dankovic_data/unique_labels_scripting.

Authored-by: Milan Dankovic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
miland-db authored and cloud-fan committed Nov 14, 2024
commit 2fd47026371488b9409750cba6b697cc61ea7371
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3411,6 +3411,12 @@
],
"sqlState" : "42K0L"
},
"LABEL_ALREADY_EXISTS" : {
"message" : [
"The label <label> already exists. Choose another name or rename the existing label."
],
"sqlState" : "42K0L"
},
"LOAD_DATA_PATH_NOT_EXISTS" : {
"message" : [
"LOAD DATA input path does not exist: <path>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,18 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = {
visit(ctx.beginEndCompoundBlock()).asInstanceOf[CompoundBody]
val labelCtx = new SqlScriptingLabelContext()
visitBeginEndCompoundBlockImpl(ctx.beginEndCompoundBlock(), labelCtx)
}

private def visitCompoundBodyImpl(
ctx: CompoundBodyContext,
label: Option[String],
allowVarDeclare: Boolean): CompoundBody = {
allowVarDeclare: Boolean,
labelCtx: SqlScriptingLabelContext): CompoundBody = {
val buff = ListBuffer[CompoundPlanStatement]()
ctx.compoundStatements.forEach(compoundStatement => {
buff += visit(compoundStatement).asInstanceOf[CompoundPlanStatement]
})
ctx.compoundStatements.forEach(
compoundStatement => buff += visitCompoundStatementImpl(compoundStatement, labelCtx))

val compoundStatements = buff.toList

Expand Down Expand Up @@ -184,90 +185,104 @@ class AstBuilder extends DataTypeAstBuilder
CompoundBody(buff.toSeq, label)
}


private def generateLabelText(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]): String = {

(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.nonEmpty &&
bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}

beginLabelCtx.map(_.multipartIdentifier().getText)
.getOrElse(java.util.UUID.randomUUID.toString).toLowerCase(Locale.ROOT)
}

override def visitBeginEndCompoundBlock(ctx: BeginEndCompoundBlockContext): CompoundBody = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true)
}

override def visitCompoundBody(ctx: CompoundBodyContext): CompoundBody = {
visitCompoundBodyImpl(ctx, None, allowVarDeclare = false)
private def visitBeginEndCompoundBlockImpl(
ctx: BeginEndCompoundBlockContext,
labelCtx: SqlScriptingLabelContext): CompoundBody = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
Some(labelText),
allowVarDeclare = true,
labelCtx
)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
body
}

override def visitCompoundStatement(ctx: CompoundStatementContext): CompoundPlanStatement =
private def visitCompoundStatementImpl(
ctx: CompoundStatementContext,
labelCtx: SqlScriptingLabelContext): CompoundPlanStatement =
withOrigin(ctx) {
Option(ctx.statement().asInstanceOf[ParserRuleContext])
.orElse(Option(ctx.setStatementWithOptionalVarKeyword().asInstanceOf[ParserRuleContext]))
.map { s =>
SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan])
}.getOrElse {
visitChildren(ctx).asInstanceOf[CompoundPlanStatement]
if (ctx.getChildCount == 1) {
ctx.getChild(0) match {
case compoundBodyContext: BeginEndCompoundBlockContext =>
visitBeginEndCompoundBlockImpl(compoundBodyContext, labelCtx)
case whileStmtContext: WhileStatementContext =>
visitWhileStatementImpl(whileStmtContext, labelCtx)
case repeatStmtContext: RepeatStatementContext =>
visitRepeatStatementImpl(repeatStmtContext, labelCtx)
case loopStatementContext: LoopStatementContext =>
visitLoopStatementImpl(loopStatementContext, labelCtx)
case ifElseStmtContext: IfElseStatementContext =>
visitIfElseStatementImpl(ifElseStmtContext, labelCtx)
case searchedCaseContext: SearchedCaseStatementContext =>
visitSearchedCaseStatementImpl(searchedCaseContext, labelCtx)
case simpleCaseContext: SimpleCaseStatementContext =>
visitSimpleCaseStatementImpl(simpleCaseContext, labelCtx)
case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement]
}
} else {
null
}
}
}

override def visitIfElseStatement(ctx: IfElseStatementContext): IfElseStatement = {
private def visitIfElseStatementImpl(
ctx: IfElseStatementContext,
labelCtx: SqlScriptingLabelContext): IfElseStatement = {
IfElseStatement(
conditions = ctx.booleanExpression().asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))
}),
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body)),
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body))
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
),
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
)
)
}

override def visitWhileStatement(ctx: WhileStatementContext): WhileStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
private def visitWhileStatementImpl(
ctx: WhileStatementContext,
labelCtx: SqlScriptingLabelContext): WhileStatement = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

WhileStatement(condition, body, Some(labelText))
}

override def visitSearchedCaseStatement(ctx: SearchedCaseStatementContext): CaseStatement = {
private def visitSearchedCaseStatementImpl(
ctx: SearchedCaseStatementContext,
labelCtx: SqlScriptingLabelContext): CaseStatement = {
val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
)

if (conditions.length != conditionalBodies.length) {
throw SparkException.internalError(
Expand All @@ -278,10 +293,14 @@ class AstBuilder extends DataTypeAstBuilder
CaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
))
}

override def visitSimpleCaseStatement(ctx: SimpleCaseStatementContext): CaseStatement = {
private def visitSimpleCaseStatementImpl(
ctx: SimpleCaseStatementContext,
labelCtx: SqlScriptingLabelContext): CaseStatement = {
// uses EqualTo to compare the case variable(the main case expression)
// to the WHEN clause expressions
val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) {
Expand All @@ -291,7 +310,9 @@ class AstBuilder extends DataTypeAstBuilder
OneRowRelation()))
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
)

if (conditions.length != conditionalBodies.length) {
throw SparkException.internalError(
Expand All @@ -302,19 +323,25 @@ class AstBuilder extends DataTypeAstBuilder
CaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
))
}

override def visitRepeatStatement(ctx: RepeatStatementContext): RepeatStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
private def visitRepeatStatementImpl(
ctx: RepeatStatementContext,
labelCtx: SqlScriptingLabelContext): RepeatStatement = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()

val condition = withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBody(ctx.compoundBody())
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

RepeatStatement(condition, body, Some(labelText))
}
Expand Down Expand Up @@ -377,9 +404,13 @@ class AstBuilder extends DataTypeAstBuilder
CurrentOrigin.get, labelText, "ITERATE")
}

override def visitLoopStatement(ctx: LoopStatementContext): LoopStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBody(ctx.compoundBody())
private def visitLoopStatementImpl(
ctx: LoopStatementContext,
labelCtx: SqlScriptingLabelContext): LoopStatement = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

LoopStatement(body, Some(labelText))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ package org.apache.spark.sql.catalyst.parser
import java.util
import java.util.Locale

import scala.collection.mutable.Set

import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.{ParseTree, TerminalNodeImpl}

import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.SparkParserUtils
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.catalyst.util.SparkParserUtils.withOrigin
import org.apache.spark.sql.errors.{QueryParsingErrors, SqlScriptingErrors}

/**
* A collection of utility methods for use during the parsing process.
Expand Down Expand Up @@ -134,3 +139,80 @@ object ParserUtils extends SparkParserUtils {
sb.toString()
}
}

class SqlScriptingLabelContext {
/** Set to keep track of labels seen so far */
private val seenLabels = Set[String]()

/**
* Check if the beginLabelCtx and endLabelCtx match.
* If the labels are defined, they must follow rules:
* - If both labels exist, they must match.
* - Begin label must exist if end label exists.
*/
private def checkLabels(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]) : Unit = {
(beginLabelCtx, endLabelCtx) match {
case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
if bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
withOrigin(bl) {
throw SqlScriptingErrors.labelsMismatch(
CurrentOrigin.get,
bl.multipartIdentifier().getText,
el.multipartIdentifier().getText)
}
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
CurrentOrigin.get, el.multipartIdentifier().getText)
}
case _ =>
}
}

/** Check if the label is defined. */
private def isLabelDefined(beginLabelCtx: Option[BeginLabelContext]): Boolean = {
beginLabelCtx.map(_.multipartIdentifier().getText).isDefined
}

/**
* Enter a labeled scope and return the label text.
* If the label is defined, it will be returned and added to seenLabels.
* If the label is not defined, a random UUID will be returned.
*/
def enterLabeledScope(
beginLabelCtx: Option[BeginLabelContext],
endLabelCtx: Option[EndLabelContext]): String = {

// Check if this label already exists in parent scopes.
checkLabels(beginLabelCtx, endLabelCtx)

// Get label text and add it to seenLabels.
val labelText = if (isLabelDefined(beginLabelCtx)) {
val txt = beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
if (seenLabels.contains(txt)) {
withOrigin(beginLabelCtx.get) {
throw SqlScriptingErrors.duplicateLabels(CurrentOrigin.get, txt)
}
}
seenLabels.add(beginLabelCtx.get.multipartIdentifier().getText)
txt
} else {
// Do not add the label to the seenLabels set if it is not defined.
java.util.UUID.randomUUID.toString.toLowerCase(Locale.ROOT)
}
labelText
}

/**
* Exit a labeled scope.
* If the label is defined, it will be removed from seenLabels.
*/
def exitLabeledScope(beginLabelCtx: Option[BeginLabelContext]): Unit = {
if (isLabelDefined(beginLabelCtx)) {
seenLabels.remove(beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ import org.apache.spark.sql.exceptions.SqlScriptingException
*/
private[sql] object SqlScriptingErrors {

def duplicateLabels(origin: Origin, label: String): Throwable = {
new SqlScriptingException(
origin = origin,
errorClass = "LABEL_ALREADY_EXISTS",
cause = null,
messageParameters = Map("label" -> toSQLId(label)))
}

def labelsMismatch(origin: Origin, beginLabel: String, endLabel: String): Throwable = {
new SqlScriptingException(
origin = origin,
Expand Down
Loading