Skip to content

Commit 5181543

Browse files
committed
[SPARK-35380][SQL] Loading SparkSessionExtensions from ServiceLoader
### What changes were proposed in this pull request? In yaooqinn/itachi#8, we had a discussion about the current extension injection for the spark session. We've agreed that the current way is not that convenient for both third-party developers and end-users. It's much simple if third-party developers can provide a resource file that contains default extensions for Spark to load ahead ### Why are the changes needed? better use experience ### Does this PR introduce _any_ user-facing change? no, dev change ### How was this patch tested? new tests Closes #32515 from yaooqinn/SPARK-35380. Authored-by: Kent Yao <[email protected]> Signed-off-by: Kent Yao <[email protected]>
1 parent dd54649 commit 5181543

File tree

10 files changed

+276
-5
lines changed

10 files changed

+276
-5
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
org.apache.spark.examples.extensions.SessionExtensionsWithLoader
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.extensions
19+
20+
import org.apache.spark.sql.catalyst.expressions.{CurrentDate, Expression, RuntimeReplaceable, SubtractDates}
21+
22+
/**
23+
* How old are you in days?
24+
*/
25+
case class AgeExample(birthday: Expression, child: Expression) extends RuntimeReplaceable {
26+
27+
def this(birthday: Expression) = this(birthday, SubtractDates(CurrentDate(), birthday))
28+
override def exprsReplaced: Seq[Expression] = Seq(birthday)
29+
30+
override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
31+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.extensions
19+
20+
import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider}
21+
import org.apache.spark.sql.catalyst.FunctionIdentifier
22+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
23+
24+
class SessionExtensionsWithLoader extends SparkSessionExtensionsProvider {
25+
override def apply(v1: SparkSessionExtensions): Unit = {
26+
v1.injectFunction(
27+
(new FunctionIdentifier("age_two"),
28+
new ExpressionInfo(classOf[AgeExample].getName,
29+
"age_two"), (children: Seq[Expression]) => new AgeExample(children.head)))
30+
}
31+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.extensions
19+
20+
import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider}
21+
import org.apache.spark.sql.catalyst.FunctionIdentifier
22+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
23+
24+
class SessionExtensionsWithoutLoader extends SparkSessionExtensionsProvider {
25+
override def apply(v1: SparkSessionExtensions): Unit = {
26+
v1.injectFunction(
27+
(new FunctionIdentifier("age_one"),
28+
new ExpressionInfo(classOf[AgeExample].getName,
29+
"age_one"), (children: Seq[Expression]) => new AgeExample(children.head)))
30+
}
31+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.extensions
19+
20+
import org.apache.spark.sql.SparkSession
21+
22+
/**
23+
* [[SessionExtensionsWithLoader]] is registered in
24+
* src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider
25+
*
26+
* [[SessionExtensionsWithoutLoader]] is registered via spark.sql.extensions
27+
*/
28+
object SparkSessionExtensionsTest {
29+
30+
def main(args: Array[String]): Unit = {
31+
val spark = SparkSession
32+
.builder()
33+
.appName("SparkSessionExtensionsTest")
34+
.config("spark.sql.extensions", classOf[SessionExtensionsWithoutLoader].getName)
35+
.getOrCreate()
36+
spark.sql("SELECT age_one('2018-11-17'), age_two('2018-11-17')").show()
37+
}
38+
}

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import java.io.Closeable
21-
import java.util.UUID
21+
import java.util.{ServiceLoader, UUID}
2222
import java.util.concurrent.TimeUnit._
2323
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
2424

@@ -949,6 +949,7 @@ object SparkSession extends Logging {
949949
// Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions.
950950
}
951951

952+
loadExtensions(extensions)
952953
applyExtensions(
953954
sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
954955
extensions)
@@ -1203,4 +1204,22 @@ object SparkSession extends Logging {
12031204
}
12041205
extensions
12051206
}
1207+
1208+
/**
1209+
* Load extensions from [[ServiceLoader]] and use them
1210+
*/
1211+
private def loadExtensions(extensions: SparkSessionExtensions): Unit = {
1212+
val loader = ServiceLoader.load(classOf[SparkSessionExtensionsProvider],
1213+
Utils.getContextOrSparkClassLoader)
1214+
val loadedExts = loader.iterator()
1215+
1216+
while (loadedExts.hasNext) {
1217+
try {
1218+
val ext = loadedExts.next()
1219+
ext(extensions)
1220+
} catch {
1221+
case e: Throwable => logWarning("Failed to load session extension", e)
1222+
}
1223+
}
1224+
}
12061225
}

sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
7171
* {{{
7272
* SparkSession.builder()
7373
* .master("...")
74-
* .config("spark.sql.extensions", "org.example.MyExtensions")
74+
* .config("spark.sql.extensions", "org.example.MyExtensions,org.example.YourExtensions")
7575
* .getOrCreate()
7676
*
7777
* class MyExtensions extends Function1[SparkSessionExtensions, Unit] {
@@ -84,6 +84,15 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
8484
* }
8585
* }
8686
* }
87+
*
88+
* class YourExtensions extends SparkSessionExtensionsProvider {
89+
* override def apply(extensions: SparkSessionExtensions): Unit = {
90+
* extensions.injectResolutionRule { session =>
91+
* ...
92+
* }
93+
* extensions.injectFunction(...)
94+
* }
95+
* }
8796
* }}}
8897
*
8998
* Note that none of the injected builders should assume that the [[SparkSession]] is fully
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.annotation.{DeveloperApi, Since, Unstable}
21+
22+
// scalastyle:off line.size.limit
23+
/**
24+
* :: Unstable ::
25+
*
26+
* Base trait for implementations used by [[SparkSessionExtensions]]
27+
*
28+
*
29+
* For example, now we have an external function named `Age` to register as an extension for SparkSession:
30+
*
31+
*
32+
* {{{
33+
* package org.apache.spark.examples.extensions
34+
*
35+
* import org.apache.spark.sql.catalyst.expressions.{CurrentDate, Expression, RuntimeReplaceable, SubtractDates}
36+
*
37+
* case class Age(birthday: Expression, child: Expression) extends RuntimeReplaceable {
38+
*
39+
* def this(birthday: Expression) = this(birthday, SubtractDates(CurrentDate(), birthday))
40+
* override def exprsReplaced: Seq[Expression] = Seq(birthday)
41+
* override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild)
42+
* }
43+
* }}}
44+
*
45+
* We need to create our extension which inherits [[SparkSessionExtensionsProvider]]
46+
* Example:
47+
*
48+
* {{{
49+
* package org.apache.spark.examples.extensions
50+
*
51+
* import org.apache.spark.sql.{SparkSessionExtensions, SparkSessionExtensionsProvider}
52+
* import org.apache.spark.sql.catalyst.FunctionIdentifier
53+
* import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
54+
*
55+
* class MyExtensions extends SparkSessionExtensionsProvider {
56+
* override def apply(v1: SparkSessionExtensions): Unit = {
57+
* v1.injectFunction(
58+
* (new FunctionIdentifier("age"),
59+
* new ExpressionInfo(classOf[Age].getName, "age"),
60+
* (children: Seq[Expression]) => new Age(children.head)))
61+
* }
62+
* }
63+
* }}}
64+
*
65+
* Then, we can inject `MyExtensions` in three ways,
66+
* <ul>
67+
* <li>withExtensions of [[SparkSession.Builder]]</li>
68+
* <li>Config - spark.sql.extensions</li>
69+
* <li>[[java.util.ServiceLoader]] - Add to src/main/resources/META-INF/services/org.apache.spark.sql.SparkSessionExtensionsProvider</li>
70+
* </ul>
71+
*
72+
* @see [[SparkSessionExtensions]]
73+
* @see [[SparkSession.Builder]]
74+
* @see [[java.util.ServiceLoader]]
75+
*
76+
* @since 3.2.0
77+
*/
78+
@DeveloperApi
79+
@Unstable
80+
@Since("3.2.0")
81+
trait SparkSessionExtensionsProvider extends Function1[SparkSessionExtensions, Unit]
82+
// scalastyle:on line.size.limit
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
org.apache.spark.sql.YourExtensions

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,17 @@ import org.apache.spark.unsafe.types.UTF8String
4646
* Test cases for the [[SparkSessionExtensions]].
4747
*/
4848
class SparkSessionExtensionSuite extends SparkFunSuite {
49-
type ExtensionsBuilder = SparkSessionExtensions => Unit
50-
private def create(builder: ExtensionsBuilder): Seq[ExtensionsBuilder] = Seq(builder)
49+
private def create(
50+
builder: SparkSessionExtensionsProvider): Seq[SparkSessionExtensionsProvider] = Seq(builder)
5151

5252
private def stop(spark: SparkSession): Unit = {
5353
spark.stop()
5454
SparkSession.clearActiveSession()
5555
SparkSession.clearDefaultSession()
5656
}
5757

58-
private def withSession(builders: Seq[ExtensionsBuilder])(f: SparkSession => Unit): Unit = {
58+
private def withSession(
59+
builders: Seq[SparkSessionExtensionsProvider])(f: SparkSession => Unit): Unit = {
5960
val builder = SparkSession.builder().master("local[1]")
6061
builders.foreach(builder.withExtensions)
6162
val spark = builder.getOrCreate()
@@ -355,6 +356,20 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
355356
stop(session)
356357
}
357358
}
359+
360+
test("SPARK-35380: Loading extensions from ServiceLoader") {
361+
val builder = SparkSession.builder().master("local[1]")
362+
363+
Seq(None, Some(classOf[YourExtensions].getName)).foreach { ext =>
364+
ext.foreach(builder.config(SPARK_SESSION_EXTENSIONS.key, _))
365+
val session = builder.getOrCreate()
366+
try {
367+
assert(session.sql("select get_fake_app_name()").head().getString(0) === "Fake App Name")
368+
} finally {
369+
stop(session)
370+
}
371+
}
372+
}
358373
}
359374

360375
case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -959,3 +974,16 @@ class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit) {
959974
e.injectFunction(MyExtensions2Duplicate.myFunction)
960975
}
961976
}
977+
978+
class YourExtensions extends SparkSessionExtensionsProvider {
979+
val getAppName = (FunctionIdentifier("get_fake_app_name"),
980+
new ExpressionInfo(
981+
"zzz.zzz.zzz",
982+
"",
983+
"get_fake_app_name"),
984+
(_: Seq[Expression]) => Literal("Fake App Name"))
985+
986+
override def apply(v1: SparkSessionExtensions): Unit = {
987+
v1.injectFunction(getAppName)
988+
}
989+
}

0 commit comments

Comments
 (0)