Skip to content

Commit 0db2826

Browse files
committed
add isRegisteredFunction check
1 parent 8dceda9 commit 0db2826

File tree

1 file changed

+45
-8
lines changed

1 file changed

+45
-8
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf
2828

2929
class LookupFunctionsSuite extends PlanTest {
3030

31-
test("SPARK-23486: LookupFunctions should not check the same function name more than once") {
31+
test("SPARK-23486: the functionExists for the Persistent function check") {
3232
val externalCatalog = new CustomInMemoryCatalog
3333
val conf = new SQLConf()
3434
val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf)
@@ -40,20 +40,57 @@ class LookupFunctionsSuite extends PlanTest {
4040
}
4141

4242
def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
43-
val unresolvedFunc = UnresolvedFunction("func", Seq.empty, false)
43+
val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false)
44+
val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
4445
val plan = Project(
45-
Seq(Alias(unresolvedFunc, "call1")(), Alias(unresolvedFunc, "call2")(),
46-
Alias(unresolvedFunc, "call1")()),
46+
Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(),
47+
Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(),
48+
Alias(unresolvedRegisteredFunc, "call5")()),
4749
table("TaBlE"))
4850
analyzer.LookupFunctions.apply(plan)
49-
assert(externalCatalog.getFunctionExistsCalledTimes == 1)
51+
assert(externalCatalog.getFunctionExistsCalledTimes == 1)
5052

5153
assert(analyzer.LookupFunctions.normalizeFuncName
52-
(unresolvedFunc.name).database == Some("default"))
53-
assert(catalog.isRegisteredFunction(unresolvedFunc.name) == false)
54-
assert(catalog.isRegisteredFunction(FunctionIdentifier("max")) == true)
54+
(unresolvedPersistentFunc.name).database == Some("default"))
55+
}
56+
57+
test("SPARK-23486: the functionExists for the Registered function check") {
58+
59+
val externalCatalog = new InMemoryCatalog
60+
val conf = new SQLConf()
61+
val customerFunctionReg = new CustomerFunctionRegistry
62+
val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf)
63+
val analyzer = {
64+
catalog.createDatabase(
65+
CatalogDatabase("default", "", new URI("loc"), Map.empty),
66+
ignoreIfExists = false)
67+
new Analyzer(catalog, conf)
68+
}
69+
70+
def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
71+
val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
72+
val plan = Project(
73+
Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()),
74+
table("TaBlE"))
75+
analyzer.LookupFunctions.apply(plan)
76+
assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2)
77+
78+
assert(analyzer.LookupFunctions.normalizeFuncName
79+
(unresolvedRegisteredFunc.name).database == Some("default"))
80+
81+
}
82+
}
83+
84+
class CustomerFunctionRegistry extends SimpleFunctionRegistry {
5585

86+
private var isRegisteredFunctionCalledTimes: Int = 0;
87+
88+
override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized {
89+
isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1
90+
true
5691
}
92+
93+
def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes
5794
}
5895

5996
class CustomInMemoryCatalog extends InMemoryCatalog {

0 commit comments

Comments
 (0)