-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-26493][SQL] Allow multiple spark.sql.extensions #23398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
cef8eb0
689a4d2
d89cfd9
fb4ad34
65a5f3f
9c0181d
deaf73e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,41 +30,49 @@ import org.apache.spark.sql.types.{DataType, IntegerType, StructType} | |
| */ | ||
| class SparkSessionExtensionSuite extends SparkFunSuite { | ||
| type ExtensionsBuilder = SparkSessionExtensions => Unit | ||
| private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder | ||
| private def create(builder: ExtensionsBuilder): Seq[ExtensionsBuilder] = Seq(builder) | ||
|
|
||
| private def stop(spark: SparkSession): Unit = { | ||
| spark.stop() | ||
| SparkSession.clearActiveSession() | ||
| SparkSession.clearDefaultSession() | ||
| } | ||
|
|
||
| private def withSession(builder: ExtensionsBuilder)(f: SparkSession => Unit): Unit = { | ||
| val spark = SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate() | ||
| private def withSession(builders: Seq[ExtensionsBuilder])(f: SparkSession => Unit): Unit = { | ||
| val builder = SparkSession.builder().master("local[1]") | ||
| builders.foreach(builder.withExtensions) | ||
| val spark = builder.getOrCreate() | ||
| try f(spark) finally { | ||
| stop(spark) | ||
| } | ||
| } | ||
|
|
||
| test("inject analyzer rule") { | ||
| withSession(_.injectResolutionRule(MyRule)) { session => | ||
| withSession(Seq(_.injectResolutionRule(MyRule))) { session => | ||
| assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) | ||
| } | ||
| } | ||
|
|
||
| test("inject post hoc resolution analyzer rule") { | ||
| withSession(Seq(_.injectPostHocResolutionRule(MyRule))) { session => | ||
| assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session))) | ||
| } | ||
| } | ||
|
|
||
| test("inject check analysis rule") { | ||
| withSession(_.injectCheckRule(MyCheckRule)) { session => | ||
| withSession(Seq(_.injectCheckRule(MyCheckRule))) { session => | ||
| assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) | ||
| } | ||
| } | ||
|
|
||
| test("inject optimizer rule") { | ||
| withSession(_.injectOptimizerRule(MyRule)) { session => | ||
| withSession(Seq(_.injectOptimizerRule(MyRule))) { session => | ||
| assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) | ||
| } | ||
| } | ||
|
|
||
| test("inject spark planner strategy") { | ||
| withSession(_.injectPlannerStrategy(MySparkStrategy)) { session => | ||
| withSession(Seq(_.injectPlannerStrategy(MySparkStrategy))) { session => | ||
| assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) | ||
| } | ||
| } | ||
|
|
@@ -78,6 +86,14 @@ class SparkSessionExtensionSuite extends SparkFunSuite { | |
| } | ||
| } | ||
|
|
||
| test("inject multiple rules") { | ||
| withSession(Seq(_.injectOptimizerRule(MyRule), | ||
| _.injectPlannerStrategy(MySparkStrategy))) { session => | ||
HyukjinKwon marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) | ||
| assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) | ||
| } | ||
| } | ||
|
|
||
| test("inject stacked parsers") { | ||
| val extension = create { extensions => | ||
| extensions.injectParser((_, _) => CatalystSqlParser) | ||
|
|
@@ -108,12 +124,86 @@ class SparkSessionExtensionSuite extends SparkFunSuite { | |
| try { | ||
| assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) | ||
| assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) | ||
| assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session))) | ||
| assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) | ||
| assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) | ||
| assert(session.sessionState.sqlParser.isInstanceOf[MyParser]) | ||
| assert(session.sessionState.functionRegistry | ||
| .lookupFunction(MyExtensions.myFunction._1).isDefined) | ||
| } finally { | ||
| stop(session) | ||
| } | ||
| } | ||
|
|
||
| test("use multiple custom class for extensions in the specified order") { | ||
| val session = SparkSession.builder() | ||
| .master("local[1]") | ||
| .config("spark.sql.extensions", Seq( | ||
| classOf[MyExtensions2].getCanonicalName, | ||
| classOf[MyExtensions].getCanonicalName).mkString(",")) | ||
| .getOrCreate() | ||
| try { | ||
| assert(session.sessionState.planner.strategies.containsSlice( | ||
| Seq(MySparkStrategy2(session), MySparkStrategy(session)))) | ||
| val orderedRules = Seq(MyRule2(session), MyRule(session)) | ||
| val orderedCheckRules = Seq(MyCheckRule2(session), MyCheckRule(session)) | ||
| val parser = MyParser(session, CatalystSqlParser) | ||
| assert(session.sessionState.analyzer.extendedResolutionRules.containsSlice(orderedRules)) | ||
| assert(session.sessionState.analyzer.postHocResolutionRules.containsSlice(orderedRules)) | ||
| assert(session.sessionState.analyzer.extendedCheckRules.containsSlice(orderedCheckRules)) | ||
| assert(session.sessionState.optimizer.batches.flatMap(_.rules).filter(orderedRules.contains) | ||
| .containsSlice(orderedRules ++ orderedRules)) // The optimizer rules are duplicated | ||
| assert(session.sessionState.sqlParser == parser) | ||
|
||
| assert(session.sessionState.functionRegistry | ||
| .lookupFunction(MyExtensions.myFunction._1).isDefined) | ||
| assert(session.sessionState.functionRegistry | ||
| .lookupFunction(MyExtensions2.myFunction._1).isDefined) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, now we have multiple extension registrations. The order of extension names might have side-effects.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the order matters, but we need to discuss and document the behavior when we have name conflicts. For example, the same rule will be added twice in class MyExtensions extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectResolutionRule(MyRule)
}
}
class MyExtensions2 extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectResolutionRule(MyRule)
}
}
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep. If there is no reason to allow that, we had better disallow that by design before this PR.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are use cases where you want to execute rules in a certain order. So I think it is reasonable to add the same rule multiple times. If you want more control you could even create 'micro' optimizer batches by calling multiple rules from one rule. I think this is more a matter of proper documentation than one where we should explicitly block things. Also note that this is a pretty advanced feature and by this stage users are expected to know what they are doing.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prior to this change, it was possible to programmatically register multiple extensions but it was not possible to do so through the spark.sql.extensions configuration. Although it wasn't documented/tested until this pull request. E.g. The following works without this pull request: So I think conflicting function names are already currently possible (but not documented). In the following cases:
As for the order, it looks to me like the last function to be stored with conflicting names is the one which is retrieved: I will update this PR to document what happens in order of operations and conflicts. If we need to explicitly block duplicates functions from being registered, I can temporarily drop this PR and see about making those changes first.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for explaining it. We do not need to block it, but we might need to detect and throw a warning message at least. More importantly, we need to document the current behavior and also add a test case to ensure the future changes will not break it. In the future, we can revisit the current behavior and make a change if needed.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have a test case for duplicated extension names? Done I added documentation for the behavior. |
||
| } finally { | ||
| stop(session) | ||
| } | ||
| } | ||
|
|
||
| test("allow an extension to be duplicated") { | ||
| val session = SparkSession.builder() | ||
| .master("local[1]") | ||
| .config("spark.sql.extensions", Seq( | ||
| classOf[MyExtensions].getCanonicalName, | ||
| classOf[MyExtensions].getCanonicalName).mkString(",")) | ||
| .getOrCreate() | ||
| try { | ||
| assert(session.sessionState.planner.strategies.count(_ == MySparkStrategy(session)) == 2) | ||
| assert(session.sessionState.analyzer.extendedResolutionRules.count(_ == MyRule(session)) == 2) | ||
| assert(session.sessionState.analyzer.postHocResolutionRules.count(_ == MyRule(session)) == 2) | ||
| assert(session.sessionState.analyzer.extendedCheckRules.count(_ == MyCheckRule(session)) == 2) | ||
| assert(session.sessionState.optimizer.batches.flatMap(_.rules) | ||
| .count(_ == MyRule(session)) == 4) // The optimizer rules are duplicated | ||
| val outerParser = session.sessionState.sqlParser | ||
| assert(outerParser.isInstanceOf[MyParser]) | ||
| assert(outerParser.asInstanceOf[MyParser].delegate.isInstanceOf[MyParser]) | ||
| assert(session.sessionState.functionRegistry | ||
| .lookupFunction(MyExtensions.myFunction._1).isDefined) | ||
| } finally { | ||
| stop(session) | ||
| } | ||
| } | ||
|
|
||
| test("use the last registered function name when there are duplicates") { | ||
| val session = SparkSession.builder() | ||
| .master("local[1]") | ||
| .config("spark.sql.extensions", Seq( | ||
| classOf[MyExtensions2].getCanonicalName, | ||
| classOf[MyExtensions2Duplicate].getCanonicalName).mkString(",")) | ||
| .getOrCreate() | ||
| try { | ||
| val lastRegistered = session.sessionState.functionRegistry | ||
| .lookupFunction(FunctionIdentifier("myFunction2")) | ||
| assert(lastRegistered.isDefined) | ||
| assert(lastRegistered.get.getExtended != MyExtensions2.myFunction._2.getExtended) | ||
| assert(lastRegistered.get.getExtended == MyExtensions2Duplicate.myFunction._2.getExtended) | ||
| } finally { | ||
| stop(session) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { | ||
|
|
@@ -151,14 +241,63 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars | |
| object MyExtensions { | ||
|
|
||
| val myFunction = (FunctionIdentifier("myFunction"), | ||
| new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended usage" ), | ||
| new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended usage"), | ||
| (myArgs: Seq[Expression]) => Literal(5, IntegerType)) | ||
|
||
|
|
||
|
||
| } | ||
|
|
||
| class MyExtensions extends (SparkSessionExtensions => Unit) { | ||
| def apply(e: SparkSessionExtensions): Unit = { | ||
| e.injectPlannerStrategy(MySparkStrategy) | ||
| e.injectResolutionRule(MyRule) | ||
| e.injectPostHocResolutionRule(MyRule) | ||
| e.injectCheckRule(MyCheckRule) | ||
| e.injectOptimizerRule(MyRule) | ||
| e.injectParser(MyParser) | ||
| e.injectFunction(MyExtensions.myFunction) | ||
| } | ||
| } | ||
|
|
||
| case class MyRule2(spark: SparkSession) extends Rule[LogicalPlan] { | ||
| override def apply(plan: LogicalPlan): LogicalPlan = plan | ||
| } | ||
|
|
||
| case class MyCheckRule2(spark: SparkSession) extends (LogicalPlan => Unit) { | ||
| override def apply(plan: LogicalPlan): Unit = { } | ||
| } | ||
|
|
||
| case class MySparkStrategy2(spark: SparkSession) extends SparkStrategy { | ||
| override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty | ||
| } | ||
|
|
||
| object MyExtensions2 { | ||
|
|
||
| val myFunction = (FunctionIdentifier("myFunction2"), | ||
| new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "extended usage" ), | ||
|
||
| (_: Seq[Expression]) => Literal(5, IntegerType)) | ||
| } | ||
|
|
||
| class MyExtensions2 extends (SparkSessionExtensions => Unit) { | ||
| def apply(e: SparkSessionExtensions): Unit = { | ||
| e.injectPlannerStrategy(MySparkStrategy2) | ||
| e.injectResolutionRule(MyRule2) | ||
| e.injectPostHocResolutionRule(MyRule2) | ||
| e.injectCheckRule(MyCheckRule2) | ||
| e.injectOptimizerRule(MyRule2) | ||
| e.injectParser((_, _) => CatalystSqlParser) | ||
|
||
| e.injectFunction(MyExtensions2.myFunction) | ||
| } | ||
| } | ||
|
|
||
| object MyExtensions2Duplicate { | ||
|
|
||
| val myFunction = (FunctionIdentifier("myFunction2"), | ||
| new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "last wins" ), | ||
|
||
| (_: Seq[Expression]) => Literal(5, IntegerType)) | ||
| } | ||
|
|
||
| class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit) { | ||
| def apply(e: SparkSessionExtensions): Unit = { | ||
| e.injectFunction(MyExtensions2Duplicate.myFunction) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is returning Unit type object. This can be just removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a change here which I think was what you were recommending. I didn't want to remove the
case _ =>otherwise I think it may result in a non-exhaustive match exception. If you wanted me to remove the entire match statement, just let me know.