@@ -53,6 +53,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
5353 }
5454 }
5555
56+ test(" inject post hoc resolution analyzer rule" ) {
57+ withSession(Seq (_.injectPostHocResolutionRule(MyRule ))) { session =>
58+ assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule (session)))
59+ }
60+ }
61+
5662 test(" inject check analysis rule" ) {
5763 withSession(Seq (_.injectCheckRule(MyCheckRule ))) { session =>
5864 assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule (session)))
@@ -118,23 +124,36 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
118124 try {
119125 assert(session.sessionState.planner.strategies.contains(MySparkStrategy (session)))
120126 assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule (session)))
127+ assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule (session)))
128+ assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule (session)))
129+ assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule (session)))
130+ assert(session.sessionState.sqlParser.isInstanceOf [MyParser ])
121131 assert(session.sessionState.functionRegistry
122132 .lookupFunction(MyExtensions .myFunction._1).isDefined)
123133 } finally {
124134 stop(session)
125135 }
126136 }
127137
128- test(" use multiple custom class for extensions" ) {
138+ test(" use multiple custom class for extensions in the specified order " ) {
129139 val session = SparkSession .builder()
130140 .master(" local[1]" )
131141 .config(" spark.sql.extensions" , Seq (
132- classOf [MyExtensions ].getCanonicalName,
133- classOf [MyExtensions2 ].getCanonicalName).mkString(" ," ))
142+ classOf [MyExtensions2 ].getCanonicalName,
143+ classOf [MyExtensions ].getCanonicalName).mkString(" ," ))
134144 .getOrCreate()
135145 try {
136- assert(session.sessionState.planner.strategies.contains(MySparkStrategy (session)))
137- assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule (session)))
146+ assert(session.sessionState.planner.strategies.containsSlice(
147+ Seq (MySparkStrategy2 (session), MySparkStrategy (session))))
148+ val orderedRules = Seq (MyRule2 (session), MyRule (session))
149+ val orderedCheckRules = Seq (MyCheckRule2 (session), MyCheckRule (session))
150+ val parser = MyParser (session, CatalystSqlParser )
151+ assert(session.sessionState.analyzer.extendedResolutionRules.containsSlice(orderedRules))
152+ assert(session.sessionState.analyzer.postHocResolutionRules.containsSlice(orderedRules))
153+ assert(session.sessionState.analyzer.extendedCheckRules.containsSlice(orderedCheckRules))
154+ assert(session.sessionState.optimizer.batches.flatMap(_.rules).filter(orderedRules.contains)
155+ .containsSlice(orderedRules ++ orderedRules)) // The optimizer rules are duplicated
156+ assert(session.sessionState.sqlParser == parser)
138157 assert(session.sessionState.functionRegistry
139158 .lookupFunction(MyExtensions .myFunction._1).isDefined)
140159 assert(session.sessionState.functionRegistry
@@ -143,6 +162,48 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
143162 stop(session)
144163 }
145164 }
165+
166+ test(" allow an extension to be duplicated" ) {
167+ val session = SparkSession .builder()
168+ .master(" local[1]" )
169+ .config(" spark.sql.extensions" , Seq (
170+ classOf [MyExtensions ].getCanonicalName,
171+ classOf [MyExtensions ].getCanonicalName).mkString(" ," ))
172+ .getOrCreate()
173+ try {
174+ assert(session.sessionState.planner.strategies.count(_ == MySparkStrategy (session)) == 2 )
175+ assert(session.sessionState.analyzer.extendedResolutionRules.count(_ == MyRule (session)) == 2 )
176+ assert(session.sessionState.analyzer.postHocResolutionRules.count(_ == MyRule (session)) == 2 )
177+ assert(session.sessionState.analyzer.extendedCheckRules.count(_ == MyCheckRule (session)) == 2 )
178+ assert(session.sessionState.optimizer.batches.flatMap(_.rules)
179+ .count(_ == MyRule (session)) == 4 ) // The optimizer rules are duplicated
180+ val outerParser = session.sessionState.sqlParser
181+ assert(outerParser.isInstanceOf [MyParser ])
182+ assert(outerParser.asInstanceOf [MyParser ].delegate.isInstanceOf [MyParser ])
183+ assert(session.sessionState.functionRegistry
184+ .lookupFunction(MyExtensions .myFunction._1).isDefined)
185+ } finally {
186+ stop(session)
187+ }
188+ }
189+
190+ test(" use the last registered function name when there are duplicates" ) {
191+ val session = SparkSession .builder()
192+ .master(" local[1]" )
193+ .config(" spark.sql.extensions" , Seq (
194+ classOf [MyExtensions2 ].getCanonicalName,
195+ classOf [MyExtensions2Duplicate ].getCanonicalName).mkString(" ," ))
196+ .getOrCreate()
197+ try {
198+ val lastRegistered = session.sessionState.functionRegistry
199+ .lookupFunction(FunctionIdentifier (" myFunction2" ))
200+ assert(lastRegistered.isDefined)
201+ assert(lastRegistered.get.getExtended != MyExtensions2 .myFunction._2.getExtended)
202+ assert(lastRegistered.get.getExtended == MyExtensions2Duplicate .myFunction._2.getExtended)
203+ } finally {
204+ stop(session)
205+ }
206+ }
146207}
147208
148209case class MyRule (spark : SparkSession ) extends Rule [LogicalPlan ] {
@@ -189,19 +250,54 @@ class MyExtensions extends (SparkSessionExtensions => Unit) {
189250 def apply (e : SparkSessionExtensions ): Unit = {
190251 e.injectPlannerStrategy(MySparkStrategy )
191252 e.injectResolutionRule(MyRule )
253+ e.injectPostHocResolutionRule(MyRule )
254+ e.injectCheckRule(MyCheckRule )
255+ e.injectOptimizerRule(MyRule )
256+ e.injectParser(MyParser )
192257 e.injectFunction(MyExtensions .myFunction)
193258 }
194259}
195260
261+ case class MyRule2 (spark : SparkSession ) extends Rule [LogicalPlan ] {
262+ override def apply (plan : LogicalPlan ): LogicalPlan = plan
263+ }
264+
265+ case class MyCheckRule2 (spark : SparkSession ) extends (LogicalPlan => Unit ) {
266+ override def apply (plan : LogicalPlan ): Unit = { }
267+ }
268+
269+ case class MySparkStrategy2 (spark : SparkSession ) extends SparkStrategy {
270+ override def apply (plan : LogicalPlan ): Seq [SparkPlan ] = Seq .empty
271+ }
272+
196273object MyExtensions2 {
197274
198275 val myFunction = (FunctionIdentifier (" myFunction2" ),
199276 new ExpressionInfo (" noClass" , " myDb" , " myFunction2" , " usage" , " extended usage" ),
200- (myArgs : Seq [Expression ]) => Literal (5 , IntegerType ))
277+ (_ : Seq [Expression ]) => Literal (5 , IntegerType ))
201278}
202279
203280class MyExtensions2 extends (SparkSessionExtensions => Unit ) {
204281 def apply (e : SparkSessionExtensions ): Unit = {
282+ e.injectPlannerStrategy(MySparkStrategy2 )
283+ e.injectResolutionRule(MyRule2 )
284+ e.injectPostHocResolutionRule(MyRule2 )
285+ e.injectCheckRule(MyCheckRule2 )
286+ e.injectOptimizerRule(MyRule2 )
287+ e.injectParser((_, _) => CatalystSqlParser )
205288 e.injectFunction(MyExtensions2 .myFunction)
206289 }
207290}
291+
292+ object MyExtensions2Duplicate {
293+
294+ val myFunction = (FunctionIdentifier (" myFunction2" ),
295+ new ExpressionInfo (" noClass" , " myDb" , " myFunction2" , " usage" , " last wins" ),
296+ (_ : Seq [Expression ]) => Literal (5 , IntegerType ))
297+ }
298+
299+ class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit ) {
300+ def apply (e : SparkSessionExtensions ): Unit = {
301+ e.injectFunction(MyExtensions2Duplicate .myFunction)
302+ }
303+ }
0 commit comments