Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove staticInvoke and only apply static for Java
  • Loading branch information
sunchao committed May 7, 2021
commit 108123b4645c86b054710004db18b553c60a3989
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,24 @@
/**
* Interface for a function that produces a result value for each input row.
* <p>
* To evaluate each input row, Spark will first try to lookup and use either a static or
* non-static "magic method" (described below) through Java reflection. If neither of the
* magic methods is not found, Spark will call {@link #produceResult(InternalRow)} as a fallback
* approach. In other words, the precedence is as follow:
* <ul>
* <li>static magic method</li>
* <li>non-static magic method</li>
* <li>{@link #produceResult(InternalRow)}</li>
* </ul>
* To evaluate each input row, Spark will first try to lookup and use a "magic method" (described
* below) through Java reflection. If the method is not found, Spark will call
* {@link #produceResult(InternalRow)} as a fallback approach.
* <p>
* The JVM type of result values produced by this function must be the type used by Spark's
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
* The mapping between {@link DataType} and the corresponding JVM type is defined below.
* <p>
* <b>IMPORTANT</b>: the default implementation of {@link #produceResult} throws
* {@link UnsupportedOperationException}. Users can choose to override this method, or implement
* a static magic method with name {@link #STATIC_MAGIC_METHOD_NAME}, or non-static magic
* method with name {@link #MAGIC_METHOD_NAME}, both of which take individual parameters
* instead of a {@link InternalRow}. <b>The static magic method is recommended if the function is
* stateless</b> (i.e., don't need to maintain any intermediate state between the calls), as it
* provides better performance over the non-static version due to avoidance of certain costs such
* as Java dynamic method dispatch. Either of the magic method approach should provide better
* performance over the default {@link #produceResult}, due to optimizations such as codegen,
* removal of Java boxing, etc.
* {@link UnsupportedOperationException}. Users must choose to either override this method, or
* implement a magic method with name {@link #MAGIC_METHOD_NAME}, which takes individual parameters
* instead of a {@link InternalRow}. The magic method approach is generally recommended because it
* provides better performance over the default {@link #produceResult}, due to optimizations such
* as whole-stage codegen, elimination of Java boxing, etc.
* <p>
* In addition, for functions implemented in Java that are stateless, users can optionally define
* the {@link #MAGIC_METHOD_NAME} as a static method, which further avoids certain runtime costs
* such as nullness check on the method receiver, potential Java dynamic dispatch, etc.
* <p>
* For example, a scalar UDF for adding two integers can be defined as follow with the static magic
* method approach:
Expand All @@ -56,12 +50,12 @@
* public DataType[] inputTypes() {
* return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
* }
* public static int staticInvoke(int left, int right) {
* public int invoke(int left, int right) {
* return left + right;
* }
* }
* </pre>
* In the above, since {@link #STATIC_MAGIC_METHOD_NAME} is defined, and also that it has
* In the above, since {@link #MAGIC_METHOD_NAME} is defined, and also that it has
* matching parameter types and return type, Spark will use it to evaluate inputs.
* <p>
* As another example, in the following:
Expand All @@ -70,10 +64,7 @@
* public DataType[] inputTypes() {
* return new DataType[] { DataTypes.IntegerType, DataTypes.IntegerType };
* }
* public static int staticInvoke(int left, int right) {
* return left + right;
* }
* public int invoke(int left, int right) {
* public static int invoke(int left, int right) {
* return left + right;
* }
* public Integer produceResult(InternalRow input) {
Expand All @@ -82,15 +73,16 @@
* }
* </pre>
*
* Even though the class define both magic methods and the {@link #produceResult}, Spark will use
* {@link #STATIC_MAGIC_METHOD_NAME} over the others as it takes higher precedence.
* the class defines both the magic method and the {@link #produceResult}, and Spark will use
* {@link #MAGIC_METHOD_NAME} over the {@link #produceResult(InternalRow)} as it takes higher
* precedence. Also note that the magic method is annotated as a static method in this case.
* <p>
* The magic method resolution is done during query analysis, where Spark looks up the magic
* Resolution on magic method is done during query analysis, where Spark looks up the magic
* method by first converting the actual input SQL data types to their corresponding Java types
* following the mapping defined below, and then checking if there is a matching method from all the
* declared methods in the UDF class, using method name and the Java types.
* <p>
* The following are the mapping from {@link DataType SQL data type} to Java type which is used
* The following are the mapping from {@link DataType SQL data type} to Java type which is used
* by Spark to infer parameter types for the magic methods as well as return value type for
* {@link #produceResult}:
* <ul>
Expand Down Expand Up @@ -122,7 +114,6 @@
*/
public interface ScalarFunction<R> extends BoundFunction {
String MAGIC_METHOD_NAME = "invoke";
String STATIC_MAGIC_METHOD_NAME = "staticInvoke";

/**
* Applies the function to an input row to produce a value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import java.lang.reflect.Method
import java.lang.reflect.{Method, Modifier}
import java.util
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean
Expand Down Expand Up @@ -48,7 +48,6 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction}
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.STATIC_MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
Expand Down Expand Up @@ -2181,30 +2180,27 @@ class Analyzer(override val catalogManager: CatalogManager)
// may also want to check if the parameter types from the magic method
// match the input type through `BoundFunction.inputTypes`.
val argClasses = inputType.fields.map(_.dataType)
findMethod(scalarFunc, STATIC_MAGIC_METHOD_NAME, argClasses) match {
case Some(_) =>
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
case Some(m) if Modifier.isStatic(m.getModifiers) =>
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we'd want to check if this method is actually static, otherwise there could be runtime error. However this only works for methods defined in Java; for Scala seems there is no easy way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scala doesn't have the concept of truly static methods, right? The equivalent (object methods) are actually just instance methods on a singleton.

Copy link
Member Author

@sunchao sunchao Apr 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's correct. The StaticInvoke calls the static method on the non-anonymous class which just forward to the non-static method defined in anonymous/singleton Java class (i.e., the class with $ at the end of its name).

For instance, for the LongAddWithStatic class, this is the method defined in LongAddWithStaticMagic.class:

  public static long staticInvoke(long, long);
    Code:
       0: getstatic     #16                 // Field org/apache/spark/sql/connector/functions/LongAddWithStaticMagic$.MODULE$:Lorg/apache/spark/sql/connector/functions/LongAddWithStaticMagic$;
       3: lload_0
       4: lload_2
       5: invokevirtual #51                 // Method org/apache/spark/sql/connector/functions/LongAddWithStaticMagic$.staticInvoke:(JJ)J
       8: lreturn

and the same method defined in the singleton class LongAddWithStaticMagic$:

  public long staticInvoke(long, long);
    Code:
       0: lload_1
       1: lload_3
       2: ladd
       3: lreturn

So I was expecting worse performance from Scala since it calls invokevirtual underneath while Java uses invokestatic, but the result doesn't look so. It could be that the performance is dominated by other factors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting, thanks for the explanation!

STATIC_MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable)
MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable)
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
arguments, returnNullable = scalarFunc.isResultNullable)
case _ =>
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
// TODO: handle functions defined in Scala too - in Scala, even if a
// subclass do not override the default method in parent interface
// defined in Java, the method can still be found from
// `getDeclaredMethod`.
// since `inputType` is a `StructType`, it is mapped to a `InternalRow`
// which we can use to lookup the `produceResult` method.
findMethod(scalarFunc, "produceResult", Seq(inputType)) match {
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
arguments, returnNullable = scalarFunc.isResultNullable)
case _ =>
// TODO: handle functions defined in Scala too - in Scala, even if a
// subclass do not override the default method in parent interface
// defined in Java, the method can still be found from
// `getDeclaredMethod`.
// since `inputType` is a `StructType`, it is mapped to a `InternalRow`
// which we can use to lookup the `produceResult` method.
findMethod(scalarFunc, "produceResult", Seq(inputType)) match {
case Some(_) =>
ApplyFunctionExpression(scalarFunc, arguments)
case None =>
failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" +
s" magic method nor override 'produceResult'")
}
ApplyFunctionExpression(scalarFunc, arguments)
case None =>
failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" +
s" magic method nor override 'produceResult'")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ case class ApplyFunctionExpression(
override def name: String = function.name()
override def dataType: DataType = function.resultType()

private lazy val childrenWithIndex = children.zipWithIndex
private lazy val reusedRow = new GenericInternalRow(children.size)

/** Returns the result of evaluating this expression on a given input Row */
override def eval(input: InternalRow): Any = {
children.zipWithIndex.foreach {
childrenWithIndex.foreach {
case (expr, pos) =>
reusedRow.update(pos, expr.eval(input))
}
Expand Down
32 changes: 0 additions & 32 deletions sql/core/benchmarks/FunctionBenchmark-jdk11-results.txt

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package test.org.apache.spark.sql.connector.catalog.functions;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class JavaLongAdd implements UnboundFunction {
private final ScalarFunction<Long> impl;

public JavaLongAdd(ScalarFunction<Long> impl) {
this.impl = impl;
}

@Override
public String name() {
return "long_add";
}

@Override
public BoundFunction bind(StructType inputType) {
if (inputType.fields().length != 2) {
throw new UnsupportedOperationException("Expect two arguments");
}
StructField[] fields = inputType.fields();
if (!(fields[0].dataType() instanceof LongType)) {
throw new UnsupportedOperationException("Expect first argument to be LongType");
}
if (!(fields[1].dataType() instanceof LongType)) {
throw new UnsupportedOperationException("Expect second argument to be LongType");
}
return impl;
}

@Override
public String description() {
return "long_add";
}

public static abstract class JavaLongAddBase implements ScalarFunction<Long> {
private final boolean isResultNullable;

public JavaLongAddBase(boolean isResultNullable) {
this.isResultNullable = isResultNullable;
}

@Override
public DataType[] inputTypes() {
return new DataType[] { DataTypes.LongType, DataTypes.LongType };
}

@Override
public DataType resultType() {
return DataTypes.LongType;
}

@Override
public boolean isResultNullable() {
return isResultNullable;
}
}

public static class JavaLongAddDefault extends JavaLongAddBase {
public JavaLongAddDefault(boolean isResultNullable) {
super(isResultNullable);
}
@Override
public String name() {
return "long_add_default";
}

@Override
public Long produceResult(InternalRow input) {
return input.getLong(0) + input.getLong(1);
}
}

public static class JavaLongAddMagic extends JavaLongAddBase {
public JavaLongAddMagic(boolean isResultNullable) {
super(isResultNullable);
}

@Override
public String name() {
return "add_long_magic";
}

public long invoke(long left, long right) {
return left + right;
}
}

public static class JavaLongAddStaticMagic extends JavaLongAddBase {
public JavaLongAddStaticMagic(boolean isResultNullable) {
super(isResultNullable);
}

@Override
public String name() {
return "long_add_static_magic";
}

public static long invoke(long left, long right) {
return left + right;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,31 @@ public Integer produceResult(InternalRow input) {
}
}

public static class JavaStrLenStaticMagic extends JavaStrLenBase {
public static int staticInvoke(UTF8String str) {
public static class JavaStrLenMagic extends JavaStrLenBase {
public int invoke(UTF8String str) {
return str.toString().length();
}
}

public static class JavaStrLenBothMagic extends JavaStrLenBase {
public static int staticInvoke(UTF8String str) {
return str.toString().length() + 100;
public static class JavaStrLenStaticMagic extends JavaStrLenBase {
public static int invoke(UTF8String str) {
return str.toString().length();
}
}

public static class JavaStrLenBoth extends JavaStrLenBase {
@Override
public Integer produceResult(InternalRow input) {
String str = input.getString(0);
return str.length();
}
public int invoke(UTF8String str) {
return str.toString().length();
}
}

public static class JavaStrLenBadStaticMagic extends JavaStrLenBase {
public static int staticInvoke(String str) {
public static int invoke(String str) {
return str.length() + 100;
}

Expand All @@ -109,12 +116,6 @@ public int invoke(UTF8String str) {
}
}

public static class JavaStrLenMagic extends JavaStrLenBase {
public int invoke(UTF8String str) {
return str.toString().length();
}
}

public static class JavaStrLenNoImpl extends JavaStrLenBase {
}
}
Expand Down
Loading