Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address comments from dongjoon-hyun, hvanhovell, beliefer, gatorsmile
This addresses the comments for #23398
  • Loading branch information
jamisonbennett committed Dec 29, 2018
commit 689a4d203dc2dec68c6b97c7d6529e90b094d31e
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
Expand Down Expand Up @@ -87,7 +88,7 @@ trait FunctionRegistry {
override def clone(): FunctionRegistry = throw new CloneNotSupportedException()
}

class SimpleFunctionRegistry extends FunctionRegistry {
class SimpleFunctionRegistry extends FunctionRegistry with Logging {

@GuardedBy("this")
private val functionBuilders =
Expand All @@ -103,7 +104,10 @@ class SimpleFunctionRegistry extends FunctionRegistry {
name: FunctionIdentifier,
info: ExpressionInfo,
builder: FunctionBuilder): Unit = synchronized {
functionBuilders.put(normalizeFuncName(name), (info, builder))
val normalizedName = normalizeFuncName(name)
if (functionBuilders.put(normalizedName, (info, builder)).isDefined) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if we can check if the new function and the old function are different. This will help to increase the signal of the error message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a check which which will only log if different function objects are registered. The "allow an extension to be duplicated" unit tests that I previously added registers the same object twice. This test no longer prints the warning. The "use the last registered function name when there are duplicates" unit tests that I previously added registers different functions with the same name. This test prints the warning.

logWarning(s"The function $normalizedName replaced a previously registered function.")
}
}

override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,13 @@ object StaticSQLConf {
.createWithDefault(false)

val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions")
.doc("List of the class names used to configure Spark Session extensions. The classes should " +
"implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.")
.doc("A comma-separated list of classes that implement " +
"Function1[SparkSessionExtension, Unit] used to configure Spark Session extensions. The " +
"classes must have a no-args constructor. If multiple extensions are specified, they are " +
"applied in the specified order. For the case of rules and planner strategies, they are " +
"applied in the specified order. For the case of parsers, the last parser is used and each " +
"parser can delegate to its predecessor. For the case of function name conflicts, the last " +
"registered function name is used.")
.stringConf
.toSequence
.createOptional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}

test("inject post hoc resolution analyzer rule") {
withSession(Seq(_.injectPostHocResolutionRule(MyRule))) { session =>
assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session)))
}
}

test("inject check analysis rule") {
withSession(Seq(_.injectCheckRule(MyCheckRule))) { session =>
assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session)))
Expand Down Expand Up @@ -118,23 +124,36 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
try {
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session)))
assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session)))
assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session)))
assert(session.sessionState.sqlParser.isInstanceOf[MyParser])
assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions.myFunction._1).isDefined)
} finally {
stop(session)
}
}

test("use multiple custom class for extensions") {
test("use multiple custom class for extensions in the specified order") {
val session = SparkSession.builder()
.master("local[1]")
.config("spark.sql.extensions", Seq(
classOf[MyExtensions].getCanonicalName,
classOf[MyExtensions2].getCanonicalName).mkString(","))
classOf[MyExtensions2].getCanonicalName,
classOf[MyExtensions].getCanonicalName).mkString(","))
.getOrCreate()
try {
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
assert(session.sessionState.planner.strategies.containsSlice(
Seq(MySparkStrategy2(session), MySparkStrategy(session))))
val orderedRules = Seq(MyRule2(session), MyRule(session))
val orderedCheckRules = Seq(MyCheckRule2(session), MyCheckRule(session))
val parser = MyParser(session, CatalystSqlParser)
assert(session.sessionState.analyzer.extendedResolutionRules.containsSlice(orderedRules))
assert(session.sessionState.analyzer.postHocResolutionRules.containsSlice(orderedRules))
assert(session.sessionState.analyzer.extendedCheckRules.containsSlice(orderedCheckRules))
assert(session.sessionState.optimizer.batches.flatMap(_.rules).filter(orderedRules.contains)
.containsSlice(orderedRules ++ orderedRules)) // The optimizer rules are duplicated
assert(session.sessionState.sqlParser == parser)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all these asserts, use === and !==.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually arguable, Vanzin. Some people prefer === whereas some prefer ==. === doesn't look always reporting a better error message give my tests. See also databricks/scala-style-guide#36.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on databricks/scala-style-guide#36, it looks like == might now be preferred over ===. For what its worth, it seems that in the cases for this test == produces reasonable error messages such as MyParser(org.apache.spark.sql.SparkSession@6e8a9c30,org.apache.spark.sql.catalyst.parser.CatalystSqlParser$@5d01ea21) did not equal IntentionalErrorThatIInsertedHere and 2 did not equal 3. So please let me know if there is newer guidance to use === and I can make the changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

== and === are not equivalent, and we use === in tests. The latter, for example, handles arrays correctly, which the former does not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the follow up, I was not aware of that difference. I updated the tests to use === and !== as originally recommended.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vanzin, where does it say we use ===? If it's practically yes, let's document it.

assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions.myFunction._1).isDefined)
assert(session.sessionState.functionRegistry
Expand All @@ -143,6 +162,48 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
stop(session)
}
}

test("allow an extension to be duplicated") {
val session = SparkSession.builder()
.master("local[1]")
.config("spark.sql.extensions", Seq(
classOf[MyExtensions].getCanonicalName,
classOf[MyExtensions].getCanonicalName).mkString(","))
.getOrCreate()
try {
assert(session.sessionState.planner.strategies.count(_ == MySparkStrategy(session)) == 2)
assert(session.sessionState.analyzer.extendedResolutionRules.count(_ == MyRule(session)) == 2)
assert(session.sessionState.analyzer.postHocResolutionRules.count(_ == MyRule(session)) == 2)
assert(session.sessionState.analyzer.extendedCheckRules.count(_ == MyCheckRule(session)) == 2)
assert(session.sessionState.optimizer.batches.flatMap(_.rules)
.count(_ == MyRule(session)) == 4) // The optimizer rules are duplicated
val outerParser = session.sessionState.sqlParser
assert(outerParser.isInstanceOf[MyParser])
assert(outerParser.asInstanceOf[MyParser].delegate.isInstanceOf[MyParser])
assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions.myFunction._1).isDefined)
} finally {
stop(session)
}
}

test("use the last registered function name when there are duplicates") {
val session = SparkSession.builder()
.master("local[1]")
.config("spark.sql.extensions", Seq(
classOf[MyExtensions2].getCanonicalName,
classOf[MyExtensions2Duplicate].getCanonicalName).mkString(","))
.getOrCreate()
try {
val lastRegistered = session.sessionState.functionRegistry
.lookupFunction(FunctionIdentifier("myFunction2"))
assert(lastRegistered.isDefined)
assert(lastRegistered.get.getExtended != MyExtensions2.myFunction._2.getExtended)
assert(lastRegistered.get.getExtended == MyExtensions2Duplicate.myFunction._2.getExtended)
} finally {
stop(session)
}
}
}

case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
Expand Down Expand Up @@ -189,19 +250,54 @@ class MyExtensions extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectPlannerStrategy(MySparkStrategy)
e.injectResolutionRule(MyRule)
e.injectPostHocResolutionRule(MyRule)
e.injectCheckRule(MyCheckRule)
e.injectOptimizerRule(MyRule)
e.injectParser(MyParser)
e.injectFunction(MyExtensions.myFunction)
}
}

case class MyRule2(spark: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan
}

case class MyCheckRule2(spark: SparkSession) extends (LogicalPlan => Unit) {
override def apply(plan: LogicalPlan): Unit = { }
}

case class MySparkStrategy2(spark: SparkSession) extends SparkStrategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
}

object MyExtensions2 {

val myFunction = (FunctionIdentifier("myFunction2"),
new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "extended usage" ),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: " ) -> ")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

(myArgs: Seq[Expression]) => Literal(5, IntegerType))
(_: Seq[Expression]) => Literal(5, IntegerType))
}

class MyExtensions2 extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectPlannerStrategy(MySparkStrategy2)
e.injectResolutionRule(MyRule2)
e.injectPostHocResolutionRule(MyRule2)
e.injectCheckRule(MyCheckRule2)
e.injectOptimizerRule(MyRule2)
e.injectParser((_, _) => CatalystSqlParser)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: e.injectParser((_: SparkSession, _: ParserInterface) => CatalystSqlParser)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also made the suggested change in 2 other places in this file so that it is consistent.

e.injectFunction(MyExtensions2.myFunction)
}
}

object MyExtensions2Duplicate {

val myFunction = (FunctionIdentifier("myFunction2"),
new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "last wins" ),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit "last wins" -> "extended usage"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: " ) -> ")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made both changes and also updated one of the tests to validate the ExpressionInfo object rather than the extended usage text.

(_: Seq[Expression]) => Literal(5, IntegerType))
}

class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectFunction(MyExtensions2Duplicate.myFunction)
}
}