-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-26493][SQL] Allow multiple spark.sql.extensions #23398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
cef8eb0
689a4d2
d89cfd9
fb4ad34
65a5f3f
9c0181d
deaf73e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
This addresses the comments for #23398
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite { | |
| } | ||
| } | ||
|
|
||
| test("inject post hoc resolution analyzer rule") { | ||
| withSession(Seq(_.injectPostHocResolutionRule(MyRule))) { session => | ||
| assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session))) | ||
| } | ||
| } | ||
|
|
||
| test("inject check analysis rule") { | ||
| withSession(Seq(_.injectCheckRule(MyCheckRule))) { session => | ||
| assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) | ||
|
|
@@ -118,23 +124,36 @@ class SparkSessionExtensionSuite extends SparkFunSuite { | |
| try { | ||
| assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) | ||
| assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) | ||
| assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session))) | ||
| assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) | ||
| assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) | ||
| assert(session.sessionState.sqlParser.isInstanceOf[MyParser]) | ||
| assert(session.sessionState.functionRegistry | ||
| .lookupFunction(MyExtensions.myFunction._1).isDefined) | ||
| } finally { | ||
| stop(session) | ||
| } | ||
| } | ||
|
|
||
| test("use multiple custom class for extensions") { | ||
| test("use multiple custom class for extensions in the specified order") { | ||
| val session = SparkSession.builder() | ||
| .master("local[1]") | ||
| .config("spark.sql.extensions", Seq( | ||
| classOf[MyExtensions].getCanonicalName, | ||
| classOf[MyExtensions2].getCanonicalName).mkString(",")) | ||
| classOf[MyExtensions2].getCanonicalName, | ||
| classOf[MyExtensions].getCanonicalName).mkString(",")) | ||
| .getOrCreate() | ||
| try { | ||
| assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) | ||
| assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) | ||
| assert(session.sessionState.planner.strategies.containsSlice( | ||
| Seq(MySparkStrategy2(session), MySparkStrategy(session)))) | ||
| val orderedRules = Seq(MyRule2(session), MyRule(session)) | ||
| val orderedCheckRules = Seq(MyCheckRule2(session), MyCheckRule(session)) | ||
| val parser = MyParser(session, CatalystSqlParser) | ||
| assert(session.sessionState.analyzer.extendedResolutionRules.containsSlice(orderedRules)) | ||
| assert(session.sessionState.analyzer.postHocResolutionRules.containsSlice(orderedRules)) | ||
| assert(session.sessionState.analyzer.extendedCheckRules.containsSlice(orderedCheckRules)) | ||
| assert(session.sessionState.optimizer.batches.flatMap(_.rules).filter(orderedRules.contains) | ||
| .containsSlice(orderedRules ++ orderedRules)) // The optimizer rules are duplicated | ||
| assert(session.sessionState.sqlParser == parser) | ||
|
||
| assert(session.sessionState.functionRegistry | ||
| .lookupFunction(MyExtensions.myFunction._1).isDefined) | ||
| assert(session.sessionState.functionRegistry | ||
|
|
@@ -143,6 +162,48 @@ class SparkSessionExtensionSuite extends SparkFunSuite { | |
| stop(session) | ||
| } | ||
| } | ||
|
|
||
| test("allow an extension to be duplicated") { | ||
| val session = SparkSession.builder() | ||
| .master("local[1]") | ||
| .config("spark.sql.extensions", Seq( | ||
| classOf[MyExtensions].getCanonicalName, | ||
| classOf[MyExtensions].getCanonicalName).mkString(",")) | ||
| .getOrCreate() | ||
| try { | ||
| assert(session.sessionState.planner.strategies.count(_ == MySparkStrategy(session)) == 2) | ||
| assert(session.sessionState.analyzer.extendedResolutionRules.count(_ == MyRule(session)) == 2) | ||
| assert(session.sessionState.analyzer.postHocResolutionRules.count(_ == MyRule(session)) == 2) | ||
| assert(session.sessionState.analyzer.extendedCheckRules.count(_ == MyCheckRule(session)) == 2) | ||
| assert(session.sessionState.optimizer.batches.flatMap(_.rules) | ||
| .count(_ == MyRule(session)) == 4) // The optimizer rules are duplicated | ||
| val outerParser = session.sessionState.sqlParser | ||
| assert(outerParser.isInstanceOf[MyParser]) | ||
| assert(outerParser.asInstanceOf[MyParser].delegate.isInstanceOf[MyParser]) | ||
| assert(session.sessionState.functionRegistry | ||
| .lookupFunction(MyExtensions.myFunction._1).isDefined) | ||
| } finally { | ||
| stop(session) | ||
| } | ||
| } | ||
|
|
||
| test("use the last registered function name when there are duplicates") { | ||
| val session = SparkSession.builder() | ||
| .master("local[1]") | ||
| .config("spark.sql.extensions", Seq( | ||
| classOf[MyExtensions2].getCanonicalName, | ||
| classOf[MyExtensions2Duplicate].getCanonicalName).mkString(",")) | ||
| .getOrCreate() | ||
| try { | ||
| val lastRegistered = session.sessionState.functionRegistry | ||
| .lookupFunction(FunctionIdentifier("myFunction2")) | ||
| assert(lastRegistered.isDefined) | ||
| assert(lastRegistered.get.getExtended != MyExtensions2.myFunction._2.getExtended) | ||
| assert(lastRegistered.get.getExtended == MyExtensions2Duplicate.myFunction._2.getExtended) | ||
| } finally { | ||
| stop(session) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { | ||
|
|
@@ -189,19 +250,54 @@ class MyExtensions extends (SparkSessionExtensions => Unit) { | |
| def apply(e: SparkSessionExtensions): Unit = { | ||
| e.injectPlannerStrategy(MySparkStrategy) | ||
| e.injectResolutionRule(MyRule) | ||
| e.injectPostHocResolutionRule(MyRule) | ||
| e.injectCheckRule(MyCheckRule) | ||
| e.injectOptimizerRule(MyRule) | ||
| e.injectParser(MyParser) | ||
| e.injectFunction(MyExtensions.myFunction) | ||
| } | ||
| } | ||
|
|
||
| case class MyRule2(spark: SparkSession) extends Rule[LogicalPlan] { | ||
| override def apply(plan: LogicalPlan): LogicalPlan = plan | ||
| } | ||
|
|
||
| case class MyCheckRule2(spark: SparkSession) extends (LogicalPlan => Unit) { | ||
| override def apply(plan: LogicalPlan): Unit = { } | ||
| } | ||
|
|
||
| case class MySparkStrategy2(spark: SparkSession) extends SparkStrategy { | ||
| override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty | ||
| } | ||
|
|
||
| object MyExtensions2 { | ||
|
|
||
| val myFunction = (FunctionIdentifier("myFunction2"), | ||
| new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "extended usage" ), | ||
|
||
| (myArgs: Seq[Expression]) => Literal(5, IntegerType)) | ||
| (_: Seq[Expression]) => Literal(5, IntegerType)) | ||
| } | ||
|
|
||
| class MyExtensions2 extends (SparkSessionExtensions => Unit) { | ||
| def apply(e: SparkSessionExtensions): Unit = { | ||
| e.injectPlannerStrategy(MySparkStrategy2) | ||
| e.injectResolutionRule(MyRule2) | ||
| e.injectPostHocResolutionRule(MyRule2) | ||
| e.injectCheckRule(MyCheckRule2) | ||
| e.injectOptimizerRule(MyRule2) | ||
| e.injectParser((_, _) => CatalystSqlParser) | ||
|
||
| e.injectFunction(MyExtensions2.myFunction) | ||
| } | ||
| } | ||
|
|
||
| object MyExtensions2Duplicate { | ||
|
|
||
| val myFunction = (FunctionIdentifier("myFunction2"), | ||
| new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "last wins" ), | ||
|
||
| (_: Seq[Expression]) => Literal(5, IntegerType)) | ||
| } | ||
|
|
||
| class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit) { | ||
| def apply(e: SparkSessionExtensions): Unit = { | ||
| e.injectFunction(MyExtensions2Duplicate.myFunction) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great if we can check if the new function and the old function are different. This will help to increase the signal of the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a check which which will only log if different function objects are registered. The "allow an extension to be duplicated" unit tests that I previously added registers the same object twice. This test no longer prints the warning. The "use the last registered function name when there are duplicates" unit tests that I previously added registers different functions with the same name. This test prints the warning.