Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.catalyst.analysis

import java.util.Locale

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random

Expand Down Expand Up @@ -1204,16 +1207,46 @@ class Analyzer(
* only performs simple existence check according to the function identifier to quickly identify
* undefined functions without triggering relation resolution, which may incur potentially
* expensive partition/schema discovery process in some cases.
*
* In order to avoid duplicate external functions lookup, the external function identifier will
* store in the local hash set externalFunctionNameSet.
* @see [[ResolveFunctions]]
* @see https://issues.apache.org/jira/browse/SPARK-19737
*/
object LookupFunctions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
case f: UnresolvedFunction if !catalog.functionExists(f.name) =>
withPosition(f) {
throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName)
}
override def apply(plan: LogicalPlan): LogicalPlan = {
val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]()
plan.transformAllExpressions {
case f: UnresolvedFunction
if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f
case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) =>
externalFunctionNameSet.add(normalizeFuncName(f.name))
f
case f: UnresolvedFunction =>
withPosition(f) {
throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase),
f.name.funcName)
}
}
}

def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a common utility function. We can refactor the code later.

val funcName = if (conf.caseSensitiveAnalysis) {
name.funcName
} else {
name.funcName.toLowerCase(Locale.ROOT)
}

val databaseName = name.database match {
case Some(a) => formatDatabaseName(a)
case None => catalog.getCurrentDatabase
}

FunctionIdentifier(funcName, Some(databaseName))
}

protected def formatDatabaseName(name: String): String = {
if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,22 @@ class SessionCatalog(
!hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT))
}

/**
* Return whether this function has been registered in the function registry of the current
* session. If not existed, return false.
*/
def isRegisteredFunction(name: FunctionIdentifier): Boolean = {
functionRegistry.functionExists(name)
}

/**
* Returns whether it is a persistent function. If not existed, returns false.
*/
def isPersistentFunction(name: FunctionIdentifier): Boolean = {
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
databaseExists(db) && externalCatalog.functionExists(db, name.funcName)
}

protected def failFunctionLookup(name: FunctionIdentifier): Nothing = {
throw new NoSuchFunctionException(
db = name.database.getOrElse(getCurrentDatabase), func = name.funcName)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import java.net.URI

import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

class LookupFunctionsSuite extends PlanTest {

test("SPARK-23486: the functionExists for the Persistent function check") {
val externalCatalog = new CustomInMemoryCatalog
val conf = new SQLConf()
val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf)
val analyzer = {
catalog.createDatabase(
CatalogDatabase("default", "", new URI("loc"), Map.empty),
ignoreIfExists = false)
new Analyzer(catalog, conf)
}

def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false)
val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
val plan = Project(
Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(),
Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(),
Alias(unresolvedRegisteredFunc, "call5")()),
table("TaBlE"))
analyzer.LookupFunctions.apply(plan)

assert(externalCatalog.getFunctionExistsCalledTimes == 1)
assert(analyzer.LookupFunctions.normalizeFuncName
(unresolvedPersistentFunc.name).database == Some("default"))
}

test("SPARK-23486: the functionExists for the Registered function check") {
val externalCatalog = new InMemoryCatalog
val conf = new SQLConf()
val customerFunctionReg = new CustomerFunctionRegistry
val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf)
val analyzer = {
catalog.createDatabase(
CatalogDatabase("default", "", new URI("loc"), Map.empty),
ignoreIfExists = false)
new Analyzer(catalog, conf)
}

def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
val plan = Project(
Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()),
table("TaBlE"))
analyzer.LookupFunctions.apply(plan)

assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2)
assert(analyzer.LookupFunctions.normalizeFuncName
(unresolvedRegisteredFunc.name).database == Some("default"))
}
}

class CustomerFunctionRegistry extends SimpleFunctionRegistry {

private var isRegisteredFunctionCalledTimes: Int = 0;

override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized {
isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1
true
}

def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes
}

class CustomInMemoryCatalog extends InMemoryCatalog {

private var functionExistsCalledTimes: Int = 0

override def functionExists(db: String, funcName: String): Boolean = synchronized {
functionExistsCalledTimes = functionExistsCalledTimes + 1
true
}

def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes
}
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,42 @@ abstract class SessionCatalogSuite extends AnalysisTest {
}
}

test("isRegisteredFunction") {
withBasicCatalog { catalog =>
// Returns false when the function does not register
assert(!catalog.isRegisteredFunction(FunctionIdentifier("temp1")))

// Returns true when the function does register
val tempFunc1 = (e: Seq[Expression]) => e.head
catalog.registerFunction(newFunc("iff", None), overrideIfExists = false,
functionBuilder = Some(tempFunc1) )
assert(catalog.isRegisteredFunction(FunctionIdentifier("iff")))

// Returns false when using the createFunction
catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false)
assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum")))
assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum", Some("db2"))))
}
}

test("isPersistentFunction") {
withBasicCatalog { catalog =>
// Returns false when the function does not register
assert(!catalog.isPersistentFunction(FunctionIdentifier("temp2")))

// Returns false when the function does register
val tempFunc2 = (e: Seq[Expression]) => e.head
catalog.registerFunction(newFunc("iff", None), overrideIfExists = false,
functionBuilder = Some(tempFunc2))
assert(!catalog.isPersistentFunction(FunctionIdentifier("iff")))

// Return true when using the createFunction
catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false)
assert(catalog.isPersistentFunction(FunctionIdentifier("sum", Some("db2"))))
assert(!catalog.isPersistentFunction(FunctionIdentifier("db2.sum")))
}
}

test("drop function") {
withBasicCatalog { catalog =>
assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ private[sql] class HiveSessionCatalog(
super.functionExists(name) || hiveFunctions.contains(name.funcName)
}

override def isPersistentFunction(name: FunctionIdentifier): Boolean = {
super.isPersistentFunction(name) || hiveFunctions.contains(name.funcName)
}

/** List of functions we pass over to Hive. Note that over time this list should go to 0. */
// We have a list of Hive built-in functions that we do not support. So, we will check
// Hive's function registry and lazily load needed functions into our own function registry.
Expand Down