From fa43cbd0de01b5aa5c11e939d7042a55d654f2b9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 3 May 2019 15:02:00 -0700 Subject: [PATCH 1/4] Add FunctionCatalog API. --- .../connector/catalog/FunctionCatalog.java | 47 ++++++ .../catalog/functions/AggregateFunction.java | 96 +++++++++++ .../catalog/functions/BoundFunction.java | 100 ++++++++++++ .../connector/catalog/functions/Function.java | 33 ++++ .../catalog/functions/ScalarFunction.java | 49 ++++++ .../catalog/functions/UnboundFunction.java | 50 ++++++ .../analysis/NoSuchItemException.scala | 18 ++- .../functions/AggregateFunctionSuite.scala | 152 ++++++++++++++++++ 8 files changed, 541 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Function.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/UnboundFunction.java create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/functions/AggregateFunctionSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java new file mode 100644 index 000000000000..a1a9a1dcb548 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; + +/** + * Catalog methods for working with Functions. + */ +public interface FunctionCatalog extends CatalogPlugin { + + /** + * List the functions in a namespace from the catalog. + * + * @param namespace a multi-part namespace + * @return an array of Identifiers for functions + * @throws NoSuchNamespaceException If the namespace does not exist (optional). + */ + Identifier[] listFunctions(String[] namespace) throws NoSuchNamespaceException; + + /** + * Load a function by {@link Identifier identifier} from the catalog. + * + * @param ident a function identifier + * @return an unbound function instance + * @throws NoSuchFunctionException If the function doesn't exist + */ + UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException; + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java new file mode 100644 index 000000000000..a0d3923fe9b4 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.DataType; + +import java.io.Serializable; + +/** + * Interface for a function that produces a result value by aggregating over multiple input rows. + *

+ * For each input row, Spark will call an update method that corresponds to the + * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by + * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call + * {@link #update(S, InternalRow)}. + *

+ * The JVM type of result values produced by this function must be the type used by Spark's + * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. + *

+ * All implementations must support partial aggregation by implementing {@link #merge(S, S)} so + * that Spark can partially aggregate and shuffle intermediate results, instead of shuffling all + * rows for an aggregate. This reduces the impact of data skew and the amount of data shuffled to + * produce the result. + *

+ * Intermediate aggregation state must be {@link Serializable} so that state produced by parallel + * tasks can be sent to a single executor and merged to produce a final result. + * + * @param the JVM type for the aggregation's intermediate state; must be {@link Serializable} + * @param the JVM type of result values + */ +public interface AggregateFunction extends BoundFunction { + + /** + * Initialize state for an aggregation. + *

+ * This method is called one or more times for every group of values to initialize intermediate + * aggregation state. More than one intermediate aggregation state variable may be used when the + * aggregation is run in parallel tasks. + *

+ * The object returned may passed to {@link #update(S, InternalRow)}, + * and {@link #produceResult(S)}. Implementations that return null must support null state + * passed into all other methods. + * + * @return a state instance or null + */ + S newAggregationState(); + + /** + * Update the aggregation state with a new row. + *

+ * This is called for each row in a group to update an intermediate aggregation state. + * + * @param state intermediate aggregation state + * @param input an input row + * @return updated aggregation state + */ + default S update(S state, InternalRow input) { + throw new UnsupportedOperationException("Cannot find a compatible AggregateFunction#update"); + } + + /** + * Merge two partial aggregation states. + *

+ * This is called to merge intermediate aggregation states that were produced by parallel tasks. + * + * @param leftState intermediate aggregation state + * @param rightState intermediate aggregation state + * @return combined aggregation state + */ + S merge(S leftState, S rightState); + + /** + * Produce the aggregation result based on intermediate state. + * + * @param state intermediate aggregation state + * @return a result value + */ + R produceResult(S state); + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java new file mode 100644 index 000000000000..fb359996b929 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.StructType; + +import java.io.Serializable; +import java.util.UUID; + +/** + * Represents a function that is bound to an input type. + */ +public interface BoundFunction extends Function, Serializable { + + /** + * Returns the required {@link DataType data types} of the input values to this function. + *

+ * If the types returned differ from the types passed to {@link UnboundFunction#bind(StructType)}, + * Spark will cast input values to the required data types. This allows implementations to + * delegate input value casting to Spark. + * + * @return an array of input value data types + */ + DataType[] inputTypes(); + + /** + * Returns the {@link DataType data type} of values produced by this function. + *

+ * For example, a "plus" function may return {@link IntegerType} when it is bound to arguments + * that are also {@link IntegerType}. + * + * @return a data type for values produced by this function + */ + DataType resultType(); + + /** + * Returns the whether values produced by this function may be null. + *

+ * For example, a "plus" function may return false when it is bound to arguments that are always + * non-null, but true when either argument may be null. + * + * @return true if values produced by this function may be null, false otherwise + */ + default boolean isResultNullable() { + return true; + } + + /** + * Returns whether this function result is deterministic. + *

+ * By default, functions are assumed to be deterministic. Functions that are not deterministic + * should override this method so that Spark can ensure the function runs only once for a given + * input. + * + * @return true if this function is deterministic, false otherwise + */ + default boolean isDeterministic() { + return true; + } + + /** + * Returns the canonical name of this function, used to determine if functions are equivalent. + *

+ * The canonical name is used to determine whether two functions are the same when loaded by + * different catalogs. For example, the same catalog implementation may be used for by two + * environments, "prod" and "test". Functions produced by the catalogs may be equivalent, but + * loaded using different names, like "test.func_name" and "prod.func_name". + *

+ * Names returned by this function should be unique and unlikely to conflict with similar + * functions in other catalogs. For example, many catalogs may define a "bucket" function with a + * different implementation. Adding context, like "com.mycompany.bucket(string)", is recommended + * to avoid unintentional collisions. + * + * @return a canonical name for this function + */ + default String canonicalName() { + // by default, use a random UUID so a function is never equivalent to another, even itself. + // this method is not required so that generated implementations (or careless ones) are not + // added and forgotten. for example, returning "" as a place-holder could cause unnecessary + // bugs if not replaced before release. + return UUID.randomUUID().toString(); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Function.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Function.java new file mode 100644 index 000000000000..b7f14eb271ef --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Function.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.functions; + +import java.io.Serializable; + +/** + * Base class for user-defined functions. + */ +public interface Function extends Serializable { + + /** + * A name to identify this function. Implementations should provide a meaningful name, like the + * database and function name from the catalog. + */ + String name(); + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java new file mode 100644 index 000000000000..c2106a21c4a8 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.DataType; + +/** + * Interface for a function that produces a result value for each input row. + *

+ * For each input row, Spark will call a produceResult method that corresponds to the + * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by + * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call + * {@link #produceResult(InternalRow)}. + *

+ * The JVM type of result values produced by this function must be the type used by Spark's + * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. + * + * @param the JVM type of result values + */ +public interface ScalarFunction extends BoundFunction { + + /** + * Applies the function to an input row to produce a value. + * + * @param input an input row + * @return a result value + */ + default R produceResult(InternalRow input) { + throw new UnsupportedOperationException( + "Cannot find a compatible ScalarFunction#produceResult"); + } + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/UnboundFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/UnboundFunction.java new file mode 100644 index 000000000000..c7dd4c2b881c --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/UnboundFunction.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.sql.types.StructType; + +/** + * Represents a user-defined function that is not bound to input types. + */ +public interface UnboundFunction extends Function { + + /** + * Bind this function to an input type. + *

+ * If the input type is not supported, implementations must throw + * {@link UnsupportedOperationException}. + *

+ * For example, a "length" function that only supports a single string argument should throw + * UnsupportedOperationException if the struct has more than one field or if that field is not a + * string, and it may optionally throw if the field is nullable. + * + * @param inputType a struct type for inputs that will be passed to the bound function + * @return a function that can process rows with the given input type + * @throws UnsupportedOperationException If the function cannot be applied to the input type + */ + BoundFunction bind(StructType inputType); + + /** + * Returns Function documentation. + * + * @return this function's documentation + */ + String description(); + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index 8a1913b40b31..ba5a9c618c65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -65,10 +65,20 @@ class NoSuchPartitionException(message: String) extends AnalysisException(messag class NoSuchPermanentFunctionException(db: String, func: String) extends AnalysisException(s"Function '$func' not found in database '$db'") -class NoSuchFunctionException(db: String, func: String, cause: Option[Throwable] = None) - extends AnalysisException( - s"Undefined function: '$func'. This function is neither a registered temporary function nor " + - s"a permanent function registered in the database '$db'.", cause = cause) +class NoSuchFunctionException( + msg: String, + cause: Option[Throwable]) extends AnalysisException(msg, cause = cause) { + + def this(db: String, func: String, cause: Option[Throwable] = None) = { + this(s"Undefined function: '$func'. " + + s"This function is neither a registered temporary function nor " + + s"a permanent function registered in the database '$db'.", cause = cause) + } + + def this(identifier: Identifier) = { + this(s"Undefined function: ${identifier.quoted}", cause = None) + } +} class NoSuchPartitionsException(message: String) extends AnalysisException(message) { def this(db: String, table: String, specs: Seq[TablePartitionSpec]) = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/functions/AggregateFunctionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/functions/AggregateFunctionSuite.scala new file mode 100644 index 000000000000..c926e60b35ac --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/functions/AggregateFunctionSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.functions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructType} + +class AggregateFunctionSuite extends SparkFunSuite { + test("Test simple iavg(int)") { + val rows = Seq(InternalRow(2), InternalRow(2), InternalRow(2)) + + val bound = IntegralAverage.bind(new StructType().add("foo", IntegerType, nullable = false)) + assert(bound.isInstanceOf[AggregateFunction[_, _]]) + val udaf = bound.asInstanceOf[AggregateFunction[Serializable, _]] + + val finalState = rows.foldLeft(udaf.newAggregationState()) { (state, row) => + udaf.update(state, row) + } + + assert(udaf.produceResult(finalState) == 2) + } + + test("Test simple iavg(long)") { + val bigValue = 9762097370951020L + val rows = Seq(InternalRow(bigValue + 2), InternalRow(bigValue), InternalRow(bigValue - 2)) + + val bound = IntegralAverage.bind(new StructType().add("foo", LongType, nullable = false)) + assert(bound.isInstanceOf[AggregateFunction[_, _]]) + val udaf = bound.asInstanceOf[AggregateFunction[Serializable, _]] + + val finalState = rows.foldLeft(udaf.newAggregationState()) { (state, row) => + udaf.update(state, row) + } + + assert(udaf.produceResult(finalState) == bigValue) + } + + test("Test associative iavg(long)") { + val bigValue = 7620099737951020L + val rows = Seq(InternalRow(bigValue + 2), InternalRow(bigValue), InternalRow(bigValue - 2)) + + val bound = IntegralAverage.bind(new StructType().add("foo", LongType, nullable = false)) + assert(bound.isInstanceOf[AggregateFunction[_, _]]) + val udaf = bound.asInstanceOf[AggregateFunction[Serializable, _]] + + val state1 = rows.foldLeft(udaf.newAggregationState()) { (state, row) => + udaf.update(state, row) + } + val state2 = rows.foldLeft(udaf.newAggregationState()) { (state, row) => + udaf.update(state, row) + } + val finalState = udaf.merge(state1, state2) + + assert(udaf.produceResult(finalState) == bigValue) + } +} + +object IntegralAverage extends UnboundFunction { + override def name(): String = "iavg" + + override def bind(inputType: StructType): BoundFunction = { + if (inputType.fields.length > 1) { + throw new UnsupportedOperationException("Too many arguments") + } + + if (inputType.fields(0).nullable) { + throw new UnsupportedOperationException("Nullable values are not supported") + } + + inputType.fields(0).dataType match { + case _: IntegerType => IntAverage + case _: LongType => LongAverage + case dataType => + throw new UnsupportedOperationException(s"Unsupported non-integral type: $dataType") + } + } + + override def description(): String = + """iavg: produces an average using integer division + | iavg(int not null) -> int + | iavg(bigint not null) -> bigint""".stripMargin +} + +object IntAverage extends AggregateFunction[(Int, Int), Int] { + + override def inputTypes(): Array[DataType] = Array(IntegerType) + + override def name(): String = "iavg" + + override def newAggregationState(): (Int, Int) = (0, 0) + + override def update(state: (Int, Int), input: InternalRow): (Int, Int) = { + val i = input.getInt(0) + state match { + case (_, 0) => + (i, 1) + case (total, count) => + (total + i, count + 1) + } + } + + override def merge(leftState: (Int, Int), rightState: (Int, Int)): (Int, Int) = { + (leftState._1 + rightState._1, leftState._2 + rightState._2) + } + + override def produceResult(state: (Int, Int)): Int = state._1 / state._2 + + override def resultType(): DataType = IntegerType +} + +object LongAverage extends AggregateFunction[(Long, Long), Long] { + + override def inputTypes(): Array[DataType] = Array(LongType) + + override def name(): String = "iavg" + + override def newAggregationState(): (Long, Long) = (0L, 0L) + + override def update(state: (Long, Long), input: InternalRow): (Long, Long) = { + val l = input.getLong(0) + state match { + case (_, 0L) => + (l, 1) + case (total, count) => + (total + l, count + 1L) + } + } + + override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = { + (leftState._1 + rightState._1, leftState._2 + rightState._2) + } + + override def produceResult(state: (Long, Long)): Long = state._1 / state._2 + + override def resultType(): DataType = IntegerType +} From b9f139a862427f60c4c7cf0abc66b2c6c952414f Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 26 Mar 2021 14:07:10 -0700 Subject: [PATCH 2/4] Fix javadoc errors. --- .../connector/catalog/functions/AggregateFunction.java | 8 +++----- .../sql/connector/catalog/functions/BoundFunction.java | 2 +- .../catalog/functions/AggregateFunctionSuite.scala | 8 ++------ 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java index a0d3923fe9b4..956b09abd785 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java @@ -28,7 +28,7 @@ * For each input row, Spark will call an update method that corresponds to the * {@link #inputTypes() input data types}. The expected JVM argument types must be the types used by * Spark's InternalRow API. If no direct method is found or when not using codegen, Spark will call - * {@link #update(S, InternalRow)}. + * update with {@link InternalRow}. *

* The JVM type of result values produced by this function must be the type used by Spark's * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. @@ -39,7 +39,7 @@ * produce the result. *

* Intermediate aggregation state must be {@link Serializable} so that state produced by parallel - * tasks can be sent to a single executor and merged to produce a final result. + * tasks can be serialized, shuffled, and then merged to produce a final result. * * @param the JVM type for the aggregation's intermediate state; must be {@link Serializable} * @param the JVM type of result values @@ -53,9 +53,7 @@ public interface AggregateFunction extends BoundFunct * aggregation state. More than one intermediate aggregation state variable may be used when the * aggregation is run in parallel tasks. *

- * The object returned may passed to {@link #update(S, InternalRow)}, - * and {@link #produceResult(S)}. Implementations that return null must support null state - * passed into all other methods. + * Implementations that return null must support null state passed into all other methods. * * @return a state instance or null */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java index fb359996b929..a9cb036c3367 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java @@ -51,7 +51,7 @@ public interface BoundFunction extends Function, Serializable { DataType resultType(); /** - * Returns the whether values produced by this function may be null. + * Returns whether the values produced by this function may be null. *

* For example, a "plus" function may return false when it is bound to arguments that are always * non-null, but true when either argument may be null. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/functions/AggregateFunctionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/functions/AggregateFunctionSuite.scala index c926e60b35ac..5e5a7d4f05ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/functions/AggregateFunctionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/functions/AggregateFunctionSuite.scala @@ -107,12 +107,8 @@ object IntAverage extends AggregateFunction[(Int, Int), Int] { override def update(state: (Int, Int), input: InternalRow): (Int, Int) = { val i = input.getInt(0) - state match { - case (_, 0) => - (i, 1) - case (total, count) => - (total + i, count + 1) - } + val (total, count) = state + (total + i, count + 1) } override def merge(leftState: (Int, Int), rightState: (Int, Int)): (Int, Int) = { From 70b384f035f4c31231740a9ba6401be4faff635d Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 31 Mar 2021 14:26:07 -0700 Subject: [PATCH 3/4] Fix more review nits. --- .../apache/spark/sql/connector/catalog/FunctionCatalog.java | 2 ++ .../spark/sql/connector/catalog/functions/BoundFunction.java | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java index a1a9a1dcb548..651c9148c470 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java @@ -28,6 +28,8 @@ public interface FunctionCatalog extends CatalogPlugin { /** * List the functions in a namespace from the catalog. + *

+ * If there are no functions in the namespace, implementations should return an empty array. * * @param namespace a multi-part namespace * @return an array of Identifiers for functions diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java index a9cb036c3367..c53f94a16893 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/BoundFunction.java @@ -21,13 +21,12 @@ import org.apache.spark.sql.types.IntegerType; import org.apache.spark.sql.types.StructType; -import java.io.Serializable; import java.util.UUID; /** * Represents a function that is bound to an input type. */ -public interface BoundFunction extends Function, Serializable { +public interface BoundFunction extends Function { /** * Returns the required {@link DataType data types} of the input values to this function. From bb8f2aa2181ff3270b91d6f15f05e4197f13df32 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 31 Mar 2021 14:28:10 -0700 Subject: [PATCH 4/4] Fix javadoc. --- .../connector/catalog/functions/AggregateFunction.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java index 956b09abd785..6982ebb329ff 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/AggregateFunction.java @@ -33,10 +33,10 @@ * The JVM type of result values produced by this function must be the type used by Spark's * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. *

- * All implementations must support partial aggregation by implementing {@link #merge(S, S)} so - * that Spark can partially aggregate and shuffle intermediate results, instead of shuffling all - * rows for an aggregate. This reduces the impact of data skew and the amount of data shuffled to - * produce the result. + * All implementations must support partial aggregation by implementing merge so that Spark can + * partially aggregate and shuffle intermediate results, instead of shuffling all rows for an + * aggregate. This reduces the impact of data skew and the amount of data shuffled to produce the + * result. *

* Intermediate aggregation state must be {@link Serializable} so that state produced by parallel * tasks can be serialized, shuffled, and then merged to produce a final result.