Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis

import java.util
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -48,7 +47,7 @@ import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.connector.catalog.{View => _, _}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
Expand Down Expand Up @@ -202,6 +201,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor

private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog
private val relationResolution = new RelationResolution(catalogManager)
private val functionResolution = new FunctionResolution(catalogManager, relationResolution)

override protected def validatePlanChanges(
previousPlan: LogicalPlan,
Expand Down Expand Up @@ -1915,7 +1915,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor

plan.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) {
case f @ UnresolvedFunction(nameParts, _, _, _, _, _, _) =>
if (ResolveFunctions.lookupBuiltinOrTempFunction(nameParts, Some(f)).isDefined) {
if (functionResolution.lookupBuiltinOrTempFunction(nameParts, Some(f)).isDefined) {
f
} else {
val CatalogAndIdentifier(catalog, ident) =
Expand Down Expand Up @@ -1954,15 +1954,13 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
* Replaces [[UnresolvedTableValuedFunction]]s with concrete [[LogicalPlan]]s.
*/
object ResolveFunctions extends Rule[LogicalPlan] {
val trimWarningEnabled = new AtomicBoolean(true)

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(UNRESOLVED_FUNC, UNRESOLVED_FUNCTION, GENERATOR,
UNRESOLVED_TABLE_VALUED_FUNCTION, UNRESOLVED_TVF_ALIASES), ruleId) {
// Resolve functions with concrete relations from v2 catalog.
case u @ UnresolvedFunctionName(nameParts, cmd, requirePersistentFunc, mismatchHint, _) =>
lookupBuiltinOrTempFunction(nameParts, None)
.orElse(lookupBuiltinOrTempTableFunction(nameParts)).map { info =>
functionResolution.lookupBuiltinOrTempFunction(nameParts, None)
.orElse(functionResolution.lookupBuiltinOrTempTableFunction(nameParts)).map { info =>
if (requirePersistentFunc) {
throw QueryCompilationErrors.expectPersistentFuncError(
nameParts.head, cmd, mismatchHint, u)
Expand All @@ -1982,7 +1980,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
withPosition(u) {
try {
val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse {
val resolvedFunc = functionResolution.resolveBuiltinOrTempTableFunction(
u.name, u.functionArgs).getOrElse {
val CatalogAndIdentifier(catalog, ident) =
relationResolution.expandIdentifier(u.name)
if (CatalogV2Util.isSessionCatalog(catalog)) {
Expand Down Expand Up @@ -2086,8 +2085,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
_.containsAnyPattern(UNRESOLVED_FUNCTION, GENERATOR),
ruleId) {
case u @ UnresolvedFunction(nameParts, arguments, _, _, _, _, _)
if hasLambdaAndResolvedArguments(arguments) => withPosition(u) {
resolveBuiltinOrTempFunction(nameParts, arguments, u).map {
if functionResolution.hasLambdaAndResolvedArguments(arguments) => withPosition(u) {
functionResolution.resolveBuiltinOrTempFunction(nameParts, arguments, u).map {
case func: HigherOrderFunction => func
case other => other.failAnalysis(
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
Expand All @@ -2114,7 +2113,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

case u: UnresolvedFunction => resolveFunction(u)
case u: UnresolvedFunction => functionResolution.resolveFunction(u)

case u: UnresolvedPolymorphicPythonUDTF => withPosition(u) {
// Check if this is a call to a Python user-defined table function whose polymorphic
Expand All @@ -2137,277 +2136,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}
}

private[analysis] def resolveFunction(u: UnresolvedFunction): Expression = {
withPosition(u) {
resolveBuiltinOrTempFunction(u.nameParts, u.arguments, u).getOrElse {
val CatalogAndIdentifier(catalog, ident) =
relationResolution.expandIdentifier(u.nameParts)
if (CatalogV2Util.isSessionCatalog(catalog)) {
resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
} else {
resolveV2Function(catalog.asFunctionCatalog, ident, u.arguments, u)
}
}
}
}

/**
* Check if the arguments of a function are either resolved or a lambda function.
*/
private def hasLambdaAndResolvedArguments(expressions: Seq[Expression]): Boolean = {
val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction])
lambdas.nonEmpty && others.forall(_.resolved)
}

def lookupBuiltinOrTempFunction(
name: Seq[String],
u: Option[UnresolvedFunction]): Option[ExpressionInfo] = {
if (name.size == 1 && u.exists(_.isInternal)) {
FunctionRegistry.internal.lookupFunction(FunctionIdentifier(name.head))
} else if (name.size == 1) {
v1SessionCatalog.lookupBuiltinOrTempFunction(name.head)
} else {
None
}
}

def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = {
if (name.length == 1) {
v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head)
} else {
None
}
}

private def resolveBuiltinOrTempFunction(
name: Seq[String],
arguments: Seq[Expression],
u: UnresolvedFunction): Option[Expression] = {
val expression = if (name.size == 1 && u.isInternal) {
Option(FunctionRegistry.internal.lookupFunction(FunctionIdentifier(name.head), arguments))
} else if (name.size == 1) {
v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments)
} else {
None
}
expression.map { func =>
validateFunction(func, arguments.length, u)
}
}

private def resolveBuiltinOrTempTableFunction(
name: Seq[String],
arguments: Seq[Expression]): Option[LogicalPlan] = {
if (name.length == 1) {
v1SessionCatalog.resolveBuiltinOrTempTableFunction(name.head, arguments)
} else {
None
}
}

private def resolveV1Function(
ident: FunctionIdentifier,
arguments: Seq[Expression],
u: UnresolvedFunction): Expression = {
val func = v1SessionCatalog.resolvePersistentFunction(ident, arguments)
validateFunction(func, arguments.length, u)
}

private def validateFunction(
func: Expression,
numArgs: Int,
u: UnresolvedFunction): Expression = {
func match {
case owg: SupportsOrderingWithinGroup if u.isDistinct =>
throw QueryCompilationErrors.distinctInverseDistributionFunctionUnsupportedError(
owg.prettyName)
case owg: SupportsOrderingWithinGroup
if !owg.orderingFilled && u.orderingWithinGroup.isEmpty =>
throw QueryCompilationErrors.inverseDistributionFunctionMissingWithinGroupError(
owg.prettyName)
case owg: SupportsOrderingWithinGroup
if owg.orderingFilled && u.orderingWithinGroup.nonEmpty =>
throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError(
owg.prettyName, 0, u.orderingWithinGroup.length)
case f
if !f.isInstanceOf[SupportsOrderingWithinGroup] && u.orderingWithinGroup.nonEmpty =>
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
func.prettyName, "WITHIN GROUP (ORDER BY ...)")
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
// the context of a Window clause. They do not need to be wrapped in an
// AggregateExpression.
case wf: AggregateWindowFunction =>
if (u.isDistinct) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
wf.prettyName, "DISTINCT")
} else if (u.filter.isDefined) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
wf.prettyName, "FILTER clause")
} else if (u.ignoreNulls) {
wf match {
case nthValue: NthValue =>
nthValue.copy(ignoreNulls = u.ignoreNulls)
case _ =>
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
wf.prettyName, "IGNORE NULLS")
}
} else {
wf
}
case owf: FrameLessOffsetWindowFunction =>
if (u.isDistinct) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
owf.prettyName, "DISTINCT")
} else if (u.filter.isDefined) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
owf.prettyName, "FILTER clause")
} else if (u.ignoreNulls) {
owf match {
case lead: Lead =>
lead.copy(ignoreNulls = u.ignoreNulls)
case lag: Lag =>
lag.copy(ignoreNulls = u.ignoreNulls)
}
} else {
owf
}
// We get an aggregate function, we need to wrap it in an AggregateExpression.
case agg: AggregateFunction =>
// Note: PythonUDAF does not support these advanced clauses.
if (agg.isInstanceOf[PythonUDAF]) checkUnsupportedAggregateClause(agg, u)
// After parse, the inverse distribution functions not set the ordering within group yet.
val newAgg = agg match {
case owg: SupportsOrderingWithinGroup
if !owg.orderingFilled && u.orderingWithinGroup.nonEmpty =>
owg.withOrderingWithinGroup(u.orderingWithinGroup)
case _ =>
agg
}

u.filter match {
case Some(filter) if !filter.deterministic =>
throw QueryCompilationErrors.nonDeterministicFilterInAggregateError(
filterExpr = filter)
case Some(filter) if filter.dataType != BooleanType =>
throw QueryCompilationErrors.nonBooleanFilterInAggregateError(
filterExpr = filter)
case Some(filter) if filter.exists(_.isInstanceOf[AggregateExpression]) =>
throw QueryCompilationErrors.aggregateInAggregateFilterError(
filterExpr = filter,
aggExpr = filter.find(_.isInstanceOf[AggregateExpression]).get)
case Some(filter) if filter.exists(_.isInstanceOf[WindowExpression]) =>
throw QueryCompilationErrors.windowFunctionInAggregateFilterError(
filterExpr = filter,
windowExpr = filter.find(_.isInstanceOf[WindowExpression]).get)
case _ =>
}
if (u.ignoreNulls) {
val aggFunc = newAgg match {
case first: First => first.copy(ignoreNulls = u.ignoreNulls)
case last: Last => last.copy(ignoreNulls = u.ignoreNulls)
case any_value: AnyValue => any_value.copy(ignoreNulls = u.ignoreNulls)
case _ =>
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
newAgg.prettyName, "IGNORE NULLS")
}
aggFunc.toAggregateExpression(u.isDistinct, u.filter)
} else {
newAgg.toAggregateExpression(u.isDistinct, u.filter)
}
// This function is not an aggregate function, just return the resolved one.
case other =>
checkUnsupportedAggregateClause(other, u)
if (other.isInstanceOf[String2TrimExpression] && numArgs == 2) {
if (trimWarningEnabled.get) {
log.warn("Two-parameter TRIM/LTRIM/RTRIM function signatures are deprecated." +
" Use SQL syntax `TRIM((BOTH | LEADING | TRAILING)? trimStr FROM str)`" +
" instead.")
trimWarningEnabled.set(false)
}
}
other
}
}

private def checkUnsupportedAggregateClause(func: Expression, u: UnresolvedFunction): Unit = {
if (u.isDistinct) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
func.prettyName, "DISTINCT")
}
if (u.filter.isDefined) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
func.prettyName, "FILTER clause")
}
if (u.ignoreNulls) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
func.prettyName, "IGNORE NULLS")
}
}

private def resolveV2Function(
catalog: FunctionCatalog,
ident: Identifier,
arguments: Seq[Expression],
u: UnresolvedFunction): Expression = {
val unbound = catalog.loadFunction(ident)
val inputType = StructType(arguments.zipWithIndex.map {
case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
})
val bound = try {
unbound.bind(inputType)
} catch {
case unsupported: UnsupportedOperationException =>
throw QueryCompilationErrors.functionCannotProcessInputError(
unbound, arguments, unsupported)
}

if (bound.inputTypes().length != arguments.length) {
throw QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError(
bound, arguments)
}

bound match {
case scalarFunc: ScalarFunction[_] =>
processV2ScalarFunction(scalarFunc, arguments, u)
case aggFunc: V2AggregateFunction[_, _] =>
processV2AggregateFunction(aggFunc, arguments, u)
case _ =>
failAnalysis(
errorClass = "INVALID_UDF_IMPLEMENTATION",
messageParameters = Map("funcName" -> toSQLId(bound.name())))
}
}

private def processV2ScalarFunction(
scalarFunc: ScalarFunction[_],
arguments: Seq[Expression],
u: UnresolvedFunction): Expression = {
if (u.isDistinct) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(), "DISTINCT")
} else if (u.filter.isDefined) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(), "FILTER clause")
} else if (u.ignoreNulls) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(), "IGNORE NULLS")
} else {
V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)
}
}

private def processV2AggregateFunction(
aggFunc: V2AggregateFunction[_, _],
arguments: Seq[Expression],
u: UnresolvedFunction): Expression = {
if (u.ignoreNulls) {
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
aggFunc.name(), "IGNORE NULLS")
}
val aggregator = V2Aggregator(aggFunc, arguments)
aggregator.toAggregateExpression(u.isDistinct, u.filter)
}
}

/**
Expand Down
Loading