From ae4bdd569b1bed7ee1df9dfb4148fb61895de0a8 Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Fri, 15 Nov 2024 09:31:03 +0000 Subject: [PATCH] Factor out function resolution code to a separate FunctionResolution class --- .../sql/catalyst/analysis/Analyzer.scala | 292 +-------------- .../analysis/FunctionResolution.scala | 354 ++++++++++++++++++ 2 files changed, 364 insertions(+), 282 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d1d04d411726..d981b5b8ea0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import java.util import java.util.Locale -import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -48,7 +47,7 @@ import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition} -import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction} +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -202,6 +201,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog private val relationResolution = new RelationResolution(catalogManager) + private val functionResolution = new FunctionResolution(catalogManager, relationResolution) override protected def validatePlanChanges( previousPlan: LogicalPlan, @@ -1915,7 +1915,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor plan.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) { case f @ UnresolvedFunction(nameParts, _, _, _, _, _, _) => - if (ResolveFunctions.lookupBuiltinOrTempFunction(nameParts, Some(f)).isDefined) { + if (functionResolution.lookupBuiltinOrTempFunction(nameParts, Some(f)).isDefined) { f } else { val CatalogAndIdentifier(catalog, ident) = @@ -1954,15 +1954,13 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor * Replaces [[UnresolvedTableValuedFunction]]s with concrete [[LogicalPlan]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - val trimWarningEnabled = new AtomicBoolean(true) - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(UNRESOLVED_FUNC, UNRESOLVED_FUNCTION, GENERATOR, UNRESOLVED_TABLE_VALUED_FUNCTION, UNRESOLVED_TVF_ALIASES), ruleId) { // Resolve functions with concrete relations from v2 catalog. case u @ UnresolvedFunctionName(nameParts, cmd, requirePersistentFunc, mismatchHint, _) => - lookupBuiltinOrTempFunction(nameParts, None) - .orElse(lookupBuiltinOrTempTableFunction(nameParts)).map { info => + functionResolution.lookupBuiltinOrTempFunction(nameParts, None) + .orElse(functionResolution.lookupBuiltinOrTempTableFunction(nameParts)).map { info => if (requirePersistentFunc) { throw QueryCompilationErrors.expectPersistentFuncError( nameParts.head, cmd, mismatchHint, u) @@ -1982,7 +1980,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => withPosition(u) { try { - val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { + val resolvedFunc = functionResolution.resolveBuiltinOrTempTableFunction( + u.name, u.functionArgs).getOrElse { val CatalogAndIdentifier(catalog, ident) = relationResolution.expandIdentifier(u.name) if (CatalogV2Util.isSessionCatalog(catalog)) { @@ -2086,8 +2085,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor _.containsAnyPattern(UNRESOLVED_FUNCTION, GENERATOR), ruleId) { case u @ UnresolvedFunction(nameParts, arguments, _, _, _, _, _) - if hasLambdaAndResolvedArguments(arguments) => withPosition(u) { - resolveBuiltinOrTempFunction(nameParts, arguments, u).map { + if functionResolution.hasLambdaAndResolvedArguments(arguments) => withPosition(u) { + functionResolution.resolveBuiltinOrTempFunction(nameParts, arguments, u).map { case func: HigherOrderFunction => func case other => other.failAnalysis( errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", @@ -2114,7 +2113,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } - case u: UnresolvedFunction => resolveFunction(u) + case u: UnresolvedFunction => functionResolution.resolveFunction(u) case u: UnresolvedPolymorphicPythonUDTF => withPosition(u) { // Check if this is a call to a Python user-defined table function whose polymorphic @@ -2137,277 +2136,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } } - - private[analysis] def resolveFunction(u: UnresolvedFunction): Expression = { - withPosition(u) { - resolveBuiltinOrTempFunction(u.nameParts, u.arguments, u).getOrElse { - val CatalogAndIdentifier(catalog, ident) = - relationResolution.expandIdentifier(u.nameParts) - if (CatalogV2Util.isSessionCatalog(catalog)) { - resolveV1Function(ident.asFunctionIdentifier, u.arguments, u) - } else { - resolveV2Function(catalog.asFunctionCatalog, ident, u.arguments, u) - } - } - } - } - - /** - * Check if the arguments of a function are either resolved or a lambda function. - */ - private def hasLambdaAndResolvedArguments(expressions: Seq[Expression]): Boolean = { - val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction]) - lambdas.nonEmpty && others.forall(_.resolved) - } - - def lookupBuiltinOrTempFunction( - name: Seq[String], - u: Option[UnresolvedFunction]): Option[ExpressionInfo] = { - if (name.size == 1 && u.exists(_.isInternal)) { - FunctionRegistry.internal.lookupFunction(FunctionIdentifier(name.head)) - } else if (name.size == 1) { - v1SessionCatalog.lookupBuiltinOrTempFunction(name.head) - } else { - None - } - } - - def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = { - if (name.length == 1) { - v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head) - } else { - None - } - } - - private def resolveBuiltinOrTempFunction( - name: Seq[String], - arguments: Seq[Expression], - u: UnresolvedFunction): Option[Expression] = { - val expression = if (name.size == 1 && u.isInternal) { - Option(FunctionRegistry.internal.lookupFunction(FunctionIdentifier(name.head), arguments)) - } else if (name.size == 1) { - v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments) - } else { - None - } - expression.map { func => - validateFunction(func, arguments.length, u) - } - } - - private def resolveBuiltinOrTempTableFunction( - name: Seq[String], - arguments: Seq[Expression]): Option[LogicalPlan] = { - if (name.length == 1) { - v1SessionCatalog.resolveBuiltinOrTempTableFunction(name.head, arguments) - } else { - None - } - } - - private def resolveV1Function( - ident: FunctionIdentifier, - arguments: Seq[Expression], - u: UnresolvedFunction): Expression = { - val func = v1SessionCatalog.resolvePersistentFunction(ident, arguments) - validateFunction(func, arguments.length, u) - } - - private def validateFunction( - func: Expression, - numArgs: Int, - u: UnresolvedFunction): Expression = { - func match { - case owg: SupportsOrderingWithinGroup if u.isDistinct => - throw QueryCompilationErrors.distinctInverseDistributionFunctionUnsupportedError( - owg.prettyName) - case owg: SupportsOrderingWithinGroup - if !owg.orderingFilled && u.orderingWithinGroup.isEmpty => - throw QueryCompilationErrors.inverseDistributionFunctionMissingWithinGroupError( - owg.prettyName) - case owg: SupportsOrderingWithinGroup - if owg.orderingFilled && u.orderingWithinGroup.nonEmpty => - throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError( - owg.prettyName, 0, u.orderingWithinGroup.length) - case f - if !f.isInstanceOf[SupportsOrderingWithinGroup] && u.orderingWithinGroup.nonEmpty => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - func.prettyName, "WITHIN GROUP (ORDER BY ...)") - // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within - // the context of a Window clause. They do not need to be wrapped in an - // AggregateExpression. - case wf: AggregateWindowFunction => - if (u.isDistinct) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - wf.prettyName, "DISTINCT") - } else if (u.filter.isDefined) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - wf.prettyName, "FILTER clause") - } else if (u.ignoreNulls) { - wf match { - case nthValue: NthValue => - nthValue.copy(ignoreNulls = u.ignoreNulls) - case _ => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - wf.prettyName, "IGNORE NULLS") - } - } else { - wf - } - case owf: FrameLessOffsetWindowFunction => - if (u.isDistinct) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - owf.prettyName, "DISTINCT") - } else if (u.filter.isDefined) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - owf.prettyName, "FILTER clause") - } else if (u.ignoreNulls) { - owf match { - case lead: Lead => - lead.copy(ignoreNulls = u.ignoreNulls) - case lag: Lag => - lag.copy(ignoreNulls = u.ignoreNulls) - } - } else { - owf - } - // We get an aggregate function, we need to wrap it in an AggregateExpression. - case agg: AggregateFunction => - // Note: PythonUDAF does not support these advanced clauses. - if (agg.isInstanceOf[PythonUDAF]) checkUnsupportedAggregateClause(agg, u) - // After parse, the inverse distribution functions not set the ordering within group yet. - val newAgg = agg match { - case owg: SupportsOrderingWithinGroup - if !owg.orderingFilled && u.orderingWithinGroup.nonEmpty => - owg.withOrderingWithinGroup(u.orderingWithinGroup) - case _ => - agg - } - - u.filter match { - case Some(filter) if !filter.deterministic => - throw QueryCompilationErrors.nonDeterministicFilterInAggregateError( - filterExpr = filter) - case Some(filter) if filter.dataType != BooleanType => - throw QueryCompilationErrors.nonBooleanFilterInAggregateError( - filterExpr = filter) - case Some(filter) if filter.exists(_.isInstanceOf[AggregateExpression]) => - throw QueryCompilationErrors.aggregateInAggregateFilterError( - filterExpr = filter, - aggExpr = filter.find(_.isInstanceOf[AggregateExpression]).get) - case Some(filter) if filter.exists(_.isInstanceOf[WindowExpression]) => - throw QueryCompilationErrors.windowFunctionInAggregateFilterError( - filterExpr = filter, - windowExpr = filter.find(_.isInstanceOf[WindowExpression]).get) - case _ => - } - if (u.ignoreNulls) { - val aggFunc = newAgg match { - case first: First => first.copy(ignoreNulls = u.ignoreNulls) - case last: Last => last.copy(ignoreNulls = u.ignoreNulls) - case any_value: AnyValue => any_value.copy(ignoreNulls = u.ignoreNulls) - case _ => - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - newAgg.prettyName, "IGNORE NULLS") - } - aggFunc.toAggregateExpression(u.isDistinct, u.filter) - } else { - newAgg.toAggregateExpression(u.isDistinct, u.filter) - } - // This function is not an aggregate function, just return the resolved one. - case other => - checkUnsupportedAggregateClause(other, u) - if (other.isInstanceOf[String2TrimExpression] && numArgs == 2) { - if (trimWarningEnabled.get) { - log.warn("Two-parameter TRIM/LTRIM/RTRIM function signatures are deprecated." + - " Use SQL syntax `TRIM((BOTH | LEADING | TRAILING)? trimStr FROM str)`" + - " instead.") - trimWarningEnabled.set(false) - } - } - other - } - } - - private def checkUnsupportedAggregateClause(func: Expression, u: UnresolvedFunction): Unit = { - if (u.isDistinct) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - func.prettyName, "DISTINCT") - } - if (u.filter.isDefined) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - func.prettyName, "FILTER clause") - } - if (u.ignoreNulls) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - func.prettyName, "IGNORE NULLS") - } - } - - private def resolveV2Function( - catalog: FunctionCatalog, - ident: Identifier, - arguments: Seq[Expression], - u: UnresolvedFunction): Expression = { - val unbound = catalog.loadFunction(ident) - val inputType = StructType(arguments.zipWithIndex.map { - case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) - }) - val bound = try { - unbound.bind(inputType) - } catch { - case unsupported: UnsupportedOperationException => - throw QueryCompilationErrors.functionCannotProcessInputError( - unbound, arguments, unsupported) - } - - if (bound.inputTypes().length != arguments.length) { - throw QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError( - bound, arguments) - } - - bound match { - case scalarFunc: ScalarFunction[_] => - processV2ScalarFunction(scalarFunc, arguments, u) - case aggFunc: V2AggregateFunction[_, _] => - processV2AggregateFunction(aggFunc, arguments, u) - case _ => - failAnalysis( - errorClass = "INVALID_UDF_IMPLEMENTATION", - messageParameters = Map("funcName" -> toSQLId(bound.name()))) - } - } - - private def processV2ScalarFunction( - scalarFunc: ScalarFunction[_], - arguments: Seq[Expression], - u: UnresolvedFunction): Expression = { - if (u.isDistinct) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - scalarFunc.name(), "DISTINCT") - } else if (u.filter.isDefined) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - scalarFunc.name(), "FILTER clause") - } else if (u.ignoreNulls) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - scalarFunc.name(), "IGNORE NULLS") - } else { - V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) - } - } - - private def processV2AggregateFunction( - aggFunc: V2AggregateFunction[_, _], - arguments: Seq[Expression], - u: UnresolvedFunction): Expression = { - if (u.ignoreNulls) { - throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( - aggFunc.name(), "IGNORE NULLS") - } - val aggregator = V2Aggregator(aggFunc, arguments) - aggregator.toAggregateExpression(u.isDistinct, u.filter) - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala new file mode 100644 index 000000000000..5a27a7219032 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{ + CatalogManager, + CatalogV2Util, + FunctionCatalog, + Identifier, + LookupCatalog +} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.connector.catalog.functions.{ + AggregateFunction => V2AggregateFunction, + ScalarFunction +} +import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors} +import org.apache.spark.sql.types._ + +class FunctionResolution( + override val catalogManager: CatalogManager, + relationResolution: RelationResolution) + extends DataTypeErrorsBase with LookupCatalog { + private val v1SessionCatalog = catalogManager.v1SessionCatalog + + private val trimWarningEnabled = new AtomicBoolean(true) + + def resolveFunction(u: UnresolvedFunction): Expression = { + withPosition(u) { + resolveBuiltinOrTempFunction(u.nameParts, u.arguments, u).getOrElse { + val CatalogAndIdentifier(catalog, ident) = + relationResolution.expandIdentifier(u.nameParts) + if (CatalogV2Util.isSessionCatalog(catalog)) { + resolveV1Function(ident.asFunctionIdentifier, u.arguments, u) + } else { + resolveV2Function(catalog.asFunctionCatalog, ident, u.arguments, u) + } + } + } + } + + /** + * Check if the arguments of a function are either resolved or a lambda function. + */ + def hasLambdaAndResolvedArguments(expressions: Seq[Expression]): Boolean = { + val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction]) + lambdas.nonEmpty && others.forall(_.resolved) + } + + def lookupBuiltinOrTempFunction( + name: Seq[String], + u: Option[UnresolvedFunction]): Option[ExpressionInfo] = { + if (name.size == 1 && u.exists(_.isInternal)) { + FunctionRegistry.internal.lookupFunction(FunctionIdentifier(name.head)) + } else if (name.size == 1) { + v1SessionCatalog.lookupBuiltinOrTempFunction(name.head) + } else { + None + } + } + + def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = { + if (name.length == 1) { + v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head) + } else { + None + } + } + + def resolveBuiltinOrTempFunction( + name: Seq[String], + arguments: Seq[Expression], + u: UnresolvedFunction): Option[Expression] = { + val expression = if (name.size == 1 && u.isInternal) { + Option(FunctionRegistry.internal.lookupFunction(FunctionIdentifier(name.head), arguments)) + } else if (name.size == 1) { + v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments) + } else { + None + } + expression.map { func => + validateFunction(func, arguments.length, u) + } + } + + def resolveBuiltinOrTempTableFunction( + name: Seq[String], + arguments: Seq[Expression]): Option[LogicalPlan] = { + if (name.length == 1) { + v1SessionCatalog.resolveBuiltinOrTempTableFunction(name.head, arguments) + } else { + None + } + } + + private def resolveV1Function( + ident: FunctionIdentifier, + arguments: Seq[Expression], + u: UnresolvedFunction): Expression = { + val func = v1SessionCatalog.resolvePersistentFunction(ident, arguments) + validateFunction(func, arguments.length, u) + } + + private def validateFunction( + func: Expression, + numArgs: Int, + u: UnresolvedFunction): Expression = { + func match { + case owg: SupportsOrderingWithinGroup if u.isDistinct => + throw QueryCompilationErrors.distinctInverseDistributionFunctionUnsupportedError( + owg.prettyName + ) + case owg: SupportsOrderingWithinGroup + if !owg.orderingFilled && u.orderingWithinGroup.isEmpty => + throw QueryCompilationErrors.inverseDistributionFunctionMissingWithinGroupError( + owg.prettyName + ) + case owg: SupportsOrderingWithinGroup + if owg.orderingFilled && u.orderingWithinGroup.nonEmpty => + throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError( + owg.prettyName, + 0, + u.orderingWithinGroup.length + ) + case f if !f.isInstanceOf[SupportsOrderingWithinGroup] && u.orderingWithinGroup.nonEmpty => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + func.prettyName, + "WITHIN GROUP (ORDER BY ...)" + ) + // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within + // the context of a Window clause. They do not need to be wrapped in an + // AggregateExpression. + case wf: AggregateWindowFunction => + if (u.isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(wf.prettyName, "DISTINCT") + } else if (u.filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + wf.prettyName, + "FILTER clause" + ) + } else if (u.ignoreNulls) { + wf match { + case nthValue: NthValue => + nthValue.copy(ignoreNulls = u.ignoreNulls) + case _ => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + wf.prettyName, + "IGNORE NULLS" + ) + } + } else { + wf + } + case owf: FrameLessOffsetWindowFunction => + if (u.isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + owf.prettyName, + "DISTINCT" + ) + } else if (u.filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + owf.prettyName, + "FILTER clause" + ) + } else if (u.ignoreNulls) { + owf match { + case lead: Lead => + lead.copy(ignoreNulls = u.ignoreNulls) + case lag: Lag => + lag.copy(ignoreNulls = u.ignoreNulls) + } + } else { + owf + } + // We get an aggregate function, we need to wrap it in an AggregateExpression. + case agg: AggregateFunction => + // Note: PythonUDAF does not support these advanced clauses. + if (agg.isInstanceOf[PythonUDAF]) checkUnsupportedAggregateClause(agg, u) + // After parse, the inverse distribution functions not set the ordering within group yet. + val newAgg = agg match { + case owg: SupportsOrderingWithinGroup + if !owg.orderingFilled && u.orderingWithinGroup.nonEmpty => + owg.withOrderingWithinGroup(u.orderingWithinGroup) + case _ => + agg + } + + u.filter match { + case Some(filter) if !filter.deterministic => + throw QueryCompilationErrors.nonDeterministicFilterInAggregateError(filterExpr = filter) + case Some(filter) if filter.dataType != BooleanType => + throw QueryCompilationErrors.nonBooleanFilterInAggregateError(filterExpr = filter) + case Some(filter) if filter.exists(_.isInstanceOf[AggregateExpression]) => + throw QueryCompilationErrors.aggregateInAggregateFilterError( + filterExpr = filter, + aggExpr = filter.find(_.isInstanceOf[AggregateExpression]).get + ) + case Some(filter) if filter.exists(_.isInstanceOf[WindowExpression]) => + throw QueryCompilationErrors.windowFunctionInAggregateFilterError( + filterExpr = filter, + windowExpr = filter.find(_.isInstanceOf[WindowExpression]).get + ) + case _ => + } + if (u.ignoreNulls) { + val aggFunc = newAgg match { + case first: First => first.copy(ignoreNulls = u.ignoreNulls) + case last: Last => last.copy(ignoreNulls = u.ignoreNulls) + case any_value: AnyValue => any_value.copy(ignoreNulls = u.ignoreNulls) + case _ => + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + newAgg.prettyName, + "IGNORE NULLS" + ) + } + aggFunc.toAggregateExpression(u.isDistinct, u.filter) + } else { + newAgg.toAggregateExpression(u.isDistinct, u.filter) + } + // This function is not an aggregate function, just return the resolved one. + case other => + checkUnsupportedAggregateClause(other, u) + if (other.isInstanceOf[String2TrimExpression] && numArgs == 2) { + if (trimWarningEnabled.get) { + log.warn( + "Two-parameter TRIM/LTRIM/RTRIM function signatures are deprecated." + + " Use SQL syntax `TRIM((BOTH | LEADING | TRAILING)? trimStr FROM str)`" + + " instead." + ) + trimWarningEnabled.set(false) + } + } + other + } + } + + private def checkUnsupportedAggregateClause(func: Expression, u: UnresolvedFunction): Unit = { + if (u.isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(func.prettyName, "DISTINCT") + } + if (u.filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + func.prettyName, + "FILTER clause" + ) + } + if (u.ignoreNulls) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + func.prettyName, + "IGNORE NULLS" + ) + } + } + + private def resolveV2Function( + catalog: FunctionCatalog, + ident: Identifier, + arguments: Seq[Expression], + u: UnresolvedFunction): Expression = { + val unbound = catalog.loadFunction(ident) + val inputType = StructType(arguments.zipWithIndex.map { + case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) + }) + val bound = try { + unbound.bind(inputType) + } catch { + case unsupported: UnsupportedOperationException => + throw QueryCompilationErrors.functionCannotProcessInputError( + unbound, + arguments, + unsupported + ) + } + + if (bound.inputTypes().length != arguments.length) { + throw QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError(bound, arguments) + } + + bound match { + case scalarFunc: ScalarFunction[_] => + processV2ScalarFunction(scalarFunc, arguments, u) + case aggFunc: V2AggregateFunction[_, _] => + processV2AggregateFunction(aggFunc, arguments, u) + case _ => + failAnalysis( + errorClass = "INVALID_UDF_IMPLEMENTATION", + messageParameters = Map("funcName" -> toSQLId(bound.name())) + ) + } + } + + private def processV2ScalarFunction( + scalarFunc: ScalarFunction[_], + arguments: Seq[Expression], + u: UnresolvedFunction): Expression = { + if (u.isDistinct) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(scalarFunc.name(), "DISTINCT") + } else if (u.filter.isDefined) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), + "FILTER clause" + ) + } else if (u.ignoreNulls) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + scalarFunc.name(), + "IGNORE NULLS" + ) + } else { + V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) + } + } + + private def processV2AggregateFunction( + aggFunc: V2AggregateFunction[_, _], + arguments: Seq[Expression], + u: UnresolvedFunction): Expression = { + if (u.ignoreNulls) { + throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( + aggFunc.name(), + "IGNORE NULLS" + ) + } + val aggregator = V2Aggregator(aggFunc, arguments) + aggregator.toAggregateExpression(u.isDistinct, u.filter) + } + + private def failAnalysis(errorClass: String, messageParameters: Map[String, String]): Nothing = { + throw new AnalysisException( + errorClass = errorClass, + messageParameters = messageParameters) + } +}