Skip to content

Commit 689a4d2

Browse files
Address comments from dongjoon-hyun, hvanhovell, beliefer, gatorsmile
This addresses the comments for apache#23398
1 parent cef8eb0 commit 689a4d2

File tree

3 files changed

+115
-10
lines changed

3 files changed

+115
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import scala.language.existentials
2525
import scala.reflect.ClassTag
2626
import scala.util.{Failure, Success, Try}
2727

28+
import org.apache.spark.internal.Logging
2829
import org.apache.spark.sql.AnalysisException
2930
import org.apache.spark.sql.catalyst.FunctionIdentifier
3031
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
@@ -87,7 +88,7 @@ trait FunctionRegistry {
8788
override def clone(): FunctionRegistry = throw new CloneNotSupportedException()
8889
}
8990

90-
class SimpleFunctionRegistry extends FunctionRegistry {
91+
class SimpleFunctionRegistry extends FunctionRegistry with Logging {
9192

9293
@GuardedBy("this")
9394
private val functionBuilders =
@@ -103,7 +104,10 @@ class SimpleFunctionRegistry extends FunctionRegistry {
103104
name: FunctionIdentifier,
104105
info: ExpressionInfo,
105106
builder: FunctionBuilder): Unit = synchronized {
106-
functionBuilders.put(normalizeFuncName(name), (info, builder))
107+
val normalizedName = normalizeFuncName(name)
108+
if (functionBuilders.put(normalizedName, (info, builder)).isDefined) {
109+
logWarning(s"The function $normalizedName replaced a previously registered function.")
110+
}
107111
}
108112

109113
override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,13 @@ object StaticSQLConf {
9999
.createWithDefault(false)
100100

101101
val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions")
102-
.doc("List of the class names used to configure Spark Session extensions. The classes should " +
103-
"implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.")
102+
.doc("A comma-separated list of classes that implement " +
103+
"Function1[SparkSessionExtension, Unit] used to configure Spark Session extensions. The " +
104+
"classes must have a no-args constructor. If multiple extensions are specified, they are " +
105+
"applied in the specified order. For the case of rules and planner strategies, they are " +
106+
"applied in the specified order. For the case of parsers, the last parser is used and each " +
107+
"parser can delegate to its predecessor. For the case of function name conflicts, the last " +
108+
"registered function name is used.")
104109
.stringConf
105110
.toSequence
106111
.createOptional

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
5353
}
5454
}
5555

56+
test("inject post hoc resolution analyzer rule") {
57+
withSession(Seq(_.injectPostHocResolutionRule(MyRule))) { session =>
58+
assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session)))
59+
}
60+
}
61+
5662
test("inject check analysis rule") {
5763
withSession(Seq(_.injectCheckRule(MyCheckRule))) { session =>
5864
assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session)))
@@ -118,23 +124,36 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
118124
try {
119125
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
120126
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
127+
assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session)))
128+
assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session)))
129+
assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session)))
130+
assert(session.sessionState.sqlParser.isInstanceOf[MyParser])
121131
assert(session.sessionState.functionRegistry
122132
.lookupFunction(MyExtensions.myFunction._1).isDefined)
123133
} finally {
124134
stop(session)
125135
}
126136
}
127137

128-
test("use multiple custom class for extensions") {
138+
test("use multiple custom class for extensions in the specified order") {
129139
val session = SparkSession.builder()
130140
.master("local[1]")
131141
.config("spark.sql.extensions", Seq(
132-
classOf[MyExtensions].getCanonicalName,
133-
classOf[MyExtensions2].getCanonicalName).mkString(","))
142+
classOf[MyExtensions2].getCanonicalName,
143+
classOf[MyExtensions].getCanonicalName).mkString(","))
134144
.getOrCreate()
135145
try {
136-
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
137-
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
146+
assert(session.sessionState.planner.strategies.containsSlice(
147+
Seq(MySparkStrategy2(session), MySparkStrategy(session))))
148+
val orderedRules = Seq(MyRule2(session), MyRule(session))
149+
val orderedCheckRules = Seq(MyCheckRule2(session), MyCheckRule(session))
150+
val parser = MyParser(session, CatalystSqlParser)
151+
assert(session.sessionState.analyzer.extendedResolutionRules.containsSlice(orderedRules))
152+
assert(session.sessionState.analyzer.postHocResolutionRules.containsSlice(orderedRules))
153+
assert(session.sessionState.analyzer.extendedCheckRules.containsSlice(orderedCheckRules))
154+
assert(session.sessionState.optimizer.batches.flatMap(_.rules).filter(orderedRules.contains)
155+
.containsSlice(orderedRules ++ orderedRules)) // The optimizer rules are duplicated
156+
assert(session.sessionState.sqlParser == parser)
138157
assert(session.sessionState.functionRegistry
139158
.lookupFunction(MyExtensions.myFunction._1).isDefined)
140159
assert(session.sessionState.functionRegistry
@@ -143,6 +162,48 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
143162
stop(session)
144163
}
145164
}
165+
166+
test("allow an extension to be duplicated") {
167+
val session = SparkSession.builder()
168+
.master("local[1]")
169+
.config("spark.sql.extensions", Seq(
170+
classOf[MyExtensions].getCanonicalName,
171+
classOf[MyExtensions].getCanonicalName).mkString(","))
172+
.getOrCreate()
173+
try {
174+
assert(session.sessionState.planner.strategies.count(_ == MySparkStrategy(session)) == 2)
175+
assert(session.sessionState.analyzer.extendedResolutionRules.count(_ == MyRule(session)) == 2)
176+
assert(session.sessionState.analyzer.postHocResolutionRules.count(_ == MyRule(session)) == 2)
177+
assert(session.sessionState.analyzer.extendedCheckRules.count(_ == MyCheckRule(session)) == 2)
178+
assert(session.sessionState.optimizer.batches.flatMap(_.rules)
179+
.count(_ == MyRule(session)) == 4) // The optimizer rules are duplicated
180+
val outerParser = session.sessionState.sqlParser
181+
assert(outerParser.isInstanceOf[MyParser])
182+
assert(outerParser.asInstanceOf[MyParser].delegate.isInstanceOf[MyParser])
183+
assert(session.sessionState.functionRegistry
184+
.lookupFunction(MyExtensions.myFunction._1).isDefined)
185+
} finally {
186+
stop(session)
187+
}
188+
}
189+
190+
test("use the last registered function name when there are duplicates") {
191+
val session = SparkSession.builder()
192+
.master("local[1]")
193+
.config("spark.sql.extensions", Seq(
194+
classOf[MyExtensions2].getCanonicalName,
195+
classOf[MyExtensions2Duplicate].getCanonicalName).mkString(","))
196+
.getOrCreate()
197+
try {
198+
val lastRegistered = session.sessionState.functionRegistry
199+
.lookupFunction(FunctionIdentifier("myFunction2"))
200+
assert(lastRegistered.isDefined)
201+
assert(lastRegistered.get.getExtended != MyExtensions2.myFunction._2.getExtended)
202+
assert(lastRegistered.get.getExtended == MyExtensions2Duplicate.myFunction._2.getExtended)
203+
} finally {
204+
stop(session)
205+
}
206+
}
146207
}
147208

148209
case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -189,19 +250,54 @@ class MyExtensions extends (SparkSessionExtensions => Unit) {
189250
def apply(e: SparkSessionExtensions): Unit = {
190251
e.injectPlannerStrategy(MySparkStrategy)
191252
e.injectResolutionRule(MyRule)
253+
e.injectPostHocResolutionRule(MyRule)
254+
e.injectCheckRule(MyCheckRule)
255+
e.injectOptimizerRule(MyRule)
256+
e.injectParser(MyParser)
192257
e.injectFunction(MyExtensions.myFunction)
193258
}
194259
}
195260

261+
case class MyRule2(spark: SparkSession) extends Rule[LogicalPlan] {
262+
override def apply(plan: LogicalPlan): LogicalPlan = plan
263+
}
264+
265+
case class MyCheckRule2(spark: SparkSession) extends (LogicalPlan => Unit) {
266+
override def apply(plan: LogicalPlan): Unit = { }
267+
}
268+
269+
case class MySparkStrategy2(spark: SparkSession) extends SparkStrategy {
270+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
271+
}
272+
196273
object MyExtensions2 {
197274

198275
val myFunction = (FunctionIdentifier("myFunction2"),
199276
new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "extended usage" ),
200-
(myArgs: Seq[Expression]) => Literal(5, IntegerType))
277+
(_: Seq[Expression]) => Literal(5, IntegerType))
201278
}
202279

203280
class MyExtensions2 extends (SparkSessionExtensions => Unit) {
204281
def apply(e: SparkSessionExtensions): Unit = {
282+
e.injectPlannerStrategy(MySparkStrategy2)
283+
e.injectResolutionRule(MyRule2)
284+
e.injectPostHocResolutionRule(MyRule2)
285+
e.injectCheckRule(MyCheckRule2)
286+
e.injectOptimizerRule(MyRule2)
287+
e.injectParser((_, _) => CatalystSqlParser)
205288
e.injectFunction(MyExtensions2.myFunction)
206289
}
207290
}
291+
292+
object MyExtensions2Duplicate {
293+
294+
val myFunction = (FunctionIdentifier("myFunction2"),
295+
new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "last wins" ),
296+
(_: Seq[Expression]) => Literal(5, IntegerType))
297+
}
298+
299+
class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit) {
300+
def apply(e: SparkSessionExtensions): Unit = {
301+
e.injectFunction(MyExtensions2Duplicate.myFunction)
302+
}
303+
}

0 commit comments

Comments
 (0)