Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
Expand Down Expand Up @@ -87,7 +88,7 @@ trait FunctionRegistry {
override def clone(): FunctionRegistry = throw new CloneNotSupportedException()
}

class SimpleFunctionRegistry extends FunctionRegistry {
class SimpleFunctionRegistry extends FunctionRegistry with Logging {

@GuardedBy("this")
private val functionBuilders =
Expand All @@ -103,7 +104,13 @@ class SimpleFunctionRegistry extends FunctionRegistry {
name: FunctionIdentifier,
info: ExpressionInfo,
builder: FunctionBuilder): Unit = synchronized {
functionBuilders.put(normalizeFuncName(name), (info, builder))
val normalizedName = normalizeFuncName(name)
val newFunction = (info, builder)
functionBuilders.put(normalizedName, newFunction) match {
case Some(previousFunction) if previousFunction != newFunction =>
logWarning(s"The function $normalizedName replaced a previously registered function.")
case _ => Unit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is returning Unit type object. This can be just removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a change here which I think was what you were recommending. I didn't want to remove the case _ => otherwise I think it may result in a non-exhaustive match exception. If you wanted me to remove the entire match statement, just let me know.

}
}

override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,15 @@ object StaticSQLConf {
.createWithDefault(false)

val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions")
.doc("Name of the class used to configure Spark Session extensions. The class should " +
"implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.")
.doc("A comma-separated list of classes that implement " +
"Function1[SparkSessionExtension, Unit] used to configure Spark Session extensions. The " +
"classes must have a no-args constructor. If multiple extensions are specified, they are " +
"applied in the specified order. For the case of rules and planner strategies, they are " +
"applied in the specified order. For the case of parsers, the last parser is used and each " +
"parser can delegate to its predecessor. For the case of function name conflicts, the last " +
"registered function name is used.")
.stringConf
.toSequence
.createOptional

val QUERY_EXECUTION_LISTENERS = buildStaticConf("spark.sql.queryExecutionListeners")
Expand Down
11 changes: 5 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class SparkSession private(
private[sql] def this(sc: SparkContext) {
this(sc, None, None,
SparkSession.applyExtensions(
sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
new SparkSessionExtensions))
}

Expand Down Expand Up @@ -950,7 +950,7 @@ object SparkSession extends Logging {
}

applyExtensions(
sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
extensions)

session = new SparkSession(sparkContext, None, None, extensions)
Expand Down Expand Up @@ -1138,14 +1138,13 @@ object SparkSession extends Logging {
}

/**
* Initialize extensions for given extension classname. This class will be applied to the
* Initialize extensions for given extension classnames. The classes will be applied to the
* extensions passed into this function.
*/
private def applyExtensions(
extensionOption: Option[String],
extensionConfClassNames: Seq[String],
extensions: SparkSessionExtensions): SparkSessionExtensions = {
if (extensionOption.isDefined) {
val extensionConfClassName = extensionOption.get
extensionConfClassNames.foreach { extensionConfClassName =>
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.getConstructor().newInstance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,41 +30,49 @@ import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
*/
class SparkSessionExtensionSuite extends SparkFunSuite {
type ExtensionsBuilder = SparkSessionExtensions => Unit
private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder
private def create(builder: ExtensionsBuilder): Seq[ExtensionsBuilder] = Seq(builder)

private def stop(spark: SparkSession): Unit = {
spark.stop()
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
}

private def withSession(builder: ExtensionsBuilder)(f: SparkSession => Unit): Unit = {
val spark = SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate()
private def withSession(builders: Seq[ExtensionsBuilder])(f: SparkSession => Unit): Unit = {
val builder = SparkSession.builder().master("local[1]")
builders.foreach(builder.withExtensions)
val spark = builder.getOrCreate()
try f(spark) finally {
stop(spark)
}
}

test("inject analyzer rule") {
withSession(_.injectResolutionRule(MyRule)) { session =>
withSession(Seq(_.injectResolutionRule(MyRule))) { session =>
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
}
}

test("inject post hoc resolution analyzer rule") {
withSession(Seq(_.injectPostHocResolutionRule(MyRule))) { session =>
assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session)))
}
}

test("inject check analysis rule") {
withSession(_.injectCheckRule(MyCheckRule)) { session =>
withSession(Seq(_.injectCheckRule(MyCheckRule))) { session =>
assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session)))
}
}

test("inject optimizer rule") {
withSession(_.injectOptimizerRule(MyRule)) { session =>
withSession(Seq(_.injectOptimizerRule(MyRule))) { session =>
assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session)))
}
}

test("inject spark planner strategy") {
withSession(_.injectPlannerStrategy(MySparkStrategy)) { session =>
withSession(Seq(_.injectPlannerStrategy(MySparkStrategy))) { session =>
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
}
}
Expand All @@ -78,6 +86,14 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}

test("inject multiple rules") {
withSession(Seq(_.injectOptimizerRule(MyRule),
_.injectPlannerStrategy(MySparkStrategy))) { session =>
assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session)))
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
}
}

test("inject stacked parsers") {
val extension = create { extensions =>
extensions.injectParser((_, _) => CatalystSqlParser)
Expand Down Expand Up @@ -108,12 +124,86 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
try {
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session)))
assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session)))
assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session)))
assert(session.sessionState.sqlParser.isInstanceOf[MyParser])
assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions.myFunction._1).isDefined)
} finally {
stop(session)
}
}

test("use multiple custom class for extensions in the specified order") {
val session = SparkSession.builder()
.master("local[1]")
.config("spark.sql.extensions", Seq(
classOf[MyExtensions2].getCanonicalName,
classOf[MyExtensions].getCanonicalName).mkString(","))
.getOrCreate()
try {
assert(session.sessionState.planner.strategies.containsSlice(
Seq(MySparkStrategy2(session), MySparkStrategy(session))))
val orderedRules = Seq(MyRule2(session), MyRule(session))
val orderedCheckRules = Seq(MyCheckRule2(session), MyCheckRule(session))
val parser = MyParser(session, CatalystSqlParser)
assert(session.sessionState.analyzer.extendedResolutionRules.containsSlice(orderedRules))
assert(session.sessionState.analyzer.postHocResolutionRules.containsSlice(orderedRules))
assert(session.sessionState.analyzer.extendedCheckRules.containsSlice(orderedCheckRules))
assert(session.sessionState.optimizer.batches.flatMap(_.rules).filter(orderedRules.contains)
.containsSlice(orderedRules ++ orderedRules)) // The optimizer rules are duplicated
assert(session.sessionState.sqlParser == parser)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all these asserts, use === and !==.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually arguable, Vanzin. Some people prefer === whereas some prefer ==. === doesn't look always reporting a better error message give my tests. See also databricks/scala-style-guide#36.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on databricks/scala-style-guide#36, it looks like == might now be preferred over ===. For what its worth, it seems that in the cases for this test == produces reasonable error messages such as MyParser(org.apache.spark.sql.SparkSession@6e8a9c30,org.apache.spark.sql.catalyst.parser.CatalystSqlParser$@5d01ea21) did not equal IntentionalErrorThatIInsertedHere and 2 did not equal 3. So please let me know if there is newer guidance to use === and I can make the changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

== and === are not equivalent, and we use === in tests. The latter, for example, handles arrays correctly, which the former does not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the follow up, I was not aware of that difference. I updated the tests to use === and !== as originally recommended.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vanzin, where does it say we use ===? If it's practically yes, let's document it.

assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions.myFunction._1).isDefined)
assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions2.myFunction._1).isDefined)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, now we have multiple extension registrations. The order of extension names might have side-effects.

  • Can we have a test case for duplicated extension names? MyExtension2 and MyExtension2?
  • Can we have a negative test case for function name conflicts? MyExtension2.myFunction and MyExtension3.myFunction?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the order matters, but we need to discuss and document the behavior when we have name conflicts.

For example, the same rule will be added twice in extendedResolutionRules. Is it desired?

class MyExtensions extends (SparkSessionExtensions => Unit) {
  def apply(e: SparkSessionExtensions): Unit = {
    e.injectResolutionRule(MyRule)
  }
}

class MyExtensions2 extends (SparkSessionExtensions => Unit) {
  def apply(e: SparkSessionExtensions): Unit = {
    e.injectResolutionRule(MyRule)
  }
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. If there is no reason to allow that, we had better disallow that by design before this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are use cases where you want to execute rules in a certain order. So I think it is reasonable to add the same rule multiple times. If you want more control you could even create 'micro' optimizer batches by calling multiple rules from one rule.

I think this is more a matter of proper documentation than one where we should explicitly block things. Also note that this is a pretty advanced feature and by this stage users are expected to know what they are doing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prior to this change, it was possible to programmatically register multiple extensions but it was not possible to do so through the spark.sql.extensions configuration. Although it wasn't documented/tested until this pull request. E.g. The following works without this pull request:

SparkSession.builder()
  .master("..")
  .withExtensions(sparkSessionExtensions1)
  .withExtensions(sparkSessionExtensions2)
  .getOrCreate()

So I think conflicting function names are already currently possible (but not documented). In the following cases:

  1. Conflicting function names are registered by calling .withExtenions() multiple times
  2. An extension accidentally registers a function that was already registered with the builtin functions
  3. An extension accidentally registers a function multiple times by calling injectFunction(myFunction)

As for the order, it looks to me like the last function to be stored with conflicting names is the one which is retrieved:

class SimpleFunctionRegistry extends FunctionRegistry {

  @GuardedBy("this")
  private val functionBuilders =
    new mutable.HashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)]
  override def registerFunction(
      name: FunctionIdentifier,
      info: ExpressionInfo,
      builder: FunctionBuilder): Unit = synchronized {
    functionBuilders.put(normalizeFuncName(name), (info, builder))
  }

I will update this PR to document what happens in order of operations and conflicts. If we need to explicitly block duplicates functions from being registered, I can temporarily drop this PR and see about making those changes first.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining it. We do not need to block it, but we might need to detect and throw a warning message at least.

More importantly, we need to document the current behavior and also add a test case to ensure the future changes will not break it. In the future, we can revisit the current behavior and make a change if needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a test case for duplicated extension names? Done
Can we have a negative test case for function name conflicts? MyExtension2.myFunction and MyExtension3.myFunction? Done

I added documentation for the behavior.
I added a warning message if a registered function is replaced.
I added a test case for the ordering.

} finally {
stop(session)
}
}

test("allow an extension to be duplicated") {
val session = SparkSession.builder()
.master("local[1]")
.config("spark.sql.extensions", Seq(
classOf[MyExtensions].getCanonicalName,
classOf[MyExtensions].getCanonicalName).mkString(","))
.getOrCreate()
try {
assert(session.sessionState.planner.strategies.count(_ == MySparkStrategy(session)) == 2)
assert(session.sessionState.analyzer.extendedResolutionRules.count(_ == MyRule(session)) == 2)
assert(session.sessionState.analyzer.postHocResolutionRules.count(_ == MyRule(session)) == 2)
assert(session.sessionState.analyzer.extendedCheckRules.count(_ == MyCheckRule(session)) == 2)
assert(session.sessionState.optimizer.batches.flatMap(_.rules)
.count(_ == MyRule(session)) == 4) // The optimizer rules are duplicated
val outerParser = session.sessionState.sqlParser
assert(outerParser.isInstanceOf[MyParser])
assert(outerParser.asInstanceOf[MyParser].delegate.isInstanceOf[MyParser])
assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions.myFunction._1).isDefined)
} finally {
stop(session)
}
}

test("use the last registered function name when there are duplicates") {
val session = SparkSession.builder()
.master("local[1]")
.config("spark.sql.extensions", Seq(
classOf[MyExtensions2].getCanonicalName,
classOf[MyExtensions2Duplicate].getCanonicalName).mkString(","))
.getOrCreate()
try {
val lastRegistered = session.sessionState.functionRegistry
.lookupFunction(FunctionIdentifier("myFunction2"))
assert(lastRegistered.isDefined)
assert(lastRegistered.get.getExtended != MyExtensions2.myFunction._2.getExtended)
assert(lastRegistered.get.getExtended == MyExtensions2Duplicate.myFunction._2.getExtended)
} finally {
stop(session)
}
}
}

case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
Expand Down Expand Up @@ -151,14 +241,63 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars
object MyExtensions {

val myFunction = (FunctionIdentifier("myFunction"),
new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended usage" ),
new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended usage"),
(myArgs: Seq[Expression]) => Literal(5, IntegerType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: (_: Seq[Expression]) => Literal(5, IntegerType))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove this newline. Looks unrelated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

class MyExtensions extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectPlannerStrategy(MySparkStrategy)
e.injectResolutionRule(MyRule)
e.injectPostHocResolutionRule(MyRule)
e.injectCheckRule(MyCheckRule)
e.injectOptimizerRule(MyRule)
e.injectParser(MyParser)
e.injectFunction(MyExtensions.myFunction)
}
}

case class MyRule2(spark: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan
}

case class MyCheckRule2(spark: SparkSession) extends (LogicalPlan => Unit) {
override def apply(plan: LogicalPlan): Unit = { }
}

case class MySparkStrategy2(spark: SparkSession) extends SparkStrategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
}

object MyExtensions2 {

val myFunction = (FunctionIdentifier("myFunction2"),
new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "extended usage" ),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: " ) -> ")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

(_: Seq[Expression]) => Literal(5, IntegerType))
}

class MyExtensions2 extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectPlannerStrategy(MySparkStrategy2)
e.injectResolutionRule(MyRule2)
e.injectPostHocResolutionRule(MyRule2)
e.injectCheckRule(MyCheckRule2)
e.injectOptimizerRule(MyRule2)
e.injectParser((_, _) => CatalystSqlParser)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: e.injectParser((_: SparkSession, _: ParserInterface) => CatalystSqlParser)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also made the suggested change in 2 other places in this file so that it is consistent.

e.injectFunction(MyExtensions2.myFunction)
}
}

object MyExtensions2Duplicate {

val myFunction = (FunctionIdentifier("myFunction2"),
new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "last wins" ),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit "last wins" -> "extended usage"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: " ) -> ")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made both changes and also updated one of the tests to validate the ExpressionInfo object rather than the extended usage text.

(_: Seq[Expression]) => Literal(5, IntegerType))
}

class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectFunction(MyExtensions2Duplicate.myFunction)
}
}