Skip to content
Prev Previous commit
Next Next commit
add isRegisteredFunction check
  • Loading branch information
kevinyu98 committed Jul 12, 2018
commit 0db2826e4a7fb6a9e3f435daa33d8cd2080c7929
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf

class LookupFunctionsSuite extends PlanTest {

test("SPARK-23486: LookupFunctions should not check the same function name more than once") {
test("SPARK-23486: the functionExists for the Persistent function check") {
val externalCatalog = new CustomInMemoryCatalog
val conf = new SQLConf()
val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf)
Expand All @@ -40,20 +40,57 @@ class LookupFunctionsSuite extends PlanTest {
}

def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
val unresolvedFunc = UnresolvedFunction("func", Seq.empty, false)
val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false)
val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
val plan = Project(
Seq(Alias(unresolvedFunc, "call1")(), Alias(unresolvedFunc, "call2")(),
Alias(unresolvedFunc, "call1")()),
Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(),
Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(),
Alias(unresolvedRegisteredFunc, "call5")()),
table("TaBlE"))
analyzer.LookupFunctions.apply(plan)
assert(externalCatalog.getFunctionExistsCalledTimes == 1)
assert(externalCatalog.getFunctionExistsCalledTimes == 1)

assert(analyzer.LookupFunctions.normalizeFuncName
(unresolvedFunc.name).database == Some("default"))
assert(catalog.isRegisteredFunction(unresolvedFunc.name) == false)
assert(catalog.isRegisteredFunction(FunctionIdentifier("max")) == true)
(unresolvedPersistentFunc.name).database == Some("default"))
}

test("SPARK-23486: the functionExists for the Registered function check") {

val externalCatalog = new InMemoryCatalog
val conf = new SQLConf()
val customerFunctionReg = new CustomerFunctionRegistry
val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf)
val analyzer = {
catalog.createDatabase(
CatalogDatabase("default", "", new URI("loc"), Map.empty),
ignoreIfExists = false)
new Analyzer(catalog, conf)
}

def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
val plan = Project(
Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()),
table("TaBlE"))
analyzer.LookupFunctions.apply(plan)
assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2)

assert(analyzer.LookupFunctions.normalizeFuncName
(unresolvedRegisteredFunc.name).database == Some("default"))

}
}

class CustomerFunctionRegistry extends SimpleFunctionRegistry {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevinyu98 Instead of extending FunctionRegistry and Catalog, what do think of extending SessionCatalog and overriding isRegisteredFunction and isPersistentFunction. So after a invocation of LookupFunction we get a count of how many times isRegisteredFunction was called and how many times isPersistentFunction was called ? We can just create an instance of analyzer with a extended Session catalog that we can use in more than one test ? Would that be simpler ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either is fine to me. The major goal of these test cases is to count the number of invocation of functionExists. That is why the current way is more straightforward to reviewers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile Sure Sean.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks


private var isRegisteredFunctionCalledTimes: Int = 0;

override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized {
isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1
true
}

def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes
}

class CustomInMemoryCatalog extends InMemoryCatalog {
Expand Down