Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.catalyst.analysis

import java.util.Locale

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random

Expand Down Expand Up @@ -1204,16 +1207,32 @@ class Analyzer(
* only performs simple existence check according to the function identifier to quickly identify
* undefined functions without triggering relation resolution, which may incur potentially
* expensive partition/schema discovery process in some cases.
*
* In order to avoid duplicate external functions lookup, the external function identifier will
* store in the local hash set externalFunctionNameSet.
* @see [[ResolveFunctions]]
* @see https://issues.apache.org/jira/browse/SPARK-19737
*/
object LookupFunctions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
case f: UnresolvedFunction if !catalog.functionExists(f.name) =>
withPosition(f) {
throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName)
}
override def apply(plan: LogicalPlan): LogicalPlan = {
val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]()
plan.transformAllExpressions {
case f: UnresolvedFunction
if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f
case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) =>
externalFunctionNameSet.add(normalizeFuncName(f.name))
f
case f: UnresolvedFunction =>
withPosition(f) {
throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase),
f.name.funcName)
}
}
}

def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a common utility function. We can refactor the code later.

FunctionIdentifier(name.funcName.toLowerCase(Locale.ROOT),
name.database.orElse(Some(catalog.getCurrentDatabase)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevinyu98 I have a question. So we normalize the funcName here. How about name.database ? Is that normalized already by the time we are here ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevinyu98 how about consideration of conf.caseSensitiveAnalysis ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I will change the code for the name.database. Thanks.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,22 @@ class SessionCatalog(
!hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT))
}

/**
* Return whether this function has been registered in the function registry of the current
* session. If not existed, return false.
*/
def isRegisteredFunction(name: FunctionIdentifier): Boolean = {
functionRegistry.functionExists(name)
}

/**
* Returns whether it is a persistent function. If not existed, returns false.
*/
def isPersistentFunction(name: FunctionIdentifier): Boolean = {
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
databaseExists(db) && externalCatalog.functionExists(db, name.funcName)
}

protected def failFunctionLookup(name: FunctionIdentifier): Nothing = {
throw new NoSuchFunctionException(
db = name.database.getOrElse(getCurrentDatabase), func = name.funcName)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import java.net.URI

import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

class LookupFunctionsSuite extends PlanTest {

test("SPARK-23486: LookupFunctions should not check the same function name more than once") {
val externalCatalog = new CustomInMemoryCatalog
val conf = new SQLConf()
val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf)
val analyzer = {
catalog.createDatabase(
CatalogDatabase("default", "", new URI("loc"), Map.empty),
ignoreIfExists = false)
new Analyzer(catalog, conf)
}

def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
val unresolvedFunc = UnresolvedFunction("func", Seq.empty, false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also verify the logic of normalizeFuncName in this test case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add another one for the function that triggers isRegisteredFunction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the delay. I will do that.

val plan = Project(
Seq(Alias(unresolvedFunc, "call1")(), Alias(unresolvedFunc, "call2")(),
Alias(unresolvedFunc, "call1")()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: call3?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed, thanks.

table("TaBlE"))
analyzer.LookupFunctions.apply(plan)
assert(externalCatalog.getFunctionExistsCalledTimes == 1)

assert(analyzer.LookupFunctions.normalizeFuncName
(unresolvedFunc.name).database == Some("default"))
assert(catalog.isRegisteredFunction(unresolvedFunc.name) == false)
assert(catalog.isRegisteredFunction(FunctionIdentifier("max")) == true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean adding another test case to check whether LookupFunctions does not resolve the registeredFunction more than once.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need to add assert.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I add the test case, can you verify ? thanks a lot.


}
}

class CustomInMemoryCatalog extends InMemoryCatalog {

private var functionExistsCalledTimes: Int = 0

override def functionExists(db: String, funcName: String): Boolean = synchronized {
functionExistsCalledTimes = functionExistsCalledTimes + 1
true
}

def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes

}
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,42 @@ abstract class SessionCatalogSuite extends AnalysisTest {
}
}

test("isRegisteredFunction") {
withBasicCatalog { catalog =>
// Returns false when the function does not register
assert(!catalog.isRegisteredFunction(FunctionIdentifier("temp1")))

// Returns true when the function does register
val tempFunc1 = (e: Seq[Expression]) => e.head
catalog.registerFunction(newFunc("iff", None), overrideIfExists = false,
functionBuilder = Some(tempFunc1) )
assert(catalog.isRegisteredFunction(FunctionIdentifier("iff")))

// Returns false when using the createFunction
catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false)
assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum")))
assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum", Some("db2"))))
}
}

test("isPersistentFunction") {
withBasicCatalog { catalog =>
// Returns false when the function does not register
assert(!catalog.isPersistentFunction(FunctionIdentifier("temp2")))

// Returns false when the function does register
val tempFunc2 = (e: Seq[Expression]) => e.head
catalog.registerFunction(newFunc("iff", None), overrideIfExists = false,
functionBuilder = Some(tempFunc2))
assert(!catalog.isPersistentFunction(FunctionIdentifier("iff")))

// Return true when using the createFunction
catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false)
assert(catalog.isPersistentFunction(FunctionIdentifier("sum", Some("db2"))))
assert(!catalog.isPersistentFunction(FunctionIdentifier("db2.sum")))
}
}

test("drop function") {
withBasicCatalog { catalog =>
assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ private[sql] class HiveSessionCatalog(
super.functionExists(name) || hiveFunctions.contains(name.funcName)
}

override def isPersistentFunction(name: FunctionIdentifier): Boolean = {
super.isPersistentFunction(name) || hiveFunctions.contains(name.funcName)
}

/** List of functions we pass over to Hive. Note that over time this list should go to 0. */
// We have a list of Hive built-in functions that we do not support. So, we will check
// Hive's function registry and lazily load needed functions into our own function registry.
Expand Down